diff --git a/CHANGELOG.md b/CHANGELOG.md index a654307..851d2dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,21 @@ # Changelog +## [0.3.0](https://github.com/googleapis/langchain-google-bigtable-python/compare/v0.2.3...v0.3.0) (2024-08-14) + + +### Features + +* Add comments to all exported functions ([#83](https://github.com/googleapis/langchain-google-bigtable-python/issues/83)) ([17f8f52](https://github.com/googleapis/langchain-google-bigtable-python/commit/17f8f52ad15b0ae156fdfdfa31c6d5a684e812c7)) +* Add init_chat_history_table function ([#86](https://github.com/googleapis/langchain-google-bigtable-python/issues/86)) ([d68d173](https://github.com/googleapis/langchain-google-bigtable-python/commit/d68d17329625401b4ba51a793a4ce4d00ea097c9)) +* **bigtable:** Add init_document_table function ([#87](https://github.com/googleapis/langchain-google-bigtable-python/issues/87)) ([e114de0](https://github.com/googleapis/langchain-google-bigtable-python/commit/e114de0c4ab0f28ed4c36f008ab6c39149861fb5)) +* Remove dependency on langchain-community ([#82](https://github.com/googleapis/langchain-google-bigtable-python/issues/82)) ([5d3d509](https://github.com/googleapis/langchain-google-bigtable-python/commit/5d3d50963ebfcac7c268e086b7e943bc738ad5e0)) +* Updated docs for recently added features ([#88](https://github.com/googleapis/langchain-google-bigtable-python/issues/88)) ([83495de](https://github.com/googleapis/langchain-google-bigtable-python/commit/83495decc4ea85a4f57563e8bc8d3c12f029e22e)) + + +### Documentation + +* Update README.rst ([#81](https://github.com/googleapis/langchain-google-bigtable-python/issues/81)) ([bc588b8](https://github.com/googleapis/langchain-google-bigtable-python/commit/bc588b8c041a6efdd8b9903ec6f0d0255195e402)) + ## [0.2.3](https://github.com/googleapis/langchain-google-bigtable-python/compare/v0.2.2...v0.2.3) (2024-07-30) diff --git a/README.rst b/README.rst index b3e667b..38528eb 100644 --- a/README.rst +++ b/README.rst @@ -88,7 +88,7 @@ See the full `Document Loader`_ tutorial. .. _`Document Loader`: https://github.com/googleapis/langchain-google-bigtable-python/blob/main/docs/document_loader.ipynb Chat Message History Usage --------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~ Use ``ChatMessageHistory`` to store messages and provide conversation history to LLMs. diff --git a/docs/chat_message_history.ipynb b/docs/chat_message_history.ipynb index 82188ad..5c337c7 100644 --- a/docs/chat_message_history.ipynb +++ b/docs/chat_message_history.ipynb @@ -160,9 +160,9 @@ "outputs": [], "source": [ "from google.cloud import bigtable\n", - "from langchain_google_bigtable import create_chat_history_table\n", + "from langchain_google_bigtable import init_chat_history_table\n", "\n", - "create_chat_history_table(\n", + "init_chat_history_table(\n", " instance_id=INSTANCE_ID,\n", " table_id=TABLE_ID,\n", ")" @@ -199,6 +199,39 @@ "message_history.add_ai_message(\"whats up?\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Adding multiple messages\n", + "\n", + "In order to add many messages efficiently, use the `add_messages()` function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.messages import AIMessage\n", + "\n", + "messages = []\n", + "messages.append(AIMessage(content=\"message 1\"))\n", + "messages.append(AIMessage(content=\"message 2\"))\n", + "messages.append(AIMessage(content=\"message 3\"))\n", + "message_history.add_messages(messages)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reading messages\n", + "\n", + "To read all the messages, use the `messages` property." + ] + }, { "cell_type": "code", "execution_count": null, @@ -251,9 +284,9 @@ "source": [ "from google.cloud import bigtable\n", "\n", - "client = (bigtable.Client(...),)\n", + "client = (bigtable.Client(admin=True, project=PROJECT_ID),)\n", "\n", - "create_chat_history_table(\n", + "init_chat_history_table(\n", " instance_id=\"my-instance\",\n", " table_id=\"my-table\",\n", " client=client,\n", diff --git a/docs/document_loader.ipynb b/docs/document_loader.ipynb index cba80a6..224c8d1 100644 --- a/docs/document_loader.ipynb +++ b/docs/document_loader.ipynb @@ -310,7 +310,7 @@ "custom_client_loader = BigtableLoader(\n", " INSTANCE_ID,\n", " TABLE_ID,\n", - " client=bigtable.Client(...),\n", + " client=bigtable.Client(admin=True, project=PROJECT_ID),\n", ")" ] }, @@ -446,6 +446,49 @@ " metadata_as_json_name=\"my_metadata_as_json_column_name\",\n", ")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Creating a table\n", + "\n", + "To create a table, use `init_document_table` with the instance, table and bigtable client, along with the column families you wish to create.\n", + "If the table already exists, google.api_core.exceptions.AlreadyExists error is thrown." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_google_bigtable.loader import init_document_table\n", + "\n", + "init_document_table(\n", + " INSTANCE_ID,\n", + " TABLE_ID,\n", + " client=bigtable.Client(...),\n", + " content_column_family=\"my_content_family\",\n", + " metadata_mappings=[\n", + " MetadataMapping(\n", + " column_family=\"my_int_family\",\n", + " column_name=\"my_int_column\",\n", + " metadata_key=\"key_in_metadata_map\",\n", + " encoding=Encoding.INT_BIG_ENDIAN,\n", + " ),\n", + " MetadataMapping(\n", + " column_family=\"my_custom_family\",\n", + " column_name=\"my_custom_column\",\n", + " metadata_key=\"custom_key\",\n", + " encoding=Encoding.CUSTOM,\n", + " custom_decoding_func=lambda input: json.loads(input.decode()),\n", + " custom_encoding_func=lambda input: str.encode(json.dumps(input)),\n", + " ),\n", + " ],\n", + " metadata_as_json_family=\"my_metadata_as_json_family\",\n", + ")" + ] } ], "metadata": { diff --git a/pyproject.toml b/pyproject.toml index 9b9eaf8..34efb53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,8 +10,8 @@ authors = [ ] dependencies = [ "langchain-core>=0.1.1, <1.0.0", - "langchain-community>=0.0.18, <1.0.0", - "google-cloud-bigtable>=2.22.0, <3.0.0" + "google-cloud-bigtable>=2.22.0, <3.0.0", + "deprecated>=1.2.14, <2.0.0" ] classifiers = [ "Intended Audience :: Developers", @@ -36,9 +36,9 @@ Changelog = "https://github.com/googleapis/langchain-google-bigtable-python/blob%20%20[project.optional-dependencies]%20test%20=%20[-"black[jupyter]==24.4.2", + "black[jupyter]==24.8.0", "isort==5.13.2", - "mypy==1.11.0", + "mypy==1.11.1", "pytest-asyncio==0.23.8", "pytest==8.3.2", "pytest-cov==5.0.0" diff --git a/requirements.txt b/requirements.txt index 059837a..e03385c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -langchain-core==0.2.24 -langchain-community==0.2.10 -google-cloud-bigtable==2.25.0 +langchain-core==0.2.30 +google-cloud-bigtable==2.26.0 +deprecated==1.2.14 diff --git a/src/langchain_google_bigtable/__init__.py b/src/langchain_google_bigtable/__init__.py index b9d98a4..cc38366 100644 --- a/src/langchain_google_bigtable/__init__.py +++ b/src/langchain_google_bigtable/__init__.py @@ -13,13 +13,25 @@ # limitations under the License. -from .chat_message_history import BigtableChatMessageHistory, create_chat_history_table -from .loader import BigtableLoader, BigtableSaver, Encoding, MetadataMapping +from .chat_message_history import ( + BigtableChatMessageHistory, + create_chat_history_table, + init_chat_history_table, +) +from .loader import ( + BigtableLoader, + BigtableSaver, + Encoding, + MetadataMapping, + init_document_table, +) from .version import __version__ __all__ = [ "BigtableChatMessageHistory", "create_chat_history_table", + "init_document_table", + "init_chat_history_table", "BigtableLoader", "BigtableSaver", "MetadataMapping", diff --git a/src/langchain_google_bigtable/chat_message_history.py b/src/langchain_google_bigtable/chat_message_history.py index 4e06a40..39760fa 100644 --- a/src/langchain_google_bigtable/chat_message_history.py +++ b/src/langchain_google_bigtable/chat_message_history.py @@ -19,9 +19,11 @@ import re import time import uuid -from typing import List, Optional +from typing import List, Optional, Sequence +from deprecated import deprecated from google.cloud import bigtable # type: ignore +from google.cloud.bigtable.row import DirectRow from google.cloud.bigtable.row_filters import RowKeyRegexFilter # type: ignore from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage, messages_from_dict @@ -32,11 +34,17 @@ COLUMN_NAME = "history" -def create_chat_history_table( +def init_chat_history_table( instance_id: str, table_id: str, client: Optional[bigtable.Client] = None, ) -> None: + """Create a table to store chat history. + Args: + instance_id: The Bigtable instance to use for chat message history. + table_id: The Bigtable table to use for chat message history. + client : Optional. The pre-created client to query bigtable. + """ table_client = ( use_client_or_default(client, "chat_history") .instance(instance_id) @@ -52,6 +60,15 @@ def create_chat_history_table( ).create() +@deprecated(reason="Use init_chat_history_table") +def create_chat_history_table( + instance_id: str, + table_id: str, + client: Optional[bigtable.Client] = None, +) -> None: + init_chat_history_table(instance_id, table_id, client) + + class BigtableChatMessageHistory(BaseChatMessageHistory): """Chat message history that stores history in Bigtable. @@ -102,9 +119,25 @@ def messages(self) -> List[BaseMessage]: # type: ignore ) return messages + def add_messages(self, messages: Sequence[BaseMessage]) -> None: + """Write messages to the table""" + batcher = self.table_client.mutations_batcher() + for message in messages: + row = self.__message_to_row(message) + batcher.mutate(row) + batcher.flush() + def add_message(self, message: BaseMessage) -> None: """Write a message to the table""" + row = self.__message_to_row(message) + row.commit() + + def clear(self) -> None: + """Clear session memory from DB""" + row_key_prefix = self.session_id + self.table_client.drop_by_prefix(row_key_prefix) + def __message_to_row(self, message: BaseMessage) -> DirectRow: row_key = str.encode( self.session_id + "#" @@ -115,9 +148,4 @@ def add_message(self, message: BaseMessage) -> None: row = self.table_client.direct_row(row_key) value = str.encode(message.json()) row.set_cell(COLUMN_FAMILY, COLUMN_NAME, value) - row.commit() - - def clear(self) -> None: - """Clear session memory from DB""" - row_key_prefix = self.session_id - self.table_client.drop_by_prefix(row_key_prefix) + return row diff --git a/src/langchain_google_bigtable/loader.py b/src/langchain_google_bigtable/loader.py index ff5c3ea..cfac301 100644 --- a/src/langchain_google_bigtable/loader.py +++ b/src/langchain_google_bigtable/loader.py @@ -18,10 +18,10 @@ import uuid from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Iterator, List, Optional +from typing import Any, Callable, Dict, Iterator, List, Optional from google.cloud import bigtable # type: ignore -from langchain_community.document_loaders.base import BaseLoader +from langchain_core.document_loaders.base import BaseLoader from langchain_core.documents import Document from .common import use_client_or_default @@ -153,6 +153,7 @@ def __init__( self.metadata_as_json_encoding = metadata_as_json_encoding def load(self) -> List[Document]: + """Load data into Document objects.""" return list(self.lazy_load()) def lazy_load( @@ -225,6 +226,36 @@ def _decode(self, value: bytes, mapping: MetadataMapping) -> Any: raise ValueError(f"Invalid encoding {mapping.encoding}") +def init_document_table( + instance_id: str, + table_id: str, + client: Optional[bigtable.Client] = None, + content_column_family: str = COLUMN_FAMILY, + metadata_mappings: List[MetadataMapping] = [], + metadata_as_json_column_family: Optional[str] = None, +) -> None: + """ + Create a table for saving of langchain documents. + If table already exists, a google.api_core.exceptions.AlreadyExists error is thrown. + """ + table_client = ( + use_client_or_default(client, "document_saver") + .instance(instance_id) + .table(table_id) + ) + + families: Dict[str, bigtable.column_family.gc_rule] = dict() + if content_column_family: + families[content_column_family] = bigtable.column_family.MaxVersionsGCRule(1) + if metadata_as_json_column_family: + families[metadata_as_json_column_family] = ( + bigtable.column_family.MaxVersionsGCRule(1) + ) + for mapping in metadata_mappings: + families[mapping.column_family] = bigtable.column_family.MaxVersionsGCRule(1) + table_client.create(column_families=families) + + class BigtableSaver: """Load from the Google Cloud Platform `Bigtable`.""" @@ -308,6 +339,13 @@ def __init__( self.metadata_as_json_encoding = metadata_as_json_encoding def add_documents(self, docs: List[Document]) -> None: + """ + Save documents in the DocumentSaver table. Document's metadata is added to columns if found or + stored in langchain_metadata JSON column. + + Args: + docs (List[langchain_core.documents.Document]): a list of documents to be saved. + """ batcher = self.client.mutations_batcher() for doc in docs: row_key = doc.metadata.get(ID_METADATA_KEY) or uuid.uuid4().hex @@ -340,6 +378,13 @@ def add_documents(self, docs: List[Document]) -> None: batcher.flush() def delete(self, docs: List[Document]) -> None: + """ + Delete all instances of a document from the DocumentSaver table by matching the entire Document + object. + + Args: + docs (List[langchain_core.documents.Document]): a list of documents to be deleted. + """ batcher = self.client.mutations_batcher() for doc in docs: row = self.client.direct_row(doc.metadata.get(ID_METADATA_KEY)) diff --git a/src/langchain_google_bigtable/version.py b/src/langchain_google_bigtable/version.py index f6bd330..2b5d97a 100644 --- a/src/langchain_google_bigtable/version.py +++ b/src/langchain_google_bigtable/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.3" +__version__ = "0.3.0" diff --git a/tests/test_bigtable_chat_message_history.py b/tests/test_bigtable_chat_message_history.py index f8d82e2..d97fb4f 100644 --- a/tests/test_bigtable_chat_message_history.py +++ b/tests/test_bigtable_chat_message_history.py @@ -19,7 +19,6 @@ import string import time import uuid -from multiprocessing import Process from typing import Iterator import pytest @@ -28,7 +27,7 @@ from langchain_google_bigtable.chat_message_history import ( BigtableChatMessageHistory, - create_chat_history_table, + init_chat_history_table, ) TABLE_ID_PREFIX = "test-table-history-" @@ -52,7 +51,7 @@ def table_id(instance_id: str, client: bigtable.Client) -> Iterator[str]: random.choice(string.ascii_lowercase) for _ in range(10) ) # Create table and column family - create_chat_history_table(instance_id=instance_id, table_id=table_id, client=client) + init_chat_history_table(instance_id=instance_id, table_id=table_id, client=client) yield table_id @@ -98,38 +97,13 @@ def test_bigtable_loads_of_messages( instance_id, table_id, session_id, client=client ) - def add_ai_message(history, i): - try: - history.add_ai_message(f"Hey! I am AI! Index: {2*i}") - except Exception as e: - print(e) - - def add_user_message(history, i): - try: - history.add_user_message(f"Hey! I am human! Index: {2*i+1}") - except Exception as e: - print(e) - - proc = [] + ai_messages = [] + human_messages = [] for i in range(NUM_MESSAGES): - proc.append( - Process( - target=lambda i: add_ai_message(history, i), - args=[i], - ) - ) - proc.append( - Process( - target=lambda i: add_user_message(history, i), - args=[i], - ) - ) - - for p in proc: - p.start() - - for p in proc: - p.join() + ai_messages.append(AIMessage(content=f"Hey! I am AI! Index: {2*i}")) + human_messages.append(HumanMessage(content=f"Hey! I am human! Index: {2*i+1}")) + history.add_messages(ai_messages) + history.add_messages(human_messages) # wait for eventual consistency time.sleep(5) diff --git a/tests/test_bigtable_loader.py b/tests/test_bigtable_loader.py index 4d7f44d..dc2b30b 100644 --- a/tests/test_bigtable_loader.py +++ b/tests/test_bigtable_loader.py @@ -20,6 +20,7 @@ from typing import Iterator import pytest +from google.api_core.exceptions import AlreadyExists from google.cloud import bigtable # type: ignore from google.cloud.bigtable import column_family, row_filters # type: ignore from langchain_core.documents import Document @@ -29,6 +30,7 @@ BigtableSaver, Encoding, MetadataMapping, + init_document_table, ) TABLE_ID_PREFIX = "test-table-loader-" @@ -666,6 +668,57 @@ def test_bigtable_metadata_as_json_execution_order( } +def test_table_creation( + instance_id: str, table_id: str, client: bigtable.Client +) -> None: + # Cleanup default table created by test framework. + table_client = client.instance(instance_id).table(table_id) + table_client.delete() + + # Create table. + init_document_table(instance_id, table_id, client) + + # Assert table exists. + assert table_client.exists() + assert sorted(table_client.list_column_families().keys()) == ["langchain"] + # Expect second creation to fail. + with pytest.raises(AlreadyExists): + init_document_table(instance_id, table_id, client) + + # Delete table. + table_client.delete() + assert not table_client.exists() + + # Create with column families. + content_column_family = "content_column_family" + first_column_from_mapping = "first_column_from_mapping" + second_column_from_mapping = "second_column_from_mapping" + metadata_as_json_column_family = "metadata_as_json_column_family" + init_document_table( + instance_id, + table_id, + client, + content_column_family, + [ + MetadataMapping(first_column_from_mapping, "", "", Encoding.ASCII), + MetadataMapping(second_column_from_mapping, "", "", Encoding.ASCII), + ], + metadata_as_json_column_family, + ) + expected_families = [ + content_column_family, + first_column_from_mapping, + second_column_from_mapping, + metadata_as_json_column_family, + ] + + # Assert successful creation. + assert table_client.exists() + created_families = table_client.list_column_families().keys() + assert len(created_families) == len(expected_families) + assert sorted(created_families) == sorted(expected_families) + + def get_env_var(key: str, desc: str) -> str: v = os.environ.get(key) if v is None: