Fix task leases and agent timeouts
This commit is contained in:
@@ -1,11 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from agent_gitea.agents import CommandRunner
|
||||
from agent_gitea.config import AppConfig
|
||||
from agent_gitea.db import Database
|
||||
from agent_gitea.gitea import GiteaClient
|
||||
from agent_gitea.models import AgentResult, TaskState
|
||||
from agent_gitea.service import (
|
||||
@@ -208,7 +213,7 @@ class FakeRunner:
|
||||
def __init__(self, *, fail_role: str | None = None):
|
||||
self.fail_role = fail_role
|
||||
|
||||
def run(self, command, cwd, *, stdin=None):
|
||||
def run(self, command, cwd, *, stdin=None, timeout_seconds=None):
|
||||
role = command[0]
|
||||
assert stdin
|
||||
if role == self.fail_role:
|
||||
@@ -250,6 +255,126 @@ def seed_task(db):
|
||||
return db.create_task(repo.id, 1)
|
||||
|
||||
|
||||
def test_claim_next_task_allows_only_one_worker_during_race(db, tmp_path):
|
||||
seed_task(db)
|
||||
db.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000))
|
||||
db.conn.execute(
|
||||
"""
|
||||
CREATE TRIGGER slow_claim_update
|
||||
BEFORE UPDATE OF state ON tasks
|
||||
WHEN NEW.state = 'CLAIMED'
|
||||
BEGIN
|
||||
SELECT sleep_ms(150);
|
||||
END
|
||||
"""
|
||||
)
|
||||
db.conn.commit()
|
||||
db.close()
|
||||
|
||||
barrier = threading.Barrier(2)
|
||||
results: dict[str, int | None] = {}
|
||||
errors: list[BaseException] = []
|
||||
|
||||
def claim(worker_id: str) -> None:
|
||||
database = Database(tmp_path / "state.sqlite3")
|
||||
database.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000))
|
||||
try:
|
||||
barrier.wait()
|
||||
task = database.claim_next_task(worker_id, 60)
|
||||
results[worker_id] = task.id if task else None
|
||||
except BaseException as exc: # pragma: no cover - assertion below reports thread failures
|
||||
errors.append(exc)
|
||||
finally:
|
||||
database.close()
|
||||
|
||||
threads = [threading.Thread(target=claim, args=(worker_id,)) for worker_id in ("worker-a", "worker-b")]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert errors == []
|
||||
assert list(results.values()).count(1) == 1
|
||||
assert list(results.values()).count(None) == 1
|
||||
|
||||
|
||||
def test_command_runner_returns_timeout_result(tmp_path):
|
||||
result = CommandRunner().run(
|
||||
[sys.executable, "-c", "import time; time.sleep(5)"],
|
||||
tmp_path,
|
||||
timeout_seconds=0.1,
|
||||
)
|
||||
|
||||
assert result.exit_code == 124
|
||||
assert "timed out" in result.stderr
|
||||
|
||||
|
||||
class ObservingSlowRunner:
|
||||
def __init__(self, db, task_id: int):
|
||||
self.db = db
|
||||
self.task_id = task_id
|
||||
self.lease_before = None
|
||||
self.lease_during = None
|
||||
|
||||
def run(self, command, cwd, *, stdin=None, timeout_seconds=None):
|
||||
role = command[0]
|
||||
if role == "implementer":
|
||||
self.lease_before = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr]
|
||||
time.sleep(1.2)
|
||||
self.lease_during = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr]
|
||||
output_dir = Path(cwd) / ".agent-output"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
(output_dir / "AGENT_IMPLEMENTATION_REPORT.md").write_text(
|
||||
"## Summary\nImplemented\n\n## Test commands run\npytest\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
if role == "reviewer":
|
||||
output_dir = Path(cwd) / ".agent-output"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
(output_dir / "AGENT_REVIEW_REPORT.md").write_text(
|
||||
"## Verdict\nAPPROVE\n\n## Suggested PR Comment\nLooks good.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return AgentResult(exit_code=0, stdout="ok", stderr="")
|
||||
|
||||
|
||||
def test_task_runner_renews_lease_while_agent_runs(db, tmp_path):
|
||||
config = make_config(
|
||||
tmp_path,
|
||||
scheduler={
|
||||
"interval_seconds": 1,
|
||||
"concurrency": 1,
|
||||
"lease_seconds": 60,
|
||||
"lease_renewal_interval_seconds": 1,
|
||||
},
|
||||
)
|
||||
task = seed_task(db)
|
||||
runner = ObservingSlowRunner(db, task.id)
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
payload = json.loads(request.content.decode() or "{}")
|
||||
if request.url.path == "/api/v1/repos/acme/service/pulls":
|
||||
return httpx.Response(201, json={"number": 5, "state": "open", "merged": False})
|
||||
if request.url.path == "/api/v1/repos/acme/service/issues/5/comments":
|
||||
return httpx.Response(201, json={"id": 1, "body": payload.get("body", ""), "user": {"login": "agent-bot"}})
|
||||
return httpx.Response(404)
|
||||
|
||||
finished = TaskRunner(
|
||||
db=db,
|
||||
config=config,
|
||||
gitea=make_client(handler),
|
||||
workspace_manager=FakeWorkspaceManager(tmp_path / "work"),
|
||||
command_runner=runner,
|
||||
worker_id="worker",
|
||||
).run_once()
|
||||
|
||||
assert finished is not None
|
||||
assert finished.state == TaskState.HUMAN_REVIEW_READY
|
||||
assert runner.lease_before is not None
|
||||
assert runner.lease_during is not None
|
||||
assert runner.lease_during > runner.lease_before
|
||||
|
||||
|
||||
def transition_to_human_review_ready(db, task_id: int, *, pr_number: int = 5, branch_name: str | None = None):
|
||||
db.transition(task_id, TaskState.CLAIMED)
|
||||
db.transition(task_id, TaskState.PLANNING)
|
||||
|
||||
Reference in New Issue
Block a user