- Add ADR 001 for Hybrid Search Architecture - Implement Phase 1 (Exact Match) and Phase 2 (Semantic Fallback) in ChromaStore - Wrap blocking ChromaDB calls in asyncio.to_thread - Update IVectorStore interface to support category filtering and thresholds - Add comprehensive tests for hybrid search logic
121 lines
4.1 KiB
Python
121 lines
4.1 KiB
Python
import pytest
|
||
import pytest_asyncio
|
||
from unittest.mock import MagicMock
|
||
from datetime import datetime, timezone
|
||
from src.storage.chroma_store import ChromaStore
|
||
from src.processor.dto import EnrichedNewsItemDTO
|
||
|
||
@pytest.fixture
|
||
def mock_chroma_client():
|
||
client = MagicMock()
|
||
collection = MagicMock()
|
||
client.get_or_create_collection.return_value = collection
|
||
return client, collection
|
||
|
||
@pytest_asyncio.fixture
|
||
async def chroma_store(mock_chroma_client):
|
||
client, collection = mock_chroma_client
|
||
store = ChromaStore(client=client, collection_name="test_collection")
|
||
return store
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_with_category_filter(chroma_store, mock_chroma_client):
|
||
client, collection = mock_chroma_client
|
||
|
||
# Mock return value for collection.query
|
||
collection.query.return_value = {
|
||
"ids": [["id1"]],
|
||
"metadatas": [[{
|
||
"title": "AI in Robotics",
|
||
"url": "https://example.com/robotics",
|
||
"source": "Tech",
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
"relevance_score": 8,
|
||
"summary_ru": "AI в робототехнике",
|
||
"category": "Robotics",
|
||
"anomalies_detected": ""
|
||
}]],
|
||
"documents": [["Full content here"]],
|
||
"distances": [[0.1]]
|
||
}
|
||
|
||
# We want to test that 'category' is passed as a 'where' clause to ChromaDB
|
||
# Note: We need to update the search method signature in the next step
|
||
results = await chroma_store.search(query="AI", limit=5, category="Robotics")
|
||
|
||
# Assert collection.query was called with correct 'where' filter
|
||
args, kwargs = collection.query.call_args
|
||
assert kwargs["where"] == {"category": "Robotics"}
|
||
assert len(results) == 1
|
||
assert results[0].category == "Robotics"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_with_relevance_threshold(chroma_store, mock_chroma_client):
|
||
client, collection = mock_chroma_client
|
||
|
||
# Mock return value: one relevant (low distance), one irrelevant (high distance)
|
||
collection.query.return_value = {
|
||
"ids": [["id-rel", "id-irrel"]],
|
||
"metadatas": [[
|
||
{
|
||
"title": "Relevant News",
|
||
"url": "url1",
|
||
"source": "s",
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
"relevance_score": 9,
|
||
"summary_ru": "Р",
|
||
"category": "C",
|
||
"anomalies_detected": ""
|
||
},
|
||
{
|
||
"title": "Irrelevant News",
|
||
"url": "url2",
|
||
"source": "s",
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
"relevance_score": 1,
|
||
"summary_ru": "И",
|
||
"category": "C",
|
||
"anomalies_detected": ""
|
||
}
|
||
]],
|
||
"documents": [["doc1", "doc2"]],
|
||
"distances": [[0.2, 0.8]] # Lower distance means more similar
|
||
}
|
||
|
||
# threshold=0.5 means distances <= 0.5 are kept
|
||
results = await chroma_store.search(query="test", limit=10, threshold=0.5)
|
||
|
||
assert len(results) == 1
|
||
assert results[0].title == "Relevant News"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_latest_semantic_threshold(chroma_store, mock_chroma_client):
|
||
"""
|
||
Test that /latest uses semantic search if a category is provided,
|
||
but also respects the threshold even for plain searches.
|
||
"""
|
||
client, collection = mock_chroma_client
|
||
|
||
collection.query.return_value = {
|
||
"ids": [["id1"]],
|
||
"metadatas": [[{
|
||
"title": "Latest News",
|
||
"url": "url",
|
||
"source": "s",
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
"relevance_score": 5,
|
||
"summary_ru": "L",
|
||
"category": "Tech",
|
||
"anomalies_detected": ""
|
||
}]],
|
||
"documents": [["doc"]],
|
||
"distances": [[0.05]]
|
||
}
|
||
|
||
# If category is provided, we should use category filter
|
||
results = await chroma_store.search(query="", limit=10, category="Tech")
|
||
|
||
args, kwargs = collection.query.call_args
|
||
assert kwargs["where"] == {"category": "Tech"}
|
||
assert len(results) == 1
|