"""
Session Manager
===============
Coordinates a single negotiation “session” between two (or more) LLM agents
playing a specific game. Responsibilities:
1. Game bootstrap – Creates the game instance and initial state.
2. Turn scheduling – Alternates prompts/actions between players.
3. Action validation – Uses each game’s `is_valid_action` method.
4. State transition & logging – Calls `process_actions` and keeps history.
5. Stopping criteria – Ends when the game says it is over.
6. Metric computation – Delegates to MetricsCalculator once done.
7. Fault tolerance – Catches malformed LLM outputs & retries.
"""
from __future__ import annotations
import logging
import uuid
from collections import defaultdict
from datetime import datetime
from typing import Dict, List, Any, Optional
from negotiation_platform.models.base_model import BaseLLMModel
from .game_engine import GameEngine
from .llm_manager import LLMManager
from .metrics_calculator import MetricsCalculator
[docs]
class SessionManager:
"""
High-level driver that orchestrates complete negotiation sessions between AI agents.
This class coordinates all aspects of a negotiation session, from game initialization
through completion and metrics calculation. It serves as the central orchestrator
that manages turn-based interactions, validates actions, handles errors, and computes
final performance metrics.
Key Responsibilities:
- Game bootstrap and initial state creation
- Turn-based scheduling and player action coordination
- Action validation using game-specific rules
- State transition management and action history logging
- Game termination detection based on stopping criteria
- Comprehensive metrics calculation via MetricsCalculator
- Fault tolerance for malformed LLM outputs with retry logic
- Winner determination and performance analysis
Workflow:
1. Initialize game instance with specified configuration
2. Establish initial game state and player assignments
3. Coordinate alternating turns between players
4. Validate each action against game rules
5. Process valid actions and update game state
6. Log all actions and state transitions
7. Check termination conditions after each round
8. Calculate comprehensive metrics upon completion
9. Return enriched results with performance analysis
Attributes:
llm_manager (LLMManager): Manages AI model loading and interaction.
game_engine (GameEngine): Creates and manages game instances.
metrics_calculator (MetricsCalculator): Computes performance metrics.
max_turn_retries (int): Maximum retry attempts for invalid actions.
logger (logging.Logger): Logger for session events and debugging.
Example:
>>> llm_manager = LLMManager(model_configs)
>>> game_engine = GameEngine()
>>> metrics_calc = MetricsCalculator()
>>> session = SessionManager(llm_manager, game_engine, metrics_calc)
>>> result = session.run_negotiation(
... game_type="price_bargaining",
... players=["model_a", "model_b"],
... game_config={"max_rounds": 5}
... )
>>> print(result['agreement_reached'])
True
Raises:
ValueError: If invalid game type or player configuration provided.
RuntimeError: If session execution fails due to unrecoverable errors.
"""
# ------------------------------------------------------------------ #
# CONSTRUCTION #
# ------------------------------------------------------------------ #
[docs]
def __init__(
self,
llm_manager: LLMManager,
game_engine: GameEngine,
metrics_calculator: MetricsCalculator,
*,
max_turn_retries: int = 3,
logger: Optional[logging.Logger] = None,
) -> None:
"""
Initialize a new SessionManager instance with required components.
Creates a session manager that orchestrates negotiations between AI agents
using the provided LLM manager, game engine, and metrics calculator.
Args:
llm_manager (LLMManager): Manager for loading and interacting with AI models.
Must be configured with the models that will participate in negotiations.
game_engine (GameEngine): Engine for creating and managing game instances.
Should have all required game types registered.
metrics_calculator (MetricsCalculator): Calculator for computing performance metrics.
Will be used to analyze negotiation outcomes and player performance.
max_turn_retries (int, optional): Maximum number of retry attempts when a player
provides an invalid action. Defaults to 3. Higher values increase robustness
but may slow down sessions with consistently invalid players.
logger (logging.Logger, optional): Logger for session events and debugging.
If None, creates a new logger using the class name.
Example:
>>> llm_manager = LLMManager({"model_a": model_config})
>>> game_engine = GameEngine()
>>> metrics_calc = MetricsCalculator()
>>> session = SessionManager(
... llm_manager=llm_manager,
... game_engine=game_engine,
... metrics_calculator=metrics_calc,
... max_turn_retries=5 # Allow more retries for unstable models
... )
"""
self.llm_manager = llm_manager
self.game_engine = game_engine
self.metrics_calculator = metrics_calculator
self.max_turn_retries = max_turn_retries
self.logger = logger or logging.getLogger(self.__class__.__name__)
def _coerce_action_numeric_fields(self, action: Any, *, session_id: Optional[str] = None, player_name: Optional[str] = None) -> Any:
"""
Coerce numeric fields that are floats but represent integers into ints.
For safety, coercion is limited to a whitelist of field names which are
expected to be integers in our games (e.g., gpu_hours, cpu_hours, price,
quantity). The function walks dicts and lists recursively; when a dict
key matches the whitelist and its value is an integral float, it will be
converted to int. All other values are left unchanged.
If a coercion happens, a debug-level log entry is emitted including
optional session/player context.
"""
# Whitelist keys that should be integer values
INT_WHITELIST = {"gpu_hours", "cpu_hours", "price", "quantity"}
def _walk(obj: Any, key_name: Optional[str] = None) -> Any:
if isinstance(obj, dict):
new = {}
for k, v in obj.items():
new[k] = _walk(v, key_name=k)
return new
if isinstance(obj, list):
return [_walk(v, key_name=key_name) for v in obj]
if isinstance(obj, float) and key_name in INT_WHITELIST:
if obj.is_integer():
coerced = int(obj)
ctx = f"session={session_id}, player={player_name}, field={key_name}"
self.logger.debug(f"Coerced float->int: {obj} -> {coerced} ({ctx})")
return coerced
return obj
return obj
return _walk(action)
# ------------------------------------------------------------------ #
# PUBLIC DRIVER #
# ------------------------------------------------------------------ #
[docs]
def run_negotiation(
self,
*,
game_type: str,
players: List[str], # list of model names registered in LLMManager
game_config: Dict[str, Any] | None = None,
session_id: str | None = None,
seed_messages: Dict[str, str] | None = None, # optional system prompts
) -> Dict[str, Any]:
"""
Execute a full negotiation session and return the enriched game result
that also contains computed metrics and a complete action log.
Parameters
----------
game_type : the registered key inside GameEngine (e.g. "company_car").
players : ordered list of model names (length == 2 for bilateral games).
game_config : per-game configuration dictionary; if None, defaults are used.
session_id : optional external identifier; autogenerated if omitted.
seed_messages : dict of {player_name: system_prompt} to prime behaviour.
Returns
-------
result : Dict containing
- raw game_state at termination
- "actions_history": chronological list of {round, actions} dicts
- "metrics": Dict[str, Dict[str, float]] (metric → player → value)
- "session_metadata": misc run info (IDs, timestamps, etc.)
"""
# ------------------------------------------------------------------ #
# 0. House-keeping #
# ------------------------------------------------------------------ #
session_id = session_id or str(uuid.uuid4())
self.logger.info(f"[{session_id}] ➜ Starting new session for game '{game_type}'")
# Keep full chronological action log
actions_history: List[Dict[str, Any]] = []
# ------------------------------------------------------------------ #
# 1. Load models & build game #
# ------------------------------------------------------------------ #
loaded_agents: Dict[str, BaseLLMModel] = {}
try:
for model_name in players:
print(f"🔄 [DEBUG] Loading model: {model_name}")
loaded_model = self.llm_manager.load_model(model_name)
print(f"🔄 [DEBUG] Loaded model type: {type(loaded_model)}")
print(f"🔄 [DEBUG] Loaded model is None: {loaded_model is None}")
loaded_agents[model_name] = loaded_model
print(f"🔄 [DEBUG] stored in loaded_agents[{model_name}]: {type(loaded_agents[model_name])}")
print(f"🔄 [DEBUG] Final loaded_agents keys: {list(loaded_agents.keys())}")
print(f"🔄 [DEBUG] Players list: {players}")
game_instance = self.game_engine.create_game(game_type, game_config or {})
except Exception as exc: # noqa: BLE001
self.logger.exception(f"[{session_id}] Failed during initialisation: {exc}")
raise
# ------------------------------------------------------------------ #
# 2. Initialise game state #
# ------------------------------------------------------------------ #
game_state = game_instance.initialize_game(players)
# Write initial system prompts if provided
if seed_messages:
for p_name, prompt in seed_messages.items():
game_state.setdefault("system_prompts", {})[p_name] = prompt
# ------------------------------------------------------------------ #
# 3. Interaction loop #
# ------------------------------------------------------------------ #
while not game_instance.is_game_over(game_state):
current_round = game_state["current_round"]
self.logger.debug(f"[{session_id}] ─ Round {current_round} ─")
round_actions: Dict[str, Dict[str, Any]] = {}
# FIX: Randomize player order each round to eliminate first-move advantage
import random
round_players = players.copy()
random.shuffle(round_players)
self.logger.info(f"🎲 [TURN ORDER] Round {current_round}: {round_players}")
for p_name in round_players:
# Log which player is making the offer
self.logger.info(f"🔄 [PLAYER TURN] {p_name} is making an offer")
# Build prompt based on visible state. ONLY use the game's own prompt
if not hasattr(game_instance, 'get_game_prompt'):
raise RuntimeError(f"Game instance {game_type} must implement get_game_prompt method")
# Ensure the game instance has the latest game_state available
setattr(game_instance, 'game_data', game_state)
prompt = game_instance.get_game_prompt(p_name)
# DEBUG: Log the actual prompt being sent (using info level to ensure it shows)
self.logger.info(f"🎯 [PROMPT DEBUG] Prompt for {p_name}:")
self.logger.info(f"📝 FULL PROMPT: {prompt}")
# Try up to N times to get a parseable and valid action.
parsed_action = None
for attempt in range(1, self.max_turn_retries + 1):
raw_reply = loaded_agents[p_name].generate_response(prompt, game_state=game_state)
try:
# Pass game_type for Pydantic validation and player_name for logging
parsed = loaded_agents[p_name].parse_action(raw_reply, game_type=game_type, player_name=p_name)
except Exception as exc: # noqa: BLE001
self.logger.warning(
f"[{session_id}] Parse failure by {p_name} (attempt {attempt}): {exc}"
)
if attempt >= self.max_turn_retries:
parsed_action = {"type": "noop"}
break
else:
continue # retry parse
# Coerce numeric fields (e.g. 40.0 -> 40) to be tolerant
parsed = self._coerce_action_numeric_fields(parsed, session_id=session_id, player_name=p_name)
# Validate action; if invalid, allow retry up to cap
if not game_instance.is_valid_action(p_name, parsed, game_state):
self.logger.warning(
f"[{session_id}] Invalid action by {p_name} (attempt {attempt}): {parsed}"
)
if attempt >= self.max_turn_retries:
parsed_action = {"type": "noop"}
break
else:
# retry the prompt
continue
# Passed parsing and validation
parsed_action = parsed
break
round_actions[p_name] = parsed_action
# CRITICAL BUG FIX: Check for acceptance immediately
# If this player just accepted an offer, don't ask the other player for action
# Handle both direct acceptance and structured response format
action_type = None
if parsed_action:
if isinstance(parsed_action, dict):
if "decision" in parsed_action:
# Structured response format
action_type = parsed_action["decision"].get("type")
else:
# Direct action format
action_type = parsed_action.get("type")
if action_type == "accept":
self.logger.info(f"🎯 [ACCEPTANCE DETECTED] Player {p_name} accepted - terminating round early")
break
# Append to history
actions_history.append({"round": current_round, "actions": round_actions})
# Transition state
game_state = game_instance.process_actions(round_actions, game_state)
# ------------------------------------------------------------------ #
# 4. Compute metrics & enrich result #
# ------------------------------------------------------------------ #
metric_results = self.metrics_calculator.calculate_all(game_state, actions_history)
# Create LLM model mapping for analysis
llm_model_mapping = {}
for player_name in players:
agent = loaded_agents[player_name]
if hasattr(agent, 'model_name'):
llm_model_mapping[player_name] = agent.model_name
elif hasattr(agent, 'config') and 'model_name' in agent.config:
llm_model_mapping[player_name] = agent.config['model_name']
else:
llm_model_mapping[player_name] = "unknown"
enriched_result: Dict[str, Any] = {
**game_state,
"actions_history": actions_history,
"metrics": metric_results,
"llm_model_mapping": llm_model_mapping, # NEW: Track which LLM is which player
"session_metadata": {
"session_id": session_id,
"game_type": game_type,
"players": players,
"timestamp_utc": datetime.utcnow().isoformat(timespec="seconds"),
},
}
self.logger.info(
f"[{session_id}] Finished – agreement={game_state.get('agreement_reached', False)}"
)
# Log LLM model performance for easy analysis
if game_state.get('agreement_reached', False):
# Universal winner determination: only players with positive surplus over BATNA can win
final_utilities = game_state.get('final_utilities', {})
batnas = game_state.get('batnas_at_agreement', {})
if final_utilities and batnas:
# Calculate surplus for each player
surpluses = {}
for player in final_utilities.keys():
utility = final_utilities[player]
batna = batnas.get(player, 0.0)
surpluses[player] = utility - batna
# Only consider players with positive surplus
positive_surplus_players = {player: surplus for player, surplus in surpluses.items() if surplus > 0}
if positive_surplus_players:
# Winner is player with highest positive surplus
winner = max(positive_surplus_players, key=positive_surplus_players.get)
winner_llm = llm_model_mapping.get(winner, "unknown")
loser_llm = llm_model_mapping.get([p for p in players if p != winner][0], "unknown")
self.logger.info(f"🏆 [LLM WINNER] {winner_llm} beat {loser_llm} (player {winner} won)")
print(f"🏆 [LLM WINNER] {winner_llm} beat {loser_llm} (player {winner} won)")
else:
# No player has positive surplus - no winner
self.logger.info(f"🤝 [NO WINNER] Agreement reached but both players have negative surplus")
print(f"🤝 [NO WINNER] Agreement reached but both players have negative surplus")
elif final_utilities:
# Fallback if no BATNA data - use simple highest utility (for backward compatibility)
winner = max(final_utilities, key=final_utilities.get)
winner_llm = llm_model_mapping.get(winner, "unknown")
loser_llm = llm_model_mapping.get([p for p in players if p != winner][0], "unknown")
self.logger.info(f"🏆 [LLM WINNER] {winner_llm} beat {loser_llm} (player {winner} won)")
print(f"🏆 [LLM WINNER] {winner_llm} beat {loser_llm} (player {winner} won)")
# Also log the model mapping for reference
model_info = ", ".join([f"{player}={llm_model_mapping[player].split('/')[-1]}" for player in players])
print(f"🤖 [MODEL MAPPING] {model_info}")
self.logger.info(f"🤖 [MODEL MAPPING] {model_info}")
# ------------------------------------------------------------------ #
# 5. Resource clean-up #
# ------------------------------------------------------------------ #
for model_name in players:
self.llm_manager.unload_model(model_name)
return enriched_result