Compare commits

..

4 Commits

Author SHA1 Message Date
17d723ecca minor ignore 2026-05-09 14:22:29 +08:00
9fc8c14445 Fix task leases and agent timeouts 2026-05-08 21:36:07 +08:00
2ae22b3492 chore: add start script 2026-05-06 17:38:32 +08:00
3c624cc46d fix: close issues after merged PRs 2026-05-06 17:33:16 +08:00
7 changed files with 355 additions and 53 deletions

1
.gitignore vendored
View File

@@ -1,4 +1,5 @@
.agent-gitea/
.agent-output/
.env
.pytest_cache/
.ruff_cache/

1
run.sh Executable file
View File

@@ -0,0 +1 @@
uv run agent-gitea --config config.yaml worker

View File

@@ -7,8 +7,29 @@ from .models import AgentResult
class CommandRunner:
def run(self, command: list[str], cwd: str | Path, *, stdin: str | None = None) -> AgentResult:
result = subprocess.run(command, cwd=cwd, input=stdin, text=True, capture_output=True, check=False)
def run(
self,
command: list[str],
cwd: str | Path,
*,
stdin: str | None = None,
timeout_seconds: int | float | None = None,
) -> AgentResult:
try:
result = subprocess.run(
command,
cwd=cwd,
input=stdin,
text=True,
capture_output=True,
check=False,
timeout=timeout_seconds,
)
except subprocess.TimeoutExpired as exc:
stdout = exc.stdout if isinstance(exc.stdout, str) else (exc.stdout or b"").decode(errors="replace")
stderr = exc.stderr if isinstance(exc.stderr, str) else (exc.stderr or b"").decode(errors="replace")
message = f"command timed out after {timeout_seconds} seconds"
return AgentResult(exit_code=124, stdout=stdout or "", stderr=f"{stderr}\n{message}".strip())
return AgentResult(exit_code=result.returncode, stdout=result.stdout, stderr=result.stderr)

View File

@@ -24,6 +24,14 @@ class SchedulerConfig(BaseModel):
interval_seconds: int = Field(default=60, ge=1)
concurrency: int = Field(default=1, ge=1)
lease_seconds: int = Field(default=1800, ge=30)
lease_renewal_interval_seconds: int | None = Field(default=None, ge=1)
agent_timeout_seconds: int = Field(default=7200, ge=1)
@property
def effective_lease_renewal_interval_seconds(self) -> int:
if self.lease_renewal_interval_seconds is not None:
return self.lease_renewal_interval_seconds
return max(1, self.lease_seconds // 3)
class WorkspaceConfig(BaseModel):

View File

@@ -292,12 +292,23 @@ class Database:
SELECT t.*
FROM tasks t
JOIN issues i ON i.repo_id = t.repo_id AND i.issue_number = t.issue_number
WHERE t.state = ?
WHERE t.state != ?
AND t.pr_number IS NOT NULL
AND i.state = 'open'
AND (
t.state IN (?, ?)
OR t.lease_owner IS NULL
OR t.lease_expires_at IS NULL
OR t.lease_expires_at < ?
)
ORDER BY t.id
""",
(TaskState.HUMAN_REVIEW_READY.value,),
(
TaskState.CANCELLED.value,
TaskState.HUMAN_REVIEW_READY.value,
TaskState.FAILED.value,
dt_to_db(utcnow()),
),
).fetchall()
return [self._task(row) for row in rows]
@@ -430,49 +441,78 @@ class Database:
def claim_next_task(self, worker_id: str, lease_seconds: int) -> TaskRecord | None:
now = utcnow()
expires = now + timedelta(seconds=lease_seconds)
row = self.conn.execute(
"""
SELECT * FROM tasks
WHERE state = ?
OR (state IN (?, ?, ?, ?, ?, ?, ?) AND lease_expires_at IS NOT NULL AND lease_expires_at < ?)
ORDER BY created_at
LIMIT 1
""",
(
TaskState.DISCOVERED.value,
TaskState.CLAIMED.value,
TaskState.PLANNING.value,
TaskState.IMPLEMENTING.value,
TaskState.TESTING.value,
TaskState.PR_OPENED.value,
TaskState.REVIEWING.value,
TaskState.DISCOVERED.value,
dt_to_db(now),
),
).fetchone()
if row is None:
return None
task = self._task(row)
self.conn.execute(
"""
UPDATE tasks
SET state = ?, lease_owner = ?, lease_expires_at = ?, updated_at = ?
WHERE id = ?
""",
(
TaskState.CLAIMED.value,
worker_id,
dt_to_db(expires),
dt_to_db(now),
task.id,
),
)
self.conn.commit()
try:
self.conn.execute("BEGIN IMMEDIATE")
row = self.conn.execute(
"""
SELECT * FROM tasks
WHERE state = ?
OR (state IN (?, ?, ?, ?, ?, ?, ?) AND lease_expires_at IS NOT NULL AND lease_expires_at < ?)
ORDER BY created_at
LIMIT 1
""",
(
TaskState.DISCOVERED.value,
TaskState.CLAIMED.value,
TaskState.PLANNING.value,
TaskState.IMPLEMENTING.value,
TaskState.TESTING.value,
TaskState.PR_OPENED.value,
TaskState.REVIEWING.value,
TaskState.DISCOVERED.value,
dt_to_db(now),
),
).fetchone()
if row is None:
self.conn.commit()
return None
task = self._task(row)
self.conn.execute(
"""
UPDATE tasks
SET state = ?, lease_owner = ?, lease_expires_at = ?, updated_at = ?
WHERE id = ?
""",
(
TaskState.CLAIMED.value,
worker_id,
dt_to_db(expires),
dt_to_db(now),
task.id,
),
)
self.conn.commit()
except Exception:
self.conn.rollback()
raise
self.add_event(task.id, task.state, TaskState.CLAIMED, f"claimed by {worker_id}")
claimed = self.get_task(task.id)
assert claimed is not None
return claimed
def renew_task_lease(self, task_id: int, worker_id: str, lease_seconds: int) -> bool:
now = utcnow()
expires = now + timedelta(seconds=lease_seconds)
placeholders = ",".join("?" for _ in ACTIVE_STATES)
cursor = self.conn.execute(
f"""
UPDATE tasks
SET lease_expires_at = ?, updated_at = ?
WHERE id = ?
AND lease_owner = ?
AND state IN ({placeholders})
""",
(
dt_to_db(expires),
dt_to_db(now),
task_id,
worker_id,
*[state.value for state in ACTIVE_STATES],
),
)
self.conn.commit()
return cursor.rowcount == 1
def transition(
self,
task_id: int,

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import socket
import time
import logging
import threading
from dataclasses import dataclass
from pathlib import Path
@@ -10,7 +11,7 @@ from .agents import CommandRunner, read_report, render_command, write_prompt
from .config import AppConfig
from .db import Database, PullRequestFeedbackCursor
from .gitea import GiteaClient, GiteaComment, GiteaPullReview
from .models import IssueRecord, RepositoryRecord, TaskRecord, TaskState
from .models import ACTIVE_STATES, IssueRecord, RepositoryRecord, TaskRecord, TaskState
from .rendering import (
parse_review_report,
render_human_review_summary,
@@ -34,6 +35,46 @@ class PullRequestFeedbackSnapshot:
newest_cursor: PullRequestFeedbackCursor
class TaskLeaseRenewer:
def __init__(
self,
*,
db_path: Path,
task_id: int,
worker_id: str,
lease_seconds: int,
interval_seconds: int,
):
self.db_path = db_path
self.task_id = task_id
self.worker_id = worker_id
self.lease_seconds = lease_seconds
self.interval_seconds = interval_seconds
self._stop = threading.Event()
self._thread = threading.Thread(target=self._run, name=f"lease-renewer-{task_id}", daemon=True)
def __enter__(self) -> TaskLeaseRenewer:
self._thread.start()
return self
def __exit__(self, exc_type, exc, tb) -> None:
self._stop.set()
self._thread.join(timeout=max(1, self.interval_seconds))
def _run(self) -> None:
database = Database(self.db_path)
try:
while not self._stop.wait(self.interval_seconds):
renewed = database.renew_task_lease(self.task_id, self.worker_id, self.lease_seconds)
if not renewed:
logger.warning("stopping lease renewal for task %d; lease no longer belongs to %s", self.task_id, self.worker_id)
return
except Exception:
logger.exception("lease renewal failed for task %d", self.task_id)
finally:
database.close()
def sync_repositories(db: Database, config: AppConfig, client: GiteaClient) -> list[RepositoryRecord]:
synced: list[RepositoryRecord] = []
discovered = client.list_owned_repositories()
@@ -92,12 +133,12 @@ def close_issues_for_merged_pull_requests(db: Database, client: GiteaClient) ->
)
client.close_issue(owner=repo.owner, name=repo.name, issue_number=issue.issue_number)
db.update_issue_state(task.repo_id, task.issue_number, "closed")
db.add_event(
task.id,
task.state,
task.state,
f"closed issue #{issue.issue_number} after merged PR #{task.pr_number}",
)
message = f"closed issue #{issue.issue_number} after merged PR #{task.pr_number}"
if task.state in ACTIVE_STATES:
db.clear_pr_feedback_pending(task.id)
db.transition(task.id, TaskState.CANCELLED, message=message, clear_lease=True)
else:
db.add_event(task.id, task.state, task.state, message)
closed += 1
return closed
@@ -388,7 +429,7 @@ class TaskRunner:
issue_title=issue.title,
branch_name=branch_name,
)
result = self.command_runner.run(command, workspace, stdin=prompt)
result = self._run_agent_command(task, command, workspace, prompt)
report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md")
self.db.add_agent_run(
task_id=task.id,
@@ -428,7 +469,7 @@ class TaskRunner:
pr_number=pr_number,
branch_name=branch_name,
)
result = self.command_runner.run(command, workspace, stdin=prompt)
result = self._run_agent_command(task, command, workspace, prompt)
report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md")
self.db.add_agent_run(
task_id=task.id,
@@ -465,7 +506,7 @@ class TaskRunner:
issue_title=issue.title,
pr_number=pr_number,
)
result = self.command_runner.run(command, workspace, stdin=prompt)
result = self._run_agent_command(task, command, workspace, prompt)
report = read_report(output_dir / "AGENT_REVIEW_REPORT.md")
self.db.add_agent_run(
task_id=task.id,
@@ -481,6 +522,21 @@ class TaskRunner:
raise RuntimeError(f"reviewer failed with exit code {result.exit_code}")
return report
def _run_agent_command(self, task: TaskRecord, command: list[str], workspace: Path, prompt: str) -> AgentResult:
with TaskLeaseRenewer(
db_path=self.db.path,
task_id=task.id,
worker_id=self.worker_id,
lease_seconds=self.config.scheduler.lease_seconds,
interval_seconds=self.config.scheduler.effective_lease_renewal_interval_seconds,
):
return self.command_runner.run(
command,
workspace,
stdin=prompt,
timeout_seconds=self.config.scheduler.agent_timeout_seconds,
)
def _load_context(self, task: TaskRecord) -> tuple[RepositoryRecord, IssueRecord]:
return load_task_context(self.db, task)

View File

@@ -1,11 +1,16 @@
from __future__ import annotations
import json
import sys
import threading
import time
from pathlib import Path
import httpx
from agent_gitea.agents import CommandRunner
from agent_gitea.config import AppConfig
from agent_gitea.db import Database
from agent_gitea.gitea import GiteaClient
from agent_gitea.models import AgentResult, TaskState
from agent_gitea.service import (
@@ -208,7 +213,7 @@ class FakeRunner:
def __init__(self, *, fail_role: str | None = None):
self.fail_role = fail_role
def run(self, command, cwd, *, stdin=None):
def run(self, command, cwd, *, stdin=None, timeout_seconds=None):
role = command[0]
assert stdin
if role == self.fail_role:
@@ -250,6 +255,126 @@ def seed_task(db):
return db.create_task(repo.id, 1)
def test_claim_next_task_allows_only_one_worker_during_race(db, tmp_path):
seed_task(db)
db.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000))
db.conn.execute(
"""
CREATE TRIGGER slow_claim_update
BEFORE UPDATE OF state ON tasks
WHEN NEW.state = 'CLAIMED'
BEGIN
SELECT sleep_ms(150);
END
"""
)
db.conn.commit()
db.close()
barrier = threading.Barrier(2)
results: dict[str, int | None] = {}
errors: list[BaseException] = []
def claim(worker_id: str) -> None:
database = Database(tmp_path / "state.sqlite3")
database.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000))
try:
barrier.wait()
task = database.claim_next_task(worker_id, 60)
results[worker_id] = task.id if task else None
except BaseException as exc: # pragma: no cover - assertion below reports thread failures
errors.append(exc)
finally:
database.close()
threads = [threading.Thread(target=claim, args=(worker_id,)) for worker_id in ("worker-a", "worker-b")]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert errors == []
assert list(results.values()).count(1) == 1
assert list(results.values()).count(None) == 1
def test_command_runner_returns_timeout_result(tmp_path):
result = CommandRunner().run(
[sys.executable, "-c", "import time; time.sleep(5)"],
tmp_path,
timeout_seconds=0.1,
)
assert result.exit_code == 124
assert "timed out" in result.stderr
class ObservingSlowRunner:
def __init__(self, db, task_id: int):
self.db = db
self.task_id = task_id
self.lease_before = None
self.lease_during = None
def run(self, command, cwd, *, stdin=None, timeout_seconds=None):
role = command[0]
if role == "implementer":
self.lease_before = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr]
time.sleep(1.2)
self.lease_during = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr]
output_dir = Path(cwd) / ".agent-output"
output_dir.mkdir(exist_ok=True)
(output_dir / "AGENT_IMPLEMENTATION_REPORT.md").write_text(
"## Summary\nImplemented\n\n## Test commands run\npytest\n",
encoding="utf-8",
)
if role == "reviewer":
output_dir = Path(cwd) / ".agent-output"
output_dir.mkdir(exist_ok=True)
(output_dir / "AGENT_REVIEW_REPORT.md").write_text(
"## Verdict\nAPPROVE\n\n## Suggested PR Comment\nLooks good.\n",
encoding="utf-8",
)
return AgentResult(exit_code=0, stdout="ok", stderr="")
def test_task_runner_renews_lease_while_agent_runs(db, tmp_path):
config = make_config(
tmp_path,
scheduler={
"interval_seconds": 1,
"concurrency": 1,
"lease_seconds": 60,
"lease_renewal_interval_seconds": 1,
},
)
task = seed_task(db)
runner = ObservingSlowRunner(db, task.id)
def handler(request: httpx.Request) -> httpx.Response:
payload = json.loads(request.content.decode() or "{}")
if request.url.path == "/api/v1/repos/acme/service/pulls":
return httpx.Response(201, json={"number": 5, "state": "open", "merged": False})
if request.url.path == "/api/v1/repos/acme/service/issues/5/comments":
return httpx.Response(201, json={"id": 1, "body": payload.get("body", ""), "user": {"login": "agent-bot"}})
return httpx.Response(404)
finished = TaskRunner(
db=db,
config=config,
gitea=make_client(handler),
workspace_manager=FakeWorkspaceManager(tmp_path / "work"),
command_runner=runner,
worker_id="worker",
).run_once()
assert finished is not None
assert finished.state == TaskState.HUMAN_REVIEW_READY
assert runner.lease_before is not None
assert runner.lease_during is not None
assert runner.lease_during > runner.lease_before
def transition_to_human_review_ready(db, task_id: int, *, pr_number: int = 5, branch_name: str | None = None):
db.transition(task_id, TaskState.CLAIMED)
db.transition(task_id, TaskState.PLANNING)
@@ -656,6 +781,56 @@ def test_close_issues_for_merged_pull_requests_skips_unmerged_pr(db):
assert db.get_issue(repo.id, 1).state == "open" # type: ignore[union-attr]
def test_close_issues_for_merged_pull_requests_handles_queued_feedback_task(db):
repo = db.upsert_repository(
owner="acme",
name="service",
clone_url="https://gitea.test/acme/service.git",
default_branch="main",
enabled=True,
)
db.upsert_issue(
repo_id=repo.id,
issue_number=1,
title="Ready issue",
body="Body",
labels=["agent:ready"],
state="open",
html_url="https://gitea.test/acme/service/issues/1",
)
task = db.create_task(repo.id, 1)
task = transition_to_human_review_ready(db, task.id, pr_number=5, branch_name="agent/issue-1-ready-issue")
db.mark_pr_feedback_pending(task.id)
db.transition(
task.id,
TaskState.DISCOVERED,
message="queued PR feedback from 1 human comment(s)",
clear_lease=True,
)
requests: list[tuple[str, str, dict]] = []
def handler(request: httpx.Request) -> httpx.Response:
payload = json.loads(request.content.decode() or "{}")
requests.append((request.method, request.url.path, payload))
if request.url.path == "/api/v1/repos/acme/service/pulls/5":
return httpx.Response(200, json={"number": 5, "state": "closed", "merged": True})
if request.url.path == "/api/v1/repos/acme/service/issues/1/comments":
return httpx.Response(201, json={"id": 1})
if request.url.path == "/api/v1/repos/acme/service/issues/1":
return httpx.Response(200, json={"number": 1, "state": "closed"})
return httpx.Response(404)
closed = close_issues_for_merged_pull_requests(db, make_client(handler))
updated_task = db.get_task(task.id)
assert closed == 1
assert db.get_issue(repo.id, 1).state == "closed" # type: ignore[union-attr]
assert updated_task is not None
assert updated_task.state == TaskState.CANCELLED
assert not db.has_pending_pr_feedback(task.id)
assert ("PATCH", "/api/v1/repos/acme/service/issues/1", {"state": "closed"}) in requests
def test_run_task_no_diff_becomes_blocked(db, tmp_path):
config = make_config(tmp_path)
seed_task(db)