Compare commits
4 Commits
70a17d6675
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 17d723ecca | |||
| 9fc8c14445 | |||
| 2ae22b3492 | |||
| 3c624cc46d |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
.agent-gitea/
|
||||
.agent-output/
|
||||
.env
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
|
||||
@@ -7,8 +7,29 @@ from .models import AgentResult
|
||||
|
||||
|
||||
class CommandRunner:
|
||||
def run(self, command: list[str], cwd: str | Path, *, stdin: str | None = None) -> AgentResult:
|
||||
result = subprocess.run(command, cwd=cwd, input=stdin, text=True, capture_output=True, check=False)
|
||||
def run(
|
||||
self,
|
||||
command: list[str],
|
||||
cwd: str | Path,
|
||||
*,
|
||||
stdin: str | None = None,
|
||||
timeout_seconds: int | float | None = None,
|
||||
) -> AgentResult:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
command,
|
||||
cwd=cwd,
|
||||
input=stdin,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
stdout = exc.stdout if isinstance(exc.stdout, str) else (exc.stdout or b"").decode(errors="replace")
|
||||
stderr = exc.stderr if isinstance(exc.stderr, str) else (exc.stderr or b"").decode(errors="replace")
|
||||
message = f"command timed out after {timeout_seconds} seconds"
|
||||
return AgentResult(exit_code=124, stdout=stdout or "", stderr=f"{stderr}\n{message}".strip())
|
||||
return AgentResult(exit_code=result.returncode, stdout=result.stdout, stderr=result.stderr)
|
||||
|
||||
|
||||
|
||||
@@ -24,6 +24,14 @@ class SchedulerConfig(BaseModel):
|
||||
interval_seconds: int = Field(default=60, ge=1)
|
||||
concurrency: int = Field(default=1, ge=1)
|
||||
lease_seconds: int = Field(default=1800, ge=30)
|
||||
lease_renewal_interval_seconds: int | None = Field(default=None, ge=1)
|
||||
agent_timeout_seconds: int = Field(default=7200, ge=1)
|
||||
|
||||
@property
|
||||
def effective_lease_renewal_interval_seconds(self) -> int:
|
||||
if self.lease_renewal_interval_seconds is not None:
|
||||
return self.lease_renewal_interval_seconds
|
||||
return max(1, self.lease_seconds // 3)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseModel):
|
||||
|
||||
@@ -292,12 +292,23 @@ class Database:
|
||||
SELECT t.*
|
||||
FROM tasks t
|
||||
JOIN issues i ON i.repo_id = t.repo_id AND i.issue_number = t.issue_number
|
||||
WHERE t.state = ?
|
||||
WHERE t.state != ?
|
||||
AND t.pr_number IS NOT NULL
|
||||
AND i.state = 'open'
|
||||
AND (
|
||||
t.state IN (?, ?)
|
||||
OR t.lease_owner IS NULL
|
||||
OR t.lease_expires_at IS NULL
|
||||
OR t.lease_expires_at < ?
|
||||
)
|
||||
ORDER BY t.id
|
||||
""",
|
||||
(TaskState.HUMAN_REVIEW_READY.value,),
|
||||
(
|
||||
TaskState.CANCELLED.value,
|
||||
TaskState.HUMAN_REVIEW_READY.value,
|
||||
TaskState.FAILED.value,
|
||||
dt_to_db(utcnow()),
|
||||
),
|
||||
).fetchall()
|
||||
return [self._task(row) for row in rows]
|
||||
|
||||
@@ -430,6 +441,8 @@ class Database:
|
||||
def claim_next_task(self, worker_id: str, lease_seconds: int) -> TaskRecord | None:
|
||||
now = utcnow()
|
||||
expires = now + timedelta(seconds=lease_seconds)
|
||||
try:
|
||||
self.conn.execute("BEGIN IMMEDIATE")
|
||||
row = self.conn.execute(
|
||||
"""
|
||||
SELECT * FROM tasks
|
||||
@@ -451,6 +464,7 @@ class Database:
|
||||
),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
self.conn.commit()
|
||||
return None
|
||||
task = self._task(row)
|
||||
self.conn.execute(
|
||||
@@ -468,11 +482,37 @@ class Database:
|
||||
),
|
||||
)
|
||||
self.conn.commit()
|
||||
except Exception:
|
||||
self.conn.rollback()
|
||||
raise
|
||||
self.add_event(task.id, task.state, TaskState.CLAIMED, f"claimed by {worker_id}")
|
||||
claimed = self.get_task(task.id)
|
||||
assert claimed is not None
|
||||
return claimed
|
||||
|
||||
def renew_task_lease(self, task_id: int, worker_id: str, lease_seconds: int) -> bool:
|
||||
now = utcnow()
|
||||
expires = now + timedelta(seconds=lease_seconds)
|
||||
placeholders = ",".join("?" for _ in ACTIVE_STATES)
|
||||
cursor = self.conn.execute(
|
||||
f"""
|
||||
UPDATE tasks
|
||||
SET lease_expires_at = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
AND lease_owner = ?
|
||||
AND state IN ({placeholders})
|
||||
""",
|
||||
(
|
||||
dt_to_db(expires),
|
||||
dt_to_db(now),
|
||||
task_id,
|
||||
worker_id,
|
||||
*[state.value for state in ACTIVE_STATES],
|
||||
),
|
||||
)
|
||||
self.conn.commit()
|
||||
return cursor.rowcount == 1
|
||||
|
||||
def transition(
|
||||
self,
|
||||
task_id: int,
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import socket
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
@@ -10,7 +11,7 @@ from .agents import CommandRunner, read_report, render_command, write_prompt
|
||||
from .config import AppConfig
|
||||
from .db import Database, PullRequestFeedbackCursor
|
||||
from .gitea import GiteaClient, GiteaComment, GiteaPullReview
|
||||
from .models import IssueRecord, RepositoryRecord, TaskRecord, TaskState
|
||||
from .models import ACTIVE_STATES, IssueRecord, RepositoryRecord, TaskRecord, TaskState
|
||||
from .rendering import (
|
||||
parse_review_report,
|
||||
render_human_review_summary,
|
||||
@@ -34,6 +35,46 @@ class PullRequestFeedbackSnapshot:
|
||||
newest_cursor: PullRequestFeedbackCursor
|
||||
|
||||
|
||||
class TaskLeaseRenewer:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
db_path: Path,
|
||||
task_id: int,
|
||||
worker_id: str,
|
||||
lease_seconds: int,
|
||||
interval_seconds: int,
|
||||
):
|
||||
self.db_path = db_path
|
||||
self.task_id = task_id
|
||||
self.worker_id = worker_id
|
||||
self.lease_seconds = lease_seconds
|
||||
self.interval_seconds = interval_seconds
|
||||
self._stop = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run, name=f"lease-renewer-{task_id}", daemon=True)
|
||||
|
||||
def __enter__(self) -> TaskLeaseRenewer:
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
self._stop.set()
|
||||
self._thread.join(timeout=max(1, self.interval_seconds))
|
||||
|
||||
def _run(self) -> None:
|
||||
database = Database(self.db_path)
|
||||
try:
|
||||
while not self._stop.wait(self.interval_seconds):
|
||||
renewed = database.renew_task_lease(self.task_id, self.worker_id, self.lease_seconds)
|
||||
if not renewed:
|
||||
logger.warning("stopping lease renewal for task %d; lease no longer belongs to %s", self.task_id, self.worker_id)
|
||||
return
|
||||
except Exception:
|
||||
logger.exception("lease renewal failed for task %d", self.task_id)
|
||||
finally:
|
||||
database.close()
|
||||
|
||||
|
||||
def sync_repositories(db: Database, config: AppConfig, client: GiteaClient) -> list[RepositoryRecord]:
|
||||
synced: list[RepositoryRecord] = []
|
||||
discovered = client.list_owned_repositories()
|
||||
@@ -92,12 +133,12 @@ def close_issues_for_merged_pull_requests(db: Database, client: GiteaClient) ->
|
||||
)
|
||||
client.close_issue(owner=repo.owner, name=repo.name, issue_number=issue.issue_number)
|
||||
db.update_issue_state(task.repo_id, task.issue_number, "closed")
|
||||
db.add_event(
|
||||
task.id,
|
||||
task.state,
|
||||
task.state,
|
||||
f"closed issue #{issue.issue_number} after merged PR #{task.pr_number}",
|
||||
)
|
||||
message = f"closed issue #{issue.issue_number} after merged PR #{task.pr_number}"
|
||||
if task.state in ACTIVE_STATES:
|
||||
db.clear_pr_feedback_pending(task.id)
|
||||
db.transition(task.id, TaskState.CANCELLED, message=message, clear_lease=True)
|
||||
else:
|
||||
db.add_event(task.id, task.state, task.state, message)
|
||||
closed += 1
|
||||
return closed
|
||||
|
||||
@@ -388,7 +429,7 @@ class TaskRunner:
|
||||
issue_title=issue.title,
|
||||
branch_name=branch_name,
|
||||
)
|
||||
result = self.command_runner.run(command, workspace, stdin=prompt)
|
||||
result = self._run_agent_command(task, command, workspace, prompt)
|
||||
report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md")
|
||||
self.db.add_agent_run(
|
||||
task_id=task.id,
|
||||
@@ -428,7 +469,7 @@ class TaskRunner:
|
||||
pr_number=pr_number,
|
||||
branch_name=branch_name,
|
||||
)
|
||||
result = self.command_runner.run(command, workspace, stdin=prompt)
|
||||
result = self._run_agent_command(task, command, workspace, prompt)
|
||||
report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md")
|
||||
self.db.add_agent_run(
|
||||
task_id=task.id,
|
||||
@@ -465,7 +506,7 @@ class TaskRunner:
|
||||
issue_title=issue.title,
|
||||
pr_number=pr_number,
|
||||
)
|
||||
result = self.command_runner.run(command, workspace, stdin=prompt)
|
||||
result = self._run_agent_command(task, command, workspace, prompt)
|
||||
report = read_report(output_dir / "AGENT_REVIEW_REPORT.md")
|
||||
self.db.add_agent_run(
|
||||
task_id=task.id,
|
||||
@@ -481,6 +522,21 @@ class TaskRunner:
|
||||
raise RuntimeError(f"reviewer failed with exit code {result.exit_code}")
|
||||
return report
|
||||
|
||||
def _run_agent_command(self, task: TaskRecord, command: list[str], workspace: Path, prompt: str) -> AgentResult:
|
||||
with TaskLeaseRenewer(
|
||||
db_path=self.db.path,
|
||||
task_id=task.id,
|
||||
worker_id=self.worker_id,
|
||||
lease_seconds=self.config.scheduler.lease_seconds,
|
||||
interval_seconds=self.config.scheduler.effective_lease_renewal_interval_seconds,
|
||||
):
|
||||
return self.command_runner.run(
|
||||
command,
|
||||
workspace,
|
||||
stdin=prompt,
|
||||
timeout_seconds=self.config.scheduler.agent_timeout_seconds,
|
||||
)
|
||||
|
||||
def _load_context(self, task: TaskRecord) -> tuple[RepositoryRecord, IssueRecord]:
|
||||
return load_task_context(self.db, task)
|
||||
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from agent_gitea.agents import CommandRunner
|
||||
from agent_gitea.config import AppConfig
|
||||
from agent_gitea.db import Database
|
||||
from agent_gitea.gitea import GiteaClient
|
||||
from agent_gitea.models import AgentResult, TaskState
|
||||
from agent_gitea.service import (
|
||||
@@ -208,7 +213,7 @@ class FakeRunner:
|
||||
def __init__(self, *, fail_role: str | None = None):
|
||||
self.fail_role = fail_role
|
||||
|
||||
def run(self, command, cwd, *, stdin=None):
|
||||
def run(self, command, cwd, *, stdin=None, timeout_seconds=None):
|
||||
role = command[0]
|
||||
assert stdin
|
||||
if role == self.fail_role:
|
||||
@@ -250,6 +255,126 @@ def seed_task(db):
|
||||
return db.create_task(repo.id, 1)
|
||||
|
||||
|
||||
def test_claim_next_task_allows_only_one_worker_during_race(db, tmp_path):
|
||||
seed_task(db)
|
||||
db.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000))
|
||||
db.conn.execute(
|
||||
"""
|
||||
CREATE TRIGGER slow_claim_update
|
||||
BEFORE UPDATE OF state ON tasks
|
||||
WHEN NEW.state = 'CLAIMED'
|
||||
BEGIN
|
||||
SELECT sleep_ms(150);
|
||||
END
|
||||
"""
|
||||
)
|
||||
db.conn.commit()
|
||||
db.close()
|
||||
|
||||
barrier = threading.Barrier(2)
|
||||
results: dict[str, int | None] = {}
|
||||
errors: list[BaseException] = []
|
||||
|
||||
def claim(worker_id: str) -> None:
|
||||
database = Database(tmp_path / "state.sqlite3")
|
||||
database.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000))
|
||||
try:
|
||||
barrier.wait()
|
||||
task = database.claim_next_task(worker_id, 60)
|
||||
results[worker_id] = task.id if task else None
|
||||
except BaseException as exc: # pragma: no cover - assertion below reports thread failures
|
||||
errors.append(exc)
|
||||
finally:
|
||||
database.close()
|
||||
|
||||
threads = [threading.Thread(target=claim, args=(worker_id,)) for worker_id in ("worker-a", "worker-b")]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert errors == []
|
||||
assert list(results.values()).count(1) == 1
|
||||
assert list(results.values()).count(None) == 1
|
||||
|
||||
|
||||
def test_command_runner_returns_timeout_result(tmp_path):
|
||||
result = CommandRunner().run(
|
||||
[sys.executable, "-c", "import time; time.sleep(5)"],
|
||||
tmp_path,
|
||||
timeout_seconds=0.1,
|
||||
)
|
||||
|
||||
assert result.exit_code == 124
|
||||
assert "timed out" in result.stderr
|
||||
|
||||
|
||||
class ObservingSlowRunner:
|
||||
def __init__(self, db, task_id: int):
|
||||
self.db = db
|
||||
self.task_id = task_id
|
||||
self.lease_before = None
|
||||
self.lease_during = None
|
||||
|
||||
def run(self, command, cwd, *, stdin=None, timeout_seconds=None):
|
||||
role = command[0]
|
||||
if role == "implementer":
|
||||
self.lease_before = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr]
|
||||
time.sleep(1.2)
|
||||
self.lease_during = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr]
|
||||
output_dir = Path(cwd) / ".agent-output"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
(output_dir / "AGENT_IMPLEMENTATION_REPORT.md").write_text(
|
||||
"## Summary\nImplemented\n\n## Test commands run\npytest\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
if role == "reviewer":
|
||||
output_dir = Path(cwd) / ".agent-output"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
(output_dir / "AGENT_REVIEW_REPORT.md").write_text(
|
||||
"## Verdict\nAPPROVE\n\n## Suggested PR Comment\nLooks good.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return AgentResult(exit_code=0, stdout="ok", stderr="")
|
||||
|
||||
|
||||
def test_task_runner_renews_lease_while_agent_runs(db, tmp_path):
|
||||
config = make_config(
|
||||
tmp_path,
|
||||
scheduler={
|
||||
"interval_seconds": 1,
|
||||
"concurrency": 1,
|
||||
"lease_seconds": 60,
|
||||
"lease_renewal_interval_seconds": 1,
|
||||
},
|
||||
)
|
||||
task = seed_task(db)
|
||||
runner = ObservingSlowRunner(db, task.id)
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
payload = json.loads(request.content.decode() or "{}")
|
||||
if request.url.path == "/api/v1/repos/acme/service/pulls":
|
||||
return httpx.Response(201, json={"number": 5, "state": "open", "merged": False})
|
||||
if request.url.path == "/api/v1/repos/acme/service/issues/5/comments":
|
||||
return httpx.Response(201, json={"id": 1, "body": payload.get("body", ""), "user": {"login": "agent-bot"}})
|
||||
return httpx.Response(404)
|
||||
|
||||
finished = TaskRunner(
|
||||
db=db,
|
||||
config=config,
|
||||
gitea=make_client(handler),
|
||||
workspace_manager=FakeWorkspaceManager(tmp_path / "work"),
|
||||
command_runner=runner,
|
||||
worker_id="worker",
|
||||
).run_once()
|
||||
|
||||
assert finished is not None
|
||||
assert finished.state == TaskState.HUMAN_REVIEW_READY
|
||||
assert runner.lease_before is not None
|
||||
assert runner.lease_during is not None
|
||||
assert runner.lease_during > runner.lease_before
|
||||
|
||||
|
||||
def transition_to_human_review_ready(db, task_id: int, *, pr_number: int = 5, branch_name: str | None = None):
|
||||
db.transition(task_id, TaskState.CLAIMED)
|
||||
db.transition(task_id, TaskState.PLANNING)
|
||||
@@ -656,6 +781,56 @@ def test_close_issues_for_merged_pull_requests_skips_unmerged_pr(db):
|
||||
assert db.get_issue(repo.id, 1).state == "open" # type: ignore[union-attr]
|
||||
|
||||
|
||||
def test_close_issues_for_merged_pull_requests_handles_queued_feedback_task(db):
|
||||
repo = db.upsert_repository(
|
||||
owner="acme",
|
||||
name="service",
|
||||
clone_url="https://gitea.test/acme/service.git",
|
||||
default_branch="main",
|
||||
enabled=True,
|
||||
)
|
||||
db.upsert_issue(
|
||||
repo_id=repo.id,
|
||||
issue_number=1,
|
||||
title="Ready issue",
|
||||
body="Body",
|
||||
labels=["agent:ready"],
|
||||
state="open",
|
||||
html_url="https://gitea.test/acme/service/issues/1",
|
||||
)
|
||||
task = db.create_task(repo.id, 1)
|
||||
task = transition_to_human_review_ready(db, task.id, pr_number=5, branch_name="agent/issue-1-ready-issue")
|
||||
db.mark_pr_feedback_pending(task.id)
|
||||
db.transition(
|
||||
task.id,
|
||||
TaskState.DISCOVERED,
|
||||
message="queued PR feedback from 1 human comment(s)",
|
||||
clear_lease=True,
|
||||
)
|
||||
requests: list[tuple[str, str, dict]] = []
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
payload = json.loads(request.content.decode() or "{}")
|
||||
requests.append((request.method, request.url.path, payload))
|
||||
if request.url.path == "/api/v1/repos/acme/service/pulls/5":
|
||||
return httpx.Response(200, json={"number": 5, "state": "closed", "merged": True})
|
||||
if request.url.path == "/api/v1/repos/acme/service/issues/1/comments":
|
||||
return httpx.Response(201, json={"id": 1})
|
||||
if request.url.path == "/api/v1/repos/acme/service/issues/1":
|
||||
return httpx.Response(200, json={"number": 1, "state": "closed"})
|
||||
return httpx.Response(404)
|
||||
|
||||
closed = close_issues_for_merged_pull_requests(db, make_client(handler))
|
||||
|
||||
updated_task = db.get_task(task.id)
|
||||
assert closed == 1
|
||||
assert db.get_issue(repo.id, 1).state == "closed" # type: ignore[union-attr]
|
||||
assert updated_task is not None
|
||||
assert updated_task.state == TaskState.CANCELLED
|
||||
assert not db.has_pending_pr_feedback(task.id)
|
||||
assert ("PATCH", "/api/v1/repos/acme/service/issues/1", {"state": "closed"}) in requests
|
||||
|
||||
|
||||
def test_run_task_no_diff_becomes_blocked(db, tmp_path):
|
||||
config = make_config(tmp_path)
|
||||
seed_task(db)
|
||||
|
||||
Reference in New Issue
Block a user