AI-Trend-Scout/tests/storage/test_semantic_search.py
Artur Mukhamadiev 65fccbc614 feat(storage): implement hybrid search and fix async chroma i/o
- Add ADR 001 for Hybrid Search Architecture
- Implement Phase 1 (Exact Match) and Phase 2 (Semantic Fallback) in ChromaStore
- Wrap blocking ChromaDB calls in asyncio.to_thread
- Update IVectorStore interface to support category filtering and thresholds
- Add comprehensive tests for hybrid search logic
2026-03-16 00:11:07 +03:00

121 lines
4.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pytest
import pytest_asyncio
from unittest.mock import MagicMock
from datetime import datetime, timezone
from src.storage.chroma_store import ChromaStore
from src.processor.dto import EnrichedNewsItemDTO
@pytest.fixture
def mock_chroma_client():
client = MagicMock()
collection = MagicMock()
client.get_or_create_collection.return_value = collection
return client, collection
@pytest_asyncio.fixture
async def chroma_store(mock_chroma_client):
client, collection = mock_chroma_client
store = ChromaStore(client=client, collection_name="test_collection")
return store
@pytest.mark.asyncio
async def test_search_with_category_filter(chroma_store, mock_chroma_client):
client, collection = mock_chroma_client
# Mock return value for collection.query
collection.query.return_value = {
"ids": [["id1"]],
"metadatas": [[{
"title": "AI in Robotics",
"url": "https://example.com/robotics",
"source": "Tech",
"timestamp": datetime.now(timezone.utc).isoformat(),
"relevance_score": 8,
"summary_ru": "AI в робототехнике",
"category": "Robotics",
"anomalies_detected": ""
}]],
"documents": [["Full content here"]],
"distances": [[0.1]]
}
# We want to test that 'category' is passed as a 'where' clause to ChromaDB
# Note: We need to update the search method signature in the next step
results = await chroma_store.search(query="AI", limit=5, category="Robotics")
# Assert collection.query was called with correct 'where' filter
args, kwargs = collection.query.call_args
assert kwargs["where"] == {"category": "Robotics"}
assert len(results) == 1
assert results[0].category == "Robotics"
@pytest.mark.asyncio
async def test_search_with_relevance_threshold(chroma_store, mock_chroma_client):
client, collection = mock_chroma_client
# Mock return value: one relevant (low distance), one irrelevant (high distance)
collection.query.return_value = {
"ids": [["id-rel", "id-irrel"]],
"metadatas": [[
{
"title": "Relevant News",
"url": "url1",
"source": "s",
"timestamp": datetime.now(timezone.utc).isoformat(),
"relevance_score": 9,
"summary_ru": "Р",
"category": "C",
"anomalies_detected": ""
},
{
"title": "Irrelevant News",
"url": "url2",
"source": "s",
"timestamp": datetime.now(timezone.utc).isoformat(),
"relevance_score": 1,
"summary_ru": "И",
"category": "C",
"anomalies_detected": ""
}
]],
"documents": [["doc1", "doc2"]],
"distances": [[0.2, 0.8]] # Lower distance means more similar
}
# threshold=0.5 means distances <= 0.5 are kept
results = await chroma_store.search(query="test", limit=10, threshold=0.5)
assert len(results) == 1
assert results[0].title == "Relevant News"
@pytest.mark.asyncio
async def test_get_latest_semantic_threshold(chroma_store, mock_chroma_client):
"""
Test that /latest uses semantic search if a category is provided,
but also respects the threshold even for plain searches.
"""
client, collection = mock_chroma_client
collection.query.return_value = {
"ids": [["id1"]],
"metadatas": [[{
"title": "Latest News",
"url": "url",
"source": "s",
"timestamp": datetime.now(timezone.utc).isoformat(),
"relevance_score": 5,
"summary_ru": "L",
"category": "Tech",
"anomalies_detected": ""
}]],
"documents": [["doc"]],
"distances": [[0.05]]
}
# If category is provided, we should use category filter
results = await chroma_store.search(query="", limit=10, category="Tech")
args, kwargs = collection.query.call_args
assert kwargs["where"] == {"category": "Tech"}
assert len(results) == 1