Compare commits

..

4 Commits

Author SHA1 Message Date
17d723ecca minor ignore 2026-05-09 14:22:29 +08:00
9fc8c14445 Fix task leases and agent timeouts 2026-05-08 21:36:07 +08:00
2ae22b3492 chore: add start script 2026-05-06 17:38:32 +08:00
3c624cc46d fix: close issues after merged PRs 2026-05-06 17:33:16 +08:00
7 changed files with 355 additions and 53 deletions

1
.gitignore vendored
View File

@@ -1,4 +1,5 @@
.agent-gitea/ .agent-gitea/
.agent-output/
.env .env
.pytest_cache/ .pytest_cache/
.ruff_cache/ .ruff_cache/

1
run.sh Executable file
View File

@@ -0,0 +1 @@
uv run agent-gitea --config config.yaml worker

View File

@@ -7,8 +7,29 @@ from .models import AgentResult
class CommandRunner: class CommandRunner:
def run(self, command: list[str], cwd: str | Path, *, stdin: str | None = None) -> AgentResult: def run(
result = subprocess.run(command, cwd=cwd, input=stdin, text=True, capture_output=True, check=False) self,
command: list[str],
cwd: str | Path,
*,
stdin: str | None = None,
timeout_seconds: int | float | None = None,
) -> AgentResult:
try:
result = subprocess.run(
command,
cwd=cwd,
input=stdin,
text=True,
capture_output=True,
check=False,
timeout=timeout_seconds,
)
except subprocess.TimeoutExpired as exc:
stdout = exc.stdout if isinstance(exc.stdout, str) else (exc.stdout or b"").decode(errors="replace")
stderr = exc.stderr if isinstance(exc.stderr, str) else (exc.stderr or b"").decode(errors="replace")
message = f"command timed out after {timeout_seconds} seconds"
return AgentResult(exit_code=124, stdout=stdout or "", stderr=f"{stderr}\n{message}".strip())
return AgentResult(exit_code=result.returncode, stdout=result.stdout, stderr=result.stderr) return AgentResult(exit_code=result.returncode, stdout=result.stdout, stderr=result.stderr)

View File

@@ -24,6 +24,14 @@ class SchedulerConfig(BaseModel):
interval_seconds: int = Field(default=60, ge=1) interval_seconds: int = Field(default=60, ge=1)
concurrency: int = Field(default=1, ge=1) concurrency: int = Field(default=1, ge=1)
lease_seconds: int = Field(default=1800, ge=30) lease_seconds: int = Field(default=1800, ge=30)
lease_renewal_interval_seconds: int | None = Field(default=None, ge=1)
agent_timeout_seconds: int = Field(default=7200, ge=1)
@property
def effective_lease_renewal_interval_seconds(self) -> int:
if self.lease_renewal_interval_seconds is not None:
return self.lease_renewal_interval_seconds
return max(1, self.lease_seconds // 3)
class WorkspaceConfig(BaseModel): class WorkspaceConfig(BaseModel):

View File

@@ -292,12 +292,23 @@ class Database:
SELECT t.* SELECT t.*
FROM tasks t FROM tasks t
JOIN issues i ON i.repo_id = t.repo_id AND i.issue_number = t.issue_number JOIN issues i ON i.repo_id = t.repo_id AND i.issue_number = t.issue_number
WHERE t.state = ? WHERE t.state != ?
AND t.pr_number IS NOT NULL AND t.pr_number IS NOT NULL
AND i.state = 'open' AND i.state = 'open'
AND (
t.state IN (?, ?)
OR t.lease_owner IS NULL
OR t.lease_expires_at IS NULL
OR t.lease_expires_at < ?
)
ORDER BY t.id ORDER BY t.id
""", """,
(TaskState.HUMAN_REVIEW_READY.value,), (
TaskState.CANCELLED.value,
TaskState.HUMAN_REVIEW_READY.value,
TaskState.FAILED.value,
dt_to_db(utcnow()),
),
).fetchall() ).fetchall()
return [self._task(row) for row in rows] return [self._task(row) for row in rows]
@@ -430,6 +441,8 @@ class Database:
def claim_next_task(self, worker_id: str, lease_seconds: int) -> TaskRecord | None: def claim_next_task(self, worker_id: str, lease_seconds: int) -> TaskRecord | None:
now = utcnow() now = utcnow()
expires = now + timedelta(seconds=lease_seconds) expires = now + timedelta(seconds=lease_seconds)
try:
self.conn.execute("BEGIN IMMEDIATE")
row = self.conn.execute( row = self.conn.execute(
""" """
SELECT * FROM tasks SELECT * FROM tasks
@@ -451,6 +464,7 @@ class Database:
), ),
).fetchone() ).fetchone()
if row is None: if row is None:
self.conn.commit()
return None return None
task = self._task(row) task = self._task(row)
self.conn.execute( self.conn.execute(
@@ -468,11 +482,37 @@ class Database:
), ),
) )
self.conn.commit() self.conn.commit()
except Exception:
self.conn.rollback()
raise
self.add_event(task.id, task.state, TaskState.CLAIMED, f"claimed by {worker_id}") self.add_event(task.id, task.state, TaskState.CLAIMED, f"claimed by {worker_id}")
claimed = self.get_task(task.id) claimed = self.get_task(task.id)
assert claimed is not None assert claimed is not None
return claimed return claimed
def renew_task_lease(self, task_id: int, worker_id: str, lease_seconds: int) -> bool:
now = utcnow()
expires = now + timedelta(seconds=lease_seconds)
placeholders = ",".join("?" for _ in ACTIVE_STATES)
cursor = self.conn.execute(
f"""
UPDATE tasks
SET lease_expires_at = ?, updated_at = ?
WHERE id = ?
AND lease_owner = ?
AND state IN ({placeholders})
""",
(
dt_to_db(expires),
dt_to_db(now),
task_id,
worker_id,
*[state.value for state in ACTIVE_STATES],
),
)
self.conn.commit()
return cursor.rowcount == 1
def transition( def transition(
self, self,
task_id: int, task_id: int,

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import socket import socket
import time import time
import logging import logging
import threading
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
@@ -10,7 +11,7 @@ from .agents import CommandRunner, read_report, render_command, write_prompt
from .config import AppConfig from .config import AppConfig
from .db import Database, PullRequestFeedbackCursor from .db import Database, PullRequestFeedbackCursor
from .gitea import GiteaClient, GiteaComment, GiteaPullReview from .gitea import GiteaClient, GiteaComment, GiteaPullReview
from .models import IssueRecord, RepositoryRecord, TaskRecord, TaskState from .models import ACTIVE_STATES, IssueRecord, RepositoryRecord, TaskRecord, TaskState
from .rendering import ( from .rendering import (
parse_review_report, parse_review_report,
render_human_review_summary, render_human_review_summary,
@@ -34,6 +35,46 @@ class PullRequestFeedbackSnapshot:
newest_cursor: PullRequestFeedbackCursor newest_cursor: PullRequestFeedbackCursor
class TaskLeaseRenewer:
def __init__(
self,
*,
db_path: Path,
task_id: int,
worker_id: str,
lease_seconds: int,
interval_seconds: int,
):
self.db_path = db_path
self.task_id = task_id
self.worker_id = worker_id
self.lease_seconds = lease_seconds
self.interval_seconds = interval_seconds
self._stop = threading.Event()
self._thread = threading.Thread(target=self._run, name=f"lease-renewer-{task_id}", daemon=True)
def __enter__(self) -> TaskLeaseRenewer:
self._thread.start()
return self
def __exit__(self, exc_type, exc, tb) -> None:
self._stop.set()
self._thread.join(timeout=max(1, self.interval_seconds))
def _run(self) -> None:
database = Database(self.db_path)
try:
while not self._stop.wait(self.interval_seconds):
renewed = database.renew_task_lease(self.task_id, self.worker_id, self.lease_seconds)
if not renewed:
logger.warning("stopping lease renewal for task %d; lease no longer belongs to %s", self.task_id, self.worker_id)
return
except Exception:
logger.exception("lease renewal failed for task %d", self.task_id)
finally:
database.close()
def sync_repositories(db: Database, config: AppConfig, client: GiteaClient) -> list[RepositoryRecord]: def sync_repositories(db: Database, config: AppConfig, client: GiteaClient) -> list[RepositoryRecord]:
synced: list[RepositoryRecord] = [] synced: list[RepositoryRecord] = []
discovered = client.list_owned_repositories() discovered = client.list_owned_repositories()
@@ -92,12 +133,12 @@ def close_issues_for_merged_pull_requests(db: Database, client: GiteaClient) ->
) )
client.close_issue(owner=repo.owner, name=repo.name, issue_number=issue.issue_number) client.close_issue(owner=repo.owner, name=repo.name, issue_number=issue.issue_number)
db.update_issue_state(task.repo_id, task.issue_number, "closed") db.update_issue_state(task.repo_id, task.issue_number, "closed")
db.add_event( message = f"closed issue #{issue.issue_number} after merged PR #{task.pr_number}"
task.id, if task.state in ACTIVE_STATES:
task.state, db.clear_pr_feedback_pending(task.id)
task.state, db.transition(task.id, TaskState.CANCELLED, message=message, clear_lease=True)
f"closed issue #{issue.issue_number} after merged PR #{task.pr_number}", else:
) db.add_event(task.id, task.state, task.state, message)
closed += 1 closed += 1
return closed return closed
@@ -388,7 +429,7 @@ class TaskRunner:
issue_title=issue.title, issue_title=issue.title,
branch_name=branch_name, branch_name=branch_name,
) )
result = self.command_runner.run(command, workspace, stdin=prompt) result = self._run_agent_command(task, command, workspace, prompt)
report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md") report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md")
self.db.add_agent_run( self.db.add_agent_run(
task_id=task.id, task_id=task.id,
@@ -428,7 +469,7 @@ class TaskRunner:
pr_number=pr_number, pr_number=pr_number,
branch_name=branch_name, branch_name=branch_name,
) )
result = self.command_runner.run(command, workspace, stdin=prompt) result = self._run_agent_command(task, command, workspace, prompt)
report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md") report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md")
self.db.add_agent_run( self.db.add_agent_run(
task_id=task.id, task_id=task.id,
@@ -465,7 +506,7 @@ class TaskRunner:
issue_title=issue.title, issue_title=issue.title,
pr_number=pr_number, pr_number=pr_number,
) )
result = self.command_runner.run(command, workspace, stdin=prompt) result = self._run_agent_command(task, command, workspace, prompt)
report = read_report(output_dir / "AGENT_REVIEW_REPORT.md") report = read_report(output_dir / "AGENT_REVIEW_REPORT.md")
self.db.add_agent_run( self.db.add_agent_run(
task_id=task.id, task_id=task.id,
@@ -481,6 +522,21 @@ class TaskRunner:
raise RuntimeError(f"reviewer failed with exit code {result.exit_code}") raise RuntimeError(f"reviewer failed with exit code {result.exit_code}")
return report return report
def _run_agent_command(self, task: TaskRecord, command: list[str], workspace: Path, prompt: str) -> AgentResult:
with TaskLeaseRenewer(
db_path=self.db.path,
task_id=task.id,
worker_id=self.worker_id,
lease_seconds=self.config.scheduler.lease_seconds,
interval_seconds=self.config.scheduler.effective_lease_renewal_interval_seconds,
):
return self.command_runner.run(
command,
workspace,
stdin=prompt,
timeout_seconds=self.config.scheduler.agent_timeout_seconds,
)
def _load_context(self, task: TaskRecord) -> tuple[RepositoryRecord, IssueRecord]: def _load_context(self, task: TaskRecord) -> tuple[RepositoryRecord, IssueRecord]:
return load_task_context(self.db, task) return load_task_context(self.db, task)

View File

@@ -1,11 +1,16 @@
from __future__ import annotations from __future__ import annotations
import json import json
import sys
import threading
import time
from pathlib import Path from pathlib import Path
import httpx import httpx
from agent_gitea.agents import CommandRunner
from agent_gitea.config import AppConfig from agent_gitea.config import AppConfig
from agent_gitea.db import Database
from agent_gitea.gitea import GiteaClient from agent_gitea.gitea import GiteaClient
from agent_gitea.models import AgentResult, TaskState from agent_gitea.models import AgentResult, TaskState
from agent_gitea.service import ( from agent_gitea.service import (
@@ -208,7 +213,7 @@ class FakeRunner:
def __init__(self, *, fail_role: str | None = None): def __init__(self, *, fail_role: str | None = None):
self.fail_role = fail_role self.fail_role = fail_role
def run(self, command, cwd, *, stdin=None): def run(self, command, cwd, *, stdin=None, timeout_seconds=None):
role = command[0] role = command[0]
assert stdin assert stdin
if role == self.fail_role: if role == self.fail_role:
@@ -250,6 +255,126 @@ def seed_task(db):
return db.create_task(repo.id, 1) return db.create_task(repo.id, 1)
def test_claim_next_task_allows_only_one_worker_during_race(db, tmp_path):
seed_task(db)
db.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000))
db.conn.execute(
"""
CREATE TRIGGER slow_claim_update
BEFORE UPDATE OF state ON tasks
WHEN NEW.state = 'CLAIMED'
BEGIN
SELECT sleep_ms(150);
END
"""
)
db.conn.commit()
db.close()
barrier = threading.Barrier(2)
results: dict[str, int | None] = {}
errors: list[BaseException] = []
def claim(worker_id: str) -> None:
database = Database(tmp_path / "state.sqlite3")
database.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000))
try:
barrier.wait()
task = database.claim_next_task(worker_id, 60)
results[worker_id] = task.id if task else None
except BaseException as exc: # pragma: no cover - assertion below reports thread failures
errors.append(exc)
finally:
database.close()
threads = [threading.Thread(target=claim, args=(worker_id,)) for worker_id in ("worker-a", "worker-b")]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert errors == []
assert list(results.values()).count(1) == 1
assert list(results.values()).count(None) == 1
def test_command_runner_returns_timeout_result(tmp_path):
result = CommandRunner().run(
[sys.executable, "-c", "import time; time.sleep(5)"],
tmp_path,
timeout_seconds=0.1,
)
assert result.exit_code == 124
assert "timed out" in result.stderr
class ObservingSlowRunner:
def __init__(self, db, task_id: int):
self.db = db
self.task_id = task_id
self.lease_before = None
self.lease_during = None
def run(self, command, cwd, *, stdin=None, timeout_seconds=None):
role = command[0]
if role == "implementer":
self.lease_before = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr]
time.sleep(1.2)
self.lease_during = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr]
output_dir = Path(cwd) / ".agent-output"
output_dir.mkdir(exist_ok=True)
(output_dir / "AGENT_IMPLEMENTATION_REPORT.md").write_text(
"## Summary\nImplemented\n\n## Test commands run\npytest\n",
encoding="utf-8",
)
if role == "reviewer":
output_dir = Path(cwd) / ".agent-output"
output_dir.mkdir(exist_ok=True)
(output_dir / "AGENT_REVIEW_REPORT.md").write_text(
"## Verdict\nAPPROVE\n\n## Suggested PR Comment\nLooks good.\n",
encoding="utf-8",
)
return AgentResult(exit_code=0, stdout="ok", stderr="")
def test_task_runner_renews_lease_while_agent_runs(db, tmp_path):
config = make_config(
tmp_path,
scheduler={
"interval_seconds": 1,
"concurrency": 1,
"lease_seconds": 60,
"lease_renewal_interval_seconds": 1,
},
)
task = seed_task(db)
runner = ObservingSlowRunner(db, task.id)
def handler(request: httpx.Request) -> httpx.Response:
payload = json.loads(request.content.decode() or "{}")
if request.url.path == "/api/v1/repos/acme/service/pulls":
return httpx.Response(201, json={"number": 5, "state": "open", "merged": False})
if request.url.path == "/api/v1/repos/acme/service/issues/5/comments":
return httpx.Response(201, json={"id": 1, "body": payload.get("body", ""), "user": {"login": "agent-bot"}})
return httpx.Response(404)
finished = TaskRunner(
db=db,
config=config,
gitea=make_client(handler),
workspace_manager=FakeWorkspaceManager(tmp_path / "work"),
command_runner=runner,
worker_id="worker",
).run_once()
assert finished is not None
assert finished.state == TaskState.HUMAN_REVIEW_READY
assert runner.lease_before is not None
assert runner.lease_during is not None
assert runner.lease_during > runner.lease_before
def transition_to_human_review_ready(db, task_id: int, *, pr_number: int = 5, branch_name: str | None = None): def transition_to_human_review_ready(db, task_id: int, *, pr_number: int = 5, branch_name: str | None = None):
db.transition(task_id, TaskState.CLAIMED) db.transition(task_id, TaskState.CLAIMED)
db.transition(task_id, TaskState.PLANNING) db.transition(task_id, TaskState.PLANNING)
@@ -656,6 +781,56 @@ def test_close_issues_for_merged_pull_requests_skips_unmerged_pr(db):
assert db.get_issue(repo.id, 1).state == "open" # type: ignore[union-attr] assert db.get_issue(repo.id, 1).state == "open" # type: ignore[union-attr]
def test_close_issues_for_merged_pull_requests_handles_queued_feedback_task(db):
repo = db.upsert_repository(
owner="acme",
name="service",
clone_url="https://gitea.test/acme/service.git",
default_branch="main",
enabled=True,
)
db.upsert_issue(
repo_id=repo.id,
issue_number=1,
title="Ready issue",
body="Body",
labels=["agent:ready"],
state="open",
html_url="https://gitea.test/acme/service/issues/1",
)
task = db.create_task(repo.id, 1)
task = transition_to_human_review_ready(db, task.id, pr_number=5, branch_name="agent/issue-1-ready-issue")
db.mark_pr_feedback_pending(task.id)
db.transition(
task.id,
TaskState.DISCOVERED,
message="queued PR feedback from 1 human comment(s)",
clear_lease=True,
)
requests: list[tuple[str, str, dict]] = []
def handler(request: httpx.Request) -> httpx.Response:
payload = json.loads(request.content.decode() or "{}")
requests.append((request.method, request.url.path, payload))
if request.url.path == "/api/v1/repos/acme/service/pulls/5":
return httpx.Response(200, json={"number": 5, "state": "closed", "merged": True})
if request.url.path == "/api/v1/repos/acme/service/issues/1/comments":
return httpx.Response(201, json={"id": 1})
if request.url.path == "/api/v1/repos/acme/service/issues/1":
return httpx.Response(200, json={"number": 1, "state": "closed"})
return httpx.Response(404)
closed = close_issues_for_merged_pull_requests(db, make_client(handler))
updated_task = db.get_task(task.id)
assert closed == 1
assert db.get_issue(repo.id, 1).state == "closed" # type: ignore[union-attr]
assert updated_task is not None
assert updated_task.state == TaskState.CANCELLED
assert not db.has_pending_pr_feedback(task.id)
assert ("PATCH", "/api/v1/repos/acme/service/issues/1", {"state": "closed"}) in requests
def test_run_task_no_diff_becomes_blocked(db, tmp_path): def test_run_task_no_diff_becomes_blocked(db, tmp_path):
config = make_config(tmp_path) config = make_config(tmp_path)
seed_task(db) seed_task(db)