Skip to content

Commit 7582233

Browse files
committed
Context fields may expose sensitive information if used in filtering
Signed-off-by: Edwin Yu <edwinyyyu@gmail.com>
1 parent b4c5ddc commit 7582233

2 files changed

Lines changed: 24 additions & 25 deletions

File tree

packages/server/server_tests/memmachine_server/episodic_memory/event_memory/segment_store/test_sqlalchemy_segment_store.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@ async def test_add_segments_with_message_context(
155155
result = await partition.get_segment_contexts([seg.uuid])
156156
assert result[seg.uuid][0].context == ctx
157157

158+
async with partition._create_session() as session:
159+
row = (
160+
await session.execute(select(SegmentRow).where(SegmentRow.uuid == seg.uuid))
161+
).scalar_one()
162+
assert row.context == {"type": "message", "source": "User"}
163+
158164

159165
@pytest.mark.asyncio
160166
async def test_add_segments_with_citation_context(
@@ -386,7 +392,7 @@ async def test_contexts_property_filter(
386392
async def test_contexts_filter_by_context_source(
387393
partition: SQLAlchemySegmentStorePartition,
388394
) -> None:
389-
"""Filter using ``context.source`` resolves via the type-tagged JSON path."""
395+
"""Filter using ``context.source`` is not supported."""
390396
ep = uuid4()
391397
s0 = _seg(
392398
event_uuid=ep,
@@ -409,22 +415,20 @@ async def test_contexts_filter_by_context_source(
409415
await partition.add_segments(_links(s0, s1, s2))
410416

411417
filt = Comparison(field="context.source", op="=", value="Alice")
412-
result = await partition.get_segment_contexts(
413-
[s0.uuid],
414-
max_backward_segments=5,
415-
max_forward_segments=5,
416-
property_filter=filt,
417-
)
418-
ctx = result[s0.uuid]
419-
# s1 excluded (source=Bob); s0 seed, s2 forward.
420-
assert [s.uuid for s in ctx] == [s0.uuid, s2.uuid]
418+
with pytest.raises(ValueError, match="Unknown filter field"):
419+
await partition.get_segment_contexts(
420+
[s0.uuid],
421+
max_backward_segments=5,
422+
max_forward_segments=5,
423+
property_filter=filt,
424+
)
421425

422426

423427
@pytest.mark.asyncio
424428
async def test_contexts_filter_by_context_type(
425429
partition: SQLAlchemySegmentStorePartition,
426430
) -> None:
427-
"""Filter using ``context.type`` discriminates between Context subtypes."""
431+
"""Filter using ``context.type`` is not supported."""
428432
ep = uuid4()
429433
s0 = _seg(
430434
event_uuid=ep,
@@ -447,15 +451,13 @@ async def test_contexts_filter_by_context_type(
447451
await partition.add_segments(_links(s0, s1, s2))
448452

449453
filt = Comparison(field="context.type", op="=", value="message")
450-
result = await partition.get_segment_contexts(
451-
[s0.uuid],
452-
max_backward_segments=5,
453-
max_forward_segments=5,
454-
property_filter=filt,
455-
)
456-
ctx = result[s0.uuid]
457-
# s1 excluded (type=citation); s0 seed, s2 forward.
458-
assert [s.uuid for s in ctx] == [s0.uuid, s2.uuid]
454+
with pytest.raises(ValueError, match="Unknown filter field"):
455+
await partition.get_segment_contexts(
456+
[s0.uuid],
457+
max_backward_segments=5,
458+
max_forward_segments=5,
459+
property_filter=filt,
460+
)
459461

460462

461463
@pytest.mark.asyncio

packages/server/src/memmachine_server/episodic_memory/event_memory/segment_store/sqlalchemy_segment_store.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ async def _insert_segments(
233233
"timestamp": ensure_tz_aware(segment.timestamp),
234234
"timestamp_timezone_offset": utc_offset_seconds(segment.timestamp),
235235
"context": (
236-
encode_properties(segment.context.model_dump(exclude_none=True))
236+
segment.context.model_dump(mode="json")
237237
if segment.context is not None
238238
else None
239239
),
@@ -650,9 +650,6 @@ def _resolve_segment_field(
650650
"""Map a filter field name to a segment column and encoding."""
651651
if field == "timestamp":
652652
return SegmentRow.timestamp.expression, "column"
653-
if field.startswith("context."):
654-
key = field.removeprefix("context.")
655-
return SegmentRow.context[key], "properties_json"
656653
internal_name, is_user_metadata = normalize_filter_field(field)
657654
if is_user_metadata:
658655
key = demangle_user_metadata_key(internal_name)
@@ -665,7 +662,7 @@ def _segment_from_segment_row(
665662
) -> Segment:
666663
"""Convert a SegmentRow into a Segment."""
667664
context = (
668-
_ContextAdapter.validate_python(decode_properties(row.context))
665+
_ContextAdapter.validate_python(row.context)
669666
if row.context is not None
670667
else None
671668
)

0 commit comments

Comments
 (0)