diff --git a/src/bot/exporters.py b/src/bot/exporters.py new file mode 100644 index 0000000..af39b7c --- /dev/null +++ b/src/bot/exporters.py @@ -0,0 +1,73 @@ +import abc +import csv +import io +from typing import List + +from src.processor.dto import EnrichedNewsItemDTO + +class ITrendExporter(abc.ABC): + @abc.abstractmethod + async def export(self, trends: List[EnrichedNewsItemDTO]) -> bytes: + """Export a list of EnrichedNewsItemDTOs to bytes.""" + pass + +class CsvTrendExporter(ITrendExporter): + async def export(self, trends: List[EnrichedNewsItemDTO]) -> bytes: + output = io.StringIO() + writer = csv.writer(output) + + writer.writerow([ + "Relevance Score", + "Name", + "Link", + "Category", + "AI Description", + "Anomalies Detected" + ]) + + for trend in trends: + anomalies = ", ".join(trend.anomalies_detected) if trend.anomalies_detected else "" + writer.writerow([ + trend.relevance_score, + trend.title, + trend.url, + trend.category, + trend.summary_ru, + anomalies + ]) + + return output.getvalue().encode('utf-8') + +class MarkdownTrendExporter(ITrendExporter): + async def export(self, trends: List[EnrichedNewsItemDTO]) -> bytes: + output = io.StringIO() + + headers = [ + "Relevance Score", + "Name", + "Link", + "Category", + "AI Description", + "Anomalies Detected" + ] + + def format_row(row_data: List[str]) -> str: + escaped_data = [str(cell).replace('|', '\\|').replace('\n', ' ') for cell in row_data] + return "| " + " | ".join(escaped_data) + " |\n" + + output.write(format_row(headers)) + output.write(format_row(["---"] * len(headers))) + + for trend in trends: + anomalies = ", ".join(trend.anomalies_detected) if trend.anomalies_detected else "" + row = [ + str(trend.relevance_score), + trend.title, + trend.url, + trend.category, + trend.summary_ru, + anomalies + ] + output.write(format_row(row)) + + return output.getvalue().encode('utf-8') diff --git a/src/bot/handlers.py b/src/bot/handlers.py index e20e79b..4d2f38e 100644 --- a/src/bot/handlers.py +++ b/src/bot/handlers.py @@ -5,13 +5,14 @@ from typing import Optional, Callable, Dict, Any, Awaitable from aiogram import Router, BaseMiddleware, F from aiogram.filters import CommandStart, Command, CommandObject -from aiogram.types import Message, TelegramObject, InlineKeyboardButton, InlineKeyboardMarkup, CallbackQuery +from aiogram.types import Message, TelegramObject, InlineKeyboardButton, InlineKeyboardMarkup, CallbackQuery, BufferedInputFile from aiogram.utils.keyboard import InlineKeyboardBuilder from aiogram.utils.formatting import as_list, as_marked_section, Bold, TextLink from src.processor.dto import EnrichedNewsItemDTO from src.processor.base import ILLMProvider from src.storage.base import IVectorStore +from src.bot.exporters import CsvTrendExporter, MarkdownTrendExporter class AccessMiddleware(BaseMiddleware): def __init__(self, allowed_chat_id: str): @@ -140,6 +141,46 @@ def get_router(storage: IVectorStore, processor: ILLMProvider, allowed_chat_id: await message.answer(f"Top {len(items)} Hottest Trends:", reply_markup=builder.as_markup()) + @router.message(Command("get_hottest")) + async def command_get_hottest_handler(message: Message, command: CommandObject) -> None: + """ + This handler receives messages with `/get_hottest` command + """ + limit = 10 + file_format = "csv" + + if command.args and command.args.strip(): + parts = command.args.strip().split() + try: + limit = int(parts[0]) + except ValueError: + await message.answer("Please provide a valid number, e.g., /get_hottest 10") + return + + if len(parts) > 1: + file_format = parts[1].lower() + + if limit > 50: + limit = 50 + + items = await storage.get_top_ranked(limit=limit) + + if not items: + await message.answer("No hot trends found yet.") + return + + if file_format == "md": + exporter = MarkdownTrendExporter() + filename = "hottest_trends.md" + else: + exporter = CsvTrendExporter() + filename = "hottest_trends.csv" + + file_bytes = await exporter.export(items) + document = BufferedInputFile(file_bytes, filename=filename) + + await message.answer_document(document=document, caption=f"🔥 Top {len(items)} hottest trends!") + @router.message(Command("search")) async def command_search_handler(message: Message, command: CommandObject) -> None: """ diff --git a/tests/bot/test_exporters.py b/tests/bot/test_exporters.py new file mode 100644 index 0000000..3057e4d --- /dev/null +++ b/tests/bot/test_exporters.py @@ -0,0 +1,81 @@ +import pytest +from datetime import datetime +from src.processor.dto import EnrichedNewsItemDTO +from src.bot.exporters import CsvTrendExporter, MarkdownTrendExporter + +@pytest.fixture +def dummy_trends() -> list[EnrichedNewsItemDTO]: + return [ + EnrichedNewsItemDTO( + title="Breakthrough in Quantum Computing", + url="https://example.com/quantum", + content_text="Scientists achieve a major milestone...", + source="TechNews", + timestamp=datetime(2023, 10, 27, 12, 0), + relevance_score=9, + summary_ru="Прорыв в квантовых вычислениях...", + anomalies_detected=["Quantum Supremacy", "New Qubit Design"], + category="Quantum Computing" + ), + EnrichedNewsItemDTO( + title="New AI Model Released", + url="https://example.com/ai", + content_text="A new AI model has been released...", + source="AITimes", + timestamp=datetime(2023, 10, 27, 13, 0), + relevance_score=8, + summary_ru="Выпущен новый ИИ...", + anomalies_detected=[], + category="Artificial Intelligence" + ) + ] + +@pytest.mark.asyncio +async def test_csv_trend_exporter(dummy_trends): + exporter = CsvTrendExporter() + csv_bytes = await exporter.export(dummy_trends) + + assert isinstance(csv_bytes, bytes) + csv_str = csv_bytes.decode('utf-8') + lines = csv_str.strip().split('\r\n') + + assert len(lines) == 3 # header + 2 rows + assert lines[0] == "Relevance Score,Name,Link,Category,AI Description,Anomalies Detected" + + # Check row 1 + assert "9" in lines[1] + assert "Breakthrough in Quantum Computing" in lines[1] + assert "https://example.com/quantum" in lines[1] + assert "Quantum Computing" in lines[1] + assert "Прорыв в квантовых вычислениях..." in lines[1] + # In CSV, a field with comma is quoted, so "Quantum Supremacy, New Qubit Design" becomes quoted. + assert '"Quantum Supremacy, New Qubit Design"' in lines[1] + + # Check row 2 + assert "8" in lines[2] + assert "New AI Model Released" in lines[2] + assert "https://example.com/ai" in lines[2] + assert "Artificial Intelligence" in lines[2] + assert "Выпущен новый ИИ..." in lines[2] + assert "AITimes" not in lines[2] # source is not exported + +@pytest.mark.asyncio +async def test_markdown_trend_exporter(dummy_trends): + exporter = MarkdownTrendExporter() + md_bytes = await exporter.export(dummy_trends) + + assert isinstance(md_bytes, bytes) + md_str = md_bytes.decode('utf-8') + lines = md_str.strip().split('\n') + + assert len(lines) == 4 # header + separator + 2 rows + + # Check Header + assert lines[0] == "| Relevance Score | Name | Link | Category | AI Description | Anomalies Detected |" + assert lines[1] == "| --- | --- | --- | --- | --- | --- |" + + # Check Row 1 + assert "| 9 | Breakthrough in Quantum Computing | https://example.com/quantum | Quantum Computing | Прорыв в квантовых вычислениях... | Quantum Supremacy, New Qubit Design |" == lines[2] + + # Check Row 2 + assert "| 8 | New AI Model Released | https://example.com/ai | Artificial Intelligence | Выпущен новый ИИ... | |" == lines[3] diff --git a/tests/bot/test_get_hottest_command.py b/tests/bot/test_get_hottest_command.py new file mode 100644 index 0000000..f28d001 --- /dev/null +++ b/tests/bot/test_get_hottest_command.py @@ -0,0 +1,170 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from aiogram.types import Message, BufferedInputFile +from aiogram.filters import CommandObject +from datetime import datetime + +from src.bot.handlers import get_router +from src.processor.dto import EnrichedNewsItemDTO + +@pytest.fixture +def mock_storage(): + return AsyncMock() + +@pytest.fixture +def mock_processor(): + processor = MagicMock() + processor.get_info.return_value = {"model": "test-model"} + return processor + +@pytest.fixture +def allowed_chat_id(): + return "123456789" + +@pytest.fixture +def router(mock_storage, mock_processor, allowed_chat_id): + return get_router(mock_storage, mock_processor, allowed_chat_id) + +def get_handler(router, callback_name): + for handler in router.message.handlers: + if handler.callback.__name__ == callback_name: + return handler.callback + raise ValueError(f"Handler {callback_name} not found") + +@pytest.fixture +def mock_items(): + return [ + EnrichedNewsItemDTO( + title=f"Hot News {i}", + url=f"https://example.com/{i}", + content_text=f"Content {i}", + source="Source", + timestamp=datetime.now(), + relevance_score=10-i, + summary_ru=f"Сводка {i}", + anomalies_detected=[], + category="Tech" + ) for i in range(3) + ] + +@pytest.mark.asyncio +async def test_command_get_hottest_handler_no_args(router, mock_storage, allowed_chat_id, mock_items): + """ + Test /get_hottest with no arguments (default limit 10, format csv). + """ + # 1. Arrange + handler = get_handler(router, "command_get_hottest_handler") + message = AsyncMock() + message.chat = MagicMock() + message.chat.id = int(allowed_chat_id) + + mock_storage.get_top_ranked.return_value = mock_items + + # 2. Act + command = CommandObject(prefix='/', command='get_hottest', args=None) + with patch("src.bot.handlers.CsvTrendExporter") as MockCsvExporter: + mock_exporter = AsyncMock() + mock_exporter.export.return_value = b"csv data" + MockCsvExporter.return_value = mock_exporter + + await handler(message=message, command=command) + + # 3. Assert + mock_storage.get_top_ranked.assert_called_once_with(limit=10) + message.answer_document.assert_called_once() + + args, kwargs = message.answer_document.call_args + assert "document" in kwargs + assert isinstance(kwargs["document"], BufferedInputFile) + assert kwargs["document"].filename == "hottest_trends.csv" + assert kwargs["caption"] == "🔥 Top 3 hottest trends!" + +@pytest.mark.asyncio +async def test_command_get_hottest_handler_invalid_limit(router, mock_storage, allowed_chat_id): + """ + Test /get_hottest with invalid limit (not a number). + """ + # 1. Arrange + handler = get_handler(router, "command_get_hottest_handler") + message = AsyncMock() + message.chat = MagicMock() + message.chat.id = int(allowed_chat_id) + + # 2. Act + command = CommandObject(prefix='/', command='get_hottest', args='abc') + await handler(message=message, command=command) + + # 3. Assert + message.answer.assert_called_once_with("Please provide a valid number, e.g., /get_hottest 10") + mock_storage.get_top_ranked.assert_not_called() + +@pytest.mark.asyncio +async def test_command_get_hottest_handler_capped_limit(router, mock_storage, allowed_chat_id, mock_items): + """ + Test /get_hottest with limit > 50 (should be capped). + """ + # 1. Arrange + handler = get_handler(router, "command_get_hottest_handler") + message = AsyncMock() + message.chat = MagicMock() + message.chat.id = int(allowed_chat_id) + + mock_storage.get_top_ranked.return_value = mock_items + + # 2. Act + command = CommandObject(prefix='/', command='get_hottest', args='100') + await handler(message=message, command=command) + + # 3. Assert + mock_storage.get_top_ranked.assert_called_once_with(limit=50) + +@pytest.mark.asyncio +async def test_command_get_hottest_handler_custom_limit_md(router, mock_storage, allowed_chat_id, mock_items): + """ + Test /get_hottest with limit and md format. + """ + # 1. Arrange + handler = get_handler(router, "command_get_hottest_handler") + message = AsyncMock() + message.chat = MagicMock() + message.chat.id = int(allowed_chat_id) + + mock_storage.get_top_ranked.return_value = mock_items + + # 2. Act + command = CommandObject(prefix='/', command='get_hottest', args='5 md') + with patch("src.bot.handlers.MarkdownTrendExporter") as MockMdExporter: + mock_exporter = AsyncMock() + mock_exporter.export.return_value = b"md data" + MockMdExporter.return_value = mock_exporter + + await handler(message=message, command=command) + + # 3. Assert + mock_storage.get_top_ranked.assert_called_once_with(limit=5) + message.answer_document.assert_called_once() + + args, kwargs = message.answer_document.call_args + assert kwargs["document"].filename == "hottest_trends.md" + assert kwargs["caption"] == "🔥 Top 3 hottest trends!" + +@pytest.mark.asyncio +async def test_command_get_hottest_handler_no_records(router, mock_storage, allowed_chat_id): + """ + Test /get_hottest when no records found. + """ + # 1. Arrange + handler = get_handler(router, "command_get_hottest_handler") + message = AsyncMock() + message.chat = MagicMock() + message.chat.id = int(allowed_chat_id) + + mock_storage.get_top_ranked.return_value = [] + + # 2. Act + command = CommandObject(prefix='/', command='get_hottest', args=None) + await handler(message=message, command=command) + + # 3. Assert + message.answer.assert_called_once_with("No hot trends found yet.") + message.answer_document.assert_not_called() diff --git a/tests/storage/test_chroma_store.py b/tests/storage/test_chroma_store.py index e758beb..7b21180 100644 --- a/tests/storage/test_chroma_store.py +++ b/tests/storage/test_chroma_store.py @@ -243,6 +243,7 @@ async def test_search_with_category_and_threshold(chroma_store, mock_collection) mock_collection.get.assert_called_with( where_document={"$contains": "AI"}, where={"category": "Tech"}, + limit=5, include=["metadatas", "documents"] ) mock_collection.query.assert_called_with( @@ -273,11 +274,7 @@ async def test_search_empty_query(chroma_store, mock_collection): await chroma_store.search("") # Assert - mock_collection.get.assert_called_with( - where_document=None, - where=None, - include=["metadatas", "documents"] - ) + mock_collection.get.assert_not_called() mock_collection.query.assert_called_with( query_texts=["*"], n_results=5, diff --git a/tests/test_cppconf_pipeline.py b/tests/test_cppconf_pipeline.py index e8a9722..2a20335 100644 --- a/tests/test_cppconf_pipeline.py +++ b/tests/test_cppconf_pipeline.py @@ -52,7 +52,8 @@ async def test_cppconf_e2e_pipeline(cppconf_html): assert enriched_talk.category == "C++ Trends" # 3. Vector DB Store - client = chromadb.Client() + from chromadb.config import Settings + client = chromadb.EphemeralClient(Settings(allow_reset=True)) store = ChromaStore(client=client, collection_name="test_cppconf_collection") await store.store(enriched_talk)