checkbox manager and handling

This commit is contained in:
Dobromir Popov
2025-06-24 21:59:23 +03:00
parent 706eb13912
commit ab8c94d735
8 changed files with 1170 additions and 29 deletions

View File

@ -14,6 +14,10 @@ import time
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
# Import checkpoint management
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from utils.training_integration import get_training_integration
# Configure logger
logger = logging.getLogger(__name__)
@ -33,7 +37,18 @@ class DQNAgent:
batch_size: int = 32,
target_update: int = 100,
priority_memory: bool = True,
device=None):
device=None,
model_name: str = "dqn_agent",
enable_checkpoints: bool = True):
# Checkpoint management
self.model_name = model_name
self.enable_checkpoints = enable_checkpoints
self.training_integration = get_training_integration() if enable_checkpoints else None
self.episode_count = 0
self.best_reward = float('-inf')
self.reward_history = deque(maxlen=100)
self.checkpoint_frequency = 100 # Save checkpoint every 100 episodes
# Extract state dimensions
if isinstance(state_shape, tuple) and len(state_shape) > 1:
@ -90,7 +105,91 @@ class DQNAgent:
'confidence': 0.0,
'raw': None
}
self.extrema_memory = [] # Special memory for storing extrema points
self.extrema_memory = []
# DQN hyperparameters
self.gamma = 0.99 # Discount factor
# Load best checkpoint if available
if self.enable_checkpoints:
self.load_best_checkpoint()
logger.info(f"DQN Agent initialized with checkpoint management: {enable_checkpoints}")
if enable_checkpoints:
logger.info(f"Model name: {model_name}, Checkpoint frequency: {self.checkpoint_frequency}")
def load_best_checkpoint(self):
"""Load the best checkpoint for this DQN agent"""
try:
if not self.enable_checkpoints:
return
result = load_best_checkpoint(self.model_name)
if result:
file_path, metadata = result
checkpoint = torch.load(file_path, map_location=self.device)
# Load model states
if 'policy_net_state_dict' in checkpoint:
self.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
if 'target_net_state_dict' in checkpoint:
self.target_net.load_state_dict(checkpoint['target_net_state_dict'])
if 'optimizer_state_dict' in checkpoint:
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Load training state
if 'episode_count' in checkpoint:
self.episode_count = checkpoint['episode_count']
if 'epsilon' in checkpoint:
self.epsilon = checkpoint['epsilon']
if 'best_reward' in checkpoint:
self.best_reward = checkpoint['best_reward']
logger.info(f"Loaded DQN checkpoint: {metadata.checkpoint_id}")
logger.info(f"Episode: {self.episode_count}, Best reward: {self.best_reward:.4f}")
except Exception as e:
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
def save_checkpoint(self, episode_reward: float, force_save: bool = False):
"""Save checkpoint if performance improved or forced"""
try:
if not self.enable_checkpoints:
return False
self.episode_count += 1
self.reward_history.append(episode_reward)
# Calculate average reward over recent episodes
avg_reward = sum(self.reward_history) / len(self.reward_history)
# Update best reward
if episode_reward > self.best_reward:
self.best_reward = episode_reward
# Save checkpoint every N episodes or if forced
should_save = (
force_save or
self.episode_count % self.checkpoint_frequency == 0 or
episode_reward > self.best_reward * 0.95 # Within 5% of best
)
if should_save and self.training_integration:
return self.training_integration.save_rl_checkpoint(
rl_agent=self,
model_name=self.model_name,
episode=self.episode_count,
avg_reward=avg_reward,
best_reward=self.best_reward,
epsilon=self.epsilon,
total_pnl=0.0 # Default to 0, can be set by calling code
)
return False
except Exception as e:
logger.error(f"Error saving DQN checkpoint: {e}")
return False
# Price prediction tracking
self.last_price_pred = {
@ -117,7 +216,6 @@ class DQNAgent:
# Performance tracking
self.losses = []
self.avg_reward = 0.0
self.best_reward = -float('inf')
self.no_improvement_count = 0
# Confidence tracking