|
1 | 1 | import fastapi |
| 2 | +from sqlalchemy import select |
| 3 | +from sqlalchemy.ext.asyncio import async_sessionmaker |
2 | 4 |
|
3 | | -from .api_models import ChatRequest |
4 | | -from .globals import global_storage |
5 | | -from .postgres_searcher import PostgresSearcher |
6 | | -from .rag_advanced import AdvancedRAGChat |
7 | | -from .rag_simple import SimpleRAGChat |
| 5 | +from fastapi_app.api_models import ChatRequest |
| 6 | +from fastapi_app.globals import global_storage |
| 7 | +from fastapi_app.postgres_models import Item |
| 8 | +from fastapi_app.postgres_searcher import PostgresSearcher |
| 9 | +from fastapi_app.rag_advanced import AdvancedRAGChat |
| 10 | +from fastapi_app.rag_simple import SimpleRAGChat |
8 | 11 |
|
9 | 12 | router = fastapi.APIRouter() |
10 | 13 |
|
11 | 14 |
|
| 15 | +@router.get("/items/{id}") |
| 16 | +async def item_handler(id: int): |
| 17 | + """A simple API to get an item by ID.""" |
| 18 | + async_session_maker = async_sessionmaker(global_storage.engine, expire_on_commit=False) |
| 19 | + async with async_session_maker() as session: |
| 20 | + item = (await session.scalars(select(Item).where(Item.id == id))).first() |
| 21 | + return item.to_dict() |
| 22 | + |
| 23 | + |
| 24 | +@router.get("/similar") |
| 25 | +async def similar_handler(id: int, n: int = 5): |
| 26 | + """A similarity API to find items similar to items with given ID.""" |
| 27 | + async_session_maker = async_sessionmaker(global_storage.engine, expire_on_commit=False) |
| 28 | + async with async_session_maker() as session: |
| 29 | + item = (await session.scalars(select(Item).where(Item.id == id))).first() |
| 30 | + closest = await session.execute( |
| 31 | + select(Item, Item.embedding.l2_distance(item.embedding)) |
| 32 | + .filter(Item.id != id) |
| 33 | + .order_by(Item.embedding.l2_distance(item.embedding)) |
| 34 | + .limit(n) |
| 35 | + ) |
| 36 | + return [item.to_dict() | {"distance": round(distance, 2)} for item, distance in closest] |
| 37 | + |
| 38 | + |
12 | 39 | @router.post("/chat") |
13 | 40 | async def chat_handler(chat_request: ChatRequest): |
14 | 41 | messages = [message.model_dump() for message in chat_request.messages] |
|
0 commit comments