Fix task leases and agent timeouts

This commit is contained in:
2026-05-08 21:36:07 +08:00
parent 2ae22b3492
commit 9fc8c14445
5 changed files with 283 additions and 44 deletions

View File

@@ -7,8 +7,29 @@ from .models import AgentResult
class CommandRunner: class CommandRunner:
def run(self, command: list[str], cwd: str | Path, *, stdin: str | None = None) -> AgentResult: def run(
result = subprocess.run(command, cwd=cwd, input=stdin, text=True, capture_output=True, check=False) self,
command: list[str],
cwd: str | Path,
*,
stdin: str | None = None,
timeout_seconds: int | float | None = None,
) -> AgentResult:
try:
result = subprocess.run(
command,
cwd=cwd,
input=stdin,
text=True,
capture_output=True,
check=False,
timeout=timeout_seconds,
)
except subprocess.TimeoutExpired as exc:
stdout = exc.stdout if isinstance(exc.stdout, str) else (exc.stdout or b"").decode(errors="replace")
stderr = exc.stderr if isinstance(exc.stderr, str) else (exc.stderr or b"").decode(errors="replace")
message = f"command timed out after {timeout_seconds} seconds"
return AgentResult(exit_code=124, stdout=stdout or "", stderr=f"{stderr}\n{message}".strip())
return AgentResult(exit_code=result.returncode, stdout=result.stdout, stderr=result.stderr) return AgentResult(exit_code=result.returncode, stdout=result.stdout, stderr=result.stderr)

View File

@@ -24,6 +24,14 @@ class SchedulerConfig(BaseModel):
interval_seconds: int = Field(default=60, ge=1) interval_seconds: int = Field(default=60, ge=1)
concurrency: int = Field(default=1, ge=1) concurrency: int = Field(default=1, ge=1)
lease_seconds: int = Field(default=1800, ge=30) lease_seconds: int = Field(default=1800, ge=30)
lease_renewal_interval_seconds: int | None = Field(default=None, ge=1)
agent_timeout_seconds: int = Field(default=7200, ge=1)
@property
def effective_lease_renewal_interval_seconds(self) -> int:
if self.lease_renewal_interval_seconds is not None:
return self.lease_renewal_interval_seconds
return max(1, self.lease_seconds // 3)
class WorkspaceConfig(BaseModel): class WorkspaceConfig(BaseModel):

View File

@@ -441,6 +441,8 @@ class Database:
def claim_next_task(self, worker_id: str, lease_seconds: int) -> TaskRecord | None: def claim_next_task(self, worker_id: str, lease_seconds: int) -> TaskRecord | None:
now = utcnow() now = utcnow()
expires = now + timedelta(seconds=lease_seconds) expires = now + timedelta(seconds=lease_seconds)
try:
self.conn.execute("BEGIN IMMEDIATE")
row = self.conn.execute( row = self.conn.execute(
""" """
SELECT * FROM tasks SELECT * FROM tasks
@@ -462,6 +464,7 @@ class Database:
), ),
).fetchone() ).fetchone()
if row is None: if row is None:
self.conn.commit()
return None return None
task = self._task(row) task = self._task(row)
self.conn.execute( self.conn.execute(
@@ -479,11 +482,37 @@ class Database:
), ),
) )
self.conn.commit() self.conn.commit()
except Exception:
self.conn.rollback()
raise
self.add_event(task.id, task.state, TaskState.CLAIMED, f"claimed by {worker_id}") self.add_event(task.id, task.state, TaskState.CLAIMED, f"claimed by {worker_id}")
claimed = self.get_task(task.id) claimed = self.get_task(task.id)
assert claimed is not None assert claimed is not None
return claimed return claimed
def renew_task_lease(self, task_id: int, worker_id: str, lease_seconds: int) -> bool:
now = utcnow()
expires = now + timedelta(seconds=lease_seconds)
placeholders = ",".join("?" for _ in ACTIVE_STATES)
cursor = self.conn.execute(
f"""
UPDATE tasks
SET lease_expires_at = ?, updated_at = ?
WHERE id = ?
AND lease_owner = ?
AND state IN ({placeholders})
""",
(
dt_to_db(expires),
dt_to_db(now),
task_id,
worker_id,
*[state.value for state in ACTIVE_STATES],
),
)
self.conn.commit()
return cursor.rowcount == 1
def transition( def transition(
self, self,
task_id: int, task_id: int,

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import socket import socket
import time import time
import logging import logging
import threading
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
@@ -34,6 +35,46 @@ class PullRequestFeedbackSnapshot:
newest_cursor: PullRequestFeedbackCursor newest_cursor: PullRequestFeedbackCursor
class TaskLeaseRenewer:
def __init__(
self,
*,
db_path: Path,
task_id: int,
worker_id: str,
lease_seconds: int,
interval_seconds: int,
):
self.db_path = db_path
self.task_id = task_id
self.worker_id = worker_id
self.lease_seconds = lease_seconds
self.interval_seconds = interval_seconds
self._stop = threading.Event()
self._thread = threading.Thread(target=self._run, name=f"lease-renewer-{task_id}", daemon=True)
def __enter__(self) -> TaskLeaseRenewer:
self._thread.start()
return self
def __exit__(self, exc_type, exc, tb) -> None:
self._stop.set()
self._thread.join(timeout=max(1, self.interval_seconds))
def _run(self) -> None:
database = Database(self.db_path)
try:
while not self._stop.wait(self.interval_seconds):
renewed = database.renew_task_lease(self.task_id, self.worker_id, self.lease_seconds)
if not renewed:
logger.warning("stopping lease renewal for task %d; lease no longer belongs to %s", self.task_id, self.worker_id)
return
except Exception:
logger.exception("lease renewal failed for task %d", self.task_id)
finally:
database.close()
def sync_repositories(db: Database, config: AppConfig, client: GiteaClient) -> list[RepositoryRecord]: def sync_repositories(db: Database, config: AppConfig, client: GiteaClient) -> list[RepositoryRecord]:
synced: list[RepositoryRecord] = [] synced: list[RepositoryRecord] = []
discovered = client.list_owned_repositories() discovered = client.list_owned_repositories()
@@ -388,7 +429,7 @@ class TaskRunner:
issue_title=issue.title, issue_title=issue.title,
branch_name=branch_name, branch_name=branch_name,
) )
result = self.command_runner.run(command, workspace, stdin=prompt) result = self._run_agent_command(task, command, workspace, prompt)
report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md") report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md")
self.db.add_agent_run( self.db.add_agent_run(
task_id=task.id, task_id=task.id,
@@ -428,7 +469,7 @@ class TaskRunner:
pr_number=pr_number, pr_number=pr_number,
branch_name=branch_name, branch_name=branch_name,
) )
result = self.command_runner.run(command, workspace, stdin=prompt) result = self._run_agent_command(task, command, workspace, prompt)
report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md") report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md")
self.db.add_agent_run( self.db.add_agent_run(
task_id=task.id, task_id=task.id,
@@ -465,7 +506,7 @@ class TaskRunner:
issue_title=issue.title, issue_title=issue.title,
pr_number=pr_number, pr_number=pr_number,
) )
result = self.command_runner.run(command, workspace, stdin=prompt) result = self._run_agent_command(task, command, workspace, prompt)
report = read_report(output_dir / "AGENT_REVIEW_REPORT.md") report = read_report(output_dir / "AGENT_REVIEW_REPORT.md")
self.db.add_agent_run( self.db.add_agent_run(
task_id=task.id, task_id=task.id,
@@ -481,6 +522,21 @@ class TaskRunner:
raise RuntimeError(f"reviewer failed with exit code {result.exit_code}") raise RuntimeError(f"reviewer failed with exit code {result.exit_code}")
return report return report
def _run_agent_command(self, task: TaskRecord, command: list[str], workspace: Path, prompt: str) -> AgentResult:
with TaskLeaseRenewer(
db_path=self.db.path,
task_id=task.id,
worker_id=self.worker_id,
lease_seconds=self.config.scheduler.lease_seconds,
interval_seconds=self.config.scheduler.effective_lease_renewal_interval_seconds,
):
return self.command_runner.run(
command,
workspace,
stdin=prompt,
timeout_seconds=self.config.scheduler.agent_timeout_seconds,
)
def _load_context(self, task: TaskRecord) -> tuple[RepositoryRecord, IssueRecord]: def _load_context(self, task: TaskRecord) -> tuple[RepositoryRecord, IssueRecord]:
return load_task_context(self.db, task) return load_task_context(self.db, task)

View File

@@ -1,11 +1,16 @@
from __future__ import annotations from __future__ import annotations
import json import json
import sys
import threading
import time
from pathlib import Path from pathlib import Path
import httpx import httpx
from agent_gitea.agents import CommandRunner
from agent_gitea.config import AppConfig from agent_gitea.config import AppConfig
from agent_gitea.db import Database
from agent_gitea.gitea import GiteaClient from agent_gitea.gitea import GiteaClient
from agent_gitea.models import AgentResult, TaskState from agent_gitea.models import AgentResult, TaskState
from agent_gitea.service import ( from agent_gitea.service import (
@@ -208,7 +213,7 @@ class FakeRunner:
def __init__(self, *, fail_role: str | None = None): def __init__(self, *, fail_role: str | None = None):
self.fail_role = fail_role self.fail_role = fail_role
def run(self, command, cwd, *, stdin=None): def run(self, command, cwd, *, stdin=None, timeout_seconds=None):
role = command[0] role = command[0]
assert stdin assert stdin
if role == self.fail_role: if role == self.fail_role:
@@ -250,6 +255,126 @@ def seed_task(db):
return db.create_task(repo.id, 1) return db.create_task(repo.id, 1)
def test_claim_next_task_allows_only_one_worker_during_race(db, tmp_path):
seed_task(db)
db.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000))
db.conn.execute(
"""
CREATE TRIGGER slow_claim_update
BEFORE UPDATE OF state ON tasks
WHEN NEW.state = 'CLAIMED'
BEGIN
SELECT sleep_ms(150);
END
"""
)
db.conn.commit()
db.close()
barrier = threading.Barrier(2)
results: dict[str, int | None] = {}
errors: list[BaseException] = []
def claim(worker_id: str) -> None:
database = Database(tmp_path / "state.sqlite3")
database.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000))
try:
barrier.wait()
task = database.claim_next_task(worker_id, 60)
results[worker_id] = task.id if task else None
except BaseException as exc: # pragma: no cover - assertion below reports thread failures
errors.append(exc)
finally:
database.close()
threads = [threading.Thread(target=claim, args=(worker_id,)) for worker_id in ("worker-a", "worker-b")]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert errors == []
assert list(results.values()).count(1) == 1
assert list(results.values()).count(None) == 1
def test_command_runner_returns_timeout_result(tmp_path):
result = CommandRunner().run(
[sys.executable, "-c", "import time; time.sleep(5)"],
tmp_path,
timeout_seconds=0.1,
)
assert result.exit_code == 124
assert "timed out" in result.stderr
class ObservingSlowRunner:
def __init__(self, db, task_id: int):
self.db = db
self.task_id = task_id
self.lease_before = None
self.lease_during = None
def run(self, command, cwd, *, stdin=None, timeout_seconds=None):
role = command[0]
if role == "implementer":
self.lease_before = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr]
time.sleep(1.2)
self.lease_during = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr]
output_dir = Path(cwd) / ".agent-output"
output_dir.mkdir(exist_ok=True)
(output_dir / "AGENT_IMPLEMENTATION_REPORT.md").write_text(
"## Summary\nImplemented\n\n## Test commands run\npytest\n",
encoding="utf-8",
)
if role == "reviewer":
output_dir = Path(cwd) / ".agent-output"
output_dir.mkdir(exist_ok=True)
(output_dir / "AGENT_REVIEW_REPORT.md").write_text(
"## Verdict\nAPPROVE\n\n## Suggested PR Comment\nLooks good.\n",
encoding="utf-8",
)
return AgentResult(exit_code=0, stdout="ok", stderr="")
def test_task_runner_renews_lease_while_agent_runs(db, tmp_path):
config = make_config(
tmp_path,
scheduler={
"interval_seconds": 1,
"concurrency": 1,
"lease_seconds": 60,
"lease_renewal_interval_seconds": 1,
},
)
task = seed_task(db)
runner = ObservingSlowRunner(db, task.id)
def handler(request: httpx.Request) -> httpx.Response:
payload = json.loads(request.content.decode() or "{}")
if request.url.path == "/api/v1/repos/acme/service/pulls":
return httpx.Response(201, json={"number": 5, "state": "open", "merged": False})
if request.url.path == "/api/v1/repos/acme/service/issues/5/comments":
return httpx.Response(201, json={"id": 1, "body": payload.get("body", ""), "user": {"login": "agent-bot"}})
return httpx.Response(404)
finished = TaskRunner(
db=db,
config=config,
gitea=make_client(handler),
workspace_manager=FakeWorkspaceManager(tmp_path / "work"),
command_runner=runner,
worker_id="worker",
).run_once()
assert finished is not None
assert finished.state == TaskState.HUMAN_REVIEW_READY
assert runner.lease_before is not None
assert runner.lease_during is not None
assert runner.lease_during > runner.lease_before
def transition_to_human_review_ready(db, task_id: int, *, pr_number: int = 5, branch_name: str | None = None): def transition_to_human_review_ready(db, task_id: int, *, pr_number: int = 5, branch_name: str | None = None):
db.transition(task_id, TaskState.CLAIMED) db.transition(task_id, TaskState.CLAIMED)
db.transition(task_id, TaskState.PLANNING) db.transition(task_id, TaskState.PLANNING)