11from __future__ import annotations
22
3- from dataclasses import asdict
4-
53from pgvector .sqlalchemy import Vector
6- from sqlalchemy import Index
7- from sqlalchemy .orm import DeclarativeBase , Mapped , MappedAsDataclass , mapped_column
4+ from sqlalchemy import Index , String
5+ from sqlalchemy .dialects .postgresql import ARRAY
6+ from sqlalchemy .orm import DeclarativeBase , Mapped , mapped_column
87
98
109# Define the models
11- class Base (DeclarativeBase , MappedAsDataclass ):
10+ class Base (DeclarativeBase ):
1211 pass
1312
1413
1514class Item (Base ):
16- __tablename__ = "items "
17- id : Mapped [ int ] = mapped_column ( primary_key = True , autoincrement = True )
18- type : Mapped [str ] = mapped_column ()
19- brand : Mapped [ str ] = mapped_column ()
20- name : Mapped [str ] = mapped_column ()
15+ __tablename__ = "sessions "
16+ # An ID column should always be defined, but it can be int or string
17+ id : Mapped [str ] = mapped_column (primary_key = True )
18+ # Schema specific:
19+ title : Mapped [str ] = mapped_column ()
2120 description : Mapped [str ] = mapped_column ()
22- price : Mapped [float ] = mapped_column ()
23- embedding_ada002 : Mapped [Vector ] = mapped_column (Vector (1536 )) # ada-002
24- embedding_nomic : Mapped [Vector ] = mapped_column (Vector (768 )) # nomic-embed-text
21+ speakers : Mapped [list [str ]] = mapped_column (ARRAY (String ))
22+ tracks : Mapped [list [str ]] = mapped_column (ARRAY (String ))
23+ day : Mapped [str ] = mapped_column ()
24+ time : Mapped [str ] = mapped_column ()
25+ mode : Mapped [str ] = mapped_column ()
26+ # Embeddings for different models:
27+ embedding_ada002 : Mapped [Vector ] = mapped_column (Vector (1536 ), nullable = True ) # ada-002
28+ embedding_nomic : Mapped [Vector ] = mapped_column (Vector (768 ), nullable = True ) # nomic-embed-text
2529
2630 def to_dict (self , include_embedding : bool = False ):
27- model_dict = asdict (self )
31+ model_dict = { column . name : getattr (self , column . name ) for column in self . __table__ . columns }
2832 if include_embedding :
2933 model_dict ["embedding_ada002" ] = model_dict .get ("embedding_ada002" , [])
3034 model_dict ["embedding_nomic" ] = model_dict .get ("embedding_nomic" , [])
@@ -34,23 +38,24 @@ def to_dict(self, include_embedding: bool = False):
3438 return model_dict
3539
3640 def to_str_for_rag (self ):
37- return f"Name :{ self .name } Description:{ self .description } Price :{ self .price } Brand :{ self .brand } Type :{ self .type } "
41+ return f"Title :{ self .title } Description:{ self .description } Speakers :{ self .speakers } Tracks :{ self .tracks } Day :{ self .day } Time: { self . time } Mode: { self . mode } " # noqa
3842
3943 def to_str_for_embedding (self ):
40- return f"Name: { self .name } Description: { self .description } Type : { self .type } "
44+ return f"Name: { self .title } Description: { self .description } Tracks : { self .tracks } Day: { self . day } Mode: { self . mode } " # noqa
4145
4246
4347# Define HNSW index to support vector similarity search through the vector_cosine_ops access method (cosine distance).
4448index_ada002 = Index (
45- "hnsw_index_for_innerproduct_item_embedding_ada002" ,
49+ # TODO: generate based off table name
50+ "hnsw_index_for_innerproduct_session_embedding_ada002" ,
4651 Item .embedding_ada002 ,
4752 postgresql_using = "hnsw" ,
4853 postgresql_with = {"m" : 16 , "ef_construction" : 64 },
4954 postgresql_ops = {"embedding_ada002" : "vector_ip_ops" },
5055)
5156
5257index_nomic = Index (
53- "hnsw_index_for_innerproduct_item_embedding_nomic " ,
58+ "hnsw_index_for_innerproduct_session_embedding_nomic " ,
5459 Item .embedding_nomic ,
5560 postgresql_using = "hnsw" ,
5661 postgresql_with = {"m" : 16 , "ef_construction" : 64 },
0 commit comments