Skip to content

Memory

Schema-enforced encrypted memory controller.

memory

Schema-bound in-memory store with retention enforcement.

MemoryValidationError

Bases: Exception

Raised when a memory write fails validation in strict mode.

Source code in safeai/core/memory.py
class MemoryValidationError(Exception):
    """Raised when a memory write fails validation in strict mode."""

MemoryWriteResult dataclass

Result of a memory write operation.

Source code in safeai/core/memory.py
@dataclass(frozen=True)
class MemoryWriteResult:
    """Result of a memory write operation."""

    success: bool
    reason: str | None = None

    def __bool__(self) -> bool:
        return self.success

MemoryReadResult dataclass

Result of a memory read operation.

Source code in safeai/core/memory.py
@dataclass(frozen=True)
class MemoryReadResult:
    """Result of a memory read operation."""

    value: Any = None
    found: bool = False
    reason: str | None = None

MemoryController dataclass

Schema-enforced memory controller with field-level retention.

Source code in safeai/core/memory.py
@dataclass
class MemoryController:
    """Schema-enforced memory controller with field-level retention."""

    schema: MemorySchemaModel
    _data: dict[str, dict[str, MemoryEntry]] = field(default_factory=dict)
    _handles: dict[str, HandleEntry] = field(default_factory=dict)
    _fernet_key: bytes | None = None
    _fernet: Fernet = field(init=False, repr=False)

    def __post_init__(self) -> None:
        self._fernet = Fernet(self._fernet_key or Fernet.generate_key())

    @classmethod
    def from_schema_file(
        cls,
        path: str | Path,
        *,
        version: str = "v1alpha1",
    ) -> "MemoryController":
        loaded = load_memory_documents(path, [Path(path).name], version=version)
        if not loaded:
            raise ValueError(f"No memory definitions found in {path}")
        return cls.from_documents(loaded)

    @classmethod
    def from_documents(cls, documents: list[dict[str, Any]]) -> "MemoryController":
        if not documents:
            raise ValueError("No memory schema documents provided")

        parsed_definitions: list[MemorySchemaModel] = []
        for doc in documents:
            parsed = MemorySchemaDocumentModel.model_validate(doc)
            if parsed.memory:
                parsed_definitions.append(parsed.memory)
            if parsed.memories:
                parsed_definitions.extend(parsed.memories)

        if not parsed_definitions:
            raise ValueError("No memory definitions present after validation")

        # MVP uses first memory definition; multi-profile routing is post-MVP.
        return cls(schema=parsed_definitions[0])

    @property
    def allowed_fields(self) -> set[str]:
        return {field.name for field in self.schema.fields}

    def write(self, key: str, value: Any, agent_id: str, *, strict: bool = False) -> MemoryWriteResult:
        field_spec = self._field(key)
        if field_spec is None:
            reason = f"Field '{key}' is not defined in memory schema. Allowed fields: {sorted(self.allowed_fields)}"
            logger.warning("memory_write rejected: %s", reason)
            if strict:
                raise MemoryValidationError(reason)
            return MemoryWriteResult(success=False, reason=reason)
        if not _matches_declared_type(value, field_spec.type):
            reason = f"Type mismatch for field '{key}': expected {field_spec.type}, got {type(value).__name__}"
            logger.warning("memory_write rejected: %s", reason)
            if strict:
                raise MemoryValidationError(reason)
            return MemoryWriteResult(success=False, reason=reason)

        bucket = self._data.setdefault(agent_id, {})
        if key not in bucket and len(bucket) >= self.schema.max_entries:
            reason = f"Memory at capacity ({self.schema.max_entries} entries) for agent '{agent_id}'. Cannot add new key '{key}'."
            logger.warning("memory_write rejected: %s", reason)
            if strict:
                raise MemoryValidationError(reason)
            return MemoryWriteResult(success=False, reason=reason)

        expiry = _compute_expiry(field_spec.retention or self.schema.default_retention)
        existing = bucket.get(key)
        if existing and existing.encrypted:
            self._handles.pop(str(existing.value), None)

        stored_value: Any
        if field_spec.encrypted:
            stored_value = self._store_handle(
                value=value,
                expires_at=expiry,
                tag=field_spec.tag,
                agent_id=agent_id,
            )
        else:
            stored_value = value
        bucket[key] = MemoryEntry(
            value=stored_value,
            expires_at=expiry,
            tag=field_spec.tag,
            encrypted=field_spec.encrypted,
        )
        return MemoryWriteResult(success=True)

    def read(self, key: str, agent_id: str) -> MemoryReadResult:
        bucket = self._data.get(agent_id)
        if not bucket:
            return MemoryReadResult(found=False, reason=f"No memory bucket for agent '{agent_id}'")
        entry = bucket.get(key)
        if not entry:
            return MemoryReadResult(found=False, reason=f"Key '{key}' not found for agent '{agent_id}'")
        if entry.expires_at <= datetime.now(timezone.utc):
            self._drop_entry(bucket=bucket, key=key, entry=entry)
            return MemoryReadResult(found=False, reason=f"Key '{key}' expired and was purged")
        return MemoryReadResult(value=entry.value, found=True)

    def purge(self, agent_id: str | None = None) -> int:
        if agent_id is None:
            count = 0
            for bucket in self._data.values():
                for key, entry in list(bucket.items()):
                    self._drop_entry(bucket=bucket, key=key, entry=entry)
                    count += 1
            self._data.clear()
            self._handles.clear()
            return count
        removed = 0
        bucket = self._data.get(agent_id, {})
        for key, entry in list(bucket.items()):
            self._drop_entry(bucket=bucket, key=key, entry=entry)
            removed += 1
        self._data.pop(agent_id, None)
        return removed

    def purge_expired(self) -> int:
        now = datetime.now(timezone.utc)
        purged = 0
        for agent_id in list(self._data.keys()):
            bucket = self._data.get(agent_id, {})
            for key in list(bucket.keys()):
                entry = bucket[key]
                if entry.expires_at <= now:
                    self._drop_entry(bucket=bucket, key=key, entry=entry)
                    purged += 1
            if not bucket:
                self._data.pop(agent_id, None)
        for handle_id in list(self._handles.keys()):
            if self._handles[handle_id].expires_at <= now:
                self._handles.pop(handle_id, None)
        return purged

    def handle_metadata(self, handle_id: str) -> dict[str, Any] | None:
        token = _normalize_handle_id(handle_id)
        if not token:
            return None
        entry = self._handles.get(token)
        if entry is None:
            return None
        if entry.expires_at <= datetime.now(timezone.utc):
            self._handles.pop(token, None)
            return None
        return {
            "tag": entry.tag,
            "agent_id": entry.agent_id,
            "expires_at": entry.expires_at,
        }

    def resolve_handle(self, handle_id: str, *, agent_id: str) -> Any:
        token = _normalize_handle_id(handle_id)
        if not token:
            raise KeyError(f"Invalid memory handle '{handle_id}'")
        entry = self._handles.get(token)
        if entry is None:
            raise KeyError(f"Memory handle '{token}' not found")
        if entry.expires_at <= datetime.now(timezone.utc):
            self._handles.pop(token, None)
            raise KeyError(f"Memory handle '{token}' expired")
        if entry.agent_id != str(agent_id).strip():
            raise PermissionError("memory handle agent binding mismatch")

        decrypted = self._fernet.decrypt(entry.ciphertext)
        payload = json.loads(decrypted.decode("utf-8"))
        if not isinstance(payload, dict) or "value" not in payload:
            raise ValueError("memory handle payload is invalid")
        return payload["value"]

    def _field(self, key: str) -> MemoryFieldModel | None:
        for field_spec in self.schema.fields:
            if field_spec.name == key:
                return field_spec
        return None

    def _drop_entry(self, *, bucket: dict[str, MemoryEntry], key: str, entry: MemoryEntry) -> None:
        bucket.pop(key, None)
        if entry.encrypted:
            self._handles.pop(str(entry.value), None)

    def _store_handle(self, *, value: Any, expires_at: datetime, tag: str, agent_id: str) -> str:
        handle_id = f"hdl_{uuid4().hex[:24]}"
        payload = json.dumps({"value": value}, sort_keys=True, default=str, ensure_ascii=True).encode("utf-8")
        ciphertext = self._fernet.encrypt(payload)
        self._handles[handle_id] = HandleEntry(
            ciphertext=ciphertext,
            expires_at=expires_at,
            tag=str(tag).strip().lower(),
            agent_id=str(agent_id).strip(),
        )
        return handle_id