@dataclass
class MemoryController:
"""Schema-enforced memory controller with field-level retention."""
schema: MemorySchemaModel
_data: dict[str, dict[str, MemoryEntry]] = field(default_factory=dict)
_handles: dict[str, HandleEntry] = field(default_factory=dict)
_fernet_key: bytes | None = None
_fernet: Fernet = field(init=False, repr=False)
def __post_init__(self) -> None:
self._fernet = Fernet(self._fernet_key or Fernet.generate_key())
@classmethod
def from_schema_file(
cls,
path: str | Path,
*,
version: str = "v1alpha1",
) -> "MemoryController":
loaded = load_memory_documents(path, [Path(path).name], version=version)
if not loaded:
raise ValueError(f"No memory definitions found in {path}")
return cls.from_documents(loaded)
@classmethod
def from_documents(cls, documents: list[dict[str, Any]]) -> "MemoryController":
if not documents:
raise ValueError("No memory schema documents provided")
parsed_definitions: list[MemorySchemaModel] = []
for doc in documents:
parsed = MemorySchemaDocumentModel.model_validate(doc)
if parsed.memory:
parsed_definitions.append(parsed.memory)
if parsed.memories:
parsed_definitions.extend(parsed.memories)
if not parsed_definitions:
raise ValueError("No memory definitions present after validation")
# MVP uses first memory definition; multi-profile routing is post-MVP.
return cls(schema=parsed_definitions[0])
@property
def allowed_fields(self) -> set[str]:
return {field.name for field in self.schema.fields}
def write(self, key: str, value: Any, agent_id: str, *, strict: bool = False) -> MemoryWriteResult:
field_spec = self._field(key)
if field_spec is None:
reason = f"Field '{key}' is not defined in memory schema. Allowed fields: {sorted(self.allowed_fields)}"
logger.warning("memory_write rejected: %s", reason)
if strict:
raise MemoryValidationError(reason)
return MemoryWriteResult(success=False, reason=reason)
if not _matches_declared_type(value, field_spec.type):
reason = f"Type mismatch for field '{key}': expected {field_spec.type}, got {type(value).__name__}"
logger.warning("memory_write rejected: %s", reason)
if strict:
raise MemoryValidationError(reason)
return MemoryWriteResult(success=False, reason=reason)
bucket = self._data.setdefault(agent_id, {})
if key not in bucket and len(bucket) >= self.schema.max_entries:
reason = f"Memory at capacity ({self.schema.max_entries} entries) for agent '{agent_id}'. Cannot add new key '{key}'."
logger.warning("memory_write rejected: %s", reason)
if strict:
raise MemoryValidationError(reason)
return MemoryWriteResult(success=False, reason=reason)
expiry = _compute_expiry(field_spec.retention or self.schema.default_retention)
existing = bucket.get(key)
if existing and existing.encrypted:
self._handles.pop(str(existing.value), None)
stored_value: Any
if field_spec.encrypted:
stored_value = self._store_handle(
value=value,
expires_at=expiry,
tag=field_spec.tag,
agent_id=agent_id,
)
else:
stored_value = value
bucket[key] = MemoryEntry(
value=stored_value,
expires_at=expiry,
tag=field_spec.tag,
encrypted=field_spec.encrypted,
)
return MemoryWriteResult(success=True)
def read(self, key: str, agent_id: str) -> MemoryReadResult:
bucket = self._data.get(agent_id)
if not bucket:
return MemoryReadResult(found=False, reason=f"No memory bucket for agent '{agent_id}'")
entry = bucket.get(key)
if not entry:
return MemoryReadResult(found=False, reason=f"Key '{key}' not found for agent '{agent_id}'")
if entry.expires_at <= datetime.now(timezone.utc):
self._drop_entry(bucket=bucket, key=key, entry=entry)
return MemoryReadResult(found=False, reason=f"Key '{key}' expired and was purged")
return MemoryReadResult(value=entry.value, found=True)
def purge(self, agent_id: str | None = None) -> int:
if agent_id is None:
count = 0
for bucket in self._data.values():
for key, entry in list(bucket.items()):
self._drop_entry(bucket=bucket, key=key, entry=entry)
count += 1
self._data.clear()
self._handles.clear()
return count
removed = 0
bucket = self._data.get(agent_id, {})
for key, entry in list(bucket.items()):
self._drop_entry(bucket=bucket, key=key, entry=entry)
removed += 1
self._data.pop(agent_id, None)
return removed
def purge_expired(self) -> int:
now = datetime.now(timezone.utc)
purged = 0
for agent_id in list(self._data.keys()):
bucket = self._data.get(agent_id, {})
for key in list(bucket.keys()):
entry = bucket[key]
if entry.expires_at <= now:
self._drop_entry(bucket=bucket, key=key, entry=entry)
purged += 1
if not bucket:
self._data.pop(agent_id, None)
for handle_id in list(self._handles.keys()):
if self._handles[handle_id].expires_at <= now:
self._handles.pop(handle_id, None)
return purged
def handle_metadata(self, handle_id: str) -> dict[str, Any] | None:
token = _normalize_handle_id(handle_id)
if not token:
return None
entry = self._handles.get(token)
if entry is None:
return None
if entry.expires_at <= datetime.now(timezone.utc):
self._handles.pop(token, None)
return None
return {
"tag": entry.tag,
"agent_id": entry.agent_id,
"expires_at": entry.expires_at,
}
def resolve_handle(self, handle_id: str, *, agent_id: str) -> Any:
token = _normalize_handle_id(handle_id)
if not token:
raise KeyError(f"Invalid memory handle '{handle_id}'")
entry = self._handles.get(token)
if entry is None:
raise KeyError(f"Memory handle '{token}' not found")
if entry.expires_at <= datetime.now(timezone.utc):
self._handles.pop(token, None)
raise KeyError(f"Memory handle '{token}' expired")
if entry.agent_id != str(agent_id).strip():
raise PermissionError("memory handle agent binding mismatch")
decrypted = self._fernet.decrypt(entry.ciphertext)
payload = json.loads(decrypted.decode("utf-8"))
if not isinstance(payload, dict) or "value" not in payload:
raise ValueError("memory handle payload is invalid")
return payload["value"]
def _field(self, key: str) -> MemoryFieldModel | None:
for field_spec in self.schema.fields:
if field_spec.name == key:
return field_spec
return None
def _drop_entry(self, *, bucket: dict[str, MemoryEntry], key: str, entry: MemoryEntry) -> None:
bucket.pop(key, None)
if entry.encrypted:
self._handles.pop(str(entry.value), None)
def _store_handle(self, *, value: Any, expires_at: datetime, tag: str, agent_id: str) -> str:
handle_id = f"hdl_{uuid4().hex[:24]}"
payload = json.dumps({"value": value}, sort_keys=True, default=str, ensure_ascii=True).encode("utf-8")
ciphertext = self._fernet.encrypt(payload)
self._handles[handle_id] = HandleEntry(
ciphertext=ciphertext,
expires_at=expires_at,
tag=str(tag).strip().lower(),
agent_id=str(agent_id).strip(),
)
return handle_id