Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 150 additions & 128 deletions backend/modules/metadata_store/prisma_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@
from prisma.models import RagApps as PrismaRagApplication

# TODO (chiragjn):
# - Use transactions!
# - Some methods are using json.dumps - not sure if this is the right way to send data via prisma client
# - primsa generates its own DB entity classes - ideally we should be using those instead of call
# - prisma generates its own DB entity classes - ideally we should be using those instead of call
# .model_dump() on the pydantic objects. See prisma.models and prisma.actions
#

Expand Down Expand Up @@ -156,97 +155,112 @@ async def aassociate_data_sources_with_collection(
collection_name: str,
data_source_associations: List[AssociateDataSourceWithCollection],
) -> Collection:
# Get the collection by name
collection = await self.aget_collection_by_name(collection_name)
# Use transaction to ensure atomicity
async with self.db.tx() as transaction:
# Get the collection by name
collection_record = await transaction.collection.find_first_or_raise(
where={"name": collection_name}
)
collection = Collection.model_validate(collection_record.model_dump())

# Get the existing associated data sources. If not found, initialize an empty dict.
existing_associated_data_sources = collection.associated_data_sources
# Get the existing associated data sources. If not found, initialize an empty dict.
existing_associated_data_sources = collection.associated_data_sources

# Fetch all data sources in a single query
data_source_fqns = [assoc.data_source_fqn for assoc in data_source_associations]
data_sources = await self.db.datasource.find_many(
where={"fqn": {"in": data_source_fqns}}
)
data_sources_dict = {ds.fqn: ds for ds in data_sources}

# Create AssociatedDataSources objects and update the existing_associated_data_sources
for assoc in data_source_associations:
data_source = data_sources_dict.get(assoc.data_source_fqn)
if not data_source:
raise HTTPException(
status_code=404,
detail=f"Data source with fqn {assoc.data_source_fqn} not found",
# Fetch all data sources in a single query
data_source_fqns = [
assoc.data_source_fqn for assoc in data_source_associations
]
data_sources = await transaction.datasource.find_many(
where={"fqn": {"in": data_source_fqns}}
)
data_sources_dict = {ds.fqn: ds for ds in data_sources}

# Create AssociatedDataSources objects and update the existing_associated_data_sources
for assoc in data_source_associations:
data_source = data_sources_dict.get(assoc.data_source_fqn)
if not data_source:
raise HTTPException(
status_code=404,
detail=f"Data source with fqn {assoc.data_source_fqn} not found",
)

data_src_to_associate = AssociatedDataSources(
data_source_fqn=assoc.data_source_fqn,
parser_config=assoc.parser_config,
data_source=DataSource.model_validate(data_source.model_dump()),
)
existing_associated_data_sources[assoc.data_source_fqn] = (
data_src_to_associate
)

data_src_to_associate = AssociatedDataSources(
data_source_fqn=assoc.data_source_fqn,
parser_config=assoc.parser_config,
data_source=DataSource.model_validate(data_source.model_dump()),
)
existing_associated_data_sources[
assoc.data_source_fqn
] = data_src_to_associate

# Convert the existing associated data sources to a dictionary
associated_data_sources = {
fqn: data_source.model_dump()
for fqn, data_source in existing_associated_data_sources.items()
}

# Update the collection with the new associated data sources
updated_collection = await self.db.collection.update(
where={"name": collection_name},
data={"associated_data_sources": json.dumps(associated_data_sources)},
)
# Convert the existing associated data sources to a dictionary
associated_data_sources = {
fqn: data_source.model_dump()
for fqn, data_source in existing_associated_data_sources.items()
}

# If the update fails, raise an HTTPException
if not updated_collection:
raise HTTPException(
status_code=404,
detail=f"Failed to associate data sources with collection {collection_name!r}. No such record found",
# Update the collection with the new associated data sources
updated_collection = await transaction.collection.update(
where={"name": collection_name},
data={"associated_data_sources": json.dumps(associated_data_sources)},
)

# Validate the updated collection and return it
return Collection.model_validate(updated_collection.model_dump())
# If the update fails, raise an HTTPException
if not updated_collection:
raise HTTPException(
status_code=404,
detail=f"Failed to associate data sources with collection {collection_name!r}. No such record found",
)

# Validate the updated collection and return it
return Collection.model_validate(updated_collection.model_dump())

async def aunassociate_data_source_with_collection(
self, collection_name: str, data_source_fqn: str
) -> Collection:
# Get the collection by name
collection = await self.aget_collection_by_name(collection_name)
# Get the existing associated data sources
associated_data_sources = collection.associated_data_sources or {}
# If the data source is not associated with the collection, deletion is not possible and an error is raised
if data_source_fqn not in associated_data_sources:
raise HTTPException(
status_code=400,
detail=f"Data source with fqn {data_source_fqn!r} not associated with collection {collection_name!r}",
# Use transaction to ensure atomicity
async with self.db.tx() as transaction:
# Get the collection by name
collection_record = await transaction.collection.find_first_or_raise(
where={"name": collection_name}
)
# Remove the data source from associated data sources
associated_data_sources.pop(data_source_fqn)

# Convert the associated data sources to a dictionary
updated_associated_data_sources = {
fqn: ds.model_dump() for fqn, ds in associated_data_sources.items()
}

# Update the collection with the new associated data sources
updated_collection = await self.db.collection.update(
where={"name": collection_name},
data={
"associated_data_sources": json.dumps(updated_associated_data_sources)
},
)
collection = Collection.model_validate(collection_record.model_dump())

# If the update fails, raise an HTTPException
if not updated_collection:
raise HTTPException(
status_code=404,
detail=f"Failed to unassociate data source from collection {collection_name!r}. No such record found",
# Get the existing associated data sources
associated_data_sources = collection.associated_data_sources or {}
# If the data source is not associated with the collection, deletion is not possible and an error is raised
if data_source_fqn not in associated_data_sources:
raise HTTPException(
status_code=400,
detail=f"Data source with fqn {data_source_fqn!r} not associated with collection {collection_name!r}",
)
# Remove the data source from associated data sources
associated_data_sources.pop(data_source_fqn)

# Convert the associated data sources to a dictionary
updated_associated_data_sources = {
fqn: ds.model_dump() for fqn, ds in associated_data_sources.items()
}

# Update the collection with the new associated data sources
updated_collection = await transaction.collection.update(
where={"name": collection_name},
data={
"associated_data_sources": json.dumps(
updated_associated_data_sources
)
},
)

# Validate the updated collection and return it
return Collection.model_validate(updated_collection.model_dump())
# If the update fails, raise an HTTPException
if not updated_collection:
raise HTTPException(
status_code=404,
detail=f"Failed to unassociate data source from collection {collection_name!r}. No such record found",
)

# Validate the updated collection and return it
return Collection.model_validate(updated_collection.model_dump())

async def alist_data_sources(
self,
Expand All @@ -255,36 +269,44 @@ async def alist_data_sources(
return [data_source.model_dump() for data_source in data_sources]

async def adelete_data_source(self, data_source_fqn: str) -> None:
# Fetch all collections
collections = await self.aget_collections()
# Check if data source is associated with any collection
for collection in collections:
associated_data_sources = collection.associated_data_sources or {}
# If data source is associated with any collection, raise an error prompting the user to either delete the collection or unassociate the data source from the collection
if data_source_fqn in associated_data_sources:
logger.error(
f"Data source with fqn {data_source_fqn} is already associated with "
f"collection {collection.name}"
)
# Use transaction to ensure atomicity
async with self.db.tx() as transaction:
# Fetch all collections
collections_records = await transaction.collection.find_many(
order={"id": "desc"}
)
collections = [
Collection.model_validate(c.model_dump()) for c in collections_records
]

# Check if data source is associated with any collection
for collection in collections:
associated_data_sources = collection.associated_data_sources or {}
# If data source is associated with any collection, raise an error prompting the user to either delete the collection or unassociate the data source from the collection
if data_source_fqn in associated_data_sources:
logger.error(
f"Data source with fqn {data_source_fqn} is already associated with "
f"collection {collection.name}"
)
raise HTTPException(
status_code=400,
detail=f"Data source with fqn {data_source_fqn} is associated "
f"with collection {collection.name}. Delete the necessary collections "
f"or unassociate them from the collection(s) before deleting the data source",
)

# Delete the data source
deleted_datasource: Optional["PrismaDataSource"] = (
await transaction.datasource.delete(where={"fqn": data_source_fqn})
)

if not deleted_datasource:
raise HTTPException(
status_code=400,
detail=f"Data source with fqn {data_source_fqn} is associated "
f"with collection {collection.name}. Delete the necessary collections "
f"or unassociate them from the collection(s) before deleting the data source",
status_code=404,
detail=f"Failed to delete data source {data_source_fqn!r}. No such record found",
)

# Delete the data source
deleted_datasource: Optional[
PrismaDataSource
] = await self.db.datasource.delete(where={"fqn": data_source_fqn})

if not deleted_datasource:
raise HTTPException(
status_code=404,
detail=f"Failed to delete data source {data_source_fqn!r}. No such record found",
)

return DataSource.model_validate(deleted_datasource.model_dump())
return DataSource.model_validate(deleted_datasource.model_dump())

######
# DATA INGESTION RUN APIS
Expand Down Expand Up @@ -319,10 +341,10 @@ async def acreate_data_ingestion_run(
async def aget_data_ingestion_run(
self, data_ingestion_run_name: str, no_cache: bool = False
) -> Optional[DataIngestionRun]:
data_ingestion_run: Optional[
"PrismaDataIngestionRun"
] = await self.db.ingestionruns.find_first(
where={"name": data_ingestion_run_name}
data_ingestion_run: Optional["PrismaDataIngestionRun"] = (
await self.db.ingestionruns.find_first(
where={"name": data_ingestion_run_name}
)
)
logger.info(f"Data ingestion run: {data_ingestion_run}")
if data_ingestion_run:
Expand All @@ -333,10 +355,10 @@ async def aget_data_ingestion_runs(
self, collection_name: str, data_source_fqn: str = None
) -> List[DataIngestionRun]:
"""Get all data ingestion runs for a collection"""
data_ingestion_runs: List[
"PrismaDataIngestionRun"
] = await self.db.ingestionruns.find_many(
where={"collection_name": collection_name}, order={"id": "desc"}
data_ingestion_runs: List["PrismaDataIngestionRun"] = (
await self.db.ingestionruns.find_many(
where={"collection_name": collection_name}, order={"id": "desc"}
)
)
return [
DataIngestionRun.model_validate(data_ir.model_dump())
Expand All @@ -347,10 +369,10 @@ async def aupdate_data_ingestion_run_status(
self, data_ingestion_run_name: str, status: DataIngestionRunStatus
) -> DataIngestionRun:
"""Update the status of a data ingestion run"""
updated_data_ingestion_run: Optional[
"PrismaDataIngestionRun"
] = await self.db.ingestionruns.update(
where={"name": data_ingestion_run_name}, data={"status": status}
updated_data_ingestion_run: Optional["PrismaDataIngestionRun"] = (
await self.db.ingestionruns.update(
where={"name": data_ingestion_run_name}, data={"status": status}
)
)
if not updated_data_ingestion_run:
raise HTTPException(
Expand All @@ -364,11 +386,11 @@ async def alog_errors_for_data_ingestion_run(
self, data_ingestion_run_name: str, errors: Dict[str, Any]
) -> None:
"""Log errors for the given data ingestion run"""
updated_data_ingestion_run: Optional[
"PrismaDataIngestionRun"
] = await self.db.ingestionruns.update(
where={"name": data_ingestion_run_name},
data={"errors": json.dumps(errors)},
updated_data_ingestion_run: Optional["PrismaDataIngestionRun"] = (
await self.db.ingestionruns.update(
where={"name": data_ingestion_run_name},
data={"errors": json.dumps(errors)},
)
)
if not updated_data_ingestion_run:
raise HTTPException(
Expand All @@ -381,9 +403,9 @@ async def alog_errors_for_data_ingestion_run(
######
async def aget_rag_app(self, app_name: str) -> Optional[RagApplication]:
"""Get a RAG application from the metadata store"""
rag_app: Optional[
"PrismaRagApplication"
] = await self.db.ragapps.find_first_or_raise(where={"name": app_name})
rag_app: Optional["PrismaRagApplication"] = (
await self.db.ragapps.find_first_or_raise(where={"name": app_name})
)

return RagApplication.model_validate(rag_app.model_dump())

Expand All @@ -403,9 +425,9 @@ async def alist_rag_apps(self) -> List[str]:

async def adelete_rag_app(self, app_name: str):
"""Delete a RAG application from the metadata store"""
deleted_rag_app: Optional[
"PrismaRagApplication"
] = await self.db.ragapps.delete(where={"name": app_name})
deleted_rag_app: Optional["PrismaRagApplication"] = (
await self.db.ragapps.delete(where={"name": app_name})
)
if not deleted_rag_app:
raise HTTPException(
status_code=404,
Expand Down