feat: close issues after merged agent PRs

This commit is contained in:
2026-05-06 16:19:38 +08:00
parent 6d1a6d037e
commit aa8afa1a63
6 changed files with 177 additions and 10 deletions

View File

@@ -195,6 +195,17 @@ class Database:
).fetchall() ).fetchall()
return [self._issue(row) for row in rows] return [self._issue(row) for row in rows]
def update_issue_state(self, repo_id: int, issue_number: int, state: str) -> None:
self.conn.execute(
"""
UPDATE issues
SET state = ?, updated_at = ?
WHERE repo_id = ? AND issue_number = ?
""",
(state, dt_to_db(utcnow()), repo_id, issue_number),
)
self.conn.commit()
def active_task_for_issue(self, repo_id: int, issue_number: int) -> TaskRecord | None: def active_task_for_issue(self, repo_id: int, issue_number: int) -> TaskRecord | None:
placeholders = ",".join("?" for _ in ACTIVE_STATES) placeholders = ",".join("?" for _ in ACTIVE_STATES)
rows = self.conn.execute( rows = self.conn.execute(
@@ -245,6 +256,21 @@ class Database:
).fetchone() ).fetchone()
return self._task(row) if row else None return self._task(row) if row else None
def list_tasks_pending_issue_close(self) -> list[TaskRecord]:
rows = self.conn.execute(
"""
SELECT t.*
FROM tasks t
JOIN issues i ON i.repo_id = t.repo_id AND i.issue_number = t.issue_number
WHERE t.state = ?
AND t.pr_number IS NOT NULL
AND i.state = 'open'
ORDER BY t.id
""",
(TaskState.HUMAN_REVIEW_READY.value,),
).fetchall()
return [self._task(row) for row in rows]
def claim_next_task(self, worker_id: str, lease_seconds: int) -> TaskRecord | None: def claim_next_task(self, worker_id: str, lease_seconds: int) -> TaskRecord | None:
now = utcnow() now = utcnow()
expires = now + timedelta(seconds=lease_seconds) expires = now + timedelta(seconds=lease_seconds)

View File

@@ -23,6 +23,8 @@ class GiteaIssue:
class GiteaPullRequest: class GiteaPullRequest:
number: int number: int
html_url: str html_url: str
state: str
merged: bool
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -131,7 +133,12 @@ class GiteaClient:
) )
response.raise_for_status() response.raise_for_status()
payload = response.json() payload = response.json()
return GiteaPullRequest(number=int(payload["number"]), html_url=payload.get("html_url") or "") return pull_request_from_payload(payload)
def get_pull_request(self, *, owner: str, name: str, pr_number: int) -> GiteaPullRequest:
response = self.client.get(f"/repos/{owner}/{name}/pulls/{pr_number}")
response.raise_for_status()
return pull_request_from_payload(response.json())
def post_issue_comment(self, *, owner: str, name: str, issue_number: int, body: str) -> None: def post_issue_comment(self, *, owner: str, name: str, issue_number: int, body: str) -> None:
response = self.client.post( response = self.client.post(
@@ -140,6 +147,13 @@ class GiteaClient:
) )
response.raise_for_status() response.raise_for_status()
def close_issue(self, *, owner: str, name: str, issue_number: int) -> None:
response = self.client.patch(
f"/repos/{owner}/{name}/issues/{issue_number}",
json={"state": "closed"},
)
response.raise_for_status()
def clone_url_from_repo_payload(payload: dict[str, Any], fallback_base_url: str, owner: str, name: str) -> str: def clone_url_from_repo_payload(payload: dict[str, Any], fallback_base_url: str, owner: str, name: str) -> str:
return ( return (
@@ -168,3 +182,13 @@ def repository_from_payload(payload: dict[str, Any], fallback_base_url: str) ->
clone_url=clone_url_from_repo_payload(payload, fallback_base_url, str(owner), str(name)), clone_url=clone_url_from_repo_payload(payload, fallback_base_url, str(owner), str(name)),
default_branch=payload.get("default_branch") or "main", default_branch=payload.get("default_branch") or "main",
) )
def pull_request_from_payload(payload: dict[str, Any]) -> GiteaPullRequest:
merged = bool(payload.get("merged") or payload.get("has_merged") or payload.get("merged_at"))
return GiteaPullRequest(
number=int(payload["number"]),
html_url=payload.get("html_url") or payload.get("url") or "",
state=payload.get("state") or "",
merged=merged,
)

View File

@@ -74,6 +74,8 @@ but write the section content and Suggested PR Comment in Chinese:
def render_pr_body(issue: IssueRecord, implementation_report: str) -> str: def render_pr_body(issue: IssueRecord, implementation_report: str) -> str:
return f"""关联 Issue#{issue.issue_number} return f"""关联 Issue#{issue.issue_number}
合并后自动关闭Closes #{issue.issue_number}
## 代理实现报告 ## 代理实现报告
{implementation_report.strip()} {implementation_report.strip()}

View File

@@ -67,6 +67,32 @@ def scan_issues(db: Database, config: AppConfig, client: GiteaClient) -> list[in
return scan_eligible_issues(db, config.labels) return scan_eligible_issues(db, config.labels)
def close_issues_for_merged_pull_requests(db: Database, client: GiteaClient) -> int:
closed = 0
for task in db.list_tasks_pending_issue_close():
repo, issue = load_task_context(db, task)
assert task.pr_number is not None
pull_request = client.get_pull_request(owner=repo.owner, name=repo.name, pr_number=task.pr_number)
if not pull_request.merged:
continue
client.post_issue_comment(
owner=repo.owner,
name=repo.name,
issue_number=issue.issue_number,
body=f"关联 PR #{task.pr_number} 已合并agent-manager 自动关闭该 issue。",
)
client.close_issue(owner=repo.owner, name=repo.name, issue_number=issue.issue_number)
db.update_issue_state(task.repo_id, task.issue_number, "closed")
db.add_event(
task.id,
task.state,
task.state,
f"closed issue #{issue.issue_number} after merged PR #{task.pr_number}",
)
closed += 1
return closed
class TaskRunner: class TaskRunner:
def __init__( def __init__(
self, self,
@@ -96,6 +122,8 @@ class TaskRunner:
try: try:
repos = sync_repositories(self.db, self.config, self.gitea) repos = sync_repositories(self.db, self.config, self.gitea)
logger.info("synced %d repositories", len(repos)) logger.info("synced %d repositories", len(repos))
closed = close_issues_for_merged_pull_requests(self.db, self.gitea)
logger.info("closed %d issues for merged pull requests", closed)
task_ids = scan_issues(self.db, self.config, self.gitea) task_ids = scan_issues(self.db, self.config, self.gitea)
logger.info("created %d tasks from issue scan", len(task_ids)) logger.info("created %d tasks from issue scan", len(task_ids))
task = self.run_once() task = self.run_once()
@@ -260,11 +288,15 @@ class TaskRunner:
return report return report
def _load_context(self, task: TaskRecord) -> tuple[RepositoryRecord, IssueRecord]: def _load_context(self, task: TaskRecord) -> tuple[RepositoryRecord, IssueRecord]:
repo_row = self.db.conn.execute("SELECT * FROM repositories WHERE id = ?", (task.repo_id,)).fetchone() return load_task_context(self.db, task)
def load_task_context(db: Database, task: TaskRecord) -> tuple[RepositoryRecord, IssueRecord]:
repo_row = db.conn.execute("SELECT * FROM repositories WHERE id = ?", (task.repo_id,)).fetchone()
if repo_row is None: if repo_row is None:
raise ValueError(f"repository not found for task {task.id}") raise ValueError(f"repository not found for task {task.id}")
repo = self.db._repo(repo_row) repo = db._repo(repo_row)
issue = self.db.get_issue(task.repo_id, task.issue_number) issue = db.get_issue(task.repo_id, task.issue_number)
if issue is None: if issue is None:
raise ValueError(f"issue not found for task {task.id}") raise ValueError(f"issue not found for task {task.id}")
return repo, issue return repo, issue

View File

@@ -8,7 +8,7 @@ import httpx
from agent_gitea.config import AppConfig from agent_gitea.config import AppConfig
from agent_gitea.gitea import GiteaClient from agent_gitea.gitea import GiteaClient
from agent_gitea.models import AgentResult, TaskState from agent_gitea.models import AgentResult, TaskState
from agent_gitea.service import TaskRunner, scan_issues, sync_repositories from agent_gitea.service import TaskRunner, close_issues_for_merged_pull_requests, scan_issues, sync_repositories
def make_config(tmp_path: Path, **overrides: object) -> AppConfig: def make_config(tmp_path: Path, **overrides: object) -> AppConfig:
@@ -265,12 +265,94 @@ def test_run_task_success_posts_review_comments(db, tmp_path):
pull_requests = [payload for _, path, payload in requests if path == "/api/v1/repos/acme/service/pulls"] pull_requests = [payload for _, path, payload in requests if path == "/api/v1/repos/acme/service/pulls"]
assert pull_requests[0]["title"] == "代理实现Ready issue" assert pull_requests[0]["title"] == "代理实现Ready issue"
assert "代理实现报告" in pull_requests[0]["body"] assert "代理实现报告" in pull_requests[0]["body"]
assert "Closes #1" in pull_requests[0]["body"]
command = json.loads(db.list_agent_runs(task.id)[0]["command_json"]) command = json.loads(db.list_agent_runs(task.id)[0]["command_json"])
assert command[1] == "--cd" assert command[1] == "--cd"
assert Path(command[2]).is_absolute() assert Path(command[2]).is_absolute()
assert [path for _, path, _ in requests].count("/api/v1/repos/acme/service/issues/5/comments") == 2 assert [path for _, path, _ in requests].count("/api/v1/repos/acme/service/issues/5/comments") == 2
def test_close_issues_for_merged_pull_requests_closes_linked_issue(db):
repo = db.upsert_repository(
owner="acme",
name="service",
clone_url="https://gitea.test/acme/service.git",
default_branch="main",
enabled=True,
)
db.upsert_issue(
repo_id=repo.id,
issue_number=1,
title="Ready issue",
body="Body",
labels=["agent:ready"],
state="open",
html_url="https://gitea.test/acme/service/issues/1",
)
task = db.create_task(repo.id, 1)
db.transition(task.id, TaskState.CLAIMED)
db.transition(task.id, TaskState.PLANNING)
db.transition(task.id, TaskState.IMPLEMENTING)
db.transition(task.id, TaskState.TESTING)
db.transition(task.id, TaskState.PR_OPENED, pr_number=5)
db.transition(task.id, TaskState.REVIEWING)
db.transition(task.id, TaskState.HUMAN_REVIEW_READY, clear_lease=True)
requests: list[tuple[str, str, dict]] = []
def handler(request: httpx.Request) -> httpx.Response:
payload = json.loads(request.content.decode() or "{}")
requests.append((request.method, request.url.path, payload))
if request.url.path == "/api/v1/repos/acme/service/pulls/5":
return httpx.Response(200, json={"number": 5, "state": "closed", "merged": True})
if request.url.path == "/api/v1/repos/acme/service/issues/1/comments":
return httpx.Response(201, json={"id": 1})
if request.url.path == "/api/v1/repos/acme/service/issues/1":
return httpx.Response(200, json={"number": 1, "state": "closed"})
return httpx.Response(404)
closed = close_issues_for_merged_pull_requests(db, make_client(handler))
assert closed == 1
assert db.get_issue(repo.id, 1).state == "closed" # type: ignore[union-attr]
assert ("PATCH", "/api/v1/repos/acme/service/issues/1", {"state": "closed"}) in requests
def test_close_issues_for_merged_pull_requests_skips_unmerged_pr(db):
repo = db.upsert_repository(
owner="acme",
name="service",
clone_url="https://gitea.test/acme/service.git",
default_branch="main",
enabled=True,
)
db.upsert_issue(
repo_id=repo.id,
issue_number=1,
title="Ready issue",
body="Body",
labels=["agent:ready"],
state="open",
html_url="https://gitea.test/acme/service/issues/1",
)
task = db.create_task(repo.id, 1)
db.transition(task.id, TaskState.CLAIMED)
db.transition(task.id, TaskState.PLANNING)
db.transition(task.id, TaskState.IMPLEMENTING)
db.transition(task.id, TaskState.TESTING)
db.transition(task.id, TaskState.PR_OPENED, pr_number=5)
db.transition(task.id, TaskState.REVIEWING)
db.transition(task.id, TaskState.HUMAN_REVIEW_READY, clear_lease=True)
def handler(request: httpx.Request) -> httpx.Response:
assert request.url.path == "/api/v1/repos/acme/service/pulls/5"
return httpx.Response(200, json={"number": 5, "state": "open", "merged": False})
closed = close_issues_for_merged_pull_requests(db, make_client(handler))
assert closed == 0
assert db.get_issue(repo.id, 1).state == "open" # type: ignore[union-attr]
def test_run_task_no_diff_becomes_blocked(db, tmp_path): def test_run_task_no_diff_becomes_blocked(db, tmp_path):
config = make_config(tmp_path) config = make_config(tmp_path)
seed_task(db) seed_task(db)

View File

@@ -40,6 +40,7 @@ def test_prompt_and_pr_body_include_contract_sections(db):
assert ".agent-output/AGENT_IMPLEMENTATION_REPORT.md" in prompt assert ".agent-output/AGENT_IMPLEMENTATION_REPORT.md" in prompt
assert "关联 Issue#7" in body assert "关联 Issue#7" in body
assert "Closes #7" in body
assert "人工审核" in body assert "人工审核" in body