diff --git a/src/agent_gitea/agents.py b/src/agent_gitea/agents.py index 1f4539d..2437878 100644 --- a/src/agent_gitea/agents.py +++ b/src/agent_gitea/agents.py @@ -7,8 +7,29 @@ from .models import AgentResult class CommandRunner: - def run(self, command: list[str], cwd: str | Path, *, stdin: str | None = None) -> AgentResult: - result = subprocess.run(command, cwd=cwd, input=stdin, text=True, capture_output=True, check=False) + def run( + self, + command: list[str], + cwd: str | Path, + *, + stdin: str | None = None, + timeout_seconds: int | float | None = None, + ) -> AgentResult: + try: + result = subprocess.run( + command, + cwd=cwd, + input=stdin, + text=True, + capture_output=True, + check=False, + timeout=timeout_seconds, + ) + except subprocess.TimeoutExpired as exc: + stdout = exc.stdout if isinstance(exc.stdout, str) else (exc.stdout or b"").decode(errors="replace") + stderr = exc.stderr if isinstance(exc.stderr, str) else (exc.stderr or b"").decode(errors="replace") + message = f"command timed out after {timeout_seconds} seconds" + return AgentResult(exit_code=124, stdout=stdout or "", stderr=f"{stderr}\n{message}".strip()) return AgentResult(exit_code=result.returncode, stdout=result.stdout, stderr=result.stderr) diff --git a/src/agent_gitea/config.py b/src/agent_gitea/config.py index 9ce74ef..d3ef14d 100644 --- a/src/agent_gitea/config.py +++ b/src/agent_gitea/config.py @@ -24,6 +24,14 @@ class SchedulerConfig(BaseModel): interval_seconds: int = Field(default=60, ge=1) concurrency: int = Field(default=1, ge=1) lease_seconds: int = Field(default=1800, ge=30) + lease_renewal_interval_seconds: int | None = Field(default=None, ge=1) + agent_timeout_seconds: int = Field(default=7200, ge=1) + + @property + def effective_lease_renewal_interval_seconds(self) -> int: + if self.lease_renewal_interval_seconds is not None: + return self.lease_renewal_interval_seconds + return max(1, self.lease_seconds // 3) class WorkspaceConfig(BaseModel): diff --git a/src/agent_gitea/db.py b/src/agent_gitea/db.py index 16875b5..d479e07 100644 --- a/src/agent_gitea/db.py +++ b/src/agent_gitea/db.py @@ -441,49 +441,78 @@ class Database: def claim_next_task(self, worker_id: str, lease_seconds: int) -> TaskRecord | None: now = utcnow() expires = now + timedelta(seconds=lease_seconds) - row = self.conn.execute( - """ - SELECT * FROM tasks - WHERE state = ? - OR (state IN (?, ?, ?, ?, ?, ?, ?) AND lease_expires_at IS NOT NULL AND lease_expires_at < ?) - ORDER BY created_at - LIMIT 1 - """, - ( - TaskState.DISCOVERED.value, - TaskState.CLAIMED.value, - TaskState.PLANNING.value, - TaskState.IMPLEMENTING.value, - TaskState.TESTING.value, - TaskState.PR_OPENED.value, - TaskState.REVIEWING.value, - TaskState.DISCOVERED.value, - dt_to_db(now), - ), - ).fetchone() - if row is None: - return None - task = self._task(row) - self.conn.execute( - """ - UPDATE tasks - SET state = ?, lease_owner = ?, lease_expires_at = ?, updated_at = ? - WHERE id = ? - """, - ( - TaskState.CLAIMED.value, - worker_id, - dt_to_db(expires), - dt_to_db(now), - task.id, - ), - ) - self.conn.commit() + try: + self.conn.execute("BEGIN IMMEDIATE") + row = self.conn.execute( + """ + SELECT * FROM tasks + WHERE state = ? + OR (state IN (?, ?, ?, ?, ?, ?, ?) AND lease_expires_at IS NOT NULL AND lease_expires_at < ?) + ORDER BY created_at + LIMIT 1 + """, + ( + TaskState.DISCOVERED.value, + TaskState.CLAIMED.value, + TaskState.PLANNING.value, + TaskState.IMPLEMENTING.value, + TaskState.TESTING.value, + TaskState.PR_OPENED.value, + TaskState.REVIEWING.value, + TaskState.DISCOVERED.value, + dt_to_db(now), + ), + ).fetchone() + if row is None: + self.conn.commit() + return None + task = self._task(row) + self.conn.execute( + """ + UPDATE tasks + SET state = ?, lease_owner = ?, lease_expires_at = ?, updated_at = ? + WHERE id = ? + """, + ( + TaskState.CLAIMED.value, + worker_id, + dt_to_db(expires), + dt_to_db(now), + task.id, + ), + ) + self.conn.commit() + except Exception: + self.conn.rollback() + raise self.add_event(task.id, task.state, TaskState.CLAIMED, f"claimed by {worker_id}") claimed = self.get_task(task.id) assert claimed is not None return claimed + def renew_task_lease(self, task_id: int, worker_id: str, lease_seconds: int) -> bool: + now = utcnow() + expires = now + timedelta(seconds=lease_seconds) + placeholders = ",".join("?" for _ in ACTIVE_STATES) + cursor = self.conn.execute( + f""" + UPDATE tasks + SET lease_expires_at = ?, updated_at = ? + WHERE id = ? + AND lease_owner = ? + AND state IN ({placeholders}) + """, + ( + dt_to_db(expires), + dt_to_db(now), + task_id, + worker_id, + *[state.value for state in ACTIVE_STATES], + ), + ) + self.conn.commit() + return cursor.rowcount == 1 + def transition( self, task_id: int, diff --git a/src/agent_gitea/service.py b/src/agent_gitea/service.py index 182c330..cb3f61d 100644 --- a/src/agent_gitea/service.py +++ b/src/agent_gitea/service.py @@ -3,6 +3,7 @@ from __future__ import annotations import socket import time import logging +import threading from dataclasses import dataclass from pathlib import Path @@ -34,6 +35,46 @@ class PullRequestFeedbackSnapshot: newest_cursor: PullRequestFeedbackCursor +class TaskLeaseRenewer: + def __init__( + self, + *, + db_path: Path, + task_id: int, + worker_id: str, + lease_seconds: int, + interval_seconds: int, + ): + self.db_path = db_path + self.task_id = task_id + self.worker_id = worker_id + self.lease_seconds = lease_seconds + self.interval_seconds = interval_seconds + self._stop = threading.Event() + self._thread = threading.Thread(target=self._run, name=f"lease-renewer-{task_id}", daemon=True) + + def __enter__(self) -> TaskLeaseRenewer: + self._thread.start() + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self._stop.set() + self._thread.join(timeout=max(1, self.interval_seconds)) + + def _run(self) -> None: + database = Database(self.db_path) + try: + while not self._stop.wait(self.interval_seconds): + renewed = database.renew_task_lease(self.task_id, self.worker_id, self.lease_seconds) + if not renewed: + logger.warning("stopping lease renewal for task %d; lease no longer belongs to %s", self.task_id, self.worker_id) + return + except Exception: + logger.exception("lease renewal failed for task %d", self.task_id) + finally: + database.close() + + def sync_repositories(db: Database, config: AppConfig, client: GiteaClient) -> list[RepositoryRecord]: synced: list[RepositoryRecord] = [] discovered = client.list_owned_repositories() @@ -388,7 +429,7 @@ class TaskRunner: issue_title=issue.title, branch_name=branch_name, ) - result = self.command_runner.run(command, workspace, stdin=prompt) + result = self._run_agent_command(task, command, workspace, prompt) report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md") self.db.add_agent_run( task_id=task.id, @@ -428,7 +469,7 @@ class TaskRunner: pr_number=pr_number, branch_name=branch_name, ) - result = self.command_runner.run(command, workspace, stdin=prompt) + result = self._run_agent_command(task, command, workspace, prompt) report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md") self.db.add_agent_run( task_id=task.id, @@ -465,7 +506,7 @@ class TaskRunner: issue_title=issue.title, pr_number=pr_number, ) - result = self.command_runner.run(command, workspace, stdin=prompt) + result = self._run_agent_command(task, command, workspace, prompt) report = read_report(output_dir / "AGENT_REVIEW_REPORT.md") self.db.add_agent_run( task_id=task.id, @@ -481,6 +522,21 @@ class TaskRunner: raise RuntimeError(f"reviewer failed with exit code {result.exit_code}") return report + def _run_agent_command(self, task: TaskRecord, command: list[str], workspace: Path, prompt: str) -> AgentResult: + with TaskLeaseRenewer( + db_path=self.db.path, + task_id=task.id, + worker_id=self.worker_id, + lease_seconds=self.config.scheduler.lease_seconds, + interval_seconds=self.config.scheduler.effective_lease_renewal_interval_seconds, + ): + return self.command_runner.run( + command, + workspace, + stdin=prompt, + timeout_seconds=self.config.scheduler.agent_timeout_seconds, + ) + def _load_context(self, task: TaskRecord) -> tuple[RepositoryRecord, IssueRecord]: return load_task_context(self.db, task) diff --git a/tests/test_gitea_service.py b/tests/test_gitea_service.py index 3d6d497..8591c30 100644 --- a/tests/test_gitea_service.py +++ b/tests/test_gitea_service.py @@ -1,11 +1,16 @@ from __future__ import annotations import json +import sys +import threading +import time from pathlib import Path import httpx +from agent_gitea.agents import CommandRunner from agent_gitea.config import AppConfig +from agent_gitea.db import Database from agent_gitea.gitea import GiteaClient from agent_gitea.models import AgentResult, TaskState from agent_gitea.service import ( @@ -208,7 +213,7 @@ class FakeRunner: def __init__(self, *, fail_role: str | None = None): self.fail_role = fail_role - def run(self, command, cwd, *, stdin=None): + def run(self, command, cwd, *, stdin=None, timeout_seconds=None): role = command[0] assert stdin if role == self.fail_role: @@ -250,6 +255,126 @@ def seed_task(db): return db.create_task(repo.id, 1) +def test_claim_next_task_allows_only_one_worker_during_race(db, tmp_path): + seed_task(db) + db.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000)) + db.conn.execute( + """ + CREATE TRIGGER slow_claim_update + BEFORE UPDATE OF state ON tasks + WHEN NEW.state = 'CLAIMED' + BEGIN + SELECT sleep_ms(150); + END + """ + ) + db.conn.commit() + db.close() + + barrier = threading.Barrier(2) + results: dict[str, int | None] = {} + errors: list[BaseException] = [] + + def claim(worker_id: str) -> None: + database = Database(tmp_path / "state.sqlite3") + database.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000)) + try: + barrier.wait() + task = database.claim_next_task(worker_id, 60) + results[worker_id] = task.id if task else None + except BaseException as exc: # pragma: no cover - assertion below reports thread failures + errors.append(exc) + finally: + database.close() + + threads = [threading.Thread(target=claim, args=(worker_id,)) for worker_id in ("worker-a", "worker-b")] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert errors == [] + assert list(results.values()).count(1) == 1 + assert list(results.values()).count(None) == 1 + + +def test_command_runner_returns_timeout_result(tmp_path): + result = CommandRunner().run( + [sys.executable, "-c", "import time; time.sleep(5)"], + tmp_path, + timeout_seconds=0.1, + ) + + assert result.exit_code == 124 + assert "timed out" in result.stderr + + +class ObservingSlowRunner: + def __init__(self, db, task_id: int): + self.db = db + self.task_id = task_id + self.lease_before = None + self.lease_during = None + + def run(self, command, cwd, *, stdin=None, timeout_seconds=None): + role = command[0] + if role == "implementer": + self.lease_before = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr] + time.sleep(1.2) + self.lease_during = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr] + output_dir = Path(cwd) / ".agent-output" + output_dir.mkdir(exist_ok=True) + (output_dir / "AGENT_IMPLEMENTATION_REPORT.md").write_text( + "## Summary\nImplemented\n\n## Test commands run\npytest\n", + encoding="utf-8", + ) + if role == "reviewer": + output_dir = Path(cwd) / ".agent-output" + output_dir.mkdir(exist_ok=True) + (output_dir / "AGENT_REVIEW_REPORT.md").write_text( + "## Verdict\nAPPROVE\n\n## Suggested PR Comment\nLooks good.\n", + encoding="utf-8", + ) + return AgentResult(exit_code=0, stdout="ok", stderr="") + + +def test_task_runner_renews_lease_while_agent_runs(db, tmp_path): + config = make_config( + tmp_path, + scheduler={ + "interval_seconds": 1, + "concurrency": 1, + "lease_seconds": 60, + "lease_renewal_interval_seconds": 1, + }, + ) + task = seed_task(db) + runner = ObservingSlowRunner(db, task.id) + + def handler(request: httpx.Request) -> httpx.Response: + payload = json.loads(request.content.decode() or "{}") + if request.url.path == "/api/v1/repos/acme/service/pulls": + return httpx.Response(201, json={"number": 5, "state": "open", "merged": False}) + if request.url.path == "/api/v1/repos/acme/service/issues/5/comments": + return httpx.Response(201, json={"id": 1, "body": payload.get("body", ""), "user": {"login": "agent-bot"}}) + return httpx.Response(404) + + finished = TaskRunner( + db=db, + config=config, + gitea=make_client(handler), + workspace_manager=FakeWorkspaceManager(tmp_path / "work"), + command_runner=runner, + worker_id="worker", + ).run_once() + + assert finished is not None + assert finished.state == TaskState.HUMAN_REVIEW_READY + assert runner.lease_before is not None + assert runner.lease_during is not None + assert runner.lease_during > runner.lease_before + + def transition_to_human_review_ready(db, task_id: int, *, pr_number: int = 5, branch_name: str | None = None): db.transition(task_id, TaskState.CLAIMED) db.transition(task_id, TaskState.PLANNING)