class PolicyEngine:
"""Deterministic first-match policy evaluator with default deny."""
def __init__(self, rules: list[PolicyRule] | None = None) -> None:
self._lock = RLock()
self._rules: list[PolicyRule] = sorted(rules or [], key=lambda item: item.priority)
self._reload_callback: PolicyRuleLoader | None = None
self._watched_files: tuple[Path, ...] = ()
self._file_mtimes: dict[Path, int] = {}
def load(self, rules: list[PolicyRule]) -> None:
with self._lock:
self._rules = sorted(rules, key=lambda item: item.priority)
def evaluate(self, context: PolicyContext) -> PolicyDecision:
with self._lock:
rules = tuple(self._rules)
for rule in rules:
if rule.tenant_id is not None and rule.tenant_id != context.tenant_id:
continue
if self._matches(rule, context):
validated = PolicyDecisionModel(
action=rule.action,
policy_name=rule.name,
reason=rule.reason,
fallback_template=rule.fallback_template,
routing_constraint=rule.allowed_providers,
)
return PolicyDecision(**validated.model_dump())
validated = PolicyDecisionModel(
action="block",
policy_name=None,
reason="default deny",
fallback_template=None,
)
return PolicyDecision(**validated.model_dump())
def register_reload(self, files: list[Path], loader: PolicyRuleLoader) -> None:
watched = tuple(sorted({Path(path).expanduser().resolve() for path in files}, key=str))
with self._lock:
self._reload_callback = loader
self._watched_files = watched
self._file_mtimes = self._snapshot_mtimes(watched)
def reload_if_changed(self) -> bool:
with self._lock:
watched = self._watched_files
previous = dict(self._file_mtimes)
callback = self._reload_callback
if callback is None or not watched:
return False
current = self._snapshot_mtimes(watched)
if current == previous:
return False
self.reload()
return True
def reload(self) -> bool:
with self._lock:
callback = self._reload_callback
watched = self._watched_files
if callback is None:
return False
fresh_rules = sorted(callback(), key=lambda item: item.priority)
fresh_mtimes = self._snapshot_mtimes(watched)
with self._lock:
self._rules = fresh_rules
self._file_mtimes = fresh_mtimes
return True
def _matches(self, rule: PolicyRule, context: PolicyContext) -> bool:
if context.boundary not in rule.boundary:
return False
cond = rule.condition or {}
data_tags = _coerce_values(cond.get("data_tags"), lower=True)
context_tags = expand_tag_hierarchy(context.data_tags)
if data_tags and not data_tags.intersection(context_tags):
return False
tools = _coerce_values(cond.get("tools"))
tool = cond.get("tool")
if tool:
tools.update(_coerce_values(tool))
if tools and context.tool_name not in tools:
return False
agents = _coerce_values(cond.get("agents"))
agent = cond.get("agent")
if agent:
agents.update(_coerce_values(agent))
if agents and context.agent_id not in agents:
return False
return True
@staticmethod
def _snapshot_mtimes(files: tuple[Path, ...]) -> dict[Path, int]:
mtimes: dict[Path, int] = {}
for file_path in files:
try:
mtimes[file_path] = file_path.stat().st_mtime_ns
except OSError:
mtimes[file_path] = -1
return mtimes