"""Conversation session management with SQLite storage."""
import json
import re
import sqlite3
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import TypedDict
_TOPIC_PATTERN = re.compile(r"^[a-z0-9_]+$")
[docs]
def validate_topic_name(name: str) -> str:
"""Normalize and validate a topic name.
Converts to lowercase, replaces spaces/hyphens with underscores,
strips non-allowed characters, and validates the result.
Raises ValueError if the result is empty after normalization.
"""
normalized = name.lower().strip()
normalized = re.sub(r"[\s\-]+", "_", normalized)
normalized = re.sub(r"[^a-z0-9_]", "", normalized)
normalized = re.sub(r"_+", "_", normalized)
normalized = normalized.strip("_")
if not normalized:
raise ValueError(f"Topic name '{name}' is empty after normalization")
return normalized
[docs]
class Message(TypedDict):
role: str # "user" or "assistant"
content: str
timestamp: str
topic: str # conversation topic label
[docs]
@dataclass(frozen=True)
class TopicReassignment:
"""A single topic reassignment instruction."""
message_index: int
new_topic: str
[docs]
@dataclass(frozen=True)
class UploadRecord:
"""Metadata for an uploaded file."""
id: int
user_id: int
original_filename: str
hash_filename: str
extension: str
file_size: int
mime_type: str
created_at: str
[docs]
@dataclass(frozen=True)
class TokenUsageRecord:
"""A single token usage record."""
id: int
user_id: int
input_tokens: int
output_tokens: int
model: str
topic: str
created_at: str
[docs]
class SessionManager:
"""Manages conversation history in SQLite."""
[docs]
def __init__(self, db_path: Path):
self.db_path = db_path
self._ensure_db()
def _ensure_db(self) -> None:
"""Create database and tables if they don't exist."""
self.db_path.parent.mkdir(parents=True, exist_ok=True)
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS sessions (
user_id INTEGER PRIMARY KEY,
messages TEXT NOT NULL DEFAULT '[]',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS user_settings (
user_id INTEGER PRIMARY KEY,
last_notified_version TEXT,
file_notifications_enabled INTEGER NOT NULL DEFAULT 0
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS uploads (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
original_filename TEXT NOT NULL,
hash_filename TEXT NOT NULL,
extension TEXT NOT NULL DEFAULT '',
file_size INTEGER NOT NULL,
mime_type TEXT NOT NULL DEFAULT 'application/pdf',
created_at TEXT NOT NULL
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS token_usage (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
input_tokens INTEGER NOT NULL,
output_tokens INTEGER NOT NULL,
model TEXT NOT NULL,
topic TEXT NOT NULL DEFAULT 'general',
created_at TEXT NOT NULL
)
""")
# Migration: add file_notifications_enabled if it doesn't exist
try:
conn.execute(
"ALTER TABLE user_settings ADD COLUMN "
"file_notifications_enabled INTEGER NOT NULL DEFAULT 0"
)
except sqlite3.OperationalError:
pass # Column already exists
conn.commit()
def _get_connection(self) -> sqlite3.Connection:
return sqlite3.connect(self.db_path)
@staticmethod
def _ensure_topic(messages: list[dict]) -> list[Message]:
"""Backfill 'general' topic on old messages that lack it."""
for msg in messages:
if "topic" not in msg:
msg["topic"] = "general"
return messages # type: ignore[return-value]
[docs]
def get_history(self, user_id: int, topic: str | None = None) -> list[Message]:
"""Get conversation history for a user, optionally filtered by topic."""
with self._get_connection() as conn:
cursor = conn.execute(
"SELECT messages FROM sessions WHERE user_id = ?", (user_id,)
)
row = cursor.fetchone()
if not row:
return []
messages = self._ensure_topic(json.loads(row[0]))
if topic is not None:
return [m for m in messages if m["topic"] == topic]
return messages
[docs]
def add_message(
self, user_id: int, role: str, content: str, topic: str = "general"
) -> None:
"""Add a message to the conversation history."""
topic = validate_topic_name(topic)
now = datetime.now().isoformat()
with self._get_connection() as conn:
# Get existing messages
cursor = conn.execute(
"SELECT messages FROM sessions WHERE user_id = ?", (user_id,)
)
row = cursor.fetchone()
if row:
messages = json.loads(row[0])
else:
messages = []
# Add new message
messages.append(
{
"role": role,
"content": content,
"timestamp": now,
"topic": topic,
}
)
# Update or insert
if row:
conn.execute(
"UPDATE sessions SET messages = ?, updated_at = ? WHERE user_id = ?",
(json.dumps(messages), now, user_id),
)
else:
conn.execute(
"INSERT INTO sessions (user_id, messages, created_at, updated_at) VALUES (?, ?, ?, ?)",
(user_id, json.dumps(messages), now, now),
)
conn.commit()
[docs]
def get_topics(self, user_id: int) -> list[dict]:
"""Get unique topics with message counts for a user."""
messages = self.get_history(user_id)
if not messages:
return []
topic_stats: dict[str, dict] = {}
for msg in messages:
t = msg["topic"]
if t not in topic_stats:
topic_stats[t] = {
"topic": t,
"message_count": 0,
"last_activity": msg["timestamp"],
}
topic_stats[t]["message_count"] += 1
topic_stats[t]["last_activity"] = msg["timestamp"]
return sorted(
topic_stats.values(), key=lambda x: x["last_activity"], reverse=True
)
[docs]
def update_message_topics(
self, user_id: int, reassignments: list[TopicReassignment]
) -> int:
"""Bulk-update topic labels on messages by index.
Returns the number of messages actually updated.
Raises ValueError if any topic name is invalid or index is out of bounds.
"""
validated: list[tuple[int, str]] = []
for r in reassignments:
validated.append((r.message_index, validate_topic_name(r.new_topic)))
with self._get_connection() as conn:
cursor = conn.execute(
"SELECT messages FROM sessions WHERE user_id = ?", (user_id,)
)
row = cursor.fetchone()
if not row:
return 0
messages = self._ensure_topic(json.loads(row[0]))
for idx, _ in validated:
if idx < 0 or idx >= len(messages):
raise ValueError(
f"Message index {idx} out of range (0-{len(messages) - 1})"
)
updated_count = 0
for idx, new_topic in validated:
if messages[idx]["topic"] != new_topic:
messages[idx]["topic"] = new_topic
updated_count += 1
if updated_count > 0:
now = datetime.now().isoformat()
conn.execute(
"UPDATE sessions SET messages = ?, updated_at = ? WHERE user_id = ?",
(json.dumps(messages), now, user_id),
)
conn.commit()
return updated_count
[docs]
def clear_session(self, user_id: int) -> None:
"""Clear conversation history for a user."""
now = datetime.now().isoformat()
with self._get_connection() as conn:
conn.execute(
"UPDATE sessions SET messages = '[]', updated_at = ? WHERE user_id = ?",
(now, user_id),
)
conn.commit()
[docs]
def get_session_info(self, user_id: int) -> dict:
"""Get session metadata."""
with self._get_connection() as conn:
cursor = conn.execute(
"SELECT messages, created_at, updated_at FROM sessions WHERE user_id = ?",
(user_id,),
)
row = cursor.fetchone()
if row:
messages = json.loads(row[0])
return {
"message_count": len(messages),
"created_at": row[1],
"updated_at": row[2],
}
return {
"message_count": 0,
"created_at": None,
"updated_at": None,
}
[docs]
def get_last_notified_version(self, user_id: int) -> str | None:
"""Get the last version the user was notified about."""
with self._get_connection() as conn:
cursor = conn.execute(
"SELECT last_notified_version FROM user_settings WHERE user_id = ?",
(user_id,),
)
row = cursor.fetchone()
return row[0] if row else None
[docs]
def set_last_notified_version(self, user_id: int, version: str) -> None:
"""Set the last version the user was notified about."""
with self._get_connection() as conn:
conn.execute(
"""INSERT INTO user_settings (user_id, last_notified_version)
VALUES (?, ?)
ON CONFLICT(user_id) DO UPDATE SET last_notified_version = ?""",
(user_id, version, version),
)
conn.commit()
[docs]
def get_file_notifications_enabled(self, user_id: int) -> bool:
"""Check if file notifications are enabled for a user."""
with self._get_connection() as conn:
cursor = conn.execute(
"SELECT file_notifications_enabled FROM user_settings WHERE user_id = ?",
(user_id,),
)
row = cursor.fetchone()
return bool(row[0]) if row else False
[docs]
def set_file_notifications_enabled(self, user_id: int, enabled: bool) -> None:
"""Enable or disable file notifications for a user."""
with self._get_connection() as conn:
conn.execute(
"""INSERT INTO user_settings (user_id, file_notifications_enabled)
VALUES (?, ?)
ON CONFLICT(user_id) DO UPDATE SET file_notifications_enabled = ?""",
(user_id, int(enabled), int(enabled)),
)
conn.commit()
[docs]
def get_users_with_file_notifications(self) -> list[int]:
"""Get all user IDs that have file notifications enabled."""
with self._get_connection() as conn:
cursor = conn.execute(
"SELECT user_id FROM user_settings WHERE file_notifications_enabled = 1"
)
return [row[0] for row in cursor.fetchall()]
[docs]
def save_upload(
self,
user_id: int,
original_filename: str,
hash_filename: str,
extension: str,
file_size: int,
mime_type: str = "application/pdf",
) -> UploadRecord:
"""Save upload metadata to the database."""
now = datetime.now().isoformat()
with self._get_connection() as conn:
cursor = conn.execute(
"""INSERT INTO uploads
(user_id, original_filename, hash_filename, extension,
file_size, mime_type, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(
user_id,
original_filename,
hash_filename,
extension,
file_size,
mime_type,
now,
),
)
conn.commit()
assert cursor.lastrowid is not None
return UploadRecord(
id=cursor.lastrowid,
user_id=user_id,
original_filename=original_filename,
hash_filename=hash_filename,
extension=extension,
file_size=file_size,
mime_type=mime_type,
created_at=now,
)
[docs]
def get_uploads(self, user_id: int) -> list[UploadRecord]:
"""Get all uploads for a user."""
with self._get_connection() as conn:
cursor = conn.execute(
"SELECT id, user_id, original_filename, hash_filename, "
"extension, file_size, mime_type, created_at "
"FROM uploads WHERE user_id = ? ORDER BY created_at DESC",
(user_id,),
)
return [self._row_to_upload(row) for row in cursor.fetchall()]
[docs]
def get_upload_by_id(self, upload_id: int, user_id: int) -> UploadRecord | None:
"""Get a single upload by ID, scoped to a user."""
with self._get_connection() as conn:
cursor = conn.execute(
"SELECT id, user_id, original_filename, hash_filename, "
"extension, file_size, mime_type, created_at "
"FROM uploads WHERE id = ? AND user_id = ?",
(upload_id, user_id),
)
row = cursor.fetchone()
return self._row_to_upload(row) if row else None
[docs]
def delete_upload(self, upload_id: int, user_id: int) -> bool:
"""Delete an upload record. Returns True if a row was deleted."""
with self._get_connection() as conn:
cursor = conn.execute(
"DELETE FROM uploads WHERE id = ? AND user_id = ?",
(upload_id, user_id),
)
conn.commit()
return cursor.rowcount > 0
@staticmethod
def _row_to_upload(row: tuple) -> UploadRecord:
return UploadRecord(
id=row[0],
user_id=row[1],
original_filename=row[2],
hash_filename=row[3],
extension=row[4],
file_size=row[5],
mime_type=row[6],
created_at=row[7],
)
[docs]
def record_token_usage(
self,
user_id: int,
input_tokens: int,
output_tokens: int,
model: str,
topic: str,
) -> None:
"""Record token usage from a single LLM call."""
now = datetime.now().isoformat()
with self._get_connection() as conn:
conn.execute(
"""INSERT INTO token_usage
(user_id, input_tokens, output_tokens, model, topic, created_at)
VALUES (?, ?, ?, ?, ?, ?)""",
(user_id, input_tokens, output_tokens, model, topic, now),
)
conn.commit()
[docs]
def get_token_usage(
self, user_id: int, since: str | None = None
) -> list[TokenUsageRecord]:
"""Get token usage records for a user, optionally filtered by date."""
with self._get_connection() as conn:
if since is not None:
cursor = conn.execute(
"SELECT id, user_id, input_tokens, output_tokens, "
"model, topic, created_at "
"FROM token_usage WHERE user_id = ? AND created_at >= ? "
"ORDER BY created_at",
(user_id, since),
)
else:
cursor = conn.execute(
"SELECT id, user_id, input_tokens, output_tokens, "
"model, topic, created_at "
"FROM token_usage WHERE user_id = ? ORDER BY created_at",
(user_id,),
)
return [
TokenUsageRecord(
id=row[0],
user_id=row[1],
input_tokens=row[2],
output_tokens=row[3],
model=row[4],
topic=row[5],
created_at=row[6],
)
for row in cursor.fetchall()
]