134 lines
4.4 KiB
Python
134 lines
4.4 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import os
|
||
|
|
import re
|
||
|
|
import uuid
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from tinydb import TinyDB, Query
|
||
|
|
import config
|
||
|
|
|
||
|
|
DB_DIR = os.path.dirname(config.SESSION_DB_PATH)
|
||
|
|
os.makedirs(DB_DIR, exist_ok=True)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class NeoSession:
|
||
|
|
session_id: str
|
||
|
|
name: str
|
||
|
|
created_at: str
|
||
|
|
updated_at: str
|
||
|
|
model_info: dict
|
||
|
|
messages: list = field(default_factory=list)
|
||
|
|
doc_id: Optional[int] = None
|
||
|
|
|
||
|
|
|
||
|
|
class NeoSessionManager:
|
||
|
|
def __init__(self, db_path: str = config.SESSION_DB_PATH):
|
||
|
|
self._db = TinyDB(db_path)
|
||
|
|
self._table = self._db.table("sessions")
|
||
|
|
|
||
|
|
def create(self, name: str, model_info: dict) -> NeoSession:
|
||
|
|
now = datetime.now(timezone.utc).isoformat()
|
||
|
|
doc = {
|
||
|
|
"session_id": str(uuid.uuid4()),
|
||
|
|
"name": name,
|
||
|
|
"created_at": now,
|
||
|
|
"updated_at": now,
|
||
|
|
"model_info": model_info,
|
||
|
|
"messages": [],
|
||
|
|
}
|
||
|
|
doc_id = self._table.insert(doc)
|
||
|
|
return self._doc_to_session(self._table.get(doc_id=doc_id))
|
||
|
|
|
||
|
|
def list(self) -> list[dict]:
|
||
|
|
results = []
|
||
|
|
for doc in self._table.all():
|
||
|
|
results.append({
|
||
|
|
"doc_id": doc.doc_id,
|
||
|
|
"session_id": doc["session_id"],
|
||
|
|
"name": doc["name"],
|
||
|
|
"created_at": doc["created_at"],
|
||
|
|
"updated_at": doc["updated_at"],
|
||
|
|
"model_info": doc["model_info"],
|
||
|
|
"message_count": len(doc["messages"]),
|
||
|
|
})
|
||
|
|
return sorted(results, key=lambda x: x["updated_at"], reverse=True)
|
||
|
|
|
||
|
|
def get(self, doc_id: int) -> Optional[NeoSession]:
|
||
|
|
doc = self._table.get(doc_id=doc_id)
|
||
|
|
if doc is None:
|
||
|
|
return None
|
||
|
|
return self._doc_to_session(doc)
|
||
|
|
|
||
|
|
def rename(self, doc_id: int, new_name: str) -> bool:
|
||
|
|
if not new_name.strip():
|
||
|
|
return False
|
||
|
|
return self._table.update({"name": new_name.strip()}, doc_ids=[doc_id])
|
||
|
|
|
||
|
|
def delete(self, doc_id: int) -> bool:
|
||
|
|
return len(self._table.remove(doc_ids=[doc_id])) > 0
|
||
|
|
|
||
|
|
def search(self, query: str) -> list[dict]:
|
||
|
|
q = Query()
|
||
|
|
results = []
|
||
|
|
for doc in self._table.search(q.name.search(query, flags=re.IGNORECASE)):
|
||
|
|
results.append({
|
||
|
|
"doc_id": doc.doc_id,
|
||
|
|
"session_id": doc["session_id"],
|
||
|
|
"name": doc["name"],
|
||
|
|
"created_at": doc["created_at"],
|
||
|
|
"updated_at": doc["updated_at"],
|
||
|
|
"model_info": doc["model_info"],
|
||
|
|
"message_count": len(doc["messages"]),
|
||
|
|
})
|
||
|
|
return sorted(results, key=lambda x: x["updated_at"], reverse=True)
|
||
|
|
|
||
|
|
def search_messages(self, query: str) -> list[dict]:
|
||
|
|
q = Query()
|
||
|
|
results = []
|
||
|
|
for doc in self._table.all():
|
||
|
|
for msg in doc["messages"]:
|
||
|
|
content = msg.get("content", "")
|
||
|
|
if query.lower() in content.lower():
|
||
|
|
results.append({
|
||
|
|
"doc_id": doc.doc_id,
|
||
|
|
"session_name": doc["name"],
|
||
|
|
"role": msg.get("role"),
|
||
|
|
"content": content,
|
||
|
|
"session_id": doc["session_id"],
|
||
|
|
})
|
||
|
|
break
|
||
|
|
return results
|
||
|
|
|
||
|
|
def add_message(self, doc_id: int, role: str, content: str, **kwargs) -> bool:
|
||
|
|
doc = self._table.get(doc_id=doc_id)
|
||
|
|
if doc is None:
|
||
|
|
return False
|
||
|
|
msg = {"role": role, "content": content}
|
||
|
|
msg.update(kwargs)
|
||
|
|
messages = doc["messages"]
|
||
|
|
messages.append(msg)
|
||
|
|
now = datetime.now(timezone.utc).isoformat()
|
||
|
|
return self._table.update(
|
||
|
|
{"messages": messages, "updated_at": now},
|
||
|
|
doc_ids=[doc_id],
|
||
|
|
)
|
||
|
|
|
||
|
|
def update_model_info(self, doc_id: int, model_info: dict) -> bool:
|
||
|
|
return self._table.update({"model_info": model_info}, doc_ids=[doc_id])
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _doc_to_session(doc) -> NeoSession:
|
||
|
|
return NeoSession(
|
||
|
|
doc_id=doc.doc_id,
|
||
|
|
session_id=doc["session_id"],
|
||
|
|
name=doc["name"],
|
||
|
|
created_at=doc["created_at"],
|
||
|
|
updated_at=doc["updated_at"],
|
||
|
|
model_info=doc["model_info"],
|
||
|
|
messages=doc.get("messages", []),
|
||
|
|
)
|