Initial commit
This commit is contained in:
217
trace_formatter/sessionization.py
Normal file
217
trace_formatter/sessionization.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
def normalize_content(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
try:
|
||||
return json.dumps(content, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
except Exception:
|
||||
return str(content)
|
||||
|
||||
|
||||
def serialize_tool_calls(tool_calls: Any) -> str:
|
||||
if tool_calls is None:
|
||||
return ""
|
||||
if isinstance(tool_calls, dict):
|
||||
tool_calls = [tool_calls]
|
||||
if not isinstance(tool_calls, list):
|
||||
return normalize_content(tool_calls)
|
||||
return "\n".join(
|
||||
json.dumps(item, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
for item in tool_calls
|
||||
)
|
||||
|
||||
|
||||
def stable_message_fingerprint(message: dict[str, Any]) -> str:
|
||||
role = str(message.get("role", "unknown"))
|
||||
content = normalize_content(message.get("content"))
|
||||
tool_calls = serialize_tool_calls(message.get("tool_calls"))
|
||||
digest = hashlib.blake2b(digest_size=16)
|
||||
digest.update(role.encode("utf-8", "ignore"))
|
||||
digest.update(b"\x1f")
|
||||
digest.update(content.encode("utf-8", "ignore"))
|
||||
digest.update(b"\x1f")
|
||||
digest.update(tool_calls.encode("utf-8", "ignore"))
|
||||
return digest.hexdigest()
|
||||
|
||||
|
||||
def build_prefix_hashes(messages: list[dict[str, Any]]) -> list[str]:
|
||||
digest = hashlib.blake2b(digest_size=16)
|
||||
prefixes: list[str] = []
|
||||
for message in messages:
|
||||
digest.update(stable_message_fingerprint(message).encode("ascii"))
|
||||
digest.update(b"\n")
|
||||
prefixes.append(digest.hexdigest())
|
||||
return prefixes
|
||||
|
||||
|
||||
def build_message_fingerprints(messages: list[dict[str, Any]]) -> list[str]:
|
||||
return [stable_message_fingerprint(message) for message in messages]
|
||||
|
||||
|
||||
def build_sequence_hashes(message_fingerprints: list[str]) -> list[str]:
|
||||
digest = hashlib.blake2b(digest_size=16)
|
||||
prefixes: list[str] = []
|
||||
for fingerprint in message_fingerprints:
|
||||
digest.update(fingerprint.encode("ascii"))
|
||||
digest.update(b"\n")
|
||||
prefixes.append(digest.hexdigest())
|
||||
return prefixes
|
||||
|
||||
|
||||
def encode_prefix_hashes(prefix_hashes: list[str]) -> str:
|
||||
return ",".join(prefix_hashes)
|
||||
|
||||
|
||||
def decode_prefix_hashes(encoded: str) -> list[str]:
|
||||
if not encoded:
|
||||
return []
|
||||
return [item for item in encoded.split(",") if item]
|
||||
|
||||
|
||||
def encode_message_fingerprints(message_fingerprints: list[str]) -> str:
|
||||
return ",".join(message_fingerprints)
|
||||
|
||||
|
||||
def decode_message_fingerprints(encoded: str) -> list[str]:
|
||||
if not encoded:
|
||||
return []
|
||||
return [item for item in encoded.split(",") if item]
|
||||
|
||||
|
||||
def encode_roles(roles: list[str]) -> str:
|
||||
return ",".join(roles)
|
||||
|
||||
|
||||
def decode_roles(encoded: str) -> list[str]:
|
||||
if not encoded:
|
||||
return []
|
||||
return [item for item in encoded.split(",") if item]
|
||||
|
||||
|
||||
def extract_user_id(request_params: dict[str, Any]) -> str:
|
||||
header = request_params.get("header", {}) if isinstance(request_params, dict) else {}
|
||||
attributes = header.get("attributes", {}) if isinstance(header, dict) else {}
|
||||
return str(attributes.get("user_id", "") or "")
|
||||
|
||||
|
||||
def build_root_session_id(user_id: str, request_id: str) -> str:
|
||||
digest = hashlib.blake2b(digest_size=10)
|
||||
digest.update(user_id.encode("utf-8", "ignore"))
|
||||
digest.update(b"\x00")
|
||||
digest.update(request_id.encode("utf-8", "ignore"))
|
||||
return f"ls-{digest.hexdigest()}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionAssignment:
|
||||
session_id: str
|
||||
parent_request_id: str
|
||||
parent_chat_id: int
|
||||
chat_id: int
|
||||
turn: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SessionNode:
|
||||
request_id: str
|
||||
session_id: str
|
||||
chat_id: int
|
||||
turn: int
|
||||
message_count: int
|
||||
|
||||
|
||||
class LogicalSessionizer:
|
||||
def __init__(self) -> None:
|
||||
self._index: dict[tuple[str, str], _SessionNode] = {}
|
||||
self._next_chat_id = 0
|
||||
|
||||
def assign_precomputed(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
request_id: str,
|
||||
sequence_hashes: list[str],
|
||||
roles: list[str],
|
||||
) -> SessionAssignment:
|
||||
parent: _SessionNode | None = None
|
||||
scope_user_id = user_id or f"missing-user:{request_id}"
|
||||
has_user_prefix = False
|
||||
user_prefix_flags: list[bool] = []
|
||||
for role in roles:
|
||||
if role == "user":
|
||||
has_user_prefix = True
|
||||
user_prefix_flags.append(has_user_prefix)
|
||||
|
||||
for prefix_len in range(len(sequence_hashes), 0, -1):
|
||||
if prefix_len - 1 >= len(user_prefix_flags) or not user_prefix_flags[prefix_len - 1]:
|
||||
continue
|
||||
candidate = self._index.get((scope_user_id, sequence_hashes[prefix_len - 1]))
|
||||
if candidate is None:
|
||||
continue
|
||||
if prefix_len < len(sequence_hashes) or candidate.message_count > prefix_len:
|
||||
parent = candidate
|
||||
break
|
||||
|
||||
if parent is None:
|
||||
session_id = build_root_session_id(scope_user_id, request_id)
|
||||
parent_request_id = ""
|
||||
parent_chat_id = -1
|
||||
turn = 1
|
||||
else:
|
||||
session_id = parent.session_id
|
||||
parent_request_id = parent.request_id
|
||||
parent_chat_id = parent.chat_id
|
||||
turn = parent.turn + 1
|
||||
|
||||
chat_id = self._next_chat_id
|
||||
self._next_chat_id += 1
|
||||
|
||||
if sequence_hashes:
|
||||
node = _SessionNode(
|
||||
request_id=request_id,
|
||||
session_id=session_id,
|
||||
chat_id=chat_id,
|
||||
turn=turn,
|
||||
message_count=len(sequence_hashes),
|
||||
)
|
||||
self._index[(scope_user_id, sequence_hashes[-1])] = node
|
||||
trailing_non_user = 0
|
||||
for role in reversed(roles):
|
||||
if role == "user":
|
||||
break
|
||||
trailing_non_user += 1
|
||||
if trailing_non_user > 2:
|
||||
break
|
||||
prefix_len = len(sequence_hashes) - trailing_non_user
|
||||
if prefix_len > 0 and prefix_len - 1 < len(user_prefix_flags) and user_prefix_flags[prefix_len - 1]:
|
||||
self._index[(scope_user_id, sequence_hashes[prefix_len - 1])] = node
|
||||
|
||||
return SessionAssignment(
|
||||
session_id=session_id,
|
||||
parent_request_id=parent_request_id,
|
||||
parent_chat_id=parent_chat_id,
|
||||
chat_id=chat_id,
|
||||
turn=turn,
|
||||
)
|
||||
|
||||
def assign(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
request_id: str,
|
||||
message_fingerprints: list[str],
|
||||
roles: list[str],
|
||||
) -> SessionAssignment:
|
||||
return self.assign_precomputed(
|
||||
user_id=user_id,
|
||||
request_id=request_id,
|
||||
sequence_hashes=build_sequence_hashes(message_fingerprints),
|
||||
roles=roles,
|
||||
)
|
||||
Reference in New Issue
Block a user