Compare commits
4 Commits
70a17d6675
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 17d723ecca | |||
| 9fc8c14445 | |||
| 2ae22b3492 | |||
| 3c624cc46d |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,4 +1,5 @@
|
|||||||
.agent-gitea/
|
.agent-gitea/
|
||||||
|
.agent-output/
|
||||||
.env
|
.env
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
.ruff_cache/
|
.ruff_cache/
|
||||||
|
|||||||
@@ -7,8 +7,29 @@ from .models import AgentResult
|
|||||||
|
|
||||||
|
|
||||||
class CommandRunner:
|
class CommandRunner:
|
||||||
def run(self, command: list[str], cwd: str | Path, *, stdin: str | None = None) -> AgentResult:
|
def run(
|
||||||
result = subprocess.run(command, cwd=cwd, input=stdin, text=True, capture_output=True, check=False)
|
self,
|
||||||
|
command: list[str],
|
||||||
|
cwd: str | Path,
|
||||||
|
*,
|
||||||
|
stdin: str | None = None,
|
||||||
|
timeout_seconds: int | float | None = None,
|
||||||
|
) -> AgentResult:
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
command,
|
||||||
|
cwd=cwd,
|
||||||
|
input=stdin,
|
||||||
|
text=True,
|
||||||
|
capture_output=True,
|
||||||
|
check=False,
|
||||||
|
timeout=timeout_seconds,
|
||||||
|
)
|
||||||
|
except subprocess.TimeoutExpired as exc:
|
||||||
|
stdout = exc.stdout if isinstance(exc.stdout, str) else (exc.stdout or b"").decode(errors="replace")
|
||||||
|
stderr = exc.stderr if isinstance(exc.stderr, str) else (exc.stderr or b"").decode(errors="replace")
|
||||||
|
message = f"command timed out after {timeout_seconds} seconds"
|
||||||
|
return AgentResult(exit_code=124, stdout=stdout or "", stderr=f"{stderr}\n{message}".strip())
|
||||||
return AgentResult(exit_code=result.returncode, stdout=result.stdout, stderr=result.stderr)
|
return AgentResult(exit_code=result.returncode, stdout=result.stdout, stderr=result.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,14 @@ class SchedulerConfig(BaseModel):
|
|||||||
interval_seconds: int = Field(default=60, ge=1)
|
interval_seconds: int = Field(default=60, ge=1)
|
||||||
concurrency: int = Field(default=1, ge=1)
|
concurrency: int = Field(default=1, ge=1)
|
||||||
lease_seconds: int = Field(default=1800, ge=30)
|
lease_seconds: int = Field(default=1800, ge=30)
|
||||||
|
lease_renewal_interval_seconds: int | None = Field(default=None, ge=1)
|
||||||
|
agent_timeout_seconds: int = Field(default=7200, ge=1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def effective_lease_renewal_interval_seconds(self) -> int:
|
||||||
|
if self.lease_renewal_interval_seconds is not None:
|
||||||
|
return self.lease_renewal_interval_seconds
|
||||||
|
return max(1, self.lease_seconds // 3)
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceConfig(BaseModel):
|
class WorkspaceConfig(BaseModel):
|
||||||
|
|||||||
@@ -292,12 +292,23 @@ class Database:
|
|||||||
SELECT t.*
|
SELECT t.*
|
||||||
FROM tasks t
|
FROM tasks t
|
||||||
JOIN issues i ON i.repo_id = t.repo_id AND i.issue_number = t.issue_number
|
JOIN issues i ON i.repo_id = t.repo_id AND i.issue_number = t.issue_number
|
||||||
WHERE t.state = ?
|
WHERE t.state != ?
|
||||||
AND t.pr_number IS NOT NULL
|
AND t.pr_number IS NOT NULL
|
||||||
AND i.state = 'open'
|
AND i.state = 'open'
|
||||||
|
AND (
|
||||||
|
t.state IN (?, ?)
|
||||||
|
OR t.lease_owner IS NULL
|
||||||
|
OR t.lease_expires_at IS NULL
|
||||||
|
OR t.lease_expires_at < ?
|
||||||
|
)
|
||||||
ORDER BY t.id
|
ORDER BY t.id
|
||||||
""",
|
""",
|
||||||
(TaskState.HUMAN_REVIEW_READY.value,),
|
(
|
||||||
|
TaskState.CANCELLED.value,
|
||||||
|
TaskState.HUMAN_REVIEW_READY.value,
|
||||||
|
TaskState.FAILED.value,
|
||||||
|
dt_to_db(utcnow()),
|
||||||
|
),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
return [self._task(row) for row in rows]
|
return [self._task(row) for row in rows]
|
||||||
|
|
||||||
@@ -430,6 +441,8 @@ class Database:
|
|||||||
def claim_next_task(self, worker_id: str, lease_seconds: int) -> TaskRecord | None:
|
def claim_next_task(self, worker_id: str, lease_seconds: int) -> TaskRecord | None:
|
||||||
now = utcnow()
|
now = utcnow()
|
||||||
expires = now + timedelta(seconds=lease_seconds)
|
expires = now + timedelta(seconds=lease_seconds)
|
||||||
|
try:
|
||||||
|
self.conn.execute("BEGIN IMMEDIATE")
|
||||||
row = self.conn.execute(
|
row = self.conn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT * FROM tasks
|
SELECT * FROM tasks
|
||||||
@@ -451,6 +464,7 @@ class Database:
|
|||||||
),
|
),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
if row is None:
|
if row is None:
|
||||||
|
self.conn.commit()
|
||||||
return None
|
return None
|
||||||
task = self._task(row)
|
task = self._task(row)
|
||||||
self.conn.execute(
|
self.conn.execute(
|
||||||
@@ -468,11 +482,37 @@ class Database:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self.conn.rollback()
|
||||||
|
raise
|
||||||
self.add_event(task.id, task.state, TaskState.CLAIMED, f"claimed by {worker_id}")
|
self.add_event(task.id, task.state, TaskState.CLAIMED, f"claimed by {worker_id}")
|
||||||
claimed = self.get_task(task.id)
|
claimed = self.get_task(task.id)
|
||||||
assert claimed is not None
|
assert claimed is not None
|
||||||
return claimed
|
return claimed
|
||||||
|
|
||||||
|
def renew_task_lease(self, task_id: int, worker_id: str, lease_seconds: int) -> bool:
|
||||||
|
now = utcnow()
|
||||||
|
expires = now + timedelta(seconds=lease_seconds)
|
||||||
|
placeholders = ",".join("?" for _ in ACTIVE_STATES)
|
||||||
|
cursor = self.conn.execute(
|
||||||
|
f"""
|
||||||
|
UPDATE tasks
|
||||||
|
SET lease_expires_at = ?, updated_at = ?
|
||||||
|
WHERE id = ?
|
||||||
|
AND lease_owner = ?
|
||||||
|
AND state IN ({placeholders})
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
dt_to_db(expires),
|
||||||
|
dt_to_db(now),
|
||||||
|
task_id,
|
||||||
|
worker_id,
|
||||||
|
*[state.value for state in ACTIVE_STATES],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
return cursor.rowcount == 1
|
||||||
|
|
||||||
def transition(
|
def transition(
|
||||||
self,
|
self,
|
||||||
task_id: int,
|
task_id: int,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -10,7 +11,7 @@ from .agents import CommandRunner, read_report, render_command, write_prompt
|
|||||||
from .config import AppConfig
|
from .config import AppConfig
|
||||||
from .db import Database, PullRequestFeedbackCursor
|
from .db import Database, PullRequestFeedbackCursor
|
||||||
from .gitea import GiteaClient, GiteaComment, GiteaPullReview
|
from .gitea import GiteaClient, GiteaComment, GiteaPullReview
|
||||||
from .models import IssueRecord, RepositoryRecord, TaskRecord, TaskState
|
from .models import ACTIVE_STATES, IssueRecord, RepositoryRecord, TaskRecord, TaskState
|
||||||
from .rendering import (
|
from .rendering import (
|
||||||
parse_review_report,
|
parse_review_report,
|
||||||
render_human_review_summary,
|
render_human_review_summary,
|
||||||
@@ -34,6 +35,46 @@ class PullRequestFeedbackSnapshot:
|
|||||||
newest_cursor: PullRequestFeedbackCursor
|
newest_cursor: PullRequestFeedbackCursor
|
||||||
|
|
||||||
|
|
||||||
|
class TaskLeaseRenewer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db_path: Path,
|
||||||
|
task_id: int,
|
||||||
|
worker_id: str,
|
||||||
|
lease_seconds: int,
|
||||||
|
interval_seconds: int,
|
||||||
|
):
|
||||||
|
self.db_path = db_path
|
||||||
|
self.task_id = task_id
|
||||||
|
self.worker_id = worker_id
|
||||||
|
self.lease_seconds = lease_seconds
|
||||||
|
self.interval_seconds = interval_seconds
|
||||||
|
self._stop = threading.Event()
|
||||||
|
self._thread = threading.Thread(target=self._run, name=f"lease-renewer-{task_id}", daemon=True)
|
||||||
|
|
||||||
|
def __enter__(self) -> TaskLeaseRenewer:
|
||||||
|
self._thread.start()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb) -> None:
|
||||||
|
self._stop.set()
|
||||||
|
self._thread.join(timeout=max(1, self.interval_seconds))
|
||||||
|
|
||||||
|
def _run(self) -> None:
|
||||||
|
database = Database(self.db_path)
|
||||||
|
try:
|
||||||
|
while not self._stop.wait(self.interval_seconds):
|
||||||
|
renewed = database.renew_task_lease(self.task_id, self.worker_id, self.lease_seconds)
|
||||||
|
if not renewed:
|
||||||
|
logger.warning("stopping lease renewal for task %d; lease no longer belongs to %s", self.task_id, self.worker_id)
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
logger.exception("lease renewal failed for task %d", self.task_id)
|
||||||
|
finally:
|
||||||
|
database.close()
|
||||||
|
|
||||||
|
|
||||||
def sync_repositories(db: Database, config: AppConfig, client: GiteaClient) -> list[RepositoryRecord]:
|
def sync_repositories(db: Database, config: AppConfig, client: GiteaClient) -> list[RepositoryRecord]:
|
||||||
synced: list[RepositoryRecord] = []
|
synced: list[RepositoryRecord] = []
|
||||||
discovered = client.list_owned_repositories()
|
discovered = client.list_owned_repositories()
|
||||||
@@ -92,12 +133,12 @@ def close_issues_for_merged_pull_requests(db: Database, client: GiteaClient) ->
|
|||||||
)
|
)
|
||||||
client.close_issue(owner=repo.owner, name=repo.name, issue_number=issue.issue_number)
|
client.close_issue(owner=repo.owner, name=repo.name, issue_number=issue.issue_number)
|
||||||
db.update_issue_state(task.repo_id, task.issue_number, "closed")
|
db.update_issue_state(task.repo_id, task.issue_number, "closed")
|
||||||
db.add_event(
|
message = f"closed issue #{issue.issue_number} after merged PR #{task.pr_number}"
|
||||||
task.id,
|
if task.state in ACTIVE_STATES:
|
||||||
task.state,
|
db.clear_pr_feedback_pending(task.id)
|
||||||
task.state,
|
db.transition(task.id, TaskState.CANCELLED, message=message, clear_lease=True)
|
||||||
f"closed issue #{issue.issue_number} after merged PR #{task.pr_number}",
|
else:
|
||||||
)
|
db.add_event(task.id, task.state, task.state, message)
|
||||||
closed += 1
|
closed += 1
|
||||||
return closed
|
return closed
|
||||||
|
|
||||||
@@ -388,7 +429,7 @@ class TaskRunner:
|
|||||||
issue_title=issue.title,
|
issue_title=issue.title,
|
||||||
branch_name=branch_name,
|
branch_name=branch_name,
|
||||||
)
|
)
|
||||||
result = self.command_runner.run(command, workspace, stdin=prompt)
|
result = self._run_agent_command(task, command, workspace, prompt)
|
||||||
report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md")
|
report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md")
|
||||||
self.db.add_agent_run(
|
self.db.add_agent_run(
|
||||||
task_id=task.id,
|
task_id=task.id,
|
||||||
@@ -428,7 +469,7 @@ class TaskRunner:
|
|||||||
pr_number=pr_number,
|
pr_number=pr_number,
|
||||||
branch_name=branch_name,
|
branch_name=branch_name,
|
||||||
)
|
)
|
||||||
result = self.command_runner.run(command, workspace, stdin=prompt)
|
result = self._run_agent_command(task, command, workspace, prompt)
|
||||||
report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md")
|
report = read_report(output_dir / "AGENT_IMPLEMENTATION_REPORT.md")
|
||||||
self.db.add_agent_run(
|
self.db.add_agent_run(
|
||||||
task_id=task.id,
|
task_id=task.id,
|
||||||
@@ -465,7 +506,7 @@ class TaskRunner:
|
|||||||
issue_title=issue.title,
|
issue_title=issue.title,
|
||||||
pr_number=pr_number,
|
pr_number=pr_number,
|
||||||
)
|
)
|
||||||
result = self.command_runner.run(command, workspace, stdin=prompt)
|
result = self._run_agent_command(task, command, workspace, prompt)
|
||||||
report = read_report(output_dir / "AGENT_REVIEW_REPORT.md")
|
report = read_report(output_dir / "AGENT_REVIEW_REPORT.md")
|
||||||
self.db.add_agent_run(
|
self.db.add_agent_run(
|
||||||
task_id=task.id,
|
task_id=task.id,
|
||||||
@@ -481,6 +522,21 @@ class TaskRunner:
|
|||||||
raise RuntimeError(f"reviewer failed with exit code {result.exit_code}")
|
raise RuntimeError(f"reviewer failed with exit code {result.exit_code}")
|
||||||
return report
|
return report
|
||||||
|
|
||||||
|
def _run_agent_command(self, task: TaskRecord, command: list[str], workspace: Path, prompt: str) -> AgentResult:
|
||||||
|
with TaskLeaseRenewer(
|
||||||
|
db_path=self.db.path,
|
||||||
|
task_id=task.id,
|
||||||
|
worker_id=self.worker_id,
|
||||||
|
lease_seconds=self.config.scheduler.lease_seconds,
|
||||||
|
interval_seconds=self.config.scheduler.effective_lease_renewal_interval_seconds,
|
||||||
|
):
|
||||||
|
return self.command_runner.run(
|
||||||
|
command,
|
||||||
|
workspace,
|
||||||
|
stdin=prompt,
|
||||||
|
timeout_seconds=self.config.scheduler.agent_timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
def _load_context(self, task: TaskRecord) -> tuple[RepositoryRecord, IssueRecord]:
|
def _load_context(self, task: TaskRecord) -> tuple[RepositoryRecord, IssueRecord]:
|
||||||
return load_task_context(self.db, task)
|
return load_task_context(self.db, task)
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from agent_gitea.agents import CommandRunner
|
||||||
from agent_gitea.config import AppConfig
|
from agent_gitea.config import AppConfig
|
||||||
|
from agent_gitea.db import Database
|
||||||
from agent_gitea.gitea import GiteaClient
|
from agent_gitea.gitea import GiteaClient
|
||||||
from agent_gitea.models import AgentResult, TaskState
|
from agent_gitea.models import AgentResult, TaskState
|
||||||
from agent_gitea.service import (
|
from agent_gitea.service import (
|
||||||
@@ -208,7 +213,7 @@ class FakeRunner:
|
|||||||
def __init__(self, *, fail_role: str | None = None):
|
def __init__(self, *, fail_role: str | None = None):
|
||||||
self.fail_role = fail_role
|
self.fail_role = fail_role
|
||||||
|
|
||||||
def run(self, command, cwd, *, stdin=None):
|
def run(self, command, cwd, *, stdin=None, timeout_seconds=None):
|
||||||
role = command[0]
|
role = command[0]
|
||||||
assert stdin
|
assert stdin
|
||||||
if role == self.fail_role:
|
if role == self.fail_role:
|
||||||
@@ -250,6 +255,126 @@ def seed_task(db):
|
|||||||
return db.create_task(repo.id, 1)
|
return db.create_task(repo.id, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_claim_next_task_allows_only_one_worker_during_race(db, tmp_path):
|
||||||
|
seed_task(db)
|
||||||
|
db.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000))
|
||||||
|
db.conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TRIGGER slow_claim_update
|
||||||
|
BEFORE UPDATE OF state ON tasks
|
||||||
|
WHEN NEW.state = 'CLAIMED'
|
||||||
|
BEGIN
|
||||||
|
SELECT sleep_ms(150);
|
||||||
|
END
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
db.conn.commit()
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
barrier = threading.Barrier(2)
|
||||||
|
results: dict[str, int | None] = {}
|
||||||
|
errors: list[BaseException] = []
|
||||||
|
|
||||||
|
def claim(worker_id: str) -> None:
|
||||||
|
database = Database(tmp_path / "state.sqlite3")
|
||||||
|
database.conn.create_function("sleep_ms", 1, lambda ms: time.sleep(ms / 1000))
|
||||||
|
try:
|
||||||
|
barrier.wait()
|
||||||
|
task = database.claim_next_task(worker_id, 60)
|
||||||
|
results[worker_id] = task.id if task else None
|
||||||
|
except BaseException as exc: # pragma: no cover - assertion below reports thread failures
|
||||||
|
errors.append(exc)
|
||||||
|
finally:
|
||||||
|
database.close()
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=claim, args=(worker_id,)) for worker_id in ("worker-a", "worker-b")]
|
||||||
|
for thread in threads:
|
||||||
|
thread.start()
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
assert errors == []
|
||||||
|
assert list(results.values()).count(1) == 1
|
||||||
|
assert list(results.values()).count(None) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_command_runner_returns_timeout_result(tmp_path):
|
||||||
|
result = CommandRunner().run(
|
||||||
|
[sys.executable, "-c", "import time; time.sleep(5)"],
|
||||||
|
tmp_path,
|
||||||
|
timeout_seconds=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 124
|
||||||
|
assert "timed out" in result.stderr
|
||||||
|
|
||||||
|
|
||||||
|
class ObservingSlowRunner:
|
||||||
|
def __init__(self, db, task_id: int):
|
||||||
|
self.db = db
|
||||||
|
self.task_id = task_id
|
||||||
|
self.lease_before = None
|
||||||
|
self.lease_during = None
|
||||||
|
|
||||||
|
def run(self, command, cwd, *, stdin=None, timeout_seconds=None):
|
||||||
|
role = command[0]
|
||||||
|
if role == "implementer":
|
||||||
|
self.lease_before = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr]
|
||||||
|
time.sleep(1.2)
|
||||||
|
self.lease_during = self.db.get_task(self.task_id).lease_expires_at # type: ignore[union-attr]
|
||||||
|
output_dir = Path(cwd) / ".agent-output"
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
(output_dir / "AGENT_IMPLEMENTATION_REPORT.md").write_text(
|
||||||
|
"## Summary\nImplemented\n\n## Test commands run\npytest\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
if role == "reviewer":
|
||||||
|
output_dir = Path(cwd) / ".agent-output"
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
(output_dir / "AGENT_REVIEW_REPORT.md").write_text(
|
||||||
|
"## Verdict\nAPPROVE\n\n## Suggested PR Comment\nLooks good.\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return AgentResult(exit_code=0, stdout="ok", stderr="")
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_runner_renews_lease_while_agent_runs(db, tmp_path):
|
||||||
|
config = make_config(
|
||||||
|
tmp_path,
|
||||||
|
scheduler={
|
||||||
|
"interval_seconds": 1,
|
||||||
|
"concurrency": 1,
|
||||||
|
"lease_seconds": 60,
|
||||||
|
"lease_renewal_interval_seconds": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
task = seed_task(db)
|
||||||
|
runner = ObservingSlowRunner(db, task.id)
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
payload = json.loads(request.content.decode() or "{}")
|
||||||
|
if request.url.path == "/api/v1/repos/acme/service/pulls":
|
||||||
|
return httpx.Response(201, json={"number": 5, "state": "open", "merged": False})
|
||||||
|
if request.url.path == "/api/v1/repos/acme/service/issues/5/comments":
|
||||||
|
return httpx.Response(201, json={"id": 1, "body": payload.get("body", ""), "user": {"login": "agent-bot"}})
|
||||||
|
return httpx.Response(404)
|
||||||
|
|
||||||
|
finished = TaskRunner(
|
||||||
|
db=db,
|
||||||
|
config=config,
|
||||||
|
gitea=make_client(handler),
|
||||||
|
workspace_manager=FakeWorkspaceManager(tmp_path / "work"),
|
||||||
|
command_runner=runner,
|
||||||
|
worker_id="worker",
|
||||||
|
).run_once()
|
||||||
|
|
||||||
|
assert finished is not None
|
||||||
|
assert finished.state == TaskState.HUMAN_REVIEW_READY
|
||||||
|
assert runner.lease_before is not None
|
||||||
|
assert runner.lease_during is not None
|
||||||
|
assert runner.lease_during > runner.lease_before
|
||||||
|
|
||||||
|
|
||||||
def transition_to_human_review_ready(db, task_id: int, *, pr_number: int = 5, branch_name: str | None = None):
|
def transition_to_human_review_ready(db, task_id: int, *, pr_number: int = 5, branch_name: str | None = None):
|
||||||
db.transition(task_id, TaskState.CLAIMED)
|
db.transition(task_id, TaskState.CLAIMED)
|
||||||
db.transition(task_id, TaskState.PLANNING)
|
db.transition(task_id, TaskState.PLANNING)
|
||||||
@@ -656,6 +781,56 @@ def test_close_issues_for_merged_pull_requests_skips_unmerged_pr(db):
|
|||||||
assert db.get_issue(repo.id, 1).state == "open" # type: ignore[union-attr]
|
assert db.get_issue(repo.id, 1).state == "open" # type: ignore[union-attr]
|
||||||
|
|
||||||
|
|
||||||
|
def test_close_issues_for_merged_pull_requests_handles_queued_feedback_task(db):
|
||||||
|
repo = db.upsert_repository(
|
||||||
|
owner="acme",
|
||||||
|
name="service",
|
||||||
|
clone_url="https://gitea.test/acme/service.git",
|
||||||
|
default_branch="main",
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
db.upsert_issue(
|
||||||
|
repo_id=repo.id,
|
||||||
|
issue_number=1,
|
||||||
|
title="Ready issue",
|
||||||
|
body="Body",
|
||||||
|
labels=["agent:ready"],
|
||||||
|
state="open",
|
||||||
|
html_url="https://gitea.test/acme/service/issues/1",
|
||||||
|
)
|
||||||
|
task = db.create_task(repo.id, 1)
|
||||||
|
task = transition_to_human_review_ready(db, task.id, pr_number=5, branch_name="agent/issue-1-ready-issue")
|
||||||
|
db.mark_pr_feedback_pending(task.id)
|
||||||
|
db.transition(
|
||||||
|
task.id,
|
||||||
|
TaskState.DISCOVERED,
|
||||||
|
message="queued PR feedback from 1 human comment(s)",
|
||||||
|
clear_lease=True,
|
||||||
|
)
|
||||||
|
requests: list[tuple[str, str, dict]] = []
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
payload = json.loads(request.content.decode() or "{}")
|
||||||
|
requests.append((request.method, request.url.path, payload))
|
||||||
|
if request.url.path == "/api/v1/repos/acme/service/pulls/5":
|
||||||
|
return httpx.Response(200, json={"number": 5, "state": "closed", "merged": True})
|
||||||
|
if request.url.path == "/api/v1/repos/acme/service/issues/1/comments":
|
||||||
|
return httpx.Response(201, json={"id": 1})
|
||||||
|
if request.url.path == "/api/v1/repos/acme/service/issues/1":
|
||||||
|
return httpx.Response(200, json={"number": 1, "state": "closed"})
|
||||||
|
return httpx.Response(404)
|
||||||
|
|
||||||
|
closed = close_issues_for_merged_pull_requests(db, make_client(handler))
|
||||||
|
|
||||||
|
updated_task = db.get_task(task.id)
|
||||||
|
assert closed == 1
|
||||||
|
assert db.get_issue(repo.id, 1).state == "closed" # type: ignore[union-attr]
|
||||||
|
assert updated_task is not None
|
||||||
|
assert updated_task.state == TaskState.CANCELLED
|
||||||
|
assert not db.has_pending_pr_feedback(task.id)
|
||||||
|
assert ("PATCH", "/api/v1/repos/acme/service/issues/1", {"state": "closed"}) in requests
|
||||||
|
|
||||||
|
|
||||||
def test_run_task_no_diff_becomes_blocked(db, tmp_path):
|
def test_run_task_no_diff_becomes_blocked(db, tmp_path):
|
||||||
config = make_config(tmp_path)
|
config = make_config(tmp_path)
|
||||||
seed_task(db)
|
seed_task(db)
|
||||||
|
|||||||
Reference in New Issue
Block a user