218 lines
6.6 KiB
Python
218 lines
6.6 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
|
|
def normalize_content(content: Any) -> str:
|
|
if isinstance(content, str):
|
|
return content
|
|
try:
|
|
return json.dumps(content, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
|
except Exception:
|
|
return str(content)
|
|
|
|
|
|
def serialize_tool_calls(tool_calls: Any) -> str:
|
|
if tool_calls is None:
|
|
return ""
|
|
if isinstance(tool_calls, dict):
|
|
tool_calls = [tool_calls]
|
|
if not isinstance(tool_calls, list):
|
|
return normalize_content(tool_calls)
|
|
return "\n".join(
|
|
json.dumps(item, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
|
for item in tool_calls
|
|
)
|
|
|
|
|
|
def stable_message_fingerprint(message: dict[str, Any]) -> str:
|
|
role = str(message.get("role", "unknown"))
|
|
content = normalize_content(message.get("content"))
|
|
tool_calls = serialize_tool_calls(message.get("tool_calls"))
|
|
digest = hashlib.blake2b(digest_size=16)
|
|
digest.update(role.encode("utf-8", "ignore"))
|
|
digest.update(b"\x1f")
|
|
digest.update(content.encode("utf-8", "ignore"))
|
|
digest.update(b"\x1f")
|
|
digest.update(tool_calls.encode("utf-8", "ignore"))
|
|
return digest.hexdigest()
|
|
|
|
|
|
def build_prefix_hashes(messages: list[dict[str, Any]]) -> list[str]:
|
|
digest = hashlib.blake2b(digest_size=16)
|
|
prefixes: list[str] = []
|
|
for message in messages:
|
|
digest.update(stable_message_fingerprint(message).encode("ascii"))
|
|
digest.update(b"\n")
|
|
prefixes.append(digest.hexdigest())
|
|
return prefixes
|
|
|
|
|
|
def build_message_fingerprints(messages: list[dict[str, Any]]) -> list[str]:
|
|
return [stable_message_fingerprint(message) for message in messages]
|
|
|
|
|
|
def build_sequence_hashes(message_fingerprints: list[str]) -> list[str]:
|
|
digest = hashlib.blake2b(digest_size=16)
|
|
prefixes: list[str] = []
|
|
for fingerprint in message_fingerprints:
|
|
digest.update(fingerprint.encode("ascii"))
|
|
digest.update(b"\n")
|
|
prefixes.append(digest.hexdigest())
|
|
return prefixes
|
|
|
|
|
|
def encode_prefix_hashes(prefix_hashes: list[str]) -> str:
|
|
return ",".join(prefix_hashes)
|
|
|
|
|
|
def decode_prefix_hashes(encoded: str) -> list[str]:
|
|
if not encoded:
|
|
return []
|
|
return [item for item in encoded.split(",") if item]
|
|
|
|
|
|
def encode_message_fingerprints(message_fingerprints: list[str]) -> str:
|
|
return ",".join(message_fingerprints)
|
|
|
|
|
|
def decode_message_fingerprints(encoded: str) -> list[str]:
|
|
if not encoded:
|
|
return []
|
|
return [item for item in encoded.split(",") if item]
|
|
|
|
|
|
def encode_roles(roles: list[str]) -> str:
|
|
return ",".join(roles)
|
|
|
|
|
|
def decode_roles(encoded: str) -> list[str]:
|
|
if not encoded:
|
|
return []
|
|
return [item for item in encoded.split(",") if item]
|
|
|
|
|
|
def extract_user_id(request_params: dict[str, Any]) -> str:
|
|
header = request_params.get("header", {}) if isinstance(request_params, dict) else {}
|
|
attributes = header.get("attributes", {}) if isinstance(header, dict) else {}
|
|
return str(attributes.get("user_id", "") or "")
|
|
|
|
|
|
def build_root_session_id(user_id: str, request_id: str) -> str:
|
|
digest = hashlib.blake2b(digest_size=10)
|
|
digest.update(user_id.encode("utf-8", "ignore"))
|
|
digest.update(b"\x00")
|
|
digest.update(request_id.encode("utf-8", "ignore"))
|
|
return f"ls-{digest.hexdigest()}"
|
|
|
|
|
|
@dataclass
|
|
class SessionAssignment:
|
|
session_id: str
|
|
parent_request_id: str
|
|
parent_chat_id: int
|
|
chat_id: int
|
|
turn: int
|
|
|
|
|
|
@dataclass
|
|
class _SessionNode:
|
|
request_id: str
|
|
session_id: str
|
|
chat_id: int
|
|
turn: int
|
|
message_count: int
|
|
|
|
|
|
class LogicalSessionizer:
|
|
def __init__(self) -> None:
|
|
self._index: dict[tuple[str, str], _SessionNode] = {}
|
|
self._next_chat_id = 0
|
|
|
|
def assign_precomputed(
|
|
self,
|
|
*,
|
|
user_id: str,
|
|
request_id: str,
|
|
sequence_hashes: list[str],
|
|
roles: list[str],
|
|
) -> SessionAssignment:
|
|
parent: _SessionNode | None = None
|
|
scope_user_id = user_id or f"missing-user:{request_id}"
|
|
has_user_prefix = False
|
|
user_prefix_flags: list[bool] = []
|
|
for role in roles:
|
|
if role == "user":
|
|
has_user_prefix = True
|
|
user_prefix_flags.append(has_user_prefix)
|
|
|
|
for prefix_len in range(len(sequence_hashes), 0, -1):
|
|
if prefix_len - 1 >= len(user_prefix_flags) or not user_prefix_flags[prefix_len - 1]:
|
|
continue
|
|
candidate = self._index.get((scope_user_id, sequence_hashes[prefix_len - 1]))
|
|
if candidate is None:
|
|
continue
|
|
if prefix_len < len(sequence_hashes) or candidate.message_count > prefix_len:
|
|
parent = candidate
|
|
break
|
|
|
|
if parent is None:
|
|
session_id = build_root_session_id(scope_user_id, request_id)
|
|
parent_request_id = ""
|
|
parent_chat_id = -1
|
|
turn = 1
|
|
else:
|
|
session_id = parent.session_id
|
|
parent_request_id = parent.request_id
|
|
parent_chat_id = parent.chat_id
|
|
turn = parent.turn + 1
|
|
|
|
chat_id = self._next_chat_id
|
|
self._next_chat_id += 1
|
|
|
|
if sequence_hashes:
|
|
node = _SessionNode(
|
|
request_id=request_id,
|
|
session_id=session_id,
|
|
chat_id=chat_id,
|
|
turn=turn,
|
|
message_count=len(sequence_hashes),
|
|
)
|
|
self._index[(scope_user_id, sequence_hashes[-1])] = node
|
|
trailing_non_user = 0
|
|
for role in reversed(roles):
|
|
if role == "user":
|
|
break
|
|
trailing_non_user += 1
|
|
if trailing_non_user > 2:
|
|
break
|
|
prefix_len = len(sequence_hashes) - trailing_non_user
|
|
if prefix_len > 0 and prefix_len - 1 < len(user_prefix_flags) and user_prefix_flags[prefix_len - 1]:
|
|
self._index[(scope_user_id, sequence_hashes[prefix_len - 1])] = node
|
|
|
|
return SessionAssignment(
|
|
session_id=session_id,
|
|
parent_request_id=parent_request_id,
|
|
parent_chat_id=parent_chat_id,
|
|
chat_id=chat_id,
|
|
turn=turn,
|
|
)
|
|
|
|
def assign(
|
|
self,
|
|
*,
|
|
user_id: str,
|
|
request_id: str,
|
|
message_fingerprints: list[str],
|
|
roles: list[str],
|
|
) -> SessionAssignment:
|
|
return self.assign_precomputed(
|
|
user_id=user_id,
|
|
request_id=request_id,
|
|
sequence_hashes=build_sequence_hashes(message_fingerprints),
|
|
roles=roles,
|
|
)
|