checkbox manager and handling
This commit is contained in:
@ -14,6 +14,10 @@ import time
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -33,7 +37,18 @@ class DQNAgent:
|
||||
batch_size: int = 32,
|
||||
target_update: int = 100,
|
||||
priority_memory: bool = True,
|
||||
device=None):
|
||||
device=None,
|
||||
model_name: str = "dqn_agent",
|
||||
enable_checkpoints: bool = True):
|
||||
|
||||
# Checkpoint management
|
||||
self.model_name = model_name
|
||||
self.enable_checkpoints = enable_checkpoints
|
||||
self.training_integration = get_training_integration() if enable_checkpoints else None
|
||||
self.episode_count = 0
|
||||
self.best_reward = float('-inf')
|
||||
self.reward_history = deque(maxlen=100)
|
||||
self.checkpoint_frequency = 100 # Save checkpoint every 100 episodes
|
||||
|
||||
# Extract state dimensions
|
||||
if isinstance(state_shape, tuple) and len(state_shape) > 1:
|
||||
@ -90,7 +105,91 @@ class DQNAgent:
|
||||
'confidence': 0.0,
|
||||
'raw': None
|
||||
}
|
||||
self.extrema_memory = [] # Special memory for storing extrema points
|
||||
self.extrema_memory = []
|
||||
|
||||
# DQN hyperparameters
|
||||
self.gamma = 0.99 # Discount factor
|
||||
|
||||
# Load best checkpoint if available
|
||||
if self.enable_checkpoints:
|
||||
self.load_best_checkpoint()
|
||||
|
||||
logger.info(f"DQN Agent initialized with checkpoint management: {enable_checkpoints}")
|
||||
if enable_checkpoints:
|
||||
logger.info(f"Model name: {model_name}, Checkpoint frequency: {self.checkpoint_frequency}")
|
||||
|
||||
def load_best_checkpoint(self):
|
||||
"""Load the best checkpoint for this DQN agent"""
|
||||
try:
|
||||
if not self.enable_checkpoints:
|
||||
return
|
||||
|
||||
result = load_best_checkpoint(self.model_name)
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
checkpoint = torch.load(file_path, map_location=self.device)
|
||||
|
||||
# Load model states
|
||||
if 'policy_net_state_dict' in checkpoint:
|
||||
self.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
|
||||
if 'target_net_state_dict' in checkpoint:
|
||||
self.target_net.load_state_dict(checkpoint['target_net_state_dict'])
|
||||
if 'optimizer_state_dict' in checkpoint:
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
# Load training state
|
||||
if 'episode_count' in checkpoint:
|
||||
self.episode_count = checkpoint['episode_count']
|
||||
if 'epsilon' in checkpoint:
|
||||
self.epsilon = checkpoint['epsilon']
|
||||
if 'best_reward' in checkpoint:
|
||||
self.best_reward = checkpoint['best_reward']
|
||||
|
||||
logger.info(f"Loaded DQN checkpoint: {metadata.checkpoint_id}")
|
||||
logger.info(f"Episode: {self.episode_count}, Best reward: {self.best_reward:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
|
||||
|
||||
def save_checkpoint(self, episode_reward: float, force_save: bool = False):
|
||||
"""Save checkpoint if performance improved or forced"""
|
||||
try:
|
||||
if not self.enable_checkpoints:
|
||||
return False
|
||||
|
||||
self.episode_count += 1
|
||||
self.reward_history.append(episode_reward)
|
||||
|
||||
# Calculate average reward over recent episodes
|
||||
avg_reward = sum(self.reward_history) / len(self.reward_history)
|
||||
|
||||
# Update best reward
|
||||
if episode_reward > self.best_reward:
|
||||
self.best_reward = episode_reward
|
||||
|
||||
# Save checkpoint every N episodes or if forced
|
||||
should_save = (
|
||||
force_save or
|
||||
self.episode_count % self.checkpoint_frequency == 0 or
|
||||
episode_reward > self.best_reward * 0.95 # Within 5% of best
|
||||
)
|
||||
|
||||
if should_save and self.training_integration:
|
||||
return self.training_integration.save_rl_checkpoint(
|
||||
rl_agent=self,
|
||||
model_name=self.model_name,
|
||||
episode=self.episode_count,
|
||||
avg_reward=avg_reward,
|
||||
best_reward=self.best_reward,
|
||||
epsilon=self.epsilon,
|
||||
total_pnl=0.0 # Default to 0, can be set by calling code
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving DQN checkpoint: {e}")
|
||||
return False
|
||||
|
||||
# Price prediction tracking
|
||||
self.last_price_pred = {
|
||||
@ -117,7 +216,6 @@ class DQNAgent:
|
||||
# Performance tracking
|
||||
self.losses = []
|
||||
self.avg_reward = 0.0
|
||||
self.best_reward = -float('inf')
|
||||
self.no_improvement_count = 0
|
||||
|
||||
# Confidence tracking
|
||||
|
Reference in New Issue
Block a user