Fix task leases and agent timeouts
This commit is contained in:
@@ -7,8 +7,29 @@ from .models import AgentResult
|
||||
|
||||
|
||||
class CommandRunner:
|
||||
def run(self, command: list[str], cwd: str | Path, *, stdin: str | None = None) -> AgentResult:
|
||||
result = subprocess.run(command, cwd=cwd, input=stdin, text=True, capture_output=True, check=False)
|
||||
def run(
|
||||
self,
|
||||
command: list[str],
|
||||
cwd: str | Path,
|
||||
*,
|
||||
stdin: str | None = None,
|
||||
timeout_seconds: int | float | None = None,
|
||||
) -> AgentResult:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
command,
|
||||
cwd=cwd,
|
||||
input=stdin,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
stdout = exc.stdout if isinstance(exc.stdout, str) else (exc.stdout or b"").decode(errors="replace")
|
||||
stderr = exc.stderr if isinstance(exc.stderr, str) else (exc.stderr or b"").decode(errors="replace")
|
||||
message = f"command timed out after {timeout_seconds} seconds"
|
||||
return AgentResult(exit_code=124, stdout=stdout or "", stderr=f"{stderr}\n{message}".strip())
|
||||
return AgentResult(exit_code=result.returncode, stdout=result.stdout, stderr=result.stderr)
|
||||
|
||||
|
||||
|
||||
@@ -24,6 +24,14 @@ class SchedulerConfig(BaseModel):
|
||||
interval_seconds: int = Field(default=60, ge=1)
|
||||
concurrency: int = Field(default=1, ge=1)
|
||||
lease_seconds: int = Field(default=1800, ge=30)
|
||||
lease_renewal_interval_seconds: int | None = Field(default=None, ge=1)
|
||||
agent_timeout_seconds: int = Field(default=7200, ge=1)
|
||||
|
||||
@property
|
||||
def effective_lease_renewal_interval_seconds(self) -> int:
|
||||
if self.lease_renewal_interval_seconds is not None:
|
||||
return self.lease_renewal_interval_seconds
|
||||
return max(1, self.lease_seconds // 3)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseModel):
|
||||
|
||||
@@ -441,49 +441,78 @@ class Database:
|
||||
def claim_next_task(self, worker_id: str, lease_seconds: int) -> TaskRecord | None:
|
||||
now = utcnow()
|
||||
expires = now + timedelta(seconds=lease_seconds)
|
||||
row = self.conn.execute(
|
||||
"""
|
||||
SELECT * FROM tasks
|
||||
WHERE state = ?
|
||||
OR (state IN (?, ?, ?, ?, ?, ?, ?) AND lease_expires_at IS NOT NULL AND lease_expires_at < ?)
|
||||
ORDER BY created_at
|
||||
LIMIT 1
|
||||
""",
|
||||
(
|
||||
TaskState.DISCOVERED.value,
|
||||
TaskState.CLAIMED.value,
|
||||
TaskState.PLANNING.value,
|
||||
TaskState.IMPLEMENTING.value,
|
||||
TaskState.TESTING.value,
|
||||
TaskState.PR_OPENED.value,
|
||||
TaskState.REVIEWING.value,
|
||||
TaskState.DISCOVERED.value,
|
||||
dt_to_db(now),
|
||||
),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
task = self._task(row)
|
||||
self.conn.execute(
|
||||
"""
|
||||
UPDATE tasks
|
||||
SET state = ?, lease_owner = ?, lease_expires_at = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
TaskState.CLAIMED.value,
|
||||
worker_id,
|
||||
dt_to_db(expires),
|
||||
dt_to_db(now),
|
||||
task.id,
|
||||
),
|
||||
)
|
||||
self.conn.commit()
|
||||
try:
|
||||
self.conn.execute("BEGIN IMMEDIATE")
|
||||
row = self.conn.execute(
|
||||
"""
|
||||
SELECT * FROM tasks
|
||||
WHERE state = ?
|
||||
OR (state IN (?, ?, ?, ?, ?, ?, ?) AND lease_expires_at IS NOT NULL AND lease_expires_at < ?)
|
||||
ORDER BY created_at
|
||||
LIMIT 1
|
||||
""",
|
||||
(
|
||||
TaskState.DISCOVERED.value,
|
||||
TaskState.CLAIMED.value,
|
||||
TaskState.PLANNING.value,
|
||||
TaskState.IMPLEMENTING.value,
|
||||
TaskState.TESTING.value,
|
||||
TaskState.PR_OPENED.value,
|
||||
TaskState.REVIEWING.value,
|
||||
TaskState.DISCOVERED.value,
|
||||
dt_to_db(now),
|
||||
),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
self.conn.commit()
|
||||
return None
|
||||
task = self._task(row)
|
||||
self.conn.execute(
|
||||
"""
|
||||
UPDATE tasks
|
||||
SET state = ?, lease_owner = ?, lease_expires_at = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
TaskState.CLAIMED.value,
|
||||
worker_id,
|
||||
dt_to_db(expires),
|
||||
dt_to_db(now),
|
||||
task.id,
|
||||
),
|
||||
)
|
||||
self.conn.commit()
|
||||
except Exception:
|
||||
self.conn.rollback()
|
||||
raise
|
||||
self.add_event(task.id, task.state, TaskState.CLAIMED, f"claimed by {worker_id}")
|
||||
claimed = self.get_task(task.id)
|
||||
assert claimed is not None
|
||||
return claimed
|
||||
|
||||
def renew_task_lease(self, task_id: int, worker_id: str, lease_seconds: int) -> bool:
|
||||
now = utcnow()
|
||||
expires = now + timedelta(seconds=lease_seconds)
|
||||
placeholders = ",".join("?" for _ in ACTIVE_STATES)
|
||||
cursor = self.conn.execute(
|
||||
f"""
|
||||
UPDATE tasks
|
||||
SET lease_expires_at = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
AND lease_owner = ?
|
||||
AND state IN ({placeholders})
|
||||
""",
|
||||
(
|
||||
dt_to_db(expires),
|
||||
dt_to_db(now),
|
||||
task_id,
|
||||
worker_id,
|
||||
*[state.value for state in ACTIVE_STATES],
|
||||
),
|
||||
)
|
||||
self.conn.commit()
|
||||
return cursor.rowcount == 1
|
||||
|
||||
def transition(
|
||||
self,
|
||||
task_id: int,
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import socket
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
@@ -34,6 +35,46 @@ class PullRequestFeedbackSnapshot:
|
||||
newest_cursor: PullRequestFeedbackCursor
|
||||
|
||||
|
||||
class TaskLeaseRenewer:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
db_path: Path,
|
||||
task_id: int,
|
||||
worker_id: str,
|
||||
lease_seconds: int,
|
||||
interval_seconds: int,
|
||||
):
|
||||
self.db_path = db_path
|
||||
self.task_id = task_id
|
||||
self.worker_id = worker_id
|
||||
self.lease_seconds = lease_seconds
|
||||
self.interval_seconds = interval_seconds
|
||||
self._stop = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run, name=f"lease-renewer-{task_id}", daemon=True)
|
||||
|
||||
def __enter__(self) -> TaskLeaseRenewer:
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
self._stop.set()
|
||||
self._thread.join(timeout=max(1, self.interval_seconds))
|
||||
|
||||
def _run(self) -> None:
|
||||
database = Database(self.db_path)
|
||||
try:
|
||||
while not self._stop.wait(self.interval_seconds):
|
||||
renewed = database.renew_task_lease(self.task_id, self.worker_id, self.lease_seconds)
|
||||
if not renewed:
|
||||
logger.warning("stopping lease renewal for task %d; lease no longer belongs to %s", self.task_id, self.worker_id)
|
||||
return
|
||||
except Exception:
|
||||
logger.exception("lease renewal failed for task %d", self.task_id)
|
||||
finally:
|
||||
database.close()
|
||||
|
||||
|
||||
def sync_repositories(db: Database, config: AppConfig, client: GiteaClient) -> list[RepositoryRecord]:
|
||||
synced: list[RepositoryRecord] = []
|
||||
discovered = client.list_owned_repositories()
|
||||
@@ -388,7 +429,7 @@ class TaskRunner:
|
||||
issue_title=issue.title,
|
||||
branch_name=branch_name,
|
||||
)
|
||||
result = self.command_runner.run(command, workspace, stdin=prompt)
|
||||
result = self._run_agent_command(task, command, workspace, prompt)
|
||||
report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md")
|
||||
self.db.add_agent_run(
|
||||
task_id=task.id,
|
||||
@@ -428,7 +469,7 @@ class TaskRunner:
|
||||
pr_number=pr_number,
|
||||
branch_name=branch_name,
|
||||
)
|
||||
result = self.command_runner.run(command, workspace, stdin=prompt)
|
||||
result = self._run_agent_command(task, command, workspace, prompt)
|
||||
report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md")
|
||||
self.db.add_agent_run(
|
||||
task_id=task.id,
|
||||
@@ -465,7 +506,7 @@ class TaskRunner:
|
||||
issue_title=issue.title,
|
||||
pr_number=pr_number,
|
||||
)
|
||||
result = self.command_runner.run(command, workspace, stdin=prompt)
|
||||
result = self._run_agent_command(task, command, workspace, prompt)
|
||||
report = read_report(output_dir / "AGENT_REVIEW_REPORT.md")
|
||||
self.db.add_agent_run(
|
||||
task_id=task.id,
|
||||
@@ -481,6 +522,21 @@ class TaskRunner:
|
||||
raise RuntimeError(f"reviewer failed with exit code {result.exit_code}")
|
||||
return report
|
||||
|
||||
def _run_agent_command(self, task: TaskRecord, command: list[str], workspace: Path, prompt: str) -> AgentResult:
|
||||
with TaskLeaseRenewer(
|
||||
db_path=self.db.path,
|
||||
task_id=task.id,
|
||||
worker_id=self.worker_id,
|
||||
lease_seconds=self.config.scheduler.lease_seconds,
|
||||
interval_seconds=self.config.scheduler.effective_lease_renewal_interval_seconds,
|
||||
):
|
||||
return self.command_runner.run(
|
||||
command,
|
||||
workspace,
|
||||
stdin=prompt,
|
||||
timeout_seconds=self.config.scheduler.agent_timeout_seconds,
|
||||
)
|
||||
|
||||
def _load_context(self, task: TaskRecord) -> tuple[RepositoryRecord, IssueRecord]:
|
||||
return load_task_context(self.db, task)
|
||||
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from agent_gitea.agents import CommandRunner
|
||||
from agent_gitea.config import AppConfig
|
||||
from agent_gitea.db import Database
|
||||
from agent_gitea.gitea import GiteaClient
|
||||
from agent_gitea.models import AgentResult, TaskState
|
||||
from agent_gitea.service import (
|
||||
@@ -208,7 +213,7 @@ class FakeRunner:
|
||||
def __init__(self, *, fail_role: str | None = None):
|
||||
self.fail_role = fail_role
|
||||
|
||||
def run(self, command, cwd, *, stdin=None):
|
||||
def run(self, command, cwd, *, stdin=None, timeout_seconds=None):
|
||||
role = command[0]
|
||||
assert stdin
|
||||
if role == self.fail_role:
|
||||
@@ -250,6 +255,126 @@ def seed_task(db):
|
||||
return db.create_task(repo.id, 1)
|
||||
|
||||
|
||||
def test_claim_next_task_allows_only_one_worker_during_race(db, tmp_path):
|
||||
seed_task(db)
|
||||
db.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000))
|
||||
db.conn.execute(
|
||||
"""
|
||||
CREATE TRIGGER slow_claim_update
|
||||
BEFORE UPDATE OF state ON tasks
|
||||
WHEN NEW.state = 'CLAIMED'
|
||||
BEGIN
|
||||
SELECT sleep_ms(150);
|
||||
END
|
||||
"""
|
||||
)
|
||||
db.conn.commit()
|
||||
db.close()
|
||||
|
||||
barrier = threading.Barrier(2)
|
||||
results: dict[str, int | None] = {}
|
||||
errors: list[BaseException] = []
|
||||
|
||||
def claim(worker_id: str) -> None:
|
||||
database = Database(tmp_path / "state.sqlite3")
|
||||
database.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000))
|
||||
try:
|
||||
barrier.wait()
|
||||
task = database.claim_next_task(worker_id, 60)
|
||||
results[worker_id] = task.id if task else None
|
||||
except BaseException as exc: # pragma: no cover - assertion below reports thread failures
|
||||
errors.append(exc)
|
||||
finally:
|
||||
database.close()
|
||||
|
||||
threads = [threading.Thread(target=claim, args=(worker_id,)) for worker_id in ("worker-a", "worker-b")]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert errors == []
|
||||
assert list(results.values()).count(1) == 1
|
||||
assert list(results.values()).count(None) == 1
|
||||
|
||||
|
||||
def test_command_runner_returns_timeout_result(tmp_path):
|
||||
result = CommandRunner().run(
|
||||
[sys.executable, "-c", "import time; time.sleep(5)"],
|
||||
tmp_path,
|
||||
timeout_seconds=0.1,
|
||||
)
|
||||
|
||||
assert result.exit_code == 124
|
||||
assert "timed out" in result.stderr
|
||||
|
||||
|
||||
class ObservingSlowRunner:
|
||||
def __init__(self, db, task_id: int):
|
||||
self.db = db
|
||||
self.task_id = task_id
|
||||
self.lease_before = None
|
||||
self.lease_during = None
|
||||
|
||||
def run(self, command, cwd, *, stdin=None, timeout_seconds=None):
|
||||
role = command[0]
|
||||
if role == "implementer":
|
||||
self.lease_before = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr]
|
||||
time.sleep(1.2)
|
||||
self.lease_during = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr]
|
||||
output_dir = Path(cwd) / ".agent-output"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
(output_dir / "AGENT_IMPLEMENTATION_REPORT.md").write_text(
|
||||
"## Summary\nImplemented\n\n## Test commands run\npytest\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
if role == "reviewer":
|
||||
output_dir = Path(cwd) / ".agent-output"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
(output_dir / "AGENT_REVIEW_REPORT.md").write_text(
|
||||
"## Verdict\nAPPROVE\n\n## Suggested PR Comment\nLooks good.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return AgentResult(exit_code=0, stdout="ok", stderr="")
|
||||
|
||||
|
||||
def test_task_runner_renews_lease_while_agent_runs(db, tmp_path):
|
||||
config = make_config(
|
||||
tmp_path,
|
||||
scheduler={
|
||||
"interval_seconds": 1,
|
||||
"concurrency": 1,
|
||||
"lease_seconds": 60,
|
||||
"lease_renewal_interval_seconds": 1,
|
||||
},
|
||||
)
|
||||
task = seed_task(db)
|
||||
runner = ObservingSlowRunner(db, task.id)
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
payload = json.loads(request.content.decode() or "{}")
|
||||
if request.url.path == "/api/v1/repos/acme/service/pulls":
|
||||
return httpx.Response(201, json={"number": 5, "state": "open", "merged": False})
|
||||
if request.url.path == "/api/v1/repos/acme/service/issues/5/comments":
|
||||
return httpx.Response(201, json={"id": 1, "body": payload.get("body", ""), "user": {"login": "agent-bot"}})
|
||||
return httpx.Response(404)
|
||||
|
||||
finished = TaskRunner(
|
||||
db=db,
|
||||
config=config,
|
||||
gitea=make_client(handler),
|
||||
workspace_manager=FakeWorkspaceManager(tmp_path / "work"),
|
||||
command_runner=runner,
|
||||
worker_id="worker",
|
||||
).run_once()
|
||||
|
||||
assert finished is not None
|
||||
assert finished.state == TaskState.HUMAN_REVIEW_READY
|
||||
assert runner.lease_before is not None
|
||||
assert runner.lease_during is not None
|
||||
assert runner.lease_during > runner.lease_before
|
||||
|
||||
|
||||
def transition_to_human_review_ready(db, task_id: int, *, pr_number: int = 5, branch_name: str | None = None):
|
||||
db.transition(task_id, TaskState.CLAIMED)
|
||||
db.transition(task_id, TaskState.PLANNING)
|
||||
|
||||
Reference in New Issue
Block a user