feat(storage): implement hybrid search and fix async chroma i/o
- Add ADR 001 for Hybrid Search Architecture - Implement Phase 1 (Exact Match) and Phase 2 (Semantic Fallback) in ChromaStore - Wrap blocking ChromaDB calls in asyncio.to_thread - Update IVectorStore interface to support category filtering and thresholds - Add comprehensive tests for hybrid search logic
This commit is contained in:
parent
217037f72e
commit
65fccbc614
112
ai/memory-bank/tasks/trend-scout-ai-search-fix-tasklist.md
Normal file
112
ai/memory-bank/tasks/trend-scout-ai-search-fix-tasklist.md
Normal file
@ -0,0 +1,112 @@
|
||||
# Trend-Scout AI Semantic Search Fix Development Tasks
|
||||
|
||||
## Specification Summary
|
||||
**Original Requirements**:
|
||||
- Fix semantic search (ChromaDB) issues: `/latest` command ignores category filters.
|
||||
- Fix `/search` command returning semantically irrelevant text.
|
||||
- Follow TDD (Test-Driven Development), SOLID principles, and use `asyncio`.
|
||||
**Technical Stack**: Python, asyncio, pytest, ChromaDB (Vector Storage), aiogram (Telegram Bot), Ollama (Embeddings/LLM).
|
||||
**Target Timeline**: ~1 Development Day (6-8 hours)
|
||||
|
||||
## Execution Plan & Estimation
|
||||
|
||||
### Phase 1: Architectural Review & Setup (1 Hour)
|
||||
Review the existing `IVectorStore` interface. Ensure that the interface supports passing metadata filters (for categories) and distance thresholds (for semantic relevance). Evaluate the embedding model being used, ensuring it supports Russian context effectively, as the AI Processor outputs summaries in Russian (`summary_ru`).
|
||||
|
||||
### Phase 2: TDD & Test Creation (2 Hours)
|
||||
Strict adherence to TDD. Before touching the ChromaDB implementation, write failing `pytest` cases that mock the database and test that queries with category filters and relevance score thresholds return the expected subsets of data.
|
||||
|
||||
### Phase 3: ChromaDB Query Tuning & Implementation (2 Hours)
|
||||
Implement the actual fixes in the ChromaDB wrapper. Map category arguments to ChromaDB's `where` metadata filter. Adjust the vector space distance metric (e.g., switching to `cosine` similarity via `hnsw:space`) and enforce a maximum distance threshold to drop irrelevant results.
|
||||
|
||||
### Phase 4: Bot Integration & E2E Testing (1-2 Hours)
|
||||
Update the Telegram bot (`aiogram` handlers) to correctly extract category arguments from the `/latest` command and pass them to the Vector Storage Agent. Handle cases where `/search` returns no results due to the new relevance thresholds.
|
||||
|
||||
---
|
||||
|
||||
## Development Tasks
|
||||
|
||||
### [ ] Task 1: Architecture Review & Interface Update
|
||||
**Description**: Update the `IVectorStore` interface to explicitly support metadata filtering and similarity thresholds.
|
||||
**Acceptance Criteria**:
|
||||
- `IVectorStore.search()` method signature accepts `filters: dict` and `max_distance: float` (or `min_relevance: float`).
|
||||
- Existing mock classes/stubs are updated to match the new interface.
|
||||
**Files to Create/Edit**:
|
||||
- `src/vector_storage/interfaces.py`
|
||||
**Estimation**: 30-45 minutes
|
||||
**Reference**: Phase 1: Architectural Review
|
||||
|
||||
### [ ] Task 2: TDD Setup for Metadata Filtering
|
||||
**Description**: Write failing tests for the Vector Storage Agent to verify category filtering.
|
||||
**Acceptance Criteria**:
|
||||
- Test verifies that calling `search()` with `filters={"category": "AI"}` only returns records with that exact metadata category.
|
||||
- Test verifies that omitting the filter returns all categories.
|
||||
- Tests must fail initially (Red phase of TDD).
|
||||
**Files to Create/Edit**:
|
||||
- `tests/vector_storage/test_chroma_filters.py`
|
||||
**Estimation**: 45 minutes
|
||||
**Reference**: Phase 2: TDD & Test Creation
|
||||
|
||||
### [ ] Task 3: TDD Setup for Semantic Relevance
|
||||
**Description**: Write failing tests to verify that semantically irrelevant results are dropped based on distance/score thresholds.
|
||||
**Acceptance Criteria**:
|
||||
- Test inserts "apple", "banana", and "quantum computing". Searching for "fruit" with a strict threshold should return "apple" and "banana" but exclude "quantum computing".
|
||||
- Tests must fail initially.
|
||||
**Files to Create/Edit**:
|
||||
- `tests/vector_storage/test_chroma_relevance.py`
|
||||
**Estimation**: 45 minutes
|
||||
**Reference**: Phase 2: TDD & Test Creation
|
||||
|
||||
### [ ] Task 4: Implement ChromaDB Metadata Filtering
|
||||
**Description**: Fix the ChromaDB implementation to pass the `filters` dictionary into the `where` parameter of the ChromaDB `query` method.
|
||||
**Acceptance Criteria**:
|
||||
- The `tests/vector_storage/test_chroma_filters.py` tests pass (Green phase).
|
||||
- Empty filters gracefully fall back to querying without the `where` clause.
|
||||
**Files to Create/Edit**:
|
||||
- `src/vector_storage/chroma_store.py`
|
||||
**Estimation**: 30-45 minutes
|
||||
**Reference**: Phase 3: ChromaDB Query Tuning
|
||||
|
||||
### [ ] Task 5: Tune Embeddings & Distance Thresholds
|
||||
**Description**: Fix the ChromaDB implementation to respect `max_distance`. Ensure the collection is initialized with `hnsw:space` set to `cosine` (if applicable) for better semantic separation.
|
||||
**Acceptance Criteria**:
|
||||
- The ChromaDB `query` method filters out results where the returned `distances` exceed the `max_distance` threshold.
|
||||
- The `tests/vector_storage/test_chroma_relevance.py` tests pass.
|
||||
**Files to Create/Edit**:
|
||||
- `src/vector_storage/chroma_store.py`
|
||||
**Estimation**: 60 minutes
|
||||
**Reference**: Phase 3: ChromaDB Query Tuning
|
||||
|
||||
### [ ] Task 6: Update Telegram Bot `/latest` Handler
|
||||
**Description**: Fix the `/latest` command in the bot to parse category arguments and pass them to the Vector Store.
|
||||
**Acceptance Criteria**:
|
||||
- Command `/latest AI` successfully parses "AI" and calls `vector_store.search(filters={"category": "AI"})`.
|
||||
- Command `/latest` defaults to no filters.
|
||||
- Unit tests for the aiogram handler pass.
|
||||
**Files to Create/Edit**:
|
||||
- `src/bot/handlers/commands.py`
|
||||
- `tests/bot/test_handlers.py`
|
||||
**Estimation**: 45 minutes
|
||||
**Reference**: Phase 4: Bot Integration
|
||||
|
||||
### [ ] Task 7: Update Telegram Bot `/search` Handler
|
||||
**Description**: Fix the `/search` command to utilize the new semantic relevance threshold and handle empty results gracefully.
|
||||
**Acceptance Criteria**:
|
||||
- Command `/search [query]` calls `vector_store.search()` with an optimal `max_distance` threshold.
|
||||
- If no results meet the threshold, the bot replies politely: "No highly relevant news found for your query." instead of showing garbage data.
|
||||
**Files to Create/Edit**:
|
||||
- `src/bot/handlers/commands.py`
|
||||
- `tests/bot/test_handlers.py`
|
||||
**Estimation**: 45 minutes
|
||||
**Reference**: Phase 4: Bot Integration
|
||||
|
||||
## Quality Requirements
|
||||
- [ ] 100% of new code must have `pytest` coverage.
|
||||
- [ ] No blocking I/O calls; all ChromaDB and Telegram API interactions must use `asyncio` or run in executors if synchronous.
|
||||
- [ ] Follow SOLID: Do not tightly couple the bot handlers directly to the ChromaDB client; route through `IVectorStore`.
|
||||
- [ ] Ensure the embedding model used for Russian text (`summary_ru`) is correctly configured in the Vector Storage initialization.
|
||||
|
||||
## Technical Notes
|
||||
**Development Stack**: Python, aiogram, ChromaDB, pytest, asyncio.
|
||||
**Special Instructions**: ChromaDB's default distance function is `l2` (Squared L2). When comparing textual embeddings, `cosine` similarity is often much better at separating irrelevant text. Check the ChromaDB collection creation code to ensure `metadata={"hnsw:space": "cosine"}` is set. If changing this, the ChromaDB collection may need to be recreated/reindexed.
|
||||
**Timeline Expectations**: ~5.5 to 7 hours.
|
||||
71
docs/ADR_001_Hybrid_Search_Architecture.md
Normal file
71
docs/ADR_001_Hybrid_Search_Architecture.md
Normal file
@ -0,0 +1,71 @@
|
||||
# ADR 001: Architecture Design for Enhanced Semantic & Hybrid Search
|
||||
|
||||
## 1. Context and Problem Statement
|
||||
The "Trend-Scout AI" bot currently utilizes a basic synchronous implementation of ChromaDB to fulfill both categorical retrieval (`/latest`) and free-text queries (`/search`). Two major issues have severely impacted the user experience:
|
||||
1. **Incorrect Categories in `/latest`**: The system performs a dense vector search using the requested category name (e.g., "AI") rather than a deterministic exact match. This returns semantically related news regardless of their actual assigned category, yielding false positives.
|
||||
2. **Poor Semantic Matches in `/search`**:
|
||||
- The default English-centric embedding model (e.g., `all-MiniLM-L6-v2`) handles Russian summaries and specialized technical acronyms poorly.
|
||||
- Pure vector search ignores exact keyword matches, leading to frustrated user expectations when searching for specific entities (e.g., "OpenAI o1" or specific version numbers).
|
||||
3. **Blocking I/O operations**: The `ChromaStore` executes blocking synchronous operations within `async def` wrappers, potentially starving the `asyncio` event loop and violating asynchronous data flow requirements.
|
||||
|
||||
## 2. Decision Drivers
|
||||
* **Accuracy & Relevance**: Strict categorization and high recall for exact keywords + conceptual similarity.
|
||||
* **Multilingual Support**: Strong performance on both English source texts and Russian summaries.
|
||||
* **Performance & Concurrency**: Fully non-blocking (async) operations.
|
||||
* **Adherence to SOLID**: Maintain strict interface boundaries, dependency inversion, and existing Domain Transfer Objects (DTOs).
|
||||
* **Alignment with Agent Architecture**: Ensure the Vector Storage Agent focuses strictly on storage/retrieval coordination without leaking AI processing duties.
|
||||
|
||||
## 3. Proposed Architecture
|
||||
|
||||
### 3.1. Asynchronous Data Flow (I/O)
|
||||
* **Decision**: Migrate the local ChromaDB calls to run in a thread pool executor. Alternatively, if ChromaDB is hosted as a standalone server, utilize `chromadb.AsyncHttpClient`.
|
||||
* **Implementation**: Encapsulate blocking calls like `self.collection.upsert()` and `self.collection.query()` inside `asyncio.to_thread()` to prevent blocking the Telegram bot's main event loop.
|
||||
|
||||
### 3.2. Interface Segregation (ISP) for Storage
|
||||
The current `IVectorStore` interface conflates generic vector searching, exact categorical retrieval, and database administration.
|
||||
* **Action**: Segregate the interfaces to adhere to ISP.
|
||||
* **Refactored Interfaces**:
|
||||
```python
|
||||
class IStoreCommand(ABC):
|
||||
@abstractmethod
|
||||
async def store(self, item: EnrichedNewsItemDTO) -> None: ...
|
||||
|
||||
class IStoreQuery(ABC):
|
||||
@abstractmethod
|
||||
async def search_hybrid(self, query: str, limit: int = 5) -> List[EnrichedNewsItemDTO]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_latest_by_category(self, category: Optional[str], limit: int = 10) -> List[EnrichedNewsItemDTO]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_top_ranked(self, limit: int = 10) -> List[EnrichedNewsItemDTO]: ...
|
||||
```
|
||||
|
||||
### 3.3. Strict Metadata Filtering for `/latest`
|
||||
* **Mechanism**: The `/latest` command must completely bypass vector similarity search. Instead, it will use ChromaDB's `.get()` method coupled with a strict `where` metadata filter: `where={"category": {"$eq": category}}`.
|
||||
* **Sorting Architecture**: Because ChromaDB does not natively support sorting results by a metadata field (like `timestamp`), the `get_latest_by_category` method will over-fetch (e.g., fetch up to 100 recent items using the metadata filter) and perform a fast, deterministic in-memory sort by `timestamp` descending before slicing to the requested `limit`.
|
||||
|
||||
### 3.4. Hybrid Search Architecture (Keyword + Vector)
|
||||
* **Mechanism**: Implement a Hybrid Search Strategy utilizing **Reciprocal Rank Fusion (RRF)**.
|
||||
* **Sparse Retrieval (Keyword)**: Integrate a lightweight keyword index alongside ChromaDB. Given the bot's scale, **SQLite FTS5 (Full-Text Search)** is the optimal choice. It provides persistent, fast token matching without the overhead of Elasticsearch.
|
||||
* **Dense Retrieval (Vector)**: ChromaDB semantic search.
|
||||
* **Fusion Strategy**:
|
||||
1. The new `HybridSearchStrategy` issues queries to both the SQLite FTS index and ChromaDB concurrently using `asyncio.gather`.
|
||||
2. The results are normalized using the RRF formula: `Score = 1 / (k + rank_sparse) + 1 / (k + rank_dense)` (where `k` is typically 60).
|
||||
3. The combined list of DTOs is sorted by the fused score and returned.
|
||||
|
||||
### 3.5. Embedding Model Evaluation & Upgrade
|
||||
* **Decision**: Replace the default ChromaDB embedding function with a dedicated, explicitly configured multilingual model.
|
||||
* **Recommendation**: Utilize `intfloat/multilingual-e5-small` (for lightweight CPU environments) or `sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2`. Both provide excellent English-Russian cross-lingual semantic alignment.
|
||||
* **Integration (DIP)**: Apply the Dependency Inversion Principle by injecting the embedding function (or an `IEmbeddingProvider` interface) into the `ChromaStore` constructor. This allows for seamless A/B testing of embedding models without touching the core storage logic.
|
||||
|
||||
## 4. Application to the Agent Architecture
|
||||
* **Vector Storage Agent (Database)**: This agent's responsibility shifts from "pure vector storage" to "Hybrid Storage Management." It coordinates the `ChromaStore` (Dense) and `SQLiteStore` (Sparse) implementations.
|
||||
* **AI Processor Agent**: To maintain Single Responsibility (SRP), embedding generation can be shifted from the storage layer to the AI Processor Agent. The AI Processor generates the vector using an Ollama hosted embedding model and attaches it directly to the `EnrichedNewsItemDTO`. The Storage Agent simply stores the pre-calculated vector, drastically reducing the dependency weight of the storage module.
|
||||
|
||||
## 5. Next Steps for Implementation
|
||||
1. Add `sqlite3` FTS5 table initialization to the project scaffolding.
|
||||
2. Refactor `src/storage/base.py` to segregate `IStoreQuery` and `IStoreCommand`.
|
||||
3. Update `ChromaStore` to accept pre-calculated embeddings and utilize `asyncio.to_thread`.
|
||||
4. Implement the RRF sorting algorithm in a new `search_hybrid` pipeline.
|
||||
5. Update `src/bot/handlers.py` to route `/latest` through `get_latest_by_category`.
|
||||
11
requirements.txt
Normal file
11
requirements.txt
Normal file
@ -0,0 +1,11 @@
|
||||
beautifulsoup4
|
||||
aiohttp
|
||||
aiogram
|
||||
chromadb
|
||||
playwright
|
||||
playwright-stealth
|
||||
pydantic
|
||||
pytest
|
||||
pytest-asyncio
|
||||
python-dotenv
|
||||
PyYAML
|
||||
@ -77,8 +77,8 @@ def get_router(storage: IVectorStore, processor: ILLMProvider, allowed_chat_id:
|
||||
"""
|
||||
This handler receives messages with `/latest` command
|
||||
"""
|
||||
category = command.args if command.args else ""
|
||||
items = await storage.search(query=category, limit=10)
|
||||
category = command.args.strip() if command.args and command.args.strip() else None
|
||||
items = await storage.get_latest(limit=10, category=category)
|
||||
|
||||
if not items:
|
||||
await message.answer("No results found.")
|
||||
@ -100,17 +100,31 @@ def get_router(storage: IVectorStore, processor: ILLMProvider, allowed_chat_id:
|
||||
This handler receives messages with `/hottest` command
|
||||
"""
|
||||
limit = 10
|
||||
if command.args:
|
||||
try:
|
||||
limit = int(command.args)
|
||||
if limit <= 0:
|
||||
limit = 10
|
||||
elif limit > 50:
|
||||
limit = 50
|
||||
except ValueError:
|
||||
pass
|
||||
category = None
|
||||
|
||||
items = await storage.get_top_ranked(limit=limit)
|
||||
if command.args and command.args.strip():
|
||||
parts = command.args.strip().split()
|
||||
if len(parts) == 1:
|
||||
if parts[0].isdigit():
|
||||
limit = int(parts[0])
|
||||
else:
|
||||
category = parts[0]
|
||||
else:
|
||||
if parts[-1].isdigit():
|
||||
limit = int(parts[-1])
|
||||
category = " ".join(parts[:-1])
|
||||
elif parts[0].isdigit():
|
||||
limit = int(parts[0])
|
||||
category = " ".join(parts[1:])
|
||||
else:
|
||||
category = command.args.strip()
|
||||
|
||||
if limit <= 0:
|
||||
limit = 10
|
||||
elif limit > 50:
|
||||
limit = 50
|
||||
|
||||
items = await storage.get_top_ranked(limit=limit, category=category)
|
||||
|
||||
if not items:
|
||||
await message.answer("No hot trends found yet.")
|
||||
@ -131,12 +145,13 @@ def get_router(storage: IVectorStore, processor: ILLMProvider, allowed_chat_id:
|
||||
"""
|
||||
This handler receives messages with `/search` command
|
||||
"""
|
||||
query = command.args
|
||||
query = command.args.strip() if command.args and command.args.strip() else None
|
||||
if not query:
|
||||
await message.answer("Please provide a search query. Usage: /search query")
|
||||
return
|
||||
|
||||
items = await storage.search(query=query, limit=10)
|
||||
# Use a threshold to filter out low-relevance results for semantic search
|
||||
items = await storage.search(query=query, limit=10, threshold=0.6)
|
||||
|
||||
if not items:
|
||||
await message.answer("No results found.")
|
||||
@ -157,6 +172,8 @@ def get_router(storage: IVectorStore, processor: ILLMProvider, allowed_chat_id:
|
||||
"""
|
||||
This handler receives callback queries for news details
|
||||
"""
|
||||
if not callback.data:
|
||||
return
|
||||
item_id = callback.data.split(":")[1]
|
||||
item = await storage.get_by_id(item_id)
|
||||
|
||||
@ -186,7 +203,8 @@ def get_router(storage: IVectorStore, processor: ILLMProvider, allowed_chat_id:
|
||||
|
||||
response_text += f"<a href='{url}'>Read more</a>"
|
||||
|
||||
await callback.message.answer(response_text, parse_mode="HTML", disable_web_page_preview=False)
|
||||
if isinstance(callback.message, Message):
|
||||
await callback.message.answer(response_text, parse_mode="HTML", disable_web_page_preview=False)
|
||||
await callback.answer()
|
||||
|
||||
@router.message(Command("stats"))
|
||||
|
||||
@ -16,7 +16,7 @@ class IVectorStore(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, query: str, limit: int = 5) -> List[EnrichedNewsItemDTO]:
|
||||
async def search(self, query: str, limit: int = 5, category: Optional[str] = None, threshold: Optional[float] = None) -> List[EnrichedNewsItemDTO]:
|
||||
"""Search for items in the vector database."""
|
||||
pass
|
||||
|
||||
@ -31,6 +31,11 @@ class IVectorStore(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_top_ranked(self, limit: int = 10) -> List[EnrichedNewsItemDTO]:
|
||||
"""Retrieve top ranked items by relevance score."""
|
||||
async def get_latest(self, limit: int = 10, category: Optional[str] = None) -> List[EnrichedNewsItemDTO]:
|
||||
"""Retrieve latest items chronologically, optionally filtered by category."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_top_ranked(self, limit: int = 10, category: Optional[str] = None) -> List[EnrichedNewsItemDTO]:
|
||||
"""Retrieve top ranked items by relevance score, optionally filtered by category."""
|
||||
pass
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import uuid
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Optional, Mapping, Any
|
||||
from datetime import datetime
|
||||
|
||||
@ -8,6 +10,8 @@ from chromadb.api import ClientAPI
|
||||
from src.storage.base import IVectorStore
|
||||
from src.processor.dto import EnrichedNewsItemDTO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ChromaStore(IVectorStore):
|
||||
def __init__(self, client: ClientAPI, collection_name: str = "news_collection"):
|
||||
self.client = client
|
||||
@ -29,14 +33,15 @@ class ChromaStore(IVectorStore):
|
||||
"anomalies_detected": ",".join(item.anomalies_detected) if item.anomalies_detected else ""
|
||||
}
|
||||
|
||||
self.collection.upsert(
|
||||
await asyncio.to_thread(
|
||||
self.collection.upsert,
|
||||
ids=[doc_id],
|
||||
documents=[item.content_text],
|
||||
metadatas=[metadata]
|
||||
)
|
||||
|
||||
async def get_by_id(self, item_id: str) -> Optional[EnrichedNewsItemDTO]:
|
||||
results = self.collection.get(ids=[item_id])
|
||||
results = await asyncio.to_thread(self.collection.get, ids=[item_id])
|
||||
|
||||
metadatas = results.get('metadatas')
|
||||
if not metadatas or not metadatas[0]:
|
||||
@ -47,31 +52,74 @@ class ChromaStore(IVectorStore):
|
||||
|
||||
return self._reconstruct_dto(metadatas[0], document)
|
||||
|
||||
async def search(self, query: str, limit: int = 5) -> List[EnrichedNewsItemDTO]:
|
||||
results = self.collection.query(
|
||||
query_texts=[query],
|
||||
n_results=limit
|
||||
)
|
||||
async def search(self, query: str, limit: int = 5, category: Optional[str] = None, threshold: Optional[float] = None) -> List[EnrichedNewsItemDTO]:
|
||||
where: Any = {}
|
||||
if category:
|
||||
where["category"] = category
|
||||
|
||||
items = []
|
||||
# Check if we have results
|
||||
metadatas = results.get('metadatas')
|
||||
if not metadatas or not metadatas[0]:
|
||||
return items
|
||||
seen_urls = set()
|
||||
|
||||
documents = results.get('documents')
|
||||
# Phase 1: Try exact match
|
||||
if query:
|
||||
try:
|
||||
keyword_results = await asyncio.to_thread(
|
||||
self.collection.get,
|
||||
where_document={"$contains": query},
|
||||
where=where if where else None,
|
||||
limit=limit,
|
||||
include=["metadatas", "documents"]
|
||||
)
|
||||
|
||||
kw_metadatas = keyword_results.get('metadatas') or []
|
||||
kw_documents = keyword_results.get('documents') or []
|
||||
for meta, doc in zip(kw_metadatas, kw_documents):
|
||||
if meta:
|
||||
dto = self._reconstruct_dto(meta, doc)
|
||||
items.append(dto)
|
||||
seen_urls.add(dto.url)
|
||||
except Exception as e:
|
||||
logger.warning(f"Phase 1 keyword search failed: {e}")
|
||||
|
||||
for idx, metadata in enumerate(metadatas[0]):
|
||||
if metadata is None:
|
||||
continue
|
||||
# Only proceed to Phase 2 if we need more items
|
||||
if len(items) < limit:
|
||||
try:
|
||||
semantic_results = await asyncio.to_thread(
|
||||
self.collection.query,
|
||||
query_texts=[query] if query else ["*"],
|
||||
n_results=limit,
|
||||
where=where if where else None
|
||||
)
|
||||
|
||||
document = documents[0][idx] if documents and documents[0] else ""
|
||||
items.append(self._reconstruct_dto(metadata, document))
|
||||
|
||||
# Sort items by relevance_score in descending order
|
||||
items.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
|
||||
return items
|
||||
metadatas = semantic_results.get('metadatas')
|
||||
if metadatas and metadatas[0]:
|
||||
documents = semantic_results.get('documents')
|
||||
distances = semantic_results.get('distances')
|
||||
|
||||
for idx, metadata in enumerate(metadatas[0]):
|
||||
if metadata is None:
|
||||
continue
|
||||
|
||||
# Distance filtering (semantic threshold)
|
||||
if threshold is not None and distances and distances[0]:
|
||||
distance = distances[0][idx]
|
||||
if distance > threshold:
|
||||
continue
|
||||
|
||||
document = documents[0][idx] if documents and documents[0] else ""
|
||||
dto = self._reconstruct_dto(metadata, document)
|
||||
|
||||
if dto.url not in seen_urls:
|
||||
items.append(dto)
|
||||
seen_urls.add(dto.url)
|
||||
if len(items) >= limit:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Phase 2 semantic search failed: {e}")
|
||||
|
||||
# Note: Do not re-sort by relevance_score here, as we want Exact Matches first,
|
||||
# then Semantic Matches sorted by distance (which ChromaDB already returns).
|
||||
return items[:limit]
|
||||
|
||||
def _reconstruct_dto(self, metadata: Mapping[str, Any], document: str) -> EnrichedNewsItemDTO:
|
||||
anomalies_str = str(metadata.get("anomalies_detected", ""))
|
||||
@ -83,7 +131,7 @@ class ChromaStore(IVectorStore):
|
||||
url=str(metadata.get("url", "")),
|
||||
content_text=str(document),
|
||||
source=str(metadata.get("source", "")),
|
||||
timestamp=datetime.fromisoformat(str(metadata.get("timestamp", ""))),
|
||||
timestamp=datetime.fromisoformat(str(metadata['timestamp'])),
|
||||
relevance_score=int(float(str(metadata.get("relevance_score", 0)))),
|
||||
summary_ru=str(metadata.get("summary_ru", "")),
|
||||
category=str(metadata.get("category", "")),
|
||||
@ -92,12 +140,11 @@ class ChromaStore(IVectorStore):
|
||||
|
||||
async def exists(self, url: str) -> bool:
|
||||
doc_id = str(uuid.uuid5(uuid.NAMESPACE_URL, url))
|
||||
result = self.collection.get(ids=[doc_id])
|
||||
result = await asyncio.to_thread(self.collection.get, ids=[doc_id])
|
||||
return len(result.get("ids", [])) > 0
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
# Retrieve all metadatas to calculate stats
|
||||
results = self.collection.get(include=["metadatas"])
|
||||
results = await asyncio.to_thread(self.collection.get, include=["metadatas"])
|
||||
metadatas = results.get("metadatas")
|
||||
if metadatas is None:
|
||||
metadatas = []
|
||||
@ -107,26 +154,56 @@ class ChromaStore(IVectorStore):
|
||||
}
|
||||
|
||||
for meta in metadatas:
|
||||
if meta:
|
||||
if meta is not None:
|
||||
# meta is a dict, but might not have 'category'
|
||||
category = str(meta.get("category", "Uncategorized"))
|
||||
key = f"category_{category}"
|
||||
stats[key] = stats.get(key, 0) + 1
|
||||
|
||||
return stats
|
||||
|
||||
async def get_top_ranked(self, limit: int = 10) -> List[EnrichedNewsItemDTO]:
|
||||
"""Retrieve top ranked items by relevance score."""
|
||||
# Retrieve all metadatas and documents to sort by relevance score
|
||||
results = self.collection.get(include=["metadatas", "documents"])
|
||||
async def get_latest(self, limit: int = 10, category: Optional[str] = None) -> List[EnrichedNewsItemDTO]:
|
||||
where: Any = {"category": category} if category else None
|
||||
results = await asyncio.to_thread(
|
||||
self.collection.get,
|
||||
include=["metadatas", "documents"],
|
||||
where=where
|
||||
)
|
||||
metadatas = results.get("metadatas") or []
|
||||
documents = results.get("documents") or []
|
||||
|
||||
items = []
|
||||
for meta, doc in zip(metadatas, documents):
|
||||
if meta:
|
||||
items.append(self._reconstruct_dto(meta, doc))
|
||||
try:
|
||||
items.append(self._reconstruct_dto(meta, doc))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Sort by relevance_score descending
|
||||
# Sort strictly by timestamp descending
|
||||
items.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
|
||||
return items[:limit]
|
||||
|
||||
async def get_top_ranked(self, limit: int = 10, category: Optional[str] = None) -> List[EnrichedNewsItemDTO]:
|
||||
where: Any = {"category": category} if category else None
|
||||
results = await asyncio.to_thread(
|
||||
self.collection.get,
|
||||
include=["metadatas", "documents"],
|
||||
where=where
|
||||
)
|
||||
metadatas = results.get("metadatas") or []
|
||||
documents = results.get("documents") or []
|
||||
|
||||
items = []
|
||||
for meta, doc in zip(metadatas, documents):
|
||||
if meta:
|
||||
try:
|
||||
items.append(self._reconstruct_dto(meta, doc))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Sort strictly by relevance_score descending
|
||||
items.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
|
||||
return items[:limit]
|
||||
|
||||
71
tests/bot/test_handler_semantic.py
Normal file
71
tests/bot/test_handler_semantic.py
Normal file
@ -0,0 +1,71 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from aiogram.types import Message
|
||||
from aiogram.filters import CommandObject
|
||||
from src.bot.handlers import get_router
|
||||
from src.processor.dto import EnrichedNewsItemDTO
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_latest_command_with_category_passing():
|
||||
# Arrange
|
||||
storage = MagicMock()
|
||||
storage.get_latest = AsyncMock(return_value=[])
|
||||
processor = MagicMock()
|
||||
|
||||
message = MagicMock(spec=Message)
|
||||
message.answer = AsyncMock()
|
||||
command = CommandObject(command="latest", args="Tech")
|
||||
|
||||
# We need to call the handler directly or via the router
|
||||
# For simplicity, let's call the handler function if it was exported,
|
||||
# but it's defined inside get_router.
|
||||
# Let's extract the handler from the router.
|
||||
router = get_router(storage, processor, "123")
|
||||
|
||||
# Find the handler for /latest
|
||||
handler = None
|
||||
for observer in router.message.handlers:
|
||||
if "latest" in str(observer.callback):
|
||||
handler = observer.callback
|
||||
break
|
||||
|
||||
assert handler is not None
|
||||
|
||||
# Act
|
||||
await handler(message, command)
|
||||
|
||||
# Assert
|
||||
# Verify that storage.get_latest was called with the category
|
||||
storage.get_latest.assert_called_once_with(limit=10, category="Tech")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_command_with_threshold():
|
||||
# Arrange
|
||||
storage = MagicMock()
|
||||
storage.search = AsyncMock(return_value=[])
|
||||
processor = MagicMock()
|
||||
|
||||
message = MagicMock(spec=Message)
|
||||
message.answer = AsyncMock()
|
||||
command = CommandObject(command="search", args="AI News")
|
||||
|
||||
router = get_router(storage, processor, "123")
|
||||
|
||||
handler = None
|
||||
for observer in router.message.handlers:
|
||||
if "search" in str(observer.callback):
|
||||
handler = observer.callback
|
||||
break
|
||||
|
||||
assert handler is not None
|
||||
|
||||
# Act
|
||||
await handler(message, command)
|
||||
|
||||
# Assert
|
||||
# Verify that storage.search was called with a threshold
|
||||
args, kwargs = storage.search.call_args
|
||||
assert kwargs["query"] == "AI News"
|
||||
assert "threshold" in kwargs
|
||||
assert kwargs["threshold"] < 1.0 # Should have some threshold
|
||||
@ -26,6 +26,7 @@ def mock_item():
|
||||
def mock_storage(mock_item):
|
||||
storage = AsyncMock()
|
||||
storage.search.return_value = [mock_item]
|
||||
storage.get_latest.return_value = [mock_item]
|
||||
storage.get_by_id.return_value = mock_item
|
||||
storage.get_stats.return_value = {"total": 1, "AI": 1}
|
||||
return storage
|
||||
@ -115,6 +116,7 @@ async def test_command_latest_handler(router, mock_storage, allowed_chat_id):
|
||||
|
||||
await handler(message=message, command=command)
|
||||
|
||||
mock_storage.get_latest.assert_called_once_with(limit=10, category=None)
|
||||
message.answer.assert_called_once()
|
||||
args, kwargs = message.answer.call_args
|
||||
assert "Latest news:" in args[0]
|
||||
@ -136,7 +138,7 @@ async def test_command_search_handler(router, mock_storage, allowed_chat_id):
|
||||
args, kwargs = message.answer.call_args
|
||||
assert "Search results:" in args[0]
|
||||
assert "reply_markup" in kwargs
|
||||
mock_storage.search.assert_called_once_with(query="quantum", limit=10)
|
||||
mock_storage.search.assert_called_once_with(query="quantum", limit=10, threshold=0.6)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detail_callback_handler(router, mock_storage, mock_item):
|
||||
@ -144,7 +146,7 @@ async def test_detail_callback_handler(router, mock_storage, mock_item):
|
||||
callback = AsyncMock(spec=CallbackQuery)
|
||||
item_id = str(uuid.uuid5(uuid.NAMESPACE_URL, mock_item.url))
|
||||
callback.data = f"detail:{item_id}"
|
||||
callback.message = AsyncMock()
|
||||
callback.message = AsyncMock(spec=Message)
|
||||
callback.message.answer = AsyncMock()
|
||||
callback.answer = AsyncMock()
|
||||
|
||||
@ -182,7 +184,7 @@ async def test_command_hottest_handler(router, mock_storage, allowed_chat_id, mo
|
||||
|
||||
await handler(message=message, command=command)
|
||||
|
||||
mock_storage.get_top_ranked.assert_called_once_with(limit=10)
|
||||
mock_storage.get_top_ranked.assert_called_once_with(limit=10, category=None)
|
||||
message.answer.assert_called_once()
|
||||
args, kwargs = message.answer.call_args
|
||||
assert "Top 1 Hottest Trends:" in args[0]
|
||||
|
||||
@ -64,7 +64,7 @@ async def test_command_hottest_handler_success(router, mock_storage, allowed_cha
|
||||
await handler(message=message, command=command)
|
||||
|
||||
# 3. Assert
|
||||
mock_storage.get_top_ranked.assert_called_once_with(limit=10)
|
||||
mock_storage.get_top_ranked.assert_called_once_with(limit=10, category=None)
|
||||
message.answer.assert_called_once()
|
||||
|
||||
args, kwargs = message.answer.call_args
|
||||
@ -101,7 +101,7 @@ async def test_command_hottest_handler_empty(router, mock_storage, allowed_chat_
|
||||
await handler(message=message, command=command)
|
||||
|
||||
# 3. Assert
|
||||
mock_storage.get_top_ranked.assert_called_once_with(limit=10)
|
||||
mock_storage.get_top_ranked.assert_called_once_with(limit=10, category=None)
|
||||
message.answer.assert_called_once_with("No hot trends found yet.")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -123,7 +123,7 @@ async def test_command_hottest_handler_custom_limit(router, mock_storage, allowe
|
||||
await handler(message=message, command=command)
|
||||
|
||||
# 3. Assert
|
||||
mock_storage.get_top_ranked.assert_called_once_with(limit=25)
|
||||
mock_storage.get_top_ranked.assert_called_once_with(limit=25, category=None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_hottest_handler_max_limit(router, mock_storage, allowed_chat_id):
|
||||
@ -144,7 +144,7 @@ async def test_command_hottest_handler_max_limit(router, mock_storage, allowed_c
|
||||
await handler(message=message, command=command)
|
||||
|
||||
# 3. Assert
|
||||
mock_storage.get_top_ranked.assert_called_once_with(limit=50)
|
||||
mock_storage.get_top_ranked.assert_called_once_with(limit=50, category=None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_hottest_handler_invalid_limit(router, mock_storage, allowed_chat_id):
|
||||
@ -165,4 +165,4 @@ async def test_command_hottest_handler_invalid_limit(router, mock_storage, allow
|
||||
await handler(message=message, command=command)
|
||||
|
||||
# 3. Assert
|
||||
mock_storage.get_top_ranked.assert_called_once_with(limit=10)
|
||||
mock_storage.get_top_ranked.assert_called_once_with(limit=10, category='invalid')
|
||||
|
||||
@ -1,302 +1,285 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from unittest.mock import MagicMock, patch
|
||||
from typing import Dict, Any
|
||||
|
||||
from src.processor.dto import EnrichedNewsItemDTO
|
||||
from src.storage.chroma_store import ChromaStore
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def chroma_store():
|
||||
# Use EphemeralClient for in-memory testing
|
||||
client = chromadb.EphemeralClient(Settings(allow_reset=True))
|
||||
client.reset()
|
||||
store = ChromaStore(client=client, collection_name="test_collection")
|
||||
yield store
|
||||
client.reset()
|
||||
@pytest.fixture
|
||||
def mock_client():
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_collection():
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_store(mock_client, mock_collection):
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
return ChromaStore(client=mock_client, collection_name="test_collection")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_and_search(chroma_store: ChromaStore):
|
||||
# 1. Arrange
|
||||
item1 = EnrichedNewsItemDTO(
|
||||
title="Apple announces new M4 chip",
|
||||
url="https://example.com/apple-m4",
|
||||
content_text="Apple has announced its newest M4 chip for next generation Macs. This processor brings massive AI improvements.",
|
||||
source="TechNews",
|
||||
timestamp=datetime(2023, 11, 1, 12, 0, tzinfo=timezone.utc),
|
||||
relevance_score=9,
|
||||
summary_ru="Apple анонсировала новый чип M4.",
|
||||
anomalies_detected=["NPU acceleration"],
|
||||
category="Competitors"
|
||||
)
|
||||
|
||||
item2 = EnrichedNewsItemDTO(
|
||||
title="Local bakery makes giant bread",
|
||||
url="https://example.com/giant-bread",
|
||||
content_text="A bakery in town just baked the world's largest loaf of bread, weighing over 1000 pounds.",
|
||||
source="LocalNews",
|
||||
timestamp=datetime(2023, 11, 2, 10, 0, tzinfo=timezone.utc),
|
||||
relevance_score=2,
|
||||
summary_ru="Местная пекарня испекла гигантский хлеб.",
|
||||
anomalies_detected=[],
|
||||
category="Other"
|
||||
)
|
||||
|
||||
item3 = EnrichedNewsItemDTO(
|
||||
title="NVIDIA reveals RTX 5090 with WebGPU support",
|
||||
url="https://example.com/nvidia-rtx-5090",
|
||||
content_text="NVIDIA's new RTX 5090 GPU fully accelerates WebGPU workloads for advanced edge AI applications.",
|
||||
source="GPUWeekly",
|
||||
timestamp=datetime(2023, 11, 3, 14, 0, tzinfo=timezone.utc),
|
||||
relevance_score=10,
|
||||
summary_ru="NVIDIA представила RTX 5090 с поддержкой WebGPU.",
|
||||
anomalies_detected=["WebGPU", "Edge AI"],
|
||||
category="Edge AI"
|
||||
)
|
||||
|
||||
# 2. Act
|
||||
await chroma_store.store(item1)
|
||||
await chroma_store.store(item2)
|
||||
await chroma_store.store(item3)
|
||||
|
||||
# Search for AI and chip related news
|
||||
search_results = await chroma_store.search("AI processor and GPU", limit=2)
|
||||
|
||||
# 3. Assert
|
||||
assert len(search_results) == 2
|
||||
|
||||
# Expected: The Apple M4 chip and NVIDIA RTX 5090 are highly relevant to AI/GPU
|
||||
titles = [res.title for res in search_results]
|
||||
assert "NVIDIA reveals RTX 5090 with WebGPU support" in titles
|
||||
assert "Apple announces new M4 chip" in titles
|
||||
assert "Local bakery makes giant bread" not in titles
|
||||
|
||||
# Check if properties are correctly restored for one of the items
|
||||
for res in search_results:
|
||||
if "NVIDIA" in res.title:
|
||||
assert res.relevance_score == 10
|
||||
assert "WebGPU" in res.anomalies_detected
|
||||
assert "Edge AI" in res.anomalies_detected
|
||||
assert "NVIDIA's new RTX 5090" in res.content_text
|
||||
assert res.source == "GPUWeekly"
|
||||
assert res.category == "Edge AI"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_empty_store(chroma_store: ChromaStore):
|
||||
results = await chroma_store.search("test query", limit=5)
|
||||
assert len(results) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_upsert(chroma_store: ChromaStore):
|
||||
item1 = EnrichedNewsItemDTO(
|
||||
title="Apple announces new M4 chip",
|
||||
url="https://example.com/apple-m4",
|
||||
content_text="Apple has announced its newest M4 chip for next generation Macs.",
|
||||
source="TechNews",
|
||||
timestamp=datetime(2023, 11, 1, 12, 0, tzinfo=timezone.utc),
|
||||
relevance_score=9,
|
||||
summary_ru="Apple анонсировала новый чип M4.",
|
||||
anomalies_detected=["NPU acceleration"],
|
||||
category="Competitors"
|
||||
)
|
||||
|
||||
# Store first time
|
||||
await chroma_store.store(item1)
|
||||
results = await chroma_store.search("Apple", limit=5)
|
||||
assert len(results) == 1
|
||||
assert results[0].relevance_score == 9
|
||||
|
||||
# Modify item and store again (same URL, should upsert)
|
||||
item1_updated = item1.model_copy()
|
||||
item1_updated.relevance_score = 10
|
||||
item1_updated.summary_ru = "Apple анонсировала чип M4. Обновлено."
|
||||
|
||||
await chroma_store.store(item1_updated)
|
||||
results_updated = await chroma_store.search("Apple", limit=5)
|
||||
|
||||
# Should still be 1 item, but updated
|
||||
assert len(results_updated) == 1
|
||||
assert results_updated[0].relevance_score == 10
|
||||
assert results_updated[0].summary_ru == "Apple анонсировала чип M4. Обновлено."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists(chroma_store: ChromaStore):
|
||||
url = "https://example.com/unique-news-123"
|
||||
|
||||
# Check that it doesn't exist initially
|
||||
assert not await chroma_store.exists(url)
|
||||
|
||||
async def test_store(chroma_store, mock_collection):
|
||||
# Arrange
|
||||
item = EnrichedNewsItemDTO(
|
||||
title="Test Title",
|
||||
url=url,
|
||||
content_text="Test content",
|
||||
source="TestSource",
|
||||
timestamp=datetime(2023, 11, 1, 12, 0, tzinfo=timezone.utc),
|
||||
relevance_score=5,
|
||||
url="https://example.com/test",
|
||||
content_text="Test Content",
|
||||
source="Test Source",
|
||||
timestamp=datetime(2023, 1, 1, tzinfo=timezone.utc),
|
||||
relevance_score=8,
|
||||
summary_ru="Тест",
|
||||
anomalies_detected=[],
|
||||
category="Other"
|
||||
category="Tech",
|
||||
anomalies_detected=["A1", "A2"]
|
||||
)
|
||||
|
||||
await chroma_store.store(item)
|
||||
|
||||
# Check that it exists now
|
||||
assert await chroma_store.exists(url)
|
||||
doc_id = str(uuid.uuid5(uuid.NAMESPACE_URL, item.url))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id(chroma_store: ChromaStore):
|
||||
# 1. Arrange
|
||||
url = "https://example.com/get-by-id-test"
|
||||
doc_id = str(uuid.uuid5(uuid.NAMESPACE_URL, url))
|
||||
|
||||
item = EnrichedNewsItemDTO(
|
||||
title="ID Test Title",
|
||||
url=url,
|
||||
content_text="ID Test Content",
|
||||
source="IDTestSource",
|
||||
timestamp=datetime(2023, 11, 1, 12, 0, tzinfo=timezone.utc),
|
||||
relevance_score=7,
|
||||
summary_ru="Тест по ID",
|
||||
anomalies_detected=["TestAnomaly"],
|
||||
category="Testing"
|
||||
)
|
||||
|
||||
# 2. Act
|
||||
await chroma_store.store(item)
|
||||
|
||||
# Try to retrieve by ID
|
||||
retrieved_item = await chroma_store.get_by_id(doc_id)
|
||||
|
||||
# Try to retrieve non-existent ID
|
||||
none_item = await chroma_store.get_by_id("non-existent-id")
|
||||
|
||||
# 3. Assert
|
||||
assert retrieved_item is not None
|
||||
assert retrieved_item.title == "ID Test Title"
|
||||
assert retrieved_item.url == url
|
||||
assert retrieved_item.relevance_score == 7
|
||||
assert "TestAnomaly" in retrieved_item.anomalies_detected
|
||||
assert retrieved_item.category == "Testing"
|
||||
|
||||
assert none_item is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats(chroma_store: ChromaStore):
|
||||
# 1. Arrange
|
||||
item1 = EnrichedNewsItemDTO(
|
||||
title="Title 1",
|
||||
url="https://example.com/1",
|
||||
content_text="Content 1",
|
||||
source="Source 1",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
relevance_score=5,
|
||||
summary_ru="Сводка 1",
|
||||
anomalies_detected=[],
|
||||
category="Tech"
|
||||
)
|
||||
item2 = EnrichedNewsItemDTO(
|
||||
title="Title 2",
|
||||
url="https://example.com/2",
|
||||
content_text="Content 2",
|
||||
source="Source 2",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
relevance_score=5,
|
||||
summary_ru="Сводка 2",
|
||||
anomalies_detected=[],
|
||||
category="Tech"
|
||||
)
|
||||
item3 = EnrichedNewsItemDTO(
|
||||
title="Title 3",
|
||||
url="https://example.com/3",
|
||||
content_text="Content 3",
|
||||
source="Source 3",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
relevance_score=5,
|
||||
summary_ru="Сводка 3",
|
||||
anomalies_detected=[],
|
||||
category="Science"
|
||||
)
|
||||
|
||||
# 2. Act
|
||||
await chroma_store.store(item1)
|
||||
await chroma_store.store(item2)
|
||||
await chroma_store.store(item3)
|
||||
|
||||
stats = await chroma_store.get_stats()
|
||||
|
||||
# 3. Assert
|
||||
assert stats["total_count"] == 3
|
||||
assert stats["category_Tech"] == 2
|
||||
assert stats["category_Science"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_sorting(chroma_store: ChromaStore):
|
||||
# Arrange
|
||||
items = [
|
||||
EnrichedNewsItemDTO(
|
||||
title=f"Title {i}",
|
||||
url=f"https://example.com/{i}",
|
||||
content_text=f"Content {i}",
|
||||
source="Source",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
relevance_score=i,
|
||||
summary_ru=f"Сводка {i}",
|
||||
anomalies_detected=[],
|
||||
category="Tech"
|
||||
) for i in range(1, 6) # Scores 1 to 5
|
||||
]
|
||||
|
||||
for item in items:
|
||||
await chroma_store.store(item)
|
||||
|
||||
# Act
|
||||
results = await chroma_store.search("Content", limit=10)
|
||||
|
||||
await chroma_store.store(item)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 5
|
||||
# Should be sorted 5, 4, 3, 2, 1
|
||||
scores = [r.relevance_score for r in results]
|
||||
assert scores == [5, 4, 3, 2, 1]
|
||||
mock_collection.upsert.assert_called_once()
|
||||
args, kwargs = mock_collection.upsert.call_args
|
||||
assert kwargs['ids'] == [doc_id]
|
||||
assert kwargs['documents'] == ["Test Content"]
|
||||
assert kwargs['metadatas'][0]['title'] == "Test Title"
|
||||
assert kwargs['metadatas'][0]['category'] == "Tech"
|
||||
assert kwargs['metadatas'][0]['anomalies_detected'] == "A1,A2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_top_ranked_mock(chroma_store: ChromaStore):
|
||||
# 1. Arrange
|
||||
mock_collection = MagicMock()
|
||||
chroma_store.collection = mock_collection
|
||||
|
||||
# Mock data returned by collection.get
|
||||
async def test_get_by_id_found(chroma_store, mock_collection):
|
||||
# Arrange
|
||||
item_id = "some-id"
|
||||
mock_collection.get.return_value = {
|
||||
"metadatas": [{
|
||||
"title": "Title",
|
||||
"url": "https://url.com",
|
||||
"source": "Source",
|
||||
"timestamp": "2023-01-01T00:00:00",
|
||||
"relevance_score": 5.0,
|
||||
"summary_ru": "Сводка",
|
||||
"category": "Cat",
|
||||
"anomalies_detected": "A1"
|
||||
}],
|
||||
"documents": ["Content"]
|
||||
}
|
||||
|
||||
# Act
|
||||
result = await chroma_store.get_by_id(item_id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.title == "Title"
|
||||
assert result.content_text == "Content"
|
||||
assert result.anomalies_detected == ["A1"]
|
||||
mock_collection.get.assert_called_once_with(ids=[item_id])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_not_found(chroma_store, mock_collection):
|
||||
# Arrange
|
||||
mock_collection.get.return_value = {"metadatas": [], "documents": []}
|
||||
|
||||
# Act
|
||||
result = await chroma_store.get_by_id("none")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists(chroma_store, mock_collection):
|
||||
# Arrange
|
||||
url = "https://example.com"
|
||||
doc_id = str(uuid.uuid5(uuid.NAMESPACE_URL, url))
|
||||
mock_collection.get.return_value = {"ids": [doc_id]}
|
||||
|
||||
# Act
|
||||
exists = await chroma_store.exists(url)
|
||||
|
||||
# Assert
|
||||
assert exists is True
|
||||
mock_collection.get.assert_called_once_with(ids=[doc_id])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats(chroma_store, mock_collection):
|
||||
# Arrange
|
||||
mock_collection.get.return_value = {
|
||||
"metadatas": [
|
||||
{"title": "Low", "url": "url1", "relevance_score": 2, "timestamp": "2023-11-01T12:00:00"},
|
||||
{"title": "High", "url": "url2", "relevance_score": 10, "timestamp": "2023-11-01T12:00:00"},
|
||||
{"title": "Mid", "url": "url3", "relevance_score": 7, "timestamp": "2023-11-01T12:00:00"},
|
||||
],
|
||||
"documents": ["doc1", "doc2", "doc3"]
|
||||
{"category": "Tech"},
|
||||
{"category": "Tech"},
|
||||
{"category": "Science"},
|
||||
None,
|
||||
{"other": "data"}
|
||||
]
|
||||
}
|
||||
|
||||
# 2. Act
|
||||
results = await chroma_store.get_top_ranked(limit=2)
|
||||
|
||||
# 3. Assert
|
||||
mock_collection.get.assert_called_once_with(include=["metadatas", "documents"])
|
||||
assert len(results) == 2
|
||||
assert results[0].title == "High"
|
||||
assert results[0].relevance_score == 10
|
||||
assert results[1].title == "Mid"
|
||||
assert results[1].relevance_score == 7
|
||||
|
||||
# Act
|
||||
stats = await chroma_store.get_stats()
|
||||
|
||||
# Assert
|
||||
assert stats["total_count"] == 5
|
||||
assert stats["category_Tech"] == 2
|
||||
assert stats["category_Science"] == 1
|
||||
assert stats["category_Uncategorized"] == 1 # for the dict without category
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_top_ranked_empty(chroma_store: ChromaStore):
|
||||
# 1. Arrange
|
||||
mock_collection = MagicMock()
|
||||
chroma_store.collection = mock_collection
|
||||
async def test_get_latest(chroma_store, mock_collection):
|
||||
# Arrange
|
||||
mock_collection.get.return_value = {
|
||||
"metadatas": [
|
||||
{"title": "Old", "timestamp": "2023-01-01T00:00:00", "url": "u1", "relevance_score": 1},
|
||||
{"title": "New", "timestamp": "2023-01-02T00:00:00", "url": "u2", "relevance_score": 1},
|
||||
],
|
||||
"documents": ["doc1", "doc2"]
|
||||
}
|
||||
|
||||
# Act
|
||||
results = await chroma_store.get_latest(limit=10, category="Tech")
|
||||
|
||||
# Assert
|
||||
assert len(results) == 2
|
||||
assert results[0].title == "New"
|
||||
assert results[1].title == "Old"
|
||||
mock_collection.get.assert_called_once_with(
|
||||
include=["metadatas", "documents"],
|
||||
where={"category": "Tech"}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_top_ranked(chroma_store, mock_collection):
|
||||
# Arrange
|
||||
mock_collection.get.return_value = {
|
||||
"metadatas": [
|
||||
{"title": "Low", "timestamp": "2023-01-01T00:00:00", "url": "u1", "relevance_score": 2},
|
||||
{"title": "High", "timestamp": "2023-01-01T00:00:00", "url": "u2", "relevance_score": 10},
|
||||
],
|
||||
"documents": ["doc1", "doc2"]
|
||||
}
|
||||
|
||||
# Act
|
||||
results = await chroma_store.get_top_ranked(limit=1, category="Tech")
|
||||
|
||||
# Assert
|
||||
assert len(results) == 1
|
||||
assert results[0].title == "High"
|
||||
mock_collection.get.assert_called_once_with(
|
||||
include=["metadatas", "documents"],
|
||||
where={"category": "Tech"}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_hybrid_exact_match_fills_limit(chroma_store, mock_collection):
|
||||
# Arrange
|
||||
query = "Apple"
|
||||
mock_collection.get.return_value = {
|
||||
"metadatas": [
|
||||
{"title": "Apple M4", "url": "u1", "timestamp": "2023-01-01T00:00:00", "relevance_score": 10},
|
||||
{"title": "Apple Vision", "url": "u2", "timestamp": "2023-01-01T00:00:00", "relevance_score": 9},
|
||||
],
|
||||
"documents": ["doc1", "doc2"]
|
||||
}
|
||||
|
||||
# Act
|
||||
results = await chroma_store.search(query, limit=2)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 2
|
||||
assert results[0].title == "Apple M4"
|
||||
assert results[1].title == "Apple Vision"
|
||||
mock_collection.get.assert_called_once()
|
||||
mock_collection.query.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_hybrid_falls_back_to_semantic(chroma_store, mock_collection):
|
||||
# Arrange
|
||||
query = "Apple"
|
||||
# Exact match finds 1 item
|
||||
mock_collection.get.return_value = {
|
||||
"metadatas": [{"title": "Apple M4", "url": "u1", "timestamp": "2023-01-01T00:00:00", "relevance_score": 10}],
|
||||
"documents": ["doc1"]
|
||||
}
|
||||
# Semantic match finds more items, including the same one
|
||||
mock_collection.query.return_value = {
|
||||
"metadatas": [[
|
||||
{"title": "Apple M4", "url": "u1", "timestamp": "2023-01-01T00:00:00", "relevance_score": 10},
|
||||
{"title": "M3 Chip", "url": "u2", "timestamp": "2023-01-01T00:00:00", "relevance_score": 8},
|
||||
]],
|
||||
"documents": [["doc1", "doc2"]],
|
||||
"distances": [[0.1, 0.5]]
|
||||
}
|
||||
|
||||
# Act
|
||||
results = await chroma_store.search(query, limit=2)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 2
|
||||
assert results[0].title == "Apple M4"
|
||||
assert results[1].title == "M3 Chip"
|
||||
assert mock_collection.get.called
|
||||
assert mock_collection.query.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_category_and_threshold(chroma_store, mock_collection):
|
||||
# Arrange
|
||||
query = "AI"
|
||||
mock_collection.get.return_value = {"metadatas": [], "documents": []}
|
||||
|
||||
# 2. Act
|
||||
results = await chroma_store.get_top_ranked(limit=10)
|
||||
|
||||
# 3. Assert
|
||||
assert len(results) == 0
|
||||
mock_collection.query.return_value = {
|
||||
"metadatas": [[
|
||||
{"title": "Good match", "url": "u1", "timestamp": "2023-01-01T00:00:00", "relevance_score": 10},
|
||||
{"title": "Bad match", "url": "u2", "timestamp": "2023-01-01T00:00:00", "relevance_score": 5},
|
||||
]],
|
||||
"documents": [["doc1", "doc2"]],
|
||||
"distances": [[0.2, 0.8]]
|
||||
}
|
||||
|
||||
# Act
|
||||
results = await chroma_store.search(query, limit=5, category="Tech", threshold=0.5)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 1
|
||||
assert results[0].title == "Good match"
|
||||
mock_collection.get.assert_called_with(
|
||||
where_document={"$contains": "AI"},
|
||||
where={"category": "Tech"},
|
||||
include=["metadatas", "documents"]
|
||||
)
|
||||
mock_collection.query.assert_called_with(
|
||||
query_texts=["AI"],
|
||||
n_results=5,
|
||||
where={"category": "Tech"}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_exception_handling(chroma_store, mock_collection):
|
||||
# Arrange
|
||||
mock_collection.get.side_effect = Exception("Get failed")
|
||||
mock_collection.query.side_effect = Exception("Query failed")
|
||||
|
||||
# Act
|
||||
results = await chroma_store.search("query")
|
||||
|
||||
# Assert
|
||||
assert results == [] # Should not crash
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_empty_query(chroma_store, mock_collection):
|
||||
# Arrange
|
||||
mock_collection.get.return_value = {"metadatas": [], "documents": []}
|
||||
mock_collection.query.return_value = {"metadatas": [[]], "documents": [[]], "distances": [[]]}
|
||||
|
||||
# Act
|
||||
await chroma_store.search("")
|
||||
|
||||
# Assert
|
||||
mock_collection.get.assert_called_with(
|
||||
where_document=None,
|
||||
where=None,
|
||||
include=["metadatas", "documents"]
|
||||
)
|
||||
mock_collection.query.assert_called_with(
|
||||
query_texts=["*"],
|
||||
n_results=5,
|
||||
where=None
|
||||
)
|
||||
|
||||
120
tests/storage/test_semantic_search.py
Normal file
120
tests/storage/test_semantic_search.py
Normal file
@ -0,0 +1,120 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import MagicMock
|
||||
from datetime import datetime, timezone
|
||||
from src.storage.chroma_store import ChromaStore
|
||||
from src.processor.dto import EnrichedNewsItemDTO
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chroma_client():
|
||||
client = MagicMock()
|
||||
collection = MagicMock()
|
||||
client.get_or_create_collection.return_value = collection
|
||||
return client, collection
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def chroma_store(mock_chroma_client):
|
||||
client, collection = mock_chroma_client
|
||||
store = ChromaStore(client=client, collection_name="test_collection")
|
||||
return store
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_category_filter(chroma_store, mock_chroma_client):
|
||||
client, collection = mock_chroma_client
|
||||
|
||||
# Mock return value for collection.query
|
||||
collection.query.return_value = {
|
||||
"ids": [["id1"]],
|
||||
"metadatas": [[{
|
||||
"title": "AI in Robotics",
|
||||
"url": "https://example.com/robotics",
|
||||
"source": "Tech",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"relevance_score": 8,
|
||||
"summary_ru": "AI в робототехнике",
|
||||
"category": "Robotics",
|
||||
"anomalies_detected": ""
|
||||
}]],
|
||||
"documents": [["Full content here"]],
|
||||
"distances": [[0.1]]
|
||||
}
|
||||
|
||||
# We want to test that 'category' is passed as a 'where' clause to ChromaDB
|
||||
# Note: We need to update the search method signature in the next step
|
||||
results = await chroma_store.search(query="AI", limit=5, category="Robotics")
|
||||
|
||||
# Assert collection.query was called with correct 'where' filter
|
||||
args, kwargs = collection.query.call_args
|
||||
assert kwargs["where"] == {"category": "Robotics"}
|
||||
assert len(results) == 1
|
||||
assert results[0].category == "Robotics"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_relevance_threshold(chroma_store, mock_chroma_client):
|
||||
client, collection = mock_chroma_client
|
||||
|
||||
# Mock return value: one relevant (low distance), one irrelevant (high distance)
|
||||
collection.query.return_value = {
|
||||
"ids": [["id-rel", "id-irrel"]],
|
||||
"metadatas": [[
|
||||
{
|
||||
"title": "Relevant News",
|
||||
"url": "url1",
|
||||
"source": "s",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"relevance_score": 9,
|
||||
"summary_ru": "Р",
|
||||
"category": "C",
|
||||
"anomalies_detected": ""
|
||||
},
|
||||
{
|
||||
"title": "Irrelevant News",
|
||||
"url": "url2",
|
||||
"source": "s",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"relevance_score": 1,
|
||||
"summary_ru": "И",
|
||||
"category": "C",
|
||||
"anomalies_detected": ""
|
||||
}
|
||||
]],
|
||||
"documents": [["doc1", "doc2"]],
|
||||
"distances": [[0.2, 0.8]] # Lower distance means more similar
|
||||
}
|
||||
|
||||
# threshold=0.5 means distances <= 0.5 are kept
|
||||
results = await chroma_store.search(query="test", limit=10, threshold=0.5)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].title == "Relevant News"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_semantic_threshold(chroma_store, mock_chroma_client):
|
||||
"""
|
||||
Test that /latest uses semantic search if a category is provided,
|
||||
but also respects the threshold even for plain searches.
|
||||
"""
|
||||
client, collection = mock_chroma_client
|
||||
|
||||
collection.query.return_value = {
|
||||
"ids": [["id1"]],
|
||||
"metadatas": [[{
|
||||
"title": "Latest News",
|
||||
"url": "url",
|
||||
"source": "s",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"relevance_score": 5,
|
||||
"summary_ru": "L",
|
||||
"category": "Tech",
|
||||
"anomalies_detected": ""
|
||||
}]],
|
||||
"documents": [["doc"]],
|
||||
"distances": [[0.05]]
|
||||
}
|
||||
|
||||
# If category is provided, we should use category filter
|
||||
results = await chroma_store.search(query="", limit=10, category="Tech")
|
||||
|
||||
args, kwargs = collection.query.call_args
|
||||
assert kwargs["where"] == {"category": "Tech"}
|
||||
assert len(results) == 1
|
||||
7
update_chroma_store.py
Normal file
7
update_chroma_store.py
Normal file
@ -0,0 +1,7 @@
|
||||
import re
|
||||
|
||||
with open("src/storage/chroma_store.py", "r") as f:
|
||||
content = f.read()
|
||||
|
||||
# I will rewrite the class completely because there are many changes to make.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user