class ToolContractRegistry:
"""Runtime registry for declared tool contracts."""
def __init__(self, contracts: list[ToolContract] | None = None) -> None:
self._contracts: dict[str, ToolContract] = {}
self.load(contracts or [])
def load(self, contracts: list[ToolContract]) -> None:
self._contracts = {item.tool_name: item for item in contracts}
def get(self, tool_name: str) -> ToolContract | None:
return self._contracts.get(str(tool_name).strip())
def has(self, tool_name: str) -> bool:
return self.get(tool_name) is not None
def validate_request(self, tool_name: str, data_tags: list[str]) -> ContractValidationResult:
contract = self.get(tool_name)
if contract is None:
return ContractValidationResult(
allowed=False,
reason=f"tool '{tool_name}' has no declared contract",
unauthorized_tags=sorted(set(data_tags)),
contract=None,
)
if not data_tags:
return ContractValidationResult(
allowed=True,
reason="no classified data tags on request",
unauthorized_tags=[],
contract=contract,
)
unauthorized: list[str] = []
accepted = {tag.lower() for tag in contract.accepts_tags}
for raw_tag in data_tags:
token = str(raw_tag).strip().lower()
if not token:
continue
expanded = expand_tag_hierarchy([token])
if accepted.intersection(expanded):
continue
unauthorized.append(token)
if unauthorized:
return ContractValidationResult(
allowed=False,
reason=f"tool '{tool_name}' does not accept data tags: {','.join(sorted(set(unauthorized)))}",
unauthorized_tags=sorted(set(unauthorized)),
contract=contract,
)
return ContractValidationResult(
allowed=True,
reason="tool contract allows request tags",
unauthorized_tags=[],
contract=contract,
)
def all(self) -> list[ToolContract]:
return list(self._contracts.values())