feat: close issues after merged agent PRs
This commit is contained in:
@@ -195,6 +195,17 @@ class Database:
|
||||
).fetchall()
|
||||
return [self._issue(row) for row in rows]
|
||||
|
||||
def update_issue_state(self, repo_id: int, issue_number: int, state: str) -> None:
|
||||
self.conn.execute(
|
||||
"""
|
||||
UPDATE issues
|
||||
SET state = ?, updated_at = ?
|
||||
WHERE repo_id = ? AND issue_number = ?
|
||||
""",
|
||||
(state, dt_to_db(utcnow()), repo_id, issue_number),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def active_task_for_issue(self, repo_id: int, issue_number: int) -> TaskRecord | None:
|
||||
placeholders = ",".join("?" for _ in ACTIVE_STATES)
|
||||
rows = self.conn.execute(
|
||||
@@ -245,6 +256,21 @@ class Database:
|
||||
).fetchone()
|
||||
return self._task(row) if row else None
|
||||
|
||||
def list_tasks_pending_issue_close(self) -> list[TaskRecord]:
|
||||
rows = self.conn.execute(
|
||||
"""
|
||||
SELECT t.*
|
||||
FROM tasks t
|
||||
JOIN issues i ON i.repo_id = t.repo_id AND i.issue_number = t.issue_number
|
||||
WHERE t.state = ?
|
||||
AND t.pr_number IS NOT NULL
|
||||
AND i.state = 'open'
|
||||
ORDER BY t.id
|
||||
""",
|
||||
(TaskState.HUMAN_REVIEW_READY.value,),
|
||||
).fetchall()
|
||||
return [self._task(row) for row in rows]
|
||||
|
||||
def claim_next_task(self, worker_id: str, lease_seconds: int) -> TaskRecord | None:
|
||||
now = utcnow()
|
||||
expires = now + timedelta(seconds=lease_seconds)
|
||||
|
||||
@@ -23,6 +23,8 @@ class GiteaIssue:
|
||||
class GiteaPullRequest:
|
||||
number: int
|
||||
html_url: str
|
||||
state: str
|
||||
merged: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -131,7 +133,12 @@ class GiteaClient:
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
return GiteaPullRequest(number=int(payload["number"]), html_url=payload.get("html_url") or "")
|
||||
return pull_request_from_payload(payload)
|
||||
|
||||
def get_pull_request(self, *, owner: str, name: str, pr_number: int) -> GiteaPullRequest:
|
||||
response = self.client.get(f"/repos/{owner}/{name}/pulls/{pr_number}")
|
||||
response.raise_for_status()
|
||||
return pull_request_from_payload(response.json())
|
||||
|
||||
def post_issue_comment(self, *, owner: str, name: str, issue_number: int, body: str) -> None:
|
||||
response = self.client.post(
|
||||
@@ -140,6 +147,13 @@ class GiteaClient:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
def close_issue(self, *, owner: str, name: str, issue_number: int) -> None:
|
||||
response = self.client.patch(
|
||||
f"/repos/{owner}/{name}/issues/{issue_number}",
|
||||
json={"state": "closed"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def clone_url_from_repo_payload(payload: dict[str, Any], fallback_base_url: str, owner: str, name: str) -> str:
|
||||
return (
|
||||
@@ -168,3 +182,13 @@ def repository_from_payload(payload: dict[str, Any], fallback_base_url: str) ->
|
||||
clone_url=clone_url_from_repo_payload(payload, fallback_base_url, str(owner), str(name)),
|
||||
default_branch=payload.get("default_branch") or "main",
|
||||
)
|
||||
|
||||
|
||||
def pull_request_from_payload(payload: dict[str, Any]) -> GiteaPullRequest:
|
||||
merged = bool(payload.get("merged") or payload.get("has_merged") or payload.get("merged_at"))
|
||||
return GiteaPullRequest(
|
||||
number=int(payload["number"]),
|
||||
html_url=payload.get("html_url") or payload.get("url") or "",
|
||||
state=payload.get("state") or "",
|
||||
merged=merged,
|
||||
)
|
||||
|
||||
@@ -74,6 +74,8 @@ but write the section content and Suggested PR Comment in Chinese:
|
||||
def render_pr_body(issue: IssueRecord, implementation_report: str) -> str:
|
||||
return f"""关联 Issue:#{issue.issue_number}
|
||||
|
||||
合并后自动关闭:Closes #{issue.issue_number}
|
||||
|
||||
## 代理实现报告
|
||||
|
||||
{implementation_report.strip()}
|
||||
|
||||
@@ -67,6 +67,32 @@ def scan_issues(db: Database, config: AppConfig, client: GiteaClient) -> list[in
|
||||
return scan_eligible_issues(db, config.labels)
|
||||
|
||||
|
||||
def close_issues_for_merged_pull_requests(db: Database, client: GiteaClient) -> int:
|
||||
closed = 0
|
||||
for task in db.list_tasks_pending_issue_close():
|
||||
repo, issue = load_task_context(db, task)
|
||||
assert task.pr_number is not None
|
||||
pull_request = client.get_pull_request(owner=repo.owner, name=repo.name, pr_number=task.pr_number)
|
||||
if not pull_request.merged:
|
||||
continue
|
||||
client.post_issue_comment(
|
||||
owner=repo.owner,
|
||||
name=repo.name,
|
||||
issue_number=issue.issue_number,
|
||||
body=f"关联 PR #{task.pr_number} 已合并,agent-manager 自动关闭该 issue。",
|
||||
)
|
||||
client.close_issue(owner=repo.owner, name=repo.name, issue_number=issue.issue_number)
|
||||
db.update_issue_state(task.repo_id, task.issue_number, "closed")
|
||||
db.add_event(
|
||||
task.id,
|
||||
task.state,
|
||||
task.state,
|
||||
f"closed issue #{issue.issue_number} after merged PR #{task.pr_number}",
|
||||
)
|
||||
closed += 1
|
||||
return closed
|
||||
|
||||
|
||||
class TaskRunner:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -96,6 +122,8 @@ class TaskRunner:
|
||||
try:
|
||||
repos = sync_repositories(self.db, self.config, self.gitea)
|
||||
logger.info("synced %d repositories", len(repos))
|
||||
closed = close_issues_for_merged_pull_requests(self.db, self.gitea)
|
||||
logger.info("closed %d issues for merged pull requests", closed)
|
||||
task_ids = scan_issues(self.db, self.config, self.gitea)
|
||||
logger.info("created %d tasks from issue scan", len(task_ids))
|
||||
task = self.run_once()
|
||||
@@ -260,11 +288,15 @@ class TaskRunner:
|
||||
return report
|
||||
|
||||
def _load_context(self, task: TaskRecord) -> tuple[RepositoryRecord, IssueRecord]:
|
||||
repo_row = self.db.conn.execute("SELECT * FROM repositories WHERE id = ?", (task.repo_id,)).fetchone()
|
||||
if repo_row is None:
|
||||
raise ValueError(f"repository not found for task {task.id}")
|
||||
repo = self.db._repo(repo_row)
|
||||
issue = self.db.get_issue(task.repo_id, task.issue_number)
|
||||
if issue is None:
|
||||
raise ValueError(f"issue not found for task {task.id}")
|
||||
return repo, issue
|
||||
return load_task_context(self.db, task)
|
||||
|
||||
|
||||
def load_task_context(db: Database, task: TaskRecord) -> tuple[RepositoryRecord, IssueRecord]:
|
||||
repo_row = db.conn.execute("SELECT * FROM repositories WHERE id = ?", (task.repo_id,)).fetchone()
|
||||
if repo_row is None:
|
||||
raise ValueError(f"repository not found for task {task.id}")
|
||||
repo = db._repo(repo_row)
|
||||
issue = db.get_issue(task.repo_id, task.issue_number)
|
||||
if issue is None:
|
||||
raise ValueError(f"issue not found for task {task.id}")
|
||||
return repo, issue
|
||||
|
||||
@@ -8,7 +8,7 @@ import httpx
|
||||
from agent_gitea.config import AppConfig
|
||||
from agent_gitea.gitea import GiteaClient
|
||||
from agent_gitea.models import AgentResult, TaskState
|
||||
from agent_gitea.service import TaskRunner, scan_issues, sync_repositories
|
||||
from agent_gitea.service import TaskRunner, close_issues_for_merged_pull_requests, scan_issues, sync_repositories
|
||||
|
||||
|
||||
def make_config(tmp_path: Path, **overrides: object) -> AppConfig:
|
||||
@@ -265,12 +265,94 @@ def test_run_task_success_posts_review_comments(db, tmp_path):
|
||||
pull_requests = [payload for _, path, payload in requests if path == "/api/v1/repos/acme/service/pulls"]
|
||||
assert pull_requests[0]["title"] == "代理实现:Ready issue"
|
||||
assert "代理实现报告" in pull_requests[0]["body"]
|
||||
assert "Closes #1" in pull_requests[0]["body"]
|
||||
command = json.loads(db.list_agent_runs(task.id)[0]["command_json"])
|
||||
assert command[1] == "--cd"
|
||||
assert Path(command[2]).is_absolute()
|
||||
assert [path for _, path, _ in requests].count("/api/v1/repos/acme/service/issues/5/comments") == 2
|
||||
|
||||
|
||||
def test_close_issues_for_merged_pull_requests_closes_linked_issue(db):
|
||||
repo = db.upsert_repository(
|
||||
owner="acme",
|
||||
name="service",
|
||||
clone_url="https://gitea.test/acme/service.git",
|
||||
default_branch="main",
|
||||
enabled=True,
|
||||
)
|
||||
db.upsert_issue(
|
||||
repo_id=repo.id,
|
||||
issue_number=1,
|
||||
title="Ready issue",
|
||||
body="Body",
|
||||
labels=["agent:ready"],
|
||||
state="open",
|
||||
html_url="https://gitea.test/acme/service/issues/1",
|
||||
)
|
||||
task = db.create_task(repo.id, 1)
|
||||
db.transition(task.id, TaskState.CLAIMED)
|
||||
db.transition(task.id, TaskState.PLANNING)
|
||||
db.transition(task.id, TaskState.IMPLEMENTING)
|
||||
db.transition(task.id, TaskState.TESTING)
|
||||
db.transition(task.id, TaskState.PR_OPENED, pr_number=5)
|
||||
db.transition(task.id, TaskState.REVIEWING)
|
||||
db.transition(task.id, TaskState.HUMAN_REVIEW_READY, clear_lease=True)
|
||||
requests: list[tuple[str, str, dict]] = []
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
payload = json.loads(request.content.decode() or "{}")
|
||||
requests.append((request.method, request.url.path, payload))
|
||||
if request.url.path == "/api/v1/repos/acme/service/pulls/5":
|
||||
return httpx.Response(200, json={"number": 5, "state": "closed", "merged": True})
|
||||
if request.url.path == "/api/v1/repos/acme/service/issues/1/comments":
|
||||
return httpx.Response(201, json={"id": 1})
|
||||
if request.url.path == "/api/v1/repos/acme/service/issues/1":
|
||||
return httpx.Response(200, json={"number": 1, "state": "closed"})
|
||||
return httpx.Response(404)
|
||||
|
||||
closed = close_issues_for_merged_pull_requests(db, make_client(handler))
|
||||
|
||||
assert closed == 1
|
||||
assert db.get_issue(repo.id, 1).state == "closed" # type: ignore[union-attr]
|
||||
assert ("PATCH", "/api/v1/repos/acme/service/issues/1", {"state": "closed"}) in requests
|
||||
|
||||
|
||||
def test_close_issues_for_merged_pull_requests_skips_unmerged_pr(db):
|
||||
repo = db.upsert_repository(
|
||||
owner="acme",
|
||||
name="service",
|
||||
clone_url="https://gitea.test/acme/service.git",
|
||||
default_branch="main",
|
||||
enabled=True,
|
||||
)
|
||||
db.upsert_issue(
|
||||
repo_id=repo.id,
|
||||
issue_number=1,
|
||||
title="Ready issue",
|
||||
body="Body",
|
||||
labels=["agent:ready"],
|
||||
state="open",
|
||||
html_url="https://gitea.test/acme/service/issues/1",
|
||||
)
|
||||
task = db.create_task(repo.id, 1)
|
||||
db.transition(task.id, TaskState.CLAIMED)
|
||||
db.transition(task.id, TaskState.PLANNING)
|
||||
db.transition(task.id, TaskState.IMPLEMENTING)
|
||||
db.transition(task.id, TaskState.TESTING)
|
||||
db.transition(task.id, TaskState.PR_OPENED, pr_number=5)
|
||||
db.transition(task.id, TaskState.REVIEWING)
|
||||
db.transition(task.id, TaskState.HUMAN_REVIEW_READY, clear_lease=True)
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
assert request.url.path == "/api/v1/repos/acme/service/pulls/5"
|
||||
return httpx.Response(200, json={"number": 5, "state": "open", "merged": False})
|
||||
|
||||
closed = close_issues_for_merged_pull_requests(db, make_client(handler))
|
||||
|
||||
assert closed == 0
|
||||
assert db.get_issue(repo.id, 1).state == "open" # type: ignore[union-attr]
|
||||
|
||||
|
||||
def test_run_task_no_diff_becomes_blocked(db, tmp_path):
|
||||
config = make_config(tmp_path)
|
||||
seed_task(db)
|
||||
|
||||
@@ -40,6 +40,7 @@ def test_prompt_and_pr_body_include_contract_sections(db):
|
||||
|
||||
assert ".agent-output/AGENT_IMPLEMENTATION_REPORT.md" in prompt
|
||||
assert "关联 Issue:#7" in body
|
||||
assert "Closes #7" in body
|
||||
assert "人工审核" in body
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user