massive clenup
This commit is contained in:
241
utils/model_utils.py
Normal file
241
utils/model_utils.py
Normal file
@ -0,0 +1,241 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Model utilities for robust saving and loading of PyTorch models
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import shutil
|
||||
import gc
|
||||
import json
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def robust_save(model: Any, path: str, include_optimizer: bool = True) -> bool:
|
||||
"""
|
||||
Robust model saving with multiple fallback approaches
|
||||
|
||||
Args:
|
||||
model: The model object to save (should have policy_net, target_net, optimizer, epsilon attributes)
|
||||
path: Path to save the model
|
||||
include_optimizer: Whether to include optimizer state in the save
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
||||
|
||||
# Backup path in case the main save fails
|
||||
backup_path = f"{path}.backup"
|
||||
|
||||
# Clean up GPU memory before saving
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Prepare checkpoint data
|
||||
checkpoint = {
|
||||
'policy_net': model.policy_net.state_dict(),
|
||||
'target_net': model.target_net.state_dict(),
|
||||
'epsilon': getattr(model, 'epsilon', 0.0),
|
||||
'state_size': getattr(model, 'state_size', None),
|
||||
'action_size': getattr(model, 'action_size', None),
|
||||
'hidden_size': getattr(model, 'hidden_size', None),
|
||||
}
|
||||
|
||||
# Add optimizer state if requested and available
|
||||
if include_optimizer and hasattr(model, 'optimizer') and model.optimizer is not None:
|
||||
checkpoint['optimizer'] = model.optimizer.state_dict()
|
||||
|
||||
# Attempt 1: Try with default settings in a separate file first
|
||||
try:
|
||||
logger.info(f"Saving model to {backup_path} (attempt 1)")
|
||||
torch.save(checkpoint, backup_path)
|
||||
logger.info(f"Successfully saved to {backup_path}")
|
||||
|
||||
# If backup worked, copy to the actual path
|
||||
if os.path.exists(backup_path):
|
||||
shutil.copy(backup_path, path)
|
||||
logger.info(f"Copied backup to {path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"First save attempt failed: {e}")
|
||||
|
||||
# Attempt 2: Try with pickle protocol 2 (more compatible)
|
||||
try:
|
||||
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
|
||||
torch.save(checkpoint, path, pickle_protocol=2)
|
||||
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Second save attempt failed: {e}")
|
||||
|
||||
# Attempt 3: Try without optimizer state (which can be large and cause issues)
|
||||
try:
|
||||
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
|
||||
checkpoint_no_opt = {k: v for k, v in checkpoint.items() if k != 'optimizer'}
|
||||
torch.save(checkpoint_no_opt, path)
|
||||
logger.info(f"Successfully saved to {path} without optimizer state")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Third save attempt failed: {e}")
|
||||
|
||||
# Attempt 4: Try with torch.jit.save instead
|
||||
try:
|
||||
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
|
||||
# Save policy network using jit
|
||||
scripted_policy = torch.jit.script(model.policy_net)
|
||||
torch.jit.save(scripted_policy, f"{path}.policy.jit")
|
||||
|
||||
# Save target network using jit
|
||||
scripted_target = torch.jit.script(model.target_net)
|
||||
torch.jit.save(scripted_target, f"{path}.target.jit")
|
||||
|
||||
# Save parameters separately as JSON
|
||||
params = {
|
||||
'epsilon': float(getattr(model, 'epsilon', 0.0)),
|
||||
'state_size': int(getattr(model, 'state_size', 0)),
|
||||
'action_size': int(getattr(model, 'action_size', 0)),
|
||||
'hidden_size': int(getattr(model, 'hidden_size', 0))
|
||||
}
|
||||
with open(f"{path}.params.json", "w") as f:
|
||||
json.dump(params, f)
|
||||
|
||||
logger.info(f"Successfully saved model components with jit.save")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"All save attempts failed: {e}")
|
||||
return False
|
||||
|
||||
def robust_load(model: Any, path: str, device: Optional[torch.device] = None) -> bool:
|
||||
"""
|
||||
Robust model loading with fallback approaches
|
||||
|
||||
Args:
|
||||
model: The model object to load into
|
||||
path: Path to load the model from
|
||||
device: Device to load the model on
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Try regular PyTorch load first
|
||||
try:
|
||||
logger.info(f"Loading model from {path}")
|
||||
if os.path.exists(path):
|
||||
checkpoint = torch.load(path, map_location=device)
|
||||
|
||||
# Load network states
|
||||
if 'policy_net' in checkpoint:
|
||||
model.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||
if 'target_net' in checkpoint:
|
||||
model.target_net.load_state_dict(checkpoint['target_net'])
|
||||
|
||||
# Load other attributes
|
||||
if 'epsilon' in checkpoint:
|
||||
model.epsilon = checkpoint['epsilon']
|
||||
if 'optimizer' in checkpoint and hasattr(model, 'optimizer') and model.optimizer is not None:
|
||||
try:
|
||||
model.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load optimizer state: {e}")
|
||||
|
||||
logger.info("Successfully loaded model")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Regular load failed: {e}")
|
||||
|
||||
# Try loading JIT saved components
|
||||
try:
|
||||
policy_path = f"{path}.policy.jit"
|
||||
target_path = f"{path}.target.jit"
|
||||
params_path = f"{path}.params.json"
|
||||
|
||||
if all(os.path.exists(p) for p in [policy_path, target_path, params_path]):
|
||||
logger.info(f"Loading JIT model components")
|
||||
|
||||
# Load JIT models (this is more complex and may need model reconstruction)
|
||||
# For now, just log that we found JIT files
|
||||
logger.info("Found JIT model files, but loading them requires special handling")
|
||||
with open(params_path, 'r') as f:
|
||||
params = json.load(f)
|
||||
logger.info(f"Model parameters: {params}")
|
||||
|
||||
# Note: Actually loading JIT models would require recreating the model architecture
|
||||
# This is a placeholder for future implementation
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"JIT load failed: {e}")
|
||||
|
||||
logger.error(f"All load attempts failed for {path}")
|
||||
return False
|
||||
|
||||
def get_model_info(path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about a saved model
|
||||
|
||||
Args:
|
||||
path: Path to the model file
|
||||
|
||||
Returns:
|
||||
dict: Model information
|
||||
"""
|
||||
info = {
|
||||
'exists': False,
|
||||
'size_bytes': 0,
|
||||
'has_optimizer': False,
|
||||
'parameters': {}
|
||||
}
|
||||
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
info['exists'] = True
|
||||
info['size_bytes'] = os.path.getsize(path)
|
||||
|
||||
# Try to load and inspect
|
||||
checkpoint = torch.load(path, map_location='cpu')
|
||||
info['has_optimizer'] = 'optimizer' in checkpoint
|
||||
|
||||
# Extract parameter info
|
||||
for key in ['epsilon', 'state_size', 'action_size', 'hidden_size']:
|
||||
if key in checkpoint:
|
||||
info['parameters'][key] = checkpoint[key]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model info for {path}: {e}")
|
||||
|
||||
return info
|
||||
|
||||
def verify_save_load_cycle(model: Any, test_path: str) -> bool:
|
||||
"""
|
||||
Test that a model can be saved and loaded correctly
|
||||
|
||||
Args:
|
||||
model: Model to test
|
||||
test_path: Path for test file
|
||||
|
||||
Returns:
|
||||
bool: True if save/load cycle successful
|
||||
"""
|
||||
try:
|
||||
# Save the model
|
||||
if not robust_save(model, test_path):
|
||||
return False
|
||||
|
||||
# Create a new model instance (this would need model creation logic)
|
||||
# For now, just verify the file exists and has content
|
||||
if os.path.exists(test_path) and os.path.getsize(test_path) > 0:
|
||||
logger.info("Save/load cycle verification successful")
|
||||
# Clean up test file
|
||||
os.remove(test_path)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Save/load cycle verification failed: {e}")
|
||||
return False
|
224
utils/reward_calculator.py
Normal file
224
utils/reward_calculator.py
Normal file
@ -0,0 +1,224 @@
|
||||
"""
|
||||
Improved Reward Function for RL Trading Agent
|
||||
|
||||
This module provides a more sophisticated reward function for the RL trading agent
|
||||
that incorporates realistic trading fees, penalties for excessive trading, and
|
||||
rewards for successful holding of positions.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from collections import deque
|
||||
|
||||
class ImprovedRewardCalculator:
|
||||
def __init__(self,
|
||||
max_drawdown_pct=0.1, # Maximum drawdown %
|
||||
risk_reward_ratio=1.5, # Risk-reward ratio
|
||||
base_fee_rate=0.0002, # 0.02% per transaction
|
||||
max_frequency_penalty=0.005, # Maximum 0.5% penalty for frequent trading
|
||||
holding_reward_rate=0.0001, # Small reward for holding profitable positions
|
||||
risk_adjusted=True, # Use Sharpe ratio for risk adjustment
|
||||
base_reward=1.0, # Base reward scale
|
||||
profit_factor=2.0, # Profit reward multiplier
|
||||
loss_factor=1.0, # Loss penalty multiplier
|
||||
trade_frequency_penalty=0.3, # Penalty for frequent trading
|
||||
position_duration_factor=0.05 # Reward for longer positions
|
||||
):
|
||||
|
||||
self.base_fee_rate = base_fee_rate
|
||||
self.max_frequency_penalty = max_frequency_penalty
|
||||
self.holding_reward_rate = holding_reward_rate
|
||||
self.risk_adjusted = risk_adjusted
|
||||
|
||||
# New parameters
|
||||
self.base_reward = base_reward
|
||||
self.profit_factor = profit_factor
|
||||
self.loss_factor = loss_factor
|
||||
self.trade_frequency_penalty = trade_frequency_penalty
|
||||
self.position_duration_factor = position_duration_factor
|
||||
|
||||
# Keep track of recent trades
|
||||
self.recent_trades = deque(maxlen=1000)
|
||||
self.trade_pnls = deque(maxlen=100) # For risk adjustment
|
||||
|
||||
# Additional tracking metrics
|
||||
self.total_trades = 0
|
||||
self.profitable_trades = 0
|
||||
self.total_pnl = 0.0
|
||||
self.daily_pnl = {}
|
||||
self.hourly_pnl = {}
|
||||
|
||||
def record_trade(self, timestamp=None, action=None, price=None):
|
||||
"""Record a trade for frequency tracking"""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
|
||||
self.recent_trades.append({
|
||||
'timestamp': timestamp,
|
||||
'action': action,
|
||||
'price': price
|
||||
})
|
||||
|
||||
def record_pnl(self, pnl):
|
||||
"""Record a PnL result for risk adjustment and tracking metrics"""
|
||||
self.trade_pnls.append(pnl)
|
||||
|
||||
# Update overall metrics
|
||||
self.total_trades += 1
|
||||
self.total_pnl += pnl
|
||||
|
||||
if pnl > 0:
|
||||
self.profitable_trades += 1
|
||||
|
||||
# Track daily and hourly PnL
|
||||
now = datetime.now()
|
||||
day_key = now.strftime('%Y-%m-%d')
|
||||
hour_key = now.strftime('%Y-%m-%d %H:00')
|
||||
|
||||
# Update daily PnL
|
||||
if day_key not in self.daily_pnl:
|
||||
self.daily_pnl[day_key] = 0.0
|
||||
self.daily_pnl[day_key] += pnl
|
||||
|
||||
# Update hourly PnL
|
||||
if hour_key not in self.hourly_pnl:
|
||||
self.hourly_pnl[hour_key] = 0.0
|
||||
self.hourly_pnl[hour_key] += pnl
|
||||
|
||||
def _calculate_frequency_penalty(self):
|
||||
"""Calculate penalty for trading too frequently"""
|
||||
if len(self.recent_trades) < 2:
|
||||
return 0.0
|
||||
|
||||
# Count trades in the last minute
|
||||
now = datetime.now()
|
||||
one_minute_ago = now - timedelta(minutes=1)
|
||||
trades_last_minute = sum(1 for trade in self.recent_trades
|
||||
if trade['timestamp'] > one_minute_ago)
|
||||
|
||||
# Apply progressive penalty (more severe as frequency increases)
|
||||
if trades_last_minute <= 1:
|
||||
return 0.0 # No penalty for normal trading rate
|
||||
|
||||
# Progressive penalty based on trade frequency
|
||||
penalty = min(self.max_frequency_penalty,
|
||||
self.base_fee_rate * trades_last_minute)
|
||||
|
||||
return penalty
|
||||
|
||||
def _calculate_holding_reward(self, position_held_time, price_change_pct):
|
||||
"""Calculate reward for holding a position for some time"""
|
||||
if position_held_time <= 0 or price_change_pct <= 0:
|
||||
return 0.0 # No reward for unprofitable holds
|
||||
|
||||
# Cap at 100 time units (seconds, minutes, etc.)
|
||||
capped_time = min(position_held_time, 100)
|
||||
|
||||
# Scale reward by both time and price change
|
||||
reward = self.holding_reward_rate * capped_time * price_change_pct
|
||||
|
||||
return reward
|
||||
|
||||
def _calculate_risk_adjustment(self, reward):
|
||||
"""Adjust rewards based on risk (simple Sharpe ratio implementation)"""
|
||||
if len(self.trade_pnls) < 5:
|
||||
return reward # Not enough data for adjustment
|
||||
|
||||
# Calculate mean and standard deviation of returns
|
||||
pnl_array = np.array(self.trade_pnls)
|
||||
mean_return = np.mean(pnl_array)
|
||||
std_return = np.std(pnl_array)
|
||||
|
||||
if std_return == 0:
|
||||
return reward # Avoid division by zero
|
||||
|
||||
# Simplified Sharpe ratio
|
||||
sharpe = mean_return / std_return
|
||||
|
||||
# Scale reward by Sharpe ratio (normalized to be around 1.0)
|
||||
adjustment_factor = np.clip(1.0 + 0.5 * sharpe, 0.5, 2.0)
|
||||
|
||||
return reward * adjustment_factor
|
||||
|
||||
def calculate_reward(self, action, price_change, position_held_time=0,
|
||||
volatility=None, is_profitable=False):
|
||||
"""
|
||||
Calculate the improved reward
|
||||
|
||||
Args:
|
||||
action (int): 0 = Buy, 1 = Sell, 2 = Hold
|
||||
price_change (float): Percent price change for the trade
|
||||
position_held_time (int): Time position was held (in time units)
|
||||
volatility (float, optional): Market volatility measure
|
||||
is_profitable (bool): Whether current position is profitable
|
||||
|
||||
Returns:
|
||||
float: Calculated reward value
|
||||
"""
|
||||
# Calculate trading fee
|
||||
fee = self.base_fee_rate
|
||||
|
||||
# Calculate frequency penalty
|
||||
frequency_penalty = self._calculate_frequency_penalty()
|
||||
|
||||
# Base reward calculation
|
||||
if action == 0: # Buy
|
||||
# Small penalty for transaction plus frequency penalty
|
||||
reward = -fee - frequency_penalty
|
||||
|
||||
elif action == 1: # Sell
|
||||
# Calculate profit percentage minus fees (both entry and exit)
|
||||
profit_pct = price_change
|
||||
net_profit = profit_pct - (fee * 2)
|
||||
|
||||
# Scale reward and apply frequency penalty
|
||||
reward = net_profit * 10 # Scale reward
|
||||
reward -= frequency_penalty
|
||||
|
||||
# Record PnL for risk adjustment
|
||||
self.record_pnl(net_profit)
|
||||
|
||||
else: # Hold
|
||||
# Small reward for holding a profitable position, small cost otherwise
|
||||
if is_profitable:
|
||||
reward = self._calculate_holding_reward(position_held_time, price_change)
|
||||
else:
|
||||
reward = -0.0001 # Very small negative reward
|
||||
|
||||
# Apply risk adjustment if enabled
|
||||
if self.risk_adjusted:
|
||||
reward = self._calculate_risk_adjustment(reward)
|
||||
|
||||
# Record this action for future frequency calculations
|
||||
self.record_trade(action=action)
|
||||
|
||||
return reward
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# Create calculator instance
|
||||
reward_calc = ImprovedRewardCalculator()
|
||||
|
||||
# Example reward for a buy action
|
||||
buy_reward = reward_calc.calculate_reward(action=0, price_change=0)
|
||||
print(f"Buy action reward: {buy_reward:.5f}")
|
||||
|
||||
# Record a trade for frequency tracking
|
||||
reward_calc.record_trade(action=0)
|
||||
|
||||
# Wait a bit and make another trade to test frequency penalty
|
||||
import time
|
||||
time.sleep(0.1)
|
||||
|
||||
# Example reward for a sell action with profit
|
||||
sell_reward = reward_calc.calculate_reward(action=1, price_change=0.015, position_held_time=60)
|
||||
print(f"Sell action reward (with profit): {sell_reward:.5f}")
|
||||
|
||||
# Example reward for a hold action on profitable position
|
||||
hold_reward = reward_calc.calculate_reward(action=2, price_change=0.01, position_held_time=30, is_profitable=True)
|
||||
print(f"Hold action reward (profitable): {hold_reward:.5f}")
|
||||
|
||||
# Example reward for a hold action on unprofitable position
|
||||
hold_reward_neg = reward_calc.calculate_reward(action=2, price_change=-0.01, position_held_time=30, is_profitable=False)
|
||||
print(f"Hold action reward (unprofitable): {hold_reward_neg:.5f}")
|
Reference in New Issue
Block a user