from __future__ import annotations import hashlib import json from dataclasses import dataclass from typing import Any def normalize_content(content: Any) -> str: if isinstance(content, str): return content try: return json.dumps(content, ensure_ascii=False, sort_keys=True, separators=(",", ":")) except Exception: return str(content) def serialize_tool_calls(tool_calls: Any) -> str: if tool_calls is None: return "" if isinstance(tool_calls, dict): tool_calls = [tool_calls] if not isinstance(tool_calls, list): return normalize_content(tool_calls) return "\n".join( json.dumps(item, ensure_ascii=False, sort_keys=True, separators=(",", ":")) for item in tool_calls ) def stable_message_fingerprint(message: dict[str, Any]) -> str: role = str(message.get("role", "unknown")) content = normalize_content(message.get("content")) tool_calls = serialize_tool_calls(message.get("tool_calls")) digest = hashlib.blake2b(digest_size=16) digest.update(role.encode("utf-8", "ignore")) digest.update(b"\x1f") digest.update(content.encode("utf-8", "ignore")) digest.update(b"\x1f") digest.update(tool_calls.encode("utf-8", "ignore")) return digest.hexdigest() def build_prefix_hashes(messages: list[dict[str, Any]]) -> list[str]: digest = hashlib.blake2b(digest_size=16) prefixes: list[str] = [] for message in messages: digest.update(stable_message_fingerprint(message).encode("ascii")) digest.update(b"\n") prefixes.append(digest.hexdigest()) return prefixes def build_message_fingerprints(messages: list[dict[str, Any]]) -> list[str]: return [stable_message_fingerprint(message) for message in messages] def build_sequence_hashes(message_fingerprints: list[str]) -> list[str]: digest = hashlib.blake2b(digest_size=16) prefixes: list[str] = [] for fingerprint in message_fingerprints: digest.update(fingerprint.encode("ascii")) digest.update(b"\n") prefixes.append(digest.hexdigest()) return prefixes def encode_prefix_hashes(prefix_hashes: list[str]) -> str: return ",".join(prefix_hashes) def decode_prefix_hashes(encoded: str) -> list[str]: if not encoded: return [] return [item for item in encoded.split(",") if item] def encode_message_fingerprints(message_fingerprints: list[str]) -> str: return ",".join(message_fingerprints) def decode_message_fingerprints(encoded: str) -> list[str]: if not encoded: return [] return [item for item in encoded.split(",") if item] def encode_roles(roles: list[str]) -> str: return ",".join(roles) def decode_roles(encoded: str) -> list[str]: if not encoded: return [] return [item for item in encoded.split(",") if item] def extract_user_id(request_params: dict[str, Any]) -> str: header = request_params.get("header", {}) if isinstance(request_params, dict) else {} attributes = header.get("attributes", {}) if isinstance(header, dict) else {} return str(attributes.get("user_id", "") or "") def build_root_session_id(user_id: str, request_id: str) -> str: digest = hashlib.blake2b(digest_size=10) digest.update(user_id.encode("utf-8", "ignore")) digest.update(b"\x00") digest.update(request_id.encode("utf-8", "ignore")) return f"ls-{digest.hexdigest()}" @dataclass class SessionAssignment: session_id: str parent_request_id: str parent_chat_id: int chat_id: int turn: int @dataclass class _SessionNode: request_id: str session_id: str chat_id: int turn: int message_count: int class LogicalSessionizer: def __init__(self) -> None: self._index: dict[tuple[str, str], _SessionNode] = {} self._next_chat_id = 0 def assign_precomputed( self, *, user_id: str, request_id: str, sequence_hashes: list[str], roles: list[str], ) -> SessionAssignment: parent: _SessionNode | None = None scope_user_id = user_id or f"missing-user:{request_id}" has_user_prefix = False user_prefix_flags: list[bool] = [] for role in roles: if role == "user": has_user_prefix = True user_prefix_flags.append(has_user_prefix) for prefix_len in range(len(sequence_hashes), 0, -1): if prefix_len - 1 >= len(user_prefix_flags) or not user_prefix_flags[prefix_len - 1]: continue candidate = self._index.get((scope_user_id, sequence_hashes[prefix_len - 1])) if candidate is None: continue if prefix_len < len(sequence_hashes) or candidate.message_count > prefix_len: parent = candidate break if parent is None: session_id = build_root_session_id(scope_user_id, request_id) parent_request_id = "" parent_chat_id = -1 turn = 1 else: session_id = parent.session_id parent_request_id = parent.request_id parent_chat_id = parent.chat_id turn = parent.turn + 1 chat_id = self._next_chat_id self._next_chat_id += 1 if sequence_hashes: node = _SessionNode( request_id=request_id, session_id=session_id, chat_id=chat_id, turn=turn, message_count=len(sequence_hashes), ) self._index[(scope_user_id, sequence_hashes[-1])] = node trailing_non_user = 0 for role in reversed(roles): if role == "user": break trailing_non_user += 1 if trailing_non_user > 2: break prefix_len = len(sequence_hashes) - trailing_non_user if prefix_len > 0 and prefix_len - 1 < len(user_prefix_flags) and user_prefix_flags[prefix_len - 1]: self._index[(scope_user_id, sequence_hashes[prefix_len - 1])] = node return SessionAssignment( session_id=session_id, parent_request_id=parent_request_id, parent_chat_id=parent_chat_id, chat_id=chat_id, turn=turn, ) def assign( self, *, user_id: str, request_id: str, message_fingerprints: list[str], roles: list[str], ) -> SessionAssignment: return self.assign_precomputed( user_id=user_id, request_id=request_id, sequence_hashes=build_sequence_hashes(message_fingerprints), roles=roles, )