"""Backend-agnostic LLM client using instructor for structured extraction."""
from __future__ import annotations
import asyncio
import json
import logging
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
import instructor
from pydantic import BaseModel, Field, field_validator
from .bot import BotContext
from .config import Config
from .prompt_metadata import PromptMetadata
from .prompts import get_system_prompt
from .session_manager import Message, validate_topic_name
from .tools import FolderTools
logger = logging.getLogger(__name__)
DEFAULT_MAX_HISTORY_CHARS = 50_000
[docs]
def trim_history_to_budget(history: list[Message], max_chars: int) -> list[Message]:
"""Trim history to fit within a character budget.
Drops oldest messages first, in pairs (user+assistant) to avoid
orphaned messages. Always keeps at least the most recent pair.
"""
if not history or max_chars <= 0:
return []
total = sum(len(m["content"]) for m in history)
if total <= max_chars:
return history
kept: list[Message] = []
running = 0
for msg in reversed(history):
msg_len = len(msg["content"])
if running + msg_len > max_chars and len(kept) >= 2:
break
kept.append(msg)
running += msg_len
kept.reverse()
if kept and kept[0]["role"] == "assistant" and len(kept) > 1:
kept = kept[1:]
return kept
_RECENCY_WINDOW = 4 # Always keep last N messages regardless of topic
[docs]
def build_topic_history(history: list[Message], max_chars: int) -> list[Message]:
"""Build topic-aware history: recent messages + same-topic backfill.
Always keeps the last RECENCY_WINDOW messages for immediate context.
Fills remaining budget with older messages matching the current topic.
"""
if not history or max_chars <= 0:
return []
# Determine current topic from the last message
current_topic = history[-1].get("topic", "general")
# Always keep the last _RECENCY_WINDOW messages
recency = history[-_RECENCY_WINDOW:]
older = history[:-_RECENCY_WINDOW] if len(history) > _RECENCY_WINDOW else []
recency_chars = sum(len(m["content"]) for m in recency)
remaining_budget = max_chars - recency_chars
# Backfill with same-topic messages from older history (most recent first)
backfill: list[Message] = []
for msg in reversed(older):
if msg.get("topic", "general") != current_topic:
continue
msg_len = len(msg["content"])
if remaining_budget - msg_len < 0 and backfill:
break
backfill.append(msg)
remaining_budget -= msg_len
backfill.reverse()
return backfill + recency
[docs]
class AgentResponse(BaseModel, frozen=True):
"""Structured response from the LLM.
Either provide tool_calls (to request tool execution) or answer (final response).
When tool_calls is non-empty, tools will be executed and results fed back.
When answer is set and tool_calls is empty, the answer is returned to the user.
"""
tool_calls: list[ToolCallRequest] = Field(
default_factory=list,
description="Tools to call. Empty when providing a final answer.",
)
answer: str | None = Field(
default=None,
description="Final answer to the user. Only set when no tools are needed.",
)
topic: str = Field(
default="general",
description=(
"Short topic label for this conversation thread. "
"MUST use only lowercase letters, numbers, and underscores "
"(e.g. 'weather', 'project_planning', 'recipes'). "
"No spaces or special characters. "
"Use consistent names across related messages."
),
)
[docs]
@field_validator("topic")
@classmethod
def normalize_topic(cls, v: str) -> str:
try:
return validate_topic_name(v)
except ValueError:
return "general"
[docs]
class AskUserRequest(BaseModel, frozen=True):
"""Request to ask the user a question with interactive UI.
Used by the agent loop when the LLM calls the ask_user tool.
The Telegram handler renders this as inline keyboards, location
pickers, or text prompts depending on input_type.
"""
question: str = Field(description="The question to ask the user.")
options: list[str] = Field(
default_factory=list,
description="Choices to present (for 'choice' and 'confirm' types).",
)
input_type: str = Field(
default="choice",
description="One of: choice, confirm, text, location.",
)
[docs]
@dataclass(frozen=True)
class TokenUsage:
"""Accumulated token usage from one chat() call."""
input_tokens: int
output_tokens: int
# Phrases that indicate the LLM claims to have performed an action
_ACTION_PHRASES = [
"i've updated",
"i've written",
"i've created",
"i've deleted",
"i've modified",
"i've saved",
"i've added",
"i've removed",
"i've scheduled",
"i've set",
"i updated",
"i wrote",
"i created",
"i deleted",
"i modified",
"i saved",
"i added",
"i removed",
"file has been",
"task has been",
"has been created",
"has been updated",
"has been written",
"has been deleted",
"has been scheduled",
]
[docs]
class LLMClient:
"""Backend-agnostic LLM client using instructor for structured extraction.
Uses instructor's from_provider() to support multiple LLM backends
(Anthropic, OpenAI, etc.) through a unified interface. The agent loop
uses structured extraction (AgentResponse model) to get either tool
calls or a final answer from the LLM.
"""
MAX_TOOL_ITERATIONS = 10
[docs]
def __init__(self, config: Config):
self.config = config
self.tools = FolderTools(config)
model = config.model
if "/" not in model:
raise ValueError(
f"Model '{model}' must include a provider prefix "
f"(e.g., 'anthropic/{model}', 'openai/{model}'). "
f"See instructor docs for supported providers."
)
self._provider = model.split("/", 1)[0]
self._client = instructor.from_provider(
model,
async_client=True,
api_key=config.api_key or None,
)
_ASK_USER_TOOL_TEXT = (
"### ask_user\n"
"Ask the user a question with interactive UI. Use this when you need\n"
"user input, confirmation, or a choice before proceeding.\n"
"Parameters:\n"
" - question: string (required) \u2014 The question to ask\n"
" - options: list[string] \u2014 Choices to present as buttons\n"
" - input_type: string \u2014 One of: choice, confirm, text, location "
"(default: choice)\n"
)
def _format_image_block(self, img: dict[str, str]) -> dict[str, Any]:
if self._provider == "openai":
return {
"type": "image_url",
"image_url": {"url": f"data:{img['media_type']};base64,{img['data']}"},
}
return {
"type": "image",
"source": {
"type": "base64",
"media_type": img["media_type"],
"data": img["data"],
},
}
def _build_system_prompt(self) -> str:
"""Build system prompt with dynamic metadata."""
metadata = PromptMetadata.build(
user_name=self.config.user_name,
confirmation_tools=self.tools.get_tools_requiring_confirmation(),
)
base_prompt = get_system_prompt().format(**metadata.format_dict())
# Add structured tool descriptions for the LLM
tool_defs = self.tools.get_tool_definitions()
tools_text = self._format_tools_for_prompt(tool_defs)
# Append ask_user tool (not in FolderTools, handled by agent loop)
tools_text += "\n" + self._ASK_USER_TOOL_TEXT
return f"{base_prompt}\n\n{tools_text}"
@staticmethod
def _format_tools_for_prompt(tool_defs: list[dict[str, Any]]) -> str:
"""Format tool definitions as text for the system prompt."""
parts = ["## Available Tools\n"]
for td in tool_defs:
parts.append(f"### {td['name']}")
parts.append(td.get("description", ""))
schema = td.get("input_schema", {})
props = schema.get("properties", {})
required = schema.get("required", [])
if props:
params = []
for pname, pschema in props.items():
req = " (required)" if pname in required else ""
desc = pschema.get("description", "")
ptype = pschema.get("type", "any")
params.append(f" - {pname}: {ptype}{req} — {desc}")
parts.append("Parameters:")
parts.extend(params)
parts.append("")
return "\n".join(parts)
[docs]
async def chat(
self,
user_message: str,
context: BotContext,
history: list[Message],
on_tool_use: Callable[[str], Awaitable[None]] | None = None,
on_ask_user: Callable[[AskUserRequest], Awaitable[str]] | None = None,
chat_id: int = 0,
images: list[dict[str, str]] | None = None,
) -> tuple[str, list[str], str, TokenUsage]:
"""Send a message and get a response through the agent loop.
Args:
user_message: The user's message
context: BotContext with services and user info for tool execution
history: Conversation history
on_tool_use: Optional async callback when tools are being used
on_ask_user: Optional async callback for ask_user tool (returns user answer)
chat_id: Telegram chat ID (for scheduler tools)
images: Optional list of image dicts with 'data' (base64) and 'media_type'
Returns:
Tuple of (response text, list of tools used, topic label, token usage)
"""
context.query = user_message
system = self._build_system_prompt()
# Build topic-aware history
trimmed_history = build_topic_history(
history, max_chars=self.config.max_history_chars
)
if len(trimmed_history) < len(history):
logger.info(
f"History trimmed: {len(history)} -> {len(trimmed_history)} messages"
)
# Build base messages from history
base_messages: list[dict[str, str]] = []
for msg in trimmed_history:
base_messages.append({"role": msg["role"], "content": msg["content"]})
# Agent loop
gathered_context: list[str] = []
tools_used: list[str] = []
last_response: AgentResponse | None = None
total_input_tokens = 0
total_output_tokens = 0
for _ in range(self.MAX_TOOL_ITERATIONS):
# Build user message with accumulated tool results
user_text = user_message
if gathered_context:
user_text += "\n\n--- Tool Results ---\n" + "\n\n".join(
gathered_context
)
# Build user content: multimodal (with images) or plain text
user_content: str | list[dict[str, Any]]
if images:
content_blocks: list[dict[str, Any]] = []
for img in images:
content_blocks.append(self._format_image_block(img))
content_blocks.append({"type": "text", "text": user_text})
user_content = content_blocks
else:
user_content = user_text
messages = base_messages + [{"role": "user", "content": user_content}]
response, completion = await self._client.create_with_completion(
response_model=AgentResponse,
messages=messages, # type: ignore[arg-type]
system=system,
max_tokens=4096,
)
last_response = response
# Accumulate token usage
if hasattr(completion, "usage") and completion.usage is not None:
total_input_tokens += getattr(completion.usage, "input_tokens", 0)
total_output_tokens += getattr(completion.usage, "output_tokens", 0)
# If answer provided and no tool calls → done
if response.answer is not None and not response.tool_calls:
# Hallucination guard: check if answer claims tool use
if not tools_used and self._claims_tool_use(response.answer):
logger.warning(
"Hallucination guard: answer claims action without tools"
)
gathered_context.append(
"[SYSTEM] Your answer claimed to perform an action but no "
"tools were called. Use the appropriate tool or answer "
"honestly that you cannot perform the action."
)
continue
usage = TokenUsage(total_input_tokens, total_output_tokens)
return response.answer, tools_used, response.topic, usage
# Execute tool calls
for tc in response.tool_calls:
tools_used.append(tc.name)
if on_tool_use:
await on_tool_use(tc.name)
if tc.name == "ask_user":
# ask_user is handled via callback, not FolderTools
if on_ask_user is None:
gathered_context.append(
f"[Tool: ask_user]\n"
f"Input: {json.dumps(tc.arguments)}\n"
f"Result: ask_user is not available in this "
f"context.\n(Error)"
)
continue
try:
ask_request = AskUserRequest(**tc.arguments)
user_answer = await on_ask_user(ask_request)
gathered_context.append(
f"[Tool: ask_user]\n"
f"Input: {json.dumps(tc.arguments)}\n"
f"Result: User chose: {user_answer}"
)
except asyncio.TimeoutError:
gathered_context.append(
f"[Tool: ask_user]\n"
f"Input: {json.dumps(tc.arguments)}\n"
f"Result: User did not respond in time.\n(Error)"
)
continue
result = await self.tools.execute_async(
tc.name, tc.arguments, context=context, chat_id=chat_id
)
error_suffix = "\n(Error)" if result.is_error else ""
gathered_context.append(
f"[Tool: {tc.name}]\n"
f"Input: {json.dumps(tc.arguments)}\n"
f"Result: {result.content}{error_suffix}"
)
# If the response also had an answer alongside tool calls, return it
if response.answer is not None:
usage = TokenUsage(total_input_tokens, total_output_tokens)
return response.answer, tools_used, response.topic, usage
# Max iterations exhausted
usage = TokenUsage(total_input_tokens, total_output_tokens)
logger.warning(f"Max tool iterations ({self.MAX_TOOL_ITERATIONS}) reached")
if last_response and last_response.answer:
return last_response.answer, tools_used, last_response.topic, usage
return "", tools_used, "general", usage
@staticmethod
def _claims_tool_use(text: str) -> bool:
"""Check if text claims to have performed an action (heuristic guard)."""
text_lower = text.lower()
return any(phrase in text_lower for phrase in _ACTION_PHRASES)