Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self._create_table_if_not_exists()

def _create_table_if_not_exists(self) -> None:
create_table_query = f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
create_table_query = f"""CREATE TABLE IF NOT EXISTS `{self.table_name}` (
id INT AUTO_INCREMENT PRIMARY KEY,
session_id TEXT NOT NULL,
data JSON NOT NULL,
Expand All @@ -50,9 +50,11 @@ def _create_table_if_not_exists(self) -> None:
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from Cloud SQL"""
query = f"SELECT data, type FROM {self.table_name} WHERE session_id = '{self.session_id}' ORDER BY id;"
query = f"SELECT data, type FROM `{self.table_name}` WHERE session_id = :session_id ORDER BY id;"
with self.engine.connect() as conn:
results = conn.execute(sqlalchemy.text(query)).fetchall()
results = conn.execute(
sqlalchemy.text(query), {"session_id": self.session_id}
).fetchall()
# load SQLAlchemy row objects into dicts
items = [
{"data": json.loads(result[0]), "type": result[1]} for result in results
Expand All @@ -62,7 +64,7 @@ def messages(self) -> List[BaseMessage]: # type: ignore

def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in Cloud SQL"""
query = f"INSERT INTO {self.table_name} (session_id, data, type) VALUES (:session_id, :data, :type);"
query = f"INSERT INTO `{self.table_name}` (session_id, data, type) VALUES (:session_id, :data, :type);"
with self.engine.connect() as conn:
conn.execute(
sqlalchemy.text(query),
Expand All @@ -76,7 +78,7 @@ def add_message(self, message: BaseMessage) -> None:

def clear(self) -> None:
"""Clear session memory from Cloud SQL"""
query = f"DELETE FROM {self.table_name} WHERE session_id = :session_id;"
query = f"DELETE FROM `{self.table_name}` WHERE session_id = :session_id;"
with self.engine.connect() as conn:
conn.execute(sqlalchemy.text(query), {"session_id": self.session_id})
conn.commit()
20 changes: 20 additions & 0 deletions tests/integration/test_mysql_chat_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,23 @@ def test_chat_message_history(memory_engine: MySQLEngine) -> None:
# verify clear() clears message history
history.clear()
assert len(history.messages) == 0


def test_chat_message_history_custom_table_name(memory_engine: MySQLEngine) -> None:
"""Test MySQLChatMessageHistory with custom table name"""
history = MySQLChatMessageHistory(
engine=memory_engine, session_id="test", table_name="message-store"
)
history.add_user_message("hi!")
history.add_ai_message("whats up?")
messages = history.messages

# verify messages are correct
assert messages[0].content == "hi!"
assert type(messages[0]) is HumanMessage
assert messages[1].content == "whats up?"
assert type(messages[1]) is AIMessage

# verify clear() clears message history
history.clear()
assert len(history.messages) == 0