Source code for manolo_bot.storage.base

import abc
import json
from abc import abstractmethod
from dataclasses import dataclass

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage


[docs] def get_messages_key(bot_uuid: str, chat_id: int) -> str: """ Generates a key for storing messages in a database based on bot UUID and chat ID. """ return f"{bot_uuid}:{chat_id}"
[docs] def convert_json_to_message(json_message: str) -> BaseMessage: """ Converts a JSON string representation of a message into a BaseMessage instance. """ message = json.loads(json_message) message_type = message.get("type") if message_type == "system": return SystemMessage(**message) elif message_type == "human": return HumanMessage(**message) elif message_type == "ai": return AIMessage(**message) else: return BaseMessage(**message)
[docs] @dataclass class StorageMessage: message: BaseMessage deleted: bool = False new: bool = False
[docs] class BaseDBHelper(abc.ABC):
[docs] @abstractmethod async def disconnect(self) -> None: """ Disconnects from the database. """ pass
[docs] async def connect(self) -> None: """ Connects to the database. """ pass
[docs] class BaseMessagesStorage(abc.ABC): """ Abstract base class for message storage. Provides the interface for persisting and retrieving chat messages. """ def __init__(self, bot_uuid: str, chat_id: int) -> None: self.bot_uuid = bot_uuid self.chat_id = chat_id self._messages: list[StorageMessage] = [] @property def messages(self) -> list[BaseMessage]: """ Returns a list of non-deleted messages. """ return [storage_message.message for storage_message in self._messages if not storage_message.deleted]
[docs] @abstractmethod async def refresh_messages(self) -> None: """ Updates the messages list from the database asynchronously. """ pass
[docs] def add_message(self, message: BaseMessage) -> None: """ Adds a new message. """ self._messages.append(StorageMessage(message=message, new=True))
[docs] def delete_message(self, index: int) -> None: """ Deletes a message from the storage by index. """ i = 0 for storage_message in self._messages: if storage_message.deleted: continue if i == index: storage_message.deleted = True break if not storage_message.deleted: i += 1
[docs] @abstractmethod async def clear_messages(self) -> None: """ Clears all messages from the storage. """ pass
[docs] @abstractmethod async def commit(self) -> None: """ Include new messages and remove deleted messages from the database asynchronously. """ pass