334 lines
12 KiB
Python
334 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Annotated
|
|
from uuid import uuid4
|
|
|
|
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.core.config import get_settings
|
|
from app.core.constants import ERROR_CODE_OK
|
|
from app.core.time import format_now
|
|
from app.db.session import get_db
|
|
from app.repositories.approval_repository import ApprovalRepository
|
|
from app.repositories.audit_repository import AuditRepository
|
|
from app.repositories.edge_repository import EdgeTaskRepository
|
|
from app.repositories.tool_call_repository import ToolCallRepository
|
|
from app.schemas.common import ApiResponse
|
|
from app.schemas.task import (
|
|
ApprovalTraceItem,
|
|
AuditTraceItem,
|
|
CancelTaskRequest,
|
|
ConfirmTaskData,
|
|
ConfirmTaskRequest,
|
|
CreateTaskData,
|
|
CreateTaskRequest,
|
|
ParsedIntent,
|
|
TaskBasic,
|
|
TaskDetailData,
|
|
TaskReportData,
|
|
ToolTraceItem,
|
|
ToolCallItem,
|
|
VerificationTraceItem,
|
|
)
|
|
from app.services.task_service import TaskConflictError, TaskNotFoundError, TaskService
|
|
from app.services.task_service import TaskPermissionError
|
|
|
|
router = APIRouter(prefix="/api/agent/tasks", tags=["agent-task"])
|
|
|
|
|
|
def build_request_id(header_value: str | None) -> str:
|
|
return header_value or f"req-{uuid4().hex[:12]}"
|
|
|
|
|
|
@router.post("", response_model=ApiResponse[CreateTaskData])
|
|
def create_task(
|
|
payload: CreateTaskRequest,
|
|
db: Annotated[Session, Depends(get_db)],
|
|
x_request_id: Annotated[str | None, Header(alias="X-Request-Id")] = None,
|
|
) -> ApiResponse[CreateTaskData]:
|
|
settings = get_settings()
|
|
request_id = build_request_id(x_request_id)
|
|
service = TaskService(db, settings.default_timezone)
|
|
task = service.create_task(payload, request_id)
|
|
|
|
missing_slots = json.loads(task.missing_slots_json)
|
|
next_action = "CONFIRM_TASK" if not missing_slots else "FILL_MISSING_SLOTS"
|
|
|
|
return ApiResponse[CreateTaskData](
|
|
request_id=request_id,
|
|
success=True,
|
|
code=ERROR_CODE_OK,
|
|
message="success",
|
|
data=CreateTaskData(
|
|
task_id=task.task_id,
|
|
parsed_intent=ParsedIntent(**json.loads(task.parsed_intent_json)),
|
|
missing_slots=missing_slots,
|
|
risk_level=task.risk_level,
|
|
task_status=task.task_status,
|
|
next_action=next_action,
|
|
),
|
|
timestamp=format_now(settings.default_timezone),
|
|
)
|
|
|
|
|
|
@router.post("/{task_id}/confirm", response_model=ApiResponse[ConfirmTaskData])
|
|
def confirm_task(
|
|
task_id: str,
|
|
payload: ConfirmTaskRequest,
|
|
db: Annotated[Session, Depends(get_db)],
|
|
x_request_id: Annotated[str | None, Header(alias="X-Request-Id")] = None,
|
|
) -> ApiResponse[ConfirmTaskData]:
|
|
settings = get_settings()
|
|
request_id = build_request_id(x_request_id)
|
|
service = TaskService(db, settings.default_timezone)
|
|
|
|
try:
|
|
task, approval_id = service.confirm_task(task_id, payload, request_id=request_id)
|
|
except TaskNotFoundError as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail={"code": exc.code, "message": "task not found"},
|
|
) from exc
|
|
except TaskConflictError as exc:
|
|
message = exc.args[0] if exc.args else "task state conflict"
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail={"code": exc.code, "message": message},
|
|
) from exc
|
|
except TaskPermissionError as exc:
|
|
message = exc.args[0] if exc.args else "permission denied"
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail={"code": exc.code, "message": message},
|
|
) from exc
|
|
|
|
return ApiResponse[ConfirmTaskData](
|
|
request_id=request_id,
|
|
success=True,
|
|
code=ERROR_CODE_OK,
|
|
message="task confirmed",
|
|
data=ConfirmTaskData(
|
|
task_id=task.task_id,
|
|
task_status=task.task_status,
|
|
approval_status=task.approval_status,
|
|
approval_id=approval_id,
|
|
software_a_task_id=task.software_a_task_id,
|
|
software_a_task_status=task.software_a_task_status,
|
|
),
|
|
timestamp=format_now(settings.default_timezone),
|
|
)
|
|
|
|
|
|
@router.post("/{task_id}/cancel", response_model=ApiResponse[ConfirmTaskData])
|
|
def cancel_task(
|
|
task_id: str,
|
|
payload: CancelTaskRequest,
|
|
db: Annotated[Session, Depends(get_db)],
|
|
x_request_id: Annotated[str | None, Header(alias="X-Request-Id")] = None,
|
|
) -> ApiResponse[ConfirmTaskData]:
|
|
settings = get_settings()
|
|
request_id = build_request_id(x_request_id)
|
|
service = TaskService(db, settings.default_timezone)
|
|
|
|
try:
|
|
task = service.cancel_task(task_id, payload.reason, request_id=request_id)
|
|
except TaskNotFoundError as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail={"code": exc.code, "message": "task not found"},
|
|
) from exc
|
|
except TaskConflictError as exc:
|
|
message = exc.args[0] if exc.args else "task state conflict"
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail={"code": exc.code, "message": message},
|
|
) from exc
|
|
|
|
return ApiResponse[ConfirmTaskData](
|
|
request_id=request_id,
|
|
success=True,
|
|
code=ERROR_CODE_OK,
|
|
message="task cancelled",
|
|
data=ConfirmTaskData(
|
|
task_id=task.task_id,
|
|
task_status=task.task_status,
|
|
approval_status=task.approval_status,
|
|
approval_id=None,
|
|
software_a_task_id=task.software_a_task_id,
|
|
software_a_task_status=task.software_a_task_status,
|
|
),
|
|
timestamp=format_now(settings.default_timezone),
|
|
)
|
|
|
|
|
|
@router.get("/{task_id}", response_model=ApiResponse[TaskDetailData])
|
|
def get_task(
|
|
task_id: str,
|
|
db: Annotated[Session, Depends(get_db)],
|
|
x_request_id: Annotated[str | None, Header(alias="X-Request-Id")] = None,
|
|
) -> ApiResponse[TaskDetailData]:
|
|
settings = get_settings()
|
|
request_id = build_request_id(x_request_id)
|
|
service = TaskService(db, settings.default_timezone)
|
|
|
|
try:
|
|
task = service.get_task(task_id)
|
|
except TaskNotFoundError as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail={"code": exc.code, "message": "task not found"},
|
|
) from exc
|
|
|
|
edge_tasks = EdgeTaskRepository(db).list_by_task_id(task_id)
|
|
tool_calls = ToolCallRepository(db).list_by_task_id(task_id)
|
|
verification_result = None
|
|
if edge_tasks:
|
|
latest_edge_task = edge_tasks[0]
|
|
if latest_edge_task.success is not None:
|
|
verification_result = {
|
|
"http_ok": bool(latest_edge_task.success),
|
|
"process_ok": None,
|
|
"port_ok": None,
|
|
"log_error_count": 0 if latest_edge_task.success else 1,
|
|
}
|
|
|
|
return ApiResponse[TaskDetailData](
|
|
request_id=request_id,
|
|
success=True,
|
|
code=ERROR_CODE_OK,
|
|
message="success",
|
|
data=TaskDetailData(
|
|
task_id=task.task_id,
|
|
task_status=task.task_status,
|
|
approval_status=task.approval_status,
|
|
risk_level=task.risk_level,
|
|
intent=ParsedIntent(**json.loads(task.parsed_intent_json)),
|
|
software_a_task_id=task.software_a_task_id,
|
|
software_a_task_status=task.software_a_task_status,
|
|
tool_calls=[
|
|
ToolCallItem(
|
|
tool_name=item.tool_name,
|
|
success=bool(item.success),
|
|
)
|
|
for item in tool_calls
|
|
],
|
|
verification_result=verification_result,
|
|
summary=task.summary,
|
|
),
|
|
timestamp=format_now(settings.default_timezone),
|
|
)
|
|
|
|
|
|
@router.get("/{task_id}/report", response_model=ApiResponse[TaskReportData])
|
|
def get_task_report(
|
|
task_id: str,
|
|
db: Annotated[Session, Depends(get_db)],
|
|
x_request_id: Annotated[str | None, Header(alias="X-Request-Id")] = None,
|
|
) -> ApiResponse[TaskReportData]:
|
|
settings = get_settings()
|
|
request_id = build_request_id(x_request_id)
|
|
service = TaskService(db, settings.default_timezone)
|
|
|
|
try:
|
|
task = service.get_task(task_id)
|
|
except TaskNotFoundError as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail={"code": exc.code, "message": "task not found"},
|
|
) from exc
|
|
|
|
approval = ApprovalRepository(db).get_by_task_id(task_id)
|
|
tool_calls = ToolCallRepository(db).list_by_task_id(task_id)
|
|
edge_tasks = EdgeTaskRepository(db).list_by_task_id(task_id)
|
|
audit_logs = AuditRepository(db).list_by_task_id(task_id)
|
|
|
|
approval_trace = []
|
|
if approval:
|
|
approval_trace.append(
|
|
ApprovalTraceItem(
|
|
approval_id=approval.approval_id,
|
|
approval_status=approval.approval_status,
|
|
risk_level=approval.risk_level,
|
|
approvers=json.loads(approval.approver_user_ids_json),
|
|
reason=approval.reason,
|
|
created_at=approval.created_at,
|
|
updated_at=approval.updated_at,
|
|
)
|
|
)
|
|
|
|
tool_trace = [
|
|
ToolTraceItem(
|
|
tool_call_id=item.tool_call_id,
|
|
request_id=item.request_id,
|
|
operator_user_id=item.operator_user_id,
|
|
operator_user_name=item.operator_user_name,
|
|
tool_name=item.tool_name,
|
|
success=bool(item.success),
|
|
duration_ms=item.duration_ms,
|
|
started_at=item.started_at,
|
|
finished_at=item.finished_at,
|
|
request_payload=json.loads(item.request_payload_json),
|
|
response_payload=json.loads(item.response_payload_json),
|
|
)
|
|
for item in tool_calls
|
|
]
|
|
|
|
verification_trace = [
|
|
VerificationTraceItem(
|
|
step_id=item.step_id,
|
|
edge_id=item.edge_id,
|
|
tool_name=item.tool_name,
|
|
step_status=item.step_status,
|
|
success=None if item.success is None else bool(item.success),
|
|
message=item.message,
|
|
params=json.loads(item.params_json),
|
|
result_data=json.loads(item.result_data_json),
|
|
evidence=json.loads(item.evidence_json),
|
|
started_at=item.started_at,
|
|
finished_at=item.finished_at,
|
|
)
|
|
for item in edge_tasks
|
|
]
|
|
|
|
audit_trace = [
|
|
AuditTraceItem(
|
|
audit_id=item.audit_id,
|
|
request_id=item.request_id,
|
|
action=item.action,
|
|
result=item.result,
|
|
operator_user_id=item.operator_user_id,
|
|
operator_user_name=item.operator_user_name,
|
|
target=item.target,
|
|
detail=json.loads(item.detail_json),
|
|
timestamp=item.timestamp,
|
|
)
|
|
for item in audit_logs
|
|
]
|
|
|
|
return ApiResponse[TaskReportData](
|
|
request_id=request_id,
|
|
success=True,
|
|
code=ERROR_CODE_OK,
|
|
message="success",
|
|
data=TaskReportData(
|
|
task_basic=TaskBasic(
|
|
task_id=task.task_id,
|
|
task_status=task.task_status,
|
|
approval_status=task.approval_status,
|
|
risk_level=task.risk_level,
|
|
created_at=task.created_at,
|
|
updated_at=task.updated_at,
|
|
confirmed_at=task.confirmed_at,
|
|
),
|
|
intent_snapshot=ParsedIntent(**json.loads(task.parsed_intent_json)),
|
|
approval_trace=approval_trace,
|
|
tool_trace=tool_trace,
|
|
verification_trace=verification_trace,
|
|
result_summary=task.summary,
|
|
audit_trace=audit_trace,
|
|
),
|
|
timestamp=format_now(settings.default_timezone),
|
|
)
|