Source code for folderbot.session_manager

"""Conversation session management with SQLite storage."""

import json
import re
import sqlite3
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import TypedDict

_TOPIC_PATTERN = re.compile(r"^[a-z0-9_]+$")


[docs] def validate_topic_name(name: str) -> str: """Normalize and validate a topic name. Converts to lowercase, replaces spaces/hyphens with underscores, strips non-allowed characters, and validates the result. Raises ValueError if the result is empty after normalization. """ normalized = name.lower().strip() normalized = re.sub(r"[\s\-]+", "_", normalized) normalized = re.sub(r"[^a-z0-9_]", "", normalized) normalized = re.sub(r"_+", "_", normalized) normalized = normalized.strip("_") if not normalized: raise ValueError(f"Topic name '{name}' is empty after normalization") return normalized
[docs] class Message(TypedDict): role: str # "user" or "assistant" content: str timestamp: str topic: str # conversation topic label
[docs] @dataclass(frozen=True) class TopicReassignment: """A single topic reassignment instruction.""" message_index: int new_topic: str
[docs] @dataclass(frozen=True) class UploadRecord: """Metadata for an uploaded file.""" id: int user_id: int original_filename: str hash_filename: str extension: str file_size: int mime_type: str created_at: str
[docs] @dataclass(frozen=True) class TokenUsageRecord: """A single token usage record.""" id: int user_id: int input_tokens: int output_tokens: int model: str topic: str created_at: str
[docs] class SessionManager: """Manages conversation history in SQLite."""
[docs] def __init__(self, db_path: Path): self.db_path = db_path self._ensure_db()
def _ensure_db(self) -> None: """Create database and tables if they don't exist.""" self.db_path.parent.mkdir(parents=True, exist_ok=True) with sqlite3.connect(self.db_path) as conn: conn.execute(""" CREATE TABLE IF NOT EXISTS sessions ( user_id INTEGER PRIMARY KEY, messages TEXT NOT NULL DEFAULT '[]', created_at TEXT NOT NULL, updated_at TEXT NOT NULL ) """) conn.execute(""" CREATE TABLE IF NOT EXISTS user_settings ( user_id INTEGER PRIMARY KEY, last_notified_version TEXT, file_notifications_enabled INTEGER NOT NULL DEFAULT 0 ) """) conn.execute(""" CREATE TABLE IF NOT EXISTS uploads ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, original_filename TEXT NOT NULL, hash_filename TEXT NOT NULL, extension TEXT NOT NULL DEFAULT '', file_size INTEGER NOT NULL, mime_type TEXT NOT NULL DEFAULT 'application/pdf', created_at TEXT NOT NULL ) """) conn.execute(""" CREATE TABLE IF NOT EXISTS token_usage ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, input_tokens INTEGER NOT NULL, output_tokens INTEGER NOT NULL, model TEXT NOT NULL, topic TEXT NOT NULL DEFAULT 'general', created_at TEXT NOT NULL ) """) # Migration: add file_notifications_enabled if it doesn't exist try: conn.execute( "ALTER TABLE user_settings ADD COLUMN " "file_notifications_enabled INTEGER NOT NULL DEFAULT 0" ) except sqlite3.OperationalError: pass # Column already exists conn.commit() def _get_connection(self) -> sqlite3.Connection: return sqlite3.connect(self.db_path) @staticmethod def _ensure_topic(messages: list[dict]) -> list[Message]: """Backfill 'general' topic on old messages that lack it.""" for msg in messages: if "topic" not in msg: msg["topic"] = "general" return messages # type: ignore[return-value]
[docs] def get_history(self, user_id: int, topic: str | None = None) -> list[Message]: """Get conversation history for a user, optionally filtered by topic.""" with self._get_connection() as conn: cursor = conn.execute( "SELECT messages FROM sessions WHERE user_id = ?", (user_id,) ) row = cursor.fetchone() if not row: return [] messages = self._ensure_topic(json.loads(row[0])) if topic is not None: return [m for m in messages if m["topic"] == topic] return messages
[docs] def add_message( self, user_id: int, role: str, content: str, topic: str = "general" ) -> None: """Add a message to the conversation history.""" topic = validate_topic_name(topic) now = datetime.now().isoformat() with self._get_connection() as conn: # Get existing messages cursor = conn.execute( "SELECT messages FROM sessions WHERE user_id = ?", (user_id,) ) row = cursor.fetchone() if row: messages = json.loads(row[0]) else: messages = [] # Add new message messages.append( { "role": role, "content": content, "timestamp": now, "topic": topic, } ) # Update or insert if row: conn.execute( "UPDATE sessions SET messages = ?, updated_at = ? WHERE user_id = ?", (json.dumps(messages), now, user_id), ) else: conn.execute( "INSERT INTO sessions (user_id, messages, created_at, updated_at) VALUES (?, ?, ?, ?)", (user_id, json.dumps(messages), now, now), ) conn.commit()
[docs] def get_topics(self, user_id: int) -> list[dict]: """Get unique topics with message counts for a user.""" messages = self.get_history(user_id) if not messages: return [] topic_stats: dict[str, dict] = {} for msg in messages: t = msg["topic"] if t not in topic_stats: topic_stats[t] = { "topic": t, "message_count": 0, "last_activity": msg["timestamp"], } topic_stats[t]["message_count"] += 1 topic_stats[t]["last_activity"] = msg["timestamp"] return sorted( topic_stats.values(), key=lambda x: x["last_activity"], reverse=True )
[docs] def update_message_topics( self, user_id: int, reassignments: list[TopicReassignment] ) -> int: """Bulk-update topic labels on messages by index. Returns the number of messages actually updated. Raises ValueError if any topic name is invalid or index is out of bounds. """ validated: list[tuple[int, str]] = [] for r in reassignments: validated.append((r.message_index, validate_topic_name(r.new_topic))) with self._get_connection() as conn: cursor = conn.execute( "SELECT messages FROM sessions WHERE user_id = ?", (user_id,) ) row = cursor.fetchone() if not row: return 0 messages = self._ensure_topic(json.loads(row[0])) for idx, _ in validated: if idx < 0 or idx >= len(messages): raise ValueError( f"Message index {idx} out of range (0-{len(messages) - 1})" ) updated_count = 0 for idx, new_topic in validated: if messages[idx]["topic"] != new_topic: messages[idx]["topic"] = new_topic updated_count += 1 if updated_count > 0: now = datetime.now().isoformat() conn.execute( "UPDATE sessions SET messages = ?, updated_at = ? WHERE user_id = ?", (json.dumps(messages), now, user_id), ) conn.commit() return updated_count
[docs] def clear_session(self, user_id: int) -> None: """Clear conversation history for a user.""" now = datetime.now().isoformat() with self._get_connection() as conn: conn.execute( "UPDATE sessions SET messages = '[]', updated_at = ? WHERE user_id = ?", (now, user_id), ) conn.commit()
[docs] def get_session_info(self, user_id: int) -> dict: """Get session metadata.""" with self._get_connection() as conn: cursor = conn.execute( "SELECT messages, created_at, updated_at FROM sessions WHERE user_id = ?", (user_id,), ) row = cursor.fetchone() if row: messages = json.loads(row[0]) return { "message_count": len(messages), "created_at": row[1], "updated_at": row[2], } return { "message_count": 0, "created_at": None, "updated_at": None, }
[docs] def get_last_notified_version(self, user_id: int) -> str | None: """Get the last version the user was notified about.""" with self._get_connection() as conn: cursor = conn.execute( "SELECT last_notified_version FROM user_settings WHERE user_id = ?", (user_id,), ) row = cursor.fetchone() return row[0] if row else None
[docs] def set_last_notified_version(self, user_id: int, version: str) -> None: """Set the last version the user was notified about.""" with self._get_connection() as conn: conn.execute( """INSERT INTO user_settings (user_id, last_notified_version) VALUES (?, ?) ON CONFLICT(user_id) DO UPDATE SET last_notified_version = ?""", (user_id, version, version), ) conn.commit()
[docs] def get_file_notifications_enabled(self, user_id: int) -> bool: """Check if file notifications are enabled for a user.""" with self._get_connection() as conn: cursor = conn.execute( "SELECT file_notifications_enabled FROM user_settings WHERE user_id = ?", (user_id,), ) row = cursor.fetchone() return bool(row[0]) if row else False
[docs] def set_file_notifications_enabled(self, user_id: int, enabled: bool) -> None: """Enable or disable file notifications for a user.""" with self._get_connection() as conn: conn.execute( """INSERT INTO user_settings (user_id, file_notifications_enabled) VALUES (?, ?) ON CONFLICT(user_id) DO UPDATE SET file_notifications_enabled = ?""", (user_id, int(enabled), int(enabled)), ) conn.commit()
[docs] def get_users_with_file_notifications(self) -> list[int]: """Get all user IDs that have file notifications enabled.""" with self._get_connection() as conn: cursor = conn.execute( "SELECT user_id FROM user_settings WHERE file_notifications_enabled = 1" ) return [row[0] for row in cursor.fetchall()]
[docs] def save_upload( self, user_id: int, original_filename: str, hash_filename: str, extension: str, file_size: int, mime_type: str = "application/pdf", ) -> UploadRecord: """Save upload metadata to the database.""" now = datetime.now().isoformat() with self._get_connection() as conn: cursor = conn.execute( """INSERT INTO uploads (user_id, original_filename, hash_filename, extension, file_size, mime_type, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)""", ( user_id, original_filename, hash_filename, extension, file_size, mime_type, now, ), ) conn.commit() assert cursor.lastrowid is not None return UploadRecord( id=cursor.lastrowid, user_id=user_id, original_filename=original_filename, hash_filename=hash_filename, extension=extension, file_size=file_size, mime_type=mime_type, created_at=now, )
[docs] def get_uploads(self, user_id: int) -> list[UploadRecord]: """Get all uploads for a user.""" with self._get_connection() as conn: cursor = conn.execute( "SELECT id, user_id, original_filename, hash_filename, " "extension, file_size, mime_type, created_at " "FROM uploads WHERE user_id = ? ORDER BY created_at DESC", (user_id,), ) return [self._row_to_upload(row) for row in cursor.fetchall()]
[docs] def get_upload_by_id(self, upload_id: int, user_id: int) -> UploadRecord | None: """Get a single upload by ID, scoped to a user.""" with self._get_connection() as conn: cursor = conn.execute( "SELECT id, user_id, original_filename, hash_filename, " "extension, file_size, mime_type, created_at " "FROM uploads WHERE id = ? AND user_id = ?", (upload_id, user_id), ) row = cursor.fetchone() return self._row_to_upload(row) if row else None
[docs] def delete_upload(self, upload_id: int, user_id: int) -> bool: """Delete an upload record. Returns True if a row was deleted.""" with self._get_connection() as conn: cursor = conn.execute( "DELETE FROM uploads WHERE id = ? AND user_id = ?", (upload_id, user_id), ) conn.commit() return cursor.rowcount > 0
@staticmethod def _row_to_upload(row: tuple) -> UploadRecord: return UploadRecord( id=row[0], user_id=row[1], original_filename=row[2], hash_filename=row[3], extension=row[4], file_size=row[5], mime_type=row[6], created_at=row[7], )
[docs] def record_token_usage( self, user_id: int, input_tokens: int, output_tokens: int, model: str, topic: str, ) -> None: """Record token usage from a single LLM call.""" now = datetime.now().isoformat() with self._get_connection() as conn: conn.execute( """INSERT INTO token_usage (user_id, input_tokens, output_tokens, model, topic, created_at) VALUES (?, ?, ?, ?, ?, ?)""", (user_id, input_tokens, output_tokens, model, topic, now), ) conn.commit()
[docs] def get_token_usage( self, user_id: int, since: str | None = None ) -> list[TokenUsageRecord]: """Get token usage records for a user, optionally filtered by date.""" with self._get_connection() as conn: if since is not None: cursor = conn.execute( "SELECT id, user_id, input_tokens, output_tokens, " "model, topic, created_at " "FROM token_usage WHERE user_id = ? AND created_at >= ? " "ORDER BY created_at", (user_id, since), ) else: cursor = conn.execute( "SELECT id, user_id, input_tokens, output_tokens, " "model, topic, created_at " "FROM token_usage WHERE user_id = ? ORDER BY created_at", (user_id,), ) return [ TokenUsageRecord( id=row[0], user_id=row[1], input_tokens=row[2], output_tokens=row[3], model=row[4], topic=row[5], created_at=row[6], ) for row in cursor.fetchall() ]