import base64
import logging
import re
import secrets
from types import TracebackType
from typing import TYPE_CHECKING
from urllib.parse import urljoin
import aiohttp
from google.genai.types import HarmBlockThreshold, HarmCategory
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from manolo_bot.ai.config import BotConfig, LLMConfig
from manolo_bot.ai.tools import get_tool, get_tools
from manolo_bot.storage.base import BaseMessagesStorage
if TYPE_CHECKING:
from manolo_bot.ai.mcp_manager import MCPManager
[docs]
class LLMBuilder:
"""Factory class for creating LangChain Chat Model instances."""
def __init__(self, llm_config: LLMConfig) -> None:
self.llm_config = llm_config
def _get_rate_limiter(self) -> InMemoryRateLimiter:
return InMemoryRateLimiter(
requests_per_second=self.llm_config.rate_limiter_requests_per_second,
check_every_n_seconds=self.llm_config.rate_limiter_check_every_n_seconds,
max_bucket_size=self.llm_config.rate_limiter_max_bucket_size,
)
def _get_chat_ollama(self) -> ChatOllama:
return ChatOllama(model=self.llm_config.ollama_model)
def _get_chat_google_generativeai(self) -> ChatGoogleGenerativeAI:
return ChatGoogleGenerativeAI(
model=self.llm_config.google_api_model,
safety_settings={
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
},
rate_limiter=self._get_rate_limiter(),
)
def _get_chat_openai(self) -> ChatOpenAI:
api_key = self.llm_config.openai_api_key if self.llm_config.openai_api_key else "not-needed"
base_url = self.llm_config.openai_api_base_url
model = self.llm_config.openai_api_model
params = {
"openai_api_key": api_key,
}
if base_url:
params["base_url"] = base_url
if model:
params["model"] = model
return ChatOpenAI(rate_limiter=self._get_rate_limiter(), **params)
[docs]
def get_llm(self) -> BaseChatModel:
"""
Creates and returns an instance of the configured LLM.
:return: A LangChain BaseChatModel instance.
:raises Exception: If no LLM configuration is found.
"""
if self.llm_config.ollama_model:
llm = self._get_chat_ollama()
elif self.llm_config.google_api_key:
llm = self._get_chat_google_generativeai()
elif self.llm_config.openai_api_key or self.llm_config.openai_api_base_url:
llm = self._get_chat_openai()
else:
raise Exception("No LLM backend data found")
return llm
[docs]
class LLMBot:
"""
Base class for a Telegram LLM Chat Bot.
Handles interaction with the LLM, message processing, and context management.
"""
bind_tools_on_init = True
def __init__(
self,
llm: BaseChatModel,
bot_config: BotConfig,
system_instructions: list[BaseMessage],
messages_storage: BaseMessagesStorage,
tools: list[BaseTool] | None = None,
) -> None:
self.bot_config = bot_config
self.system_instructions = system_instructions
# self.messages_buffer = messages_buffer
self.llm = llm
self.messages_storage: BaseMessagesStorage = messages_storage
# self._load_llm()
self.tools = tools
self._mcp_manager: MCPManager | None = None
self._async_resources_initialized = False
if self.bind_tools_on_init and self.bot_config.use_tools:
self._load_tools()
def _get_langchain_config(self, chat_id: int) -> RunnableConfig:
"""Helper to create LangChain config with metadata and tags."""
bot_uuid = self.bot_config.bot_uuid
user_id = self.bot_config.user_id
return RunnableConfig(
tags=[f"bot:{bot_uuid}", f"user:{user_id}"],
metadata={
"bot_uuid": bot_uuid,
"bot_username": self.bot_config.bot_username,
"user_id": user_id,
"chat_id": chat_id,
},
)
def _get_session_timeout(self) -> aiohttp.ClientTimeout:
"""Get the timeout for aiohttp sessions."""
return aiohttp.ClientTimeout(total=self.bot_config.web_content_request_timeout)
[docs]
async def initialize_async_resources(self) -> None:
"""Initialize all async resources (MCP, etc.)."""
if self._async_resources_initialized:
return
logging.debug("Initializing async resources...")
# Initialize MCP if enabled
if self.bot_config.enable_mcp:
logging.info("Initializing MCP...")
try:
from manolo_bot.ai.mcp_manager import MCPManager
# TODO: We probably don't want to initialize MCP in each call, maybe we can cache this somehow?
self._mcp_manager = MCPManager(self.bot_config)
await self._mcp_manager.connect()
logging.info("MCP initialized successfully")
# Reload tools to include MCP tools
if self.bot_config.use_tools:
await self._reload_tools_with_mcp()
except Exception as e:
logging.warning(
f"MCP initialization failed, continuing without MCP: {e}",
exc_info=True,
)
self._mcp_manager = None
self._async_resources_initialized = True
# async def cleanup(self) -> None:
# """Clean up all async resources."""
# if not self._async_resources_initialized:
# return
#
# logging.debug("Cleaning up async resources...")
#
# # Clean up MCP if it was initialized
# if self._mcp_manager is not None:
# try:
# await self._mcp_manager.close()
# logging.info("MCP resources cleaned up successfully")
# except Exception as e:
# logging.error(f"Error cleaning up MCP resources: {e}", exc_info=True)
#
# self._async_resources_initialized = False
async def _reload_tools_with_mcp(self) -> None:
"""Reload tools including MCP tools."""
from manolo_bot.ai.tools import get_all_tools
tools = await get_all_tools(self._mcp_manager, self.bot_config, custom_tools=self.tools)
self.llm = self.llm.bind_tools(tools)
logging.debug(f"Reloaded {len(tools)} tools (including MCP)")
[docs]
async def close(self) -> None:
"""Close all async resources."""
if self._mcp_manager:
await self._mcp_manager.disconnect()
logging.debug("Async resources closed")
async def __aenter__(self) -> "LLMBot":
"""Async context manager entry - initialize async resources."""
await self.initialize_async_resources()
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Async context manager exit - cleanup resources."""
await self.close()
def _extract_url(self, text: str) -> str | None:
"""
Extract the URL from the text.
:param text: Text to extract the URL from
:return: URL if found, None otherwise
"""
url = re.search(
r"https?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*(),]|%[0-9a-fA-F][0-9a-fA-F])+",
text,
)
return url.group(0) if url else None
def _remove_urls(self, text: str) -> str:
"""
Remove URLs from the text.
:param text: Text to remove URLs from
:return: Text without URLs
"""
return re.sub(
r"https?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*(),]|%[0-9a-fA-F][0-9a-fA-F])+",
"",
text,
)
[docs]
def truncate_chat_context(self) -> None:
"""
Truncate the chat context if it is too long.
"""
while self.count_tokens(self.messages_storage.messages) > self.bot_config.context_max_tokens:
self.messages_storage.delete_message(0)
logging.debug(f"Chat context truncated for chat {self.messages_storage.chat_id}")
[docs]
async def clean_context(self) -> None:
"""
Clean the chat context.
"""
await self.messages_storage.clear_messages()
logging.debug(f"Chat context cleaned for chat {self.messages_storage.chat_id}")
[docs]
async def answer_message(self, chat_id: int, message: str) -> BaseMessage:
"""
Processes a text message and returns the LLM's response.
:param chat_id: The ID of the chat.
:param message: The text of the message to process.
:return: The response message from the LLM.
"""
self.messages_storage.add_message(HumanMessage(content=message))
self.truncate_chat_context()
config = self._get_langchain_config(chat_id)
ai_msg = await self.llm.ainvoke(self.system_instructions + self.messages_storage.messages)
if ai_msg.tool_calls:
self.messages_storage.add_message(ai_msg)
for tool_call in ai_msg.tool_calls:
selected_tool = get_tool(tool_call["name"])
tool_msg = await selected_tool.ainvoke(tool_call, config=config)
self.messages_storage.add_message(tool_msg)
ai_msg = await self.llm.ainvoke(self.system_instructions + self.messages_storage.messages, config=config)
return ai_msg
[docs]
async def answer_image_message(self, chat_id: int, text: str, image: str) -> BaseMessage:
"""
Answer an image message.
:param chat_id: Chat ID
:param text: Text to answer
:param image: Image to answer
:return: Response
"""
logging.debug(f"Image message: {text}")
try:
async with aiohttp.ClientSession() as session:
timeout = self._get_session_timeout()
async with session.get(image, timeout=timeout) as response:
response.raise_for_status()
image_bytes = await response.read()
image_data = base64.b64encode(image_bytes).decode("utf-8")
llm_message = HumanMessage(
content=[
{
"type": "text",
"text": text,
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
},
]
)
self.messages_storage.add_message(llm_message)
self.truncate_chat_context()
response = await self.llm.ainvoke(
self.messages_storage.messages,
config=self._get_langchain_config(chat_id),
)
except (aiohttp.ClientError, Exception) as e:
if isinstance(e, aiohttp.ClientError):
logging.error(f"Failed to get image: {image}")
logging.exception(e)
response = BaseMessage(content="NO_ANSWER", type="text")
logging.debug(f"Image message response: {response}")
return response
[docs]
async def answer_voice_message(self, chat_id: int, text: str, audio: str):
"""
Answer a voice message.
:param chat_id: Chat ID
:param audio: Voice message audio
"""
logging.debug(f"Voice message: {audio}")
try:
async with aiohttp.ClientSession() as session:
timeout = self._get_session_timeout()
async with session.get(audio, timeout=timeout) as response:
response.raise_for_status()
audio_bytes = await response.read()
audio_data = base64.b64encode(audio_bytes).decode("utf-8")
llm_message = HumanMessage(
content=[
{
"type": "text",
"text": text,
},
{
"type": "media",
"mime_type": "audio/ogg",
# "data": audio_bytes,
"data": audio_data,
# "file_uri": audio,
},
]
)
self.messages_storage.add_message(llm_message)
self.truncate_chat_context()
response = await self.llm.ainvoke(
self.system_instructions + self.messages_storage.messages,
config=self._get_langchain_config(chat_id),
)
except (aiohttp.ClientError, Exception) as e:
if isinstance(e, aiohttp.ClientError):
logging.error(f"Failed to get audio: {audio}")
logging.exception(e)
response = BaseMessage(content="NO_ANSWER", type="text")
logging.debug(f"Voice message response: {response}")
return response
[docs]
async def postprocess_response(self, response: BaseMessage, message_text: str, chat_id: int) -> dict | None:
"""
Postprocess the response from the LLM.
:param response: Response from the LLM
:param message_text: Text of the user message
:param chat_id: Chat ID
return: Final response data
"""
# response.content is sometimes a list instead of a string, TODO: find why this happens and fix it
if isinstance(response.content, list):
response_content = ""
for i, content_item in enumerate(response.content):
if isinstance(content_item, str):
response_content += content_item
else:
response_content += content_item.get("text", "")
if i + 1 != len(response.content):
response_content += "\n\n"
else:
response_content = response.content
final_response = None
if response_content.startswith("GENERATE_IMAGE"):
logging.debug(f"GENERATE_IMAGE response, generating image for chat {chat_id}")
image = await self.generate_image(response_content[len("GENERATE_IMAGE ") :])
if image:
final_response = {
"type": "image",
"data": image,
}
elif "WEBCONTENT_RESUME" in response_content:
logging.debug(f"WEBCONTENT_RESUME response, generating web content abstract for chat {chat_id}")
response_content = await self.answer_webcontent(message_text, response_content)
# TODO: find a way to graciously handle failed web content requests
response_content = response_content if response_content else "😐"
final_response = {"type": "text", "data": response_content}
elif "WEBCONTENT_OPINION" in response_content:
logging.debug(f"WEBCONTENT_OPINION response, generating web content opinion for chat {chat_id}")
response_content = await self.answer_webcontent(message_text, response_content)
# TODO: find a way to graciously handle failed web content requests
response_content = response_content if response_content else "😐"
final_response = {"type": "text", "data": response_content}
elif "NO_ANSWER" not in response_content:
logging.debug(f"Response for chat {chat_id}")
final_response = {"type": "text", "data": response_content}
else:
logging.debug(f"NO_ANSWER response for chat {chat_id}")
final_response = {
"type": "text",
"data": secrets.choice(["😐", "😶", "😳", "😕", "😑"]),
}
self.messages_storage.add_message(AIMessage(content=response_content))
return final_response
[docs]
async def answer_webcontent(self, message_text: str, response_content: str) -> str | None:
"""
Answer a web content message.
:param message_text: Text to answer
:param response_content: Response content
:param chat_id: Chat ID
:return: New response content if the call was successful, None otherwise
"""
url = self._extract_url(response_content)
try:
if url:
logging.debug(f"Obtaining web content for {url} using pseudotool")
loader = WebBaseLoader(web_path=url)
# alazy_load returns an async iterator of Document objects
docs = []
async for doc in loader.alazy_load():
docs.append(doc)
template = self._remove_urls(message_text) + "\n" + '"{text}"'
prompt = PromptTemplate.from_template(template)
logging.debug(f"Web content prompt: {prompt}")
self.truncate_chat_context()
# TODO: Add full chat context
stuff_chain = create_stuff_documents_chain(
llm=self.llm, prompt=prompt, document_variable_name="text", output_parser=StrOutputParser()
)
# The key should match the document_variable_name parameter
response = await stuff_chain.ainvoke({"text": docs})
logging.debug(f"Web content response: {response}")
return response
else:
logging.debug(f"No URL found for web content: {message_text}")
except aiohttp.ClientError as e:
logging.error("Connection error connecting to web content")
logging.exception(e)
error_prompt = (
f"Generate a brief response in {self.bot_config.preferred_language} "
f"explaining that you couldn't connect to the webpage {url}. "
f"Suggest checking the URL or trying again later. "
f"Keep your response under 150 characters and maintain your character's style."
)
return await self.generate_feedback_message(error_prompt)
except TimeoutError as e:
logging.error("Timeout error connecting to web content")
logging.exception(e)
error_prompt = (
f"Generate a brief response in {self.bot_config.preferred_language} "
f"explaining that the webpage {url} took too long to respond. "
f"Suggest it might be unavailable or too large. "
f"Keep your response under 150 characters and maintain your character's style."
)
return await self.generate_feedback_message(error_prompt)
except Exception as e:
logging.error("Error connecting to web content")
logging.exception(e)
error_prompt = (
f"Generate a brief response in {self.bot_config.preferred_language} "
f"explaining that you had trouble processing the webpage {url}. "
f"Suggest trying again later or trying a different URL. "
f"Keep your response under 150 characters and maintain your character's style."
)
return await self.generate_feedback_message(error_prompt)
return None
[docs]
async def call_sdapi(self, prompt: str) -> dict | None:
"""
Call the StableDiffusion API.
:param prompt: The prompt to send to the StableDiffusion API.
:return: The response from the StableDiffusion API.
"""
if self.bot_config.sdapi_url:
try:
params = self.bot_config.sdapi_params.copy()
params["prompt"] = prompt
if self.bot_config.sdapi_negative_prompt:
params["negative_prompt"] = self.bot_config.sdapi_negative_prompt
# Use aiohttp for async HTTP requests
async with aiohttp.ClientSession() as session:
async with session.post(
urljoin(self.bot_config.sdapi_url, "/sdapi/v1/txt2img"),
json=params,
) as response:
if response.status == 200:
return await response.json()
except Exception as e:
logging.error("Failed to call SDAPI")
logging.exception(e)
return None
[docs]
async def generate_image(self, prompt: str) -> str | None:
"""
Generate an image.
:param prompt: Prompt to generate the image
:return: Image representation in base64 format if the call was successful, None otherwise
"""
logging.debug(f"Generate image: {prompt}")
response = await self.call_sdapi(prompt)
if response and "images" in response:
return response["images"][0]
return None
[docs]
def count_tokens(self, messages: list[BaseMessage]) -> int:
"""
Count the number of tokens in the messages.
:param messages: List of messages
:return: Number of tokens
"""
extra_tokens = 0
context_text = ""
for message in messages:
if isinstance(message.content, list):
for item in message.content:
if item.get("type") == "text":
context_text += "\n " + item.get("text")
elif item.get("type") == "image_url":
# TODO: Use an LLM-based method to get the image token count.
extra_tokens += 258 # using gemini image context size
else:
context_text += "\n " + message.content
return self.llm.get_num_tokens(context_text) + extra_tokens
[docs]
async def generate_feedback_message(self, prompt: str, max_length: int = 200, chat_id: int | None = None) -> str:
"""
Generate a feedback message using the LLM.
:param prompt: Prompt to generate the feedback message
:param max_length: Maximum length of the feedback message
:param chat_id: Optional chat ID for metadata
:return: Generated feedback message
"""
logging.debug("Generating feedback message")
# Create a simple message list with just the prompt
messages = [HumanMessage(content=prompt)]
config = self._get_langchain_config(chat_id) if chat_id else None
response = await self.llm.ainvoke(messages, config=config)
# Clean up the response if needed
feedback_message = response.content.strip()
# Ensure the message isn't too long
if len(feedback_message) > max_length:
feedback_message = feedback_message[: max_length - 3] + "..."
logging.debug(f"Generated feedback message: {feedback_message}")
return feedback_message
def _get_time_from_wpm(self, text: str, wpm: float) -> float:
"""
Get the time it takes to write a text with a given WPM.
:param text: Text to write
:param wpm: Words per minute
:return: Time in seconds
"""
return (len(text.split()) / wpm) * 60
def _load_tools(self) -> None:
tools_to_bind = self.tools if self.tools is not None else get_tools(bot_config=self.bot_config)
self.llm = self.llm.bind_tools(tools_to_bind) # add wikipedia?