Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ async def get_episode_messages(
end_time=end_time,
)

async def get_episode_ids(
self,
*,
page_size: int,
filter_expr: FilterExpr | None = None,
) -> list[EpisodeIdT]:
return await self._wrapped.get_episode_ids(
page_size=page_size,
filter_expr=filter_expr,
)

async def get_episode_messages_count(
self,
*,
Expand Down
21 changes: 21 additions & 0 deletions src/memmachine/common/episode_store/episode_sqlalchemy_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,27 @@ async def get_episode_messages_count(

return int(n_messages)

async def get_episode_ids(
self,
*,
page_size: int,
filter_expr: FilterExpr | None = None,
) -> list[EpisodeIdT]:
stmt = select(Episode.id)

stmt = self._apply_episode_filter(
stmt,
filter_expr=filter_expr,
)

stmt = stmt.order_by(Episode.created_at.asc()).limit(page_size)

async with self._create_session() as session:
result = await session.execute(stmt)
rows = result.scalars().all()

return [EpisodeIdT(row) for row in rows]

@validate_call
async def delete_episodes(self, episode_ids: list[EpisodeIdT]) -> None:
try:
Expand Down
9 changes: 9 additions & 0 deletions src/memmachine/common/episode_store/episode_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ async def get_episode_messages_count(
) -> int:
raise NotImplementedError

@abstractmethod
async def get_episode_ids(
self,
*,
page_size: int,
filter_expr: FilterExpr | None = None,
) -> list[EpisodeIdT]:
raise NotImplementedError

@abstractmethod
async def delete_episodes(
self,
Expand Down
52 changes: 42 additions & 10 deletions src/memmachine/main/memmachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@


ALL_MEMORY_TYPES: Final[list[MemoryType]] = list(MemoryType)
EPISODE_DELETE_BATCH_SIZE: Final[int] = 1000


class MemMachine:
Expand Down Expand Up @@ -311,6 +312,23 @@ async def get_session(
session_data_manager = await self._resources.get_session_data_manager()
return await session_data_manager.get_session_info(session_key)

async def _cleanup_semantic_history(
self,
episode_ids: list[EpisodeIdT],
) -> None:
"""Remove semantic history and citations for the given episode IDs."""
try:
semantic_service = await self._resources.get_semantic_service()
except ResourceNotReadyError:
logger.exception(
"Semantic service not ready during history cleanup; "
"skipping cleanup for episode IDs %s",
episode_ids,
Comment thread
o-love marked this conversation as resolved.
)
return

await semantic_service.delete_history(episode_ids)

async def delete_session(self, session_data: SessionData) -> None:
"""
Delete all data associated with a session.
Expand All @@ -336,9 +354,15 @@ async def _delete_episode_store() -> None:
op="=",
value=session_data.session_key,
)
await episode_store.delete_episode_messages(
filter_expr=session_filter,
)
while True:
episode_ids = await episode_store.get_episode_ids(
filter_expr=session_filter,
page_size=EPISODE_DELETE_BATCH_SIZE,
)
if not episode_ids:
break
await self._cleanup_semantic_history(episode_ids)
await episode_store.delete_episodes(episode_ids)

async def _delete_episodic_memory() -> None:
episodic_memory_manager = (
Expand All @@ -353,13 +377,8 @@ async def _delete_semantic_memory() -> None:
semantic_memory_manager = (
await self._resources.get_semantic_session_manager()
)
await asyncio.gather(
semantic_memory_manager.delete_feature_set(
session_data=session_data,
),
semantic_memory_manager.delete_all_project_messages(
session_data=session_data
),
await semantic_memory_manager.delete_feature_set(
session_data=session_data,
)

tasks = [_delete_episode_store()]
Expand Down Expand Up @@ -754,6 +773,19 @@ async def delete_episodes(
tasks.append(semantic_service.delete_history(episode_ids))
await asyncio.gather(*tasks)

async def _cleanup_semantic_history(self, episode_ids: list[str]) -> None:
"""Delete semantic history entries for the given episode IDs.

Args:
episode_ids: IDs of episodes whose semantic history should be removed.

Returns:
None.

"""
semantic_service = await self._resources.get_semantic_service()
await semantic_service.delete_history(episode_ids)

async def delete_features(
self,
feature_ids: list[FeatureIdT],
Expand Down
6 changes: 0 additions & 6 deletions src/memmachine/semantic_memory/semantic_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,6 @@ async def add_message_to_sets(

_consolidate_errors_and_raise(res, "Failed to add message to sets")

async def delete_messages(self, *, set_ids: list[SetIdT]) -> None:
logger.info("Deleting messages from sets %s", set_ids)

await self._semantic_storage.delete_history_set(set_ids=set_ids)

async def number_of_uningested(self, set_ids: list[SetIdT]) -> int:
logger.debug("Getting number of uningested messages for set ids %s", set_ids)

Expand Down Expand Up @@ -537,7 +532,6 @@ async def delete_set_id(self, *, set_ids: list[SetIdT]) -> None:
logger.info("Deleting set ids %s", set_ids)

async with asyncio.TaskGroup() as tg:
tg.create_task(self._semantic_storage.delete_history_set(set_ids=set_ids))
tg.create_task(
self._semantic_storage.delete_feature_set(
filter_expr=_with_has_set_ids(set_ids=set_ids, filter_expr=None),
Comment thread
o-love marked this conversation as resolved.
Expand Down
26 changes: 0 additions & 26 deletions src/memmachine/semantic_memory/semantic_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,32 +144,6 @@ async def add_message(
for e in episodes:
tg.create_task(self._add_single_episode(e, session_data))

async def delete_all_project_messages(
self,
session_data: SessionData,
) -> None:
self._assert_session_data_implements_protocol(session_data=session_data)

set_ids = await self._get_all_set_ids(
org_id=session_data.org_id,
project_id=session_data.project_id,
)

await self._semantic_service.delete_messages(set_ids=list(set_ids))

async def delete_all_org_messages(
self,
session_data: SessionData,
) -> None:
self._assert_session_data_implements_protocol(session_data=session_data)

set_ids = await self._get_all_set_ids(
org_id=session_data.org_id,
project_id=None,
)

await self._semantic_service.delete_messages(set_ids=list(set_ids))

async def search(
self,
message: str,
Expand Down
33 changes: 16 additions & 17 deletions src/memmachine/semantic_memory/storage/neo4j_semantic_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,26 +696,25 @@ async def delete_history(self, history_ids: list[EpisodeIdT]) -> None:
if not history_ids:
return

history_ids_param = [str(history_id) for history_id in history_ids]
timestamp = _utc_timestamp()
await self._driver.execute_query(
"""
MATCH (h:SetHistory)
WHERE h.history_id IN $history_ids
DELETE h
""",
history_ids=[str(history_id) for history_id in history_ids],
)

async def delete_history_set(self, set_ids: list[SetIdT]) -> None:
if not set_ids:
return

await self._driver.execute_query(
"""
MATCH (h:SetHistory)
WHERE h.set_id IN $set_ids
DELETE h
WITH $history_ids AS history_ids, $ts AS ts
CALL {
WITH history_ids
MATCH (h:SetHistory)
WHERE h.history_id IN history_ids
DETACH DELETE h
}
WITH history_ids, ts
MATCH (f:Feature)
WHERE any(id IN f.citations WHERE id IN history_ids)
SET f.citations = [id IN f.citations WHERE NOT id IN history_ids],
f.updated_at_ts = ts
""",
set_ids=[str(set_id) for set_id in set_ids],
history_ids=history_ids_param,
ts=timestamp,
)

async def mark_messages_ingested(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,24 +487,16 @@ async def delete_history(self, history_ids: list[EpisodeIdT]) -> None:
if not history_ids:
return

stmt = delete(SetIngestedHistory).where(
SetIngestedHistory.history_id.in_(history_ids),
stmt_citations = delete(citation_association_table).where(
citation_association_table.c.history_id.in_(history_ids)
)

async with self._create_session() as session:
await session.execute(stmt)
await session.commit()

async def delete_history_set(self, set_ids: list[SetIdT]) -> None:
if not set_ids:
return

stmt = delete(SetIngestedHistory).where(
SetIngestedHistory.set_id.in_(set_ids),
stmt_history = delete(SetIngestedHistory).where(
SetIngestedHistory.history_id.in_(history_ids)
)

async with self._create_session() as session:
await session.execute(stmt)
await session.execute(stmt_citations)
await session.execute(stmt_history)
await session.commit()

def _apply_history_filter(
Expand Down
5 changes: 1 addition & 4 deletions src/memmachine/semantic_memory/storage/storage_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,7 @@ async def add_history_to_set(self, set_id: SetIdT, history_id: EpisodeIdT) -> No

@abstractmethod
async def delete_history(self, history_ids: list[EpisodeIdT]) -> None:
raise NotImplementedError

@abstractmethod
async def delete_history_set(self, set_ids: list[SetIdT]) -> None:
"""Delete history references and citations for the episode IDs."""
raise NotImplementedError

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def wrapped_store():
store.get_episode = AsyncMock()
store.get_episode_messages = AsyncMock()
store.get_episode_messages_count = AsyncMock()
store.get_episode_ids = AsyncMock()
store.delete_episodes = AsyncMock()
store.delete_episode_messages = AsyncMock()
return store
Expand Down Expand Up @@ -108,3 +109,22 @@ async def test_non_session_filters_bypass_cache(wrapped_store):
assert first == 7
assert second == 9
assert wrapped_store.get_episode_messages_count.await_count == 2


@pytest.mark.asyncio
async def test_get_episode_ids_passes_through(wrapped_store):
"""CountCachingEpisodeStorage.get_episode_ids delegates to the wrapped store."""
wrapped_store.get_episode_ids.return_value = ["1", "2", "3"]
storage = CountCachingEpisodeStorage(wrapped_store)

session_filter = Comparison(field="session_key", op="=", value="s1")
result = await storage.get_episode_ids(
filter_expr=session_filter,
page_size=10,
)

assert result == ["1", "2", "3"]
wrapped_store.get_episode_ids.assert_awaited_once_with(
filter_expr=session_filter,
page_size=10,
)
6 changes: 1 addition & 5 deletions tests/memmachine/main/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,7 @@ class _SessionData:
await memmachine._resources.get_semantic_session_manager()
)

await asyncio.gather(
semantic_session.delete_feature_set(session_data=s_data),
semantic_session.delete_all_project_messages(session_data=s_data),
)
await semantic_session.delete_feature_set(session_data=s_data)

await memmachine.create_session(s_data.session_key)

Expand All @@ -313,5 +310,4 @@ class _SessionData:
await asyncio.gather(
memmachine.delete_session(session_data=s_data),
semantic_session.delete_feature_set(session_data=s_data),
semantic_session.delete_all_project_messages(session_data=s_data),
)
Loading
Loading