Skip to content

Approval

Human-in-the-loop approval workflow manager.

approval

Approval workflow manager for high-risk action gating.

ApprovalManager

Stateful approval gate with optional file-backed persistence.

Source code in safeai/core/approval.py
class ApprovalManager:
    """Stateful approval gate with optional file-backed persistence."""

    def __init__(
        self,
        *,
        file_path: str | Path | None = None,
        default_ttl: str = "30m",
        clock: Clock | None = None,
    ) -> None:
        self._clock = clock or (lambda: datetime.now(timezone.utc))
        self._default_ttl = default_ttl
        self._requests: dict[str, ApprovalRequest] = {}
        self._file_path = Path(file_path).expanduser().resolve() if file_path else None
        self._last_mtime_ns: int | None = None
        self._load()

    def create_request(
        self,
        *,
        reason: str,
        policy_name: str | None,
        agent_id: str,
        tool_name: str,
        session_id: str | None = None,
        action_type: str = "tool_call",
        data_tags: list[str] | None = None,
        metadata: dict[str, Any] | None = None,
        ttl: str | None = None,
        dedupe_key: str | None = None,
    ) -> ApprovalRequest:
        self._reload_if_changed()
        now = self._clock()
        normalized_dedupe = _normalize_optional_token(dedupe_key)
        if normalized_dedupe:
            existing = self._find_pending_by_dedupe(normalized_dedupe, now=now)
            if existing is not None:
                return existing

        duration = _parse_duration(ttl or self._default_ttl)
        request = ApprovalRequest(
            request_id=f"apr_{uuid4().hex[:12]}",
            status="pending",
            reason=str(reason).strip(),
            policy_name=_normalize_optional_token(policy_name),
            agent_id=_normalize_required_token(agent_id, field_name="agent_id"),
            tool_name=_normalize_required_token(tool_name, field_name="tool_name"),
            session_id=_normalize_optional_token(session_id),
            action_type=_normalize_optional_token(action_type) or "tool_call",
            data_tags=sorted({str(tag).strip().lower() for tag in (data_tags or []) if str(tag).strip()}),
            requested_at=now,
            expires_at=now + duration,
            metadata=dict(metadata or {}),
            dedupe_key=normalized_dedupe,
        )
        self._requests[request.request_id] = request
        self._persist()
        return request

    def get(self, request_id: str) -> ApprovalRequest | None:
        self._reload_if_changed()
        token = _normalize_optional_token(request_id)
        if not token:
            return None
        row = self._requests.get(token)
        if row is None:
            return None
        if row.status == "pending" and row.is_expired(now=self._clock()):
            row = row.__class__(**{**row.__dict__, "status": "expired"})
            self._requests[row.request_id] = row
            self._persist()
        return row

    def list_requests(
        self,
        *,
        status: ApprovalStatus | None = None,
        agent_id: str | None = None,
        tool_name: str | None = None,
        newest_first: bool = True,
        limit: int = 100,
    ) -> list[ApprovalRequest]:
        self._reload_if_changed()
        now = self._clock()
        rows: list[ApprovalRequest] = []
        for item in self._requests.values():
            row = item
            if row.status == "pending" and row.is_expired(now=now):
                row = row.__class__(**{**row.__dict__, "status": "expired"})
                self._requests[row.request_id] = row
            if status and row.status != status:
                continue
            if agent_id and row.agent_id != _normalize_required_token(agent_id, field_name="agent_id"):
                continue
            if tool_name and row.tool_name != _normalize_required_token(tool_name, field_name="tool_name"):
                continue
            rows.append(row)
        if rows:
            self._persist()
        rows.sort(key=lambda item: item.requested_at, reverse=newest_first)
        if limit <= 0:
            return rows
        return rows[:limit]

    def approve(self, request_id: str, *, approver_id: str, note: str | None = None) -> bool:
        return self._decide(
            request_id=request_id,
            status="approved",
            approver_id=approver_id,
            note=note,
        )

    def deny(self, request_id: str, *, approver_id: str, note: str | None = None) -> bool:
        return self._decide(
            request_id=request_id,
            status="denied",
            approver_id=approver_id,
            note=note,
        )

    def validate(
        self,
        request_id: str,
        *,
        agent_id: str,
        tool_name: str,
        session_id: str | None = None,
    ) -> ApprovalValidationResult:
        row = self.get(request_id)
        if row is None:
            return ApprovalValidationResult(
                allowed=False,
                reason=f"approval request '{request_id}' not found",
                request=None,
            )
        if row.status == "expired":
            return ApprovalValidationResult(
                allowed=False,
                reason=f"approval request '{request_id}' expired",
                request=row,
            )
        if row.status == "denied":
            return ApprovalValidationResult(
                allowed=False,
                reason=f"approval request '{request_id}' denied",
                request=row,
            )
        if row.status == "pending":
            return ApprovalValidationResult(
                allowed=False,
                reason=f"approval request '{request_id}' pending",
                request=row,
            )

        normalized_agent = _normalize_required_token(agent_id, field_name="agent_id")
        if row.agent_id != normalized_agent:
            return ApprovalValidationResult(
                allowed=False,
                reason="approval request agent binding mismatch",
                request=row,
            )
        normalized_tool = _normalize_required_token(tool_name, field_name="tool_name")
        if row.tool_name != normalized_tool:
            return ApprovalValidationResult(
                allowed=False,
                reason="approval request tool binding mismatch",
                request=row,
            )
        normalized_session = _normalize_optional_token(session_id)
        if row.session_id and row.session_id != normalized_session:
            return ApprovalValidationResult(
                allowed=False,
                reason="approval request session binding mismatch",
                request=row,
            )
        return ApprovalValidationResult(
            allowed=True,
            reason="approval request approved",
            request=row,
        )

    def purge_expired(self) -> int:
        self._reload_if_changed()
        now = self._clock()
        purged = 0
        for request_id in list(self._requests.keys()):
            row = self._requests[request_id]
            if row.status == "pending" and row.is_expired(now=now):
                self._requests.pop(request_id, None)
                purged += 1
        if purged:
            self._persist()
        return purged

    def _decide(
        self,
        *,
        request_id: str,
        status: ApprovalStatus,
        approver_id: str,
        note: str | None,
    ) -> bool:
        self._reload_if_changed()
        token = _normalize_required_token(request_id, field_name="request_id")
        row = self._requests.get(token)
        if row is None:
            return False
        if row.status != "pending" or row.is_expired(now=self._clock()):
            return False
        updated = row.__class__(
            **{
                **row.__dict__,
                "status": status,
                "approver_id": _normalize_required_token(approver_id, field_name="approver_id"),
                "decision_note": _normalize_optional_token(note),
                "decided_at": self._clock(),
            }
        )
        self._requests[token] = updated
        self._persist()
        return True

    def _find_pending_by_dedupe(self, dedupe_key: str, *, now: datetime) -> ApprovalRequest | None:
        for row in self._requests.values():
            if row.dedupe_key != dedupe_key:
                continue
            if row.status != "pending":
                continue
            if row.is_expired(now=now):
                continue
            return row
        return None

    def _reload_if_changed(self) -> None:
        if self._file_path is None or not self._file_path.exists():
            return
        try:
            mtime_ns = self._file_path.stat().st_mtime_ns
        except OSError:
            return
        if self._last_mtime_ns is not None and mtime_ns == self._last_mtime_ns:
            return
        self._load()

    def _load(self) -> None:
        if self._file_path is None:
            return
        self._file_path.parent.mkdir(parents=True, exist_ok=True)
        if not self._file_path.exists():
            self._file_path.write_text("", encoding="utf-8")
            try:
                self._last_mtime_ns = self._file_path.stat().st_mtime_ns
            except OSError:
                self._last_mtime_ns = None
            return

        rows: dict[str, ApprovalRequest] = {}
        for line in self._file_path.read_text(encoding="utf-8").splitlines():
            line = line.strip()
            if not line:
                continue
            try:
                payload = json.loads(line)
                row = _request_from_payload(payload)
            except Exception:
                continue
            rows[row.request_id] = row
        self._requests = rows
        try:
            self._last_mtime_ns = self._file_path.stat().st_mtime_ns
        except OSError:
            self._last_mtime_ns = None

    def _persist(self) -> None:
        if self._file_path is None:
            return
        self._file_path.parent.mkdir(parents=True, exist_ok=True)
        rows = sorted(self._requests.values(), key=lambda item: item.requested_at)
        encoded = "\n".join(json.dumps(_request_to_payload(row), separators=(",", ":"), ensure_ascii=True) for row in rows)
        self._file_path.write_text(encoded + ("\n" if encoded else ""), encoding="utf-8")
        try:
            self._last_mtime_ns = self._file_path.stat().st_mtime_ns
        except OSError:
            self._last_mtime_ns = None