data normalizations
This commit is contained in:
@@ -1,177 +0,0 @@
|
||||
"""
|
||||
Improved Reward Function for RL Trading Agent
|
||||
|
||||
This module provides a more sophisticated reward function for the RL trading agent
|
||||
that incorporates realistic trading fees, penalties for excessive trading, and
|
||||
rewards for successful holding of positions.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from collections import deque
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RewardCalculator:
|
||||
def __init__(self, base_fee_rate=0.001, reward_scaling=10.0, risk_aversion=0.1):
|
||||
self.base_fee_rate = base_fee_rate
|
||||
self.reward_scaling = reward_scaling
|
||||
self.risk_aversion = risk_aversion
|
||||
self.trade_pnls = []
|
||||
self.returns = []
|
||||
self.trade_timestamps = []
|
||||
self.frequency_threshold = 10 # Trades per minute threshold for penalty
|
||||
self.max_frequency_penalty = 0.05
|
||||
|
||||
def record_pnl(self, pnl):
|
||||
"""Record P&L for risk adjustment calculations"""
|
||||
self.trade_pnls.append(pnl)
|
||||
if len(self.trade_pnls) > 100:
|
||||
self.trade_pnls.pop(0)
|
||||
|
||||
def record_trade(self, action):
|
||||
"""Record trade action for frequency penalty calculations"""
|
||||
from time import time
|
||||
self.trade_timestamps.append(time())
|
||||
if len(self.trade_timestamps) > 100:
|
||||
self.trade_timestamps.pop(0)
|
||||
|
||||
def _calculate_frequency_penalty(self):
|
||||
"""Calculate penalty for high-frequency trading"""
|
||||
if len(self.trade_timestamps) < 2:
|
||||
return 0.0
|
||||
time_span = self.trade_timestamps[-1] - self.trade_timestamps[0]
|
||||
if time_span <= 0:
|
||||
return 0.0
|
||||
trades_per_minute = (len(self.trade_timestamps) / time_span) * 60
|
||||
if trades_per_minute > self.frequency_threshold:
|
||||
penalty = min(self.max_frequency_penalty, (trades_per_minute - self.frequency_threshold) * 0.001)
|
||||
return penalty
|
||||
return 0.0
|
||||
|
||||
def _calculate_risk_adjustment(self, reward):
|
||||
"""Adjust rewards based on risk (simple Sharpe ratio implementation)"""
|
||||
if len(self.trade_pnls) < 5:
|
||||
return reward
|
||||
pnl_array = np.array(self.trade_pnls)
|
||||
mean_return = np.mean(pnl_array)
|
||||
std_return = np.std(pnl_array)
|
||||
if std_return == 0:
|
||||
return reward
|
||||
sharpe = mean_return / std_return
|
||||
adjustment_factor = np.clip(1.0 + 0.5 * sharpe, 0.5, 2.0)
|
||||
return reward * adjustment_factor
|
||||
|
||||
def _calculate_holding_reward(self, position_held_time, price_change):
|
||||
"""Calculate reward for holding a position"""
|
||||
base_holding_reward = 0.0005 * (position_held_time / 60.0)
|
||||
if price_change > 0:
|
||||
return base_holding_reward * 2
|
||||
elif price_change < 0:
|
||||
return base_holding_reward * 0.5
|
||||
return base_holding_reward
|
||||
|
||||
def calculate_basic_reward(self, pnl, confidence):
|
||||
"""Calculate basic training reward based on P&L and confidence"""
|
||||
try:
|
||||
# Reward based on net PnL after fees and confidence alignment
|
||||
base_reward = pnl
|
||||
# Stronger penalty for confident wrong decisions
|
||||
if pnl < 0 and confidence >= 0.6:
|
||||
confidence_adjustment = -confidence * 3.0
|
||||
elif pnl > 0 and confidence >= 0.6:
|
||||
confidence_adjustment = confidence * 1.0
|
||||
else:
|
||||
confidence_adjustment = 0.0
|
||||
final_reward = base_reward + confidence_adjustment
|
||||
# Reduce tanh compression so small PnL changes are not flattened
|
||||
normalized_reward = np.tanh(final_reward / 2.5)
|
||||
logger.debug(f"Basic reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}")
|
||||
return float(normalized_reward)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating basic reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def calculate_enhanced_reward(self, action, price_change, position_held_time=0, volatility=None, is_profitable=False, confidence=0.0, predicted_change=0.0, actual_change=0.0, current_pnl=0.0, symbol='UNKNOWN'):
|
||||
"""Calculate enhanced reward for trading actions"""
|
||||
fee = self.base_fee_rate
|
||||
frequency_penalty = self._calculate_frequency_penalty()
|
||||
if action == 0: # Buy
|
||||
reward = -fee - frequency_penalty
|
||||
elif action == 1: # Sell
|
||||
profit_pct = price_change
|
||||
net_profit = profit_pct - (fee * 2)
|
||||
reward = net_profit * self.reward_scaling
|
||||
reward -= frequency_penalty
|
||||
self.record_pnl(net_profit)
|
||||
else: # Hold
|
||||
if is_profitable:
|
||||
reward = self._calculate_holding_reward(position_held_time, price_change)
|
||||
else:
|
||||
reward = -0.0001
|
||||
if action in [0, 1] and predicted_change != 0:
|
||||
if (action == 0 and actual_change > 0) or (action == 1 and actual_change < 0):
|
||||
reward += abs(actual_change) * 5.0
|
||||
else:
|
||||
reward -= abs(predicted_change) * 2.0
|
||||
reward += current_pnl * 0.1
|
||||
if volatility is not None:
|
||||
reward -= abs(volatility) * 100
|
||||
if self.risk_aversion > 0 and len(self.returns) > 1:
|
||||
returns_std = np.std(self.returns)
|
||||
reward -= returns_std * self.risk_aversion
|
||||
self.record_trade(action)
|
||||
return reward
|
||||
|
||||
def calculate_prediction_reward(self, symbol, predicted_direction, actual_direction, confidence, predicted_change, actual_change, current_pnl=0.0, position_duration=0.0):
|
||||
"""Calculate reward for prediction accuracy"""
|
||||
reward = 0.0
|
||||
if predicted_direction == actual_direction:
|
||||
reward += 1.0 * confidence
|
||||
else:
|
||||
reward -= 0.5
|
||||
if predicted_direction == actual_direction and abs(predicted_change) > 0.001:
|
||||
reward += abs(actual_change) * 5.0
|
||||
if predicted_direction != actual_direction and abs(predicted_change) > 0.001:
|
||||
reward -= abs(predicted_change) * 2.0
|
||||
reward += current_pnl * 0.1
|
||||
# Dynamic adjustment based on recent PnL (loss cutting incentive)
|
||||
if hasattr(self, 'pnl_history') and symbol in self.pnl_history and self.pnl_history[symbol]:
|
||||
latest_pnl_entry = self.pnl_history[symbol][-1]
|
||||
latest_pnl_value = latest_pnl_entry.get('pnl', 0.0) if isinstance(latest_pnl_entry, dict) else 0.0
|
||||
if latest_pnl_value < 0 and position_duration > 60:
|
||||
reward -= (abs(latest_pnl_value) * 0.2)
|
||||
pnl_values = [entry.get('pnl', 0.0) for entry in self.pnl_history[symbol] if isinstance(entry, dict)]
|
||||
best_pnl = max(pnl_values) if pnl_values else 0.0
|
||||
if best_pnl < 0.0:
|
||||
reward -= 0.1
|
||||
return reward
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# Create calculator instance
|
||||
reward_calc = RewardCalculator()
|
||||
|
||||
# Example reward for a buy action
|
||||
buy_reward = reward_calc.calculate_enhanced_reward(action=0, price_change=0)
|
||||
print(f"Buy action reward: {buy_reward:.5f}")
|
||||
|
||||
# Record a trade for frequency tracking
|
||||
reward_calc.record_trade(0)
|
||||
|
||||
# Wait a bit and make another trade to test frequency penalty
|
||||
import time
|
||||
time.sleep(0.1)
|
||||
|
||||
# Example reward for a sell action with profit
|
||||
sell_reward = reward_calc.calculate_enhanced_reward(action=1, price_change=0.015, position_held_time=60)
|
||||
print(f"Sell action reward (with profit): {sell_reward:.5f}")
|
||||
|
||||
# Example reward for a hold action on profitable position
|
||||
hold_reward = reward_calc.calculate_enhanced_reward(action=2, price_change=0.01, position_held_time=30, is_profitable=True)
|
||||
print(f"Hold action reward (profitable): {hold_reward:.5f}")
|
||||
|
||||
# Example reward for a hold action on unprofitable position
|
||||
hold_reward_neg = reward_calc.calculate_enhanced_reward(action=2, price_change=-0.01, position_held_time=30, is_profitable=False)
|
||||
print(f"Hold action reward (unprofitable): {hold_reward_neg:.5f}")
|
@@ -1,233 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training Integration for Checkpoint Management
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_best_checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TrainingIntegration:
|
||||
def __init__(self, enable_wandb: bool = False):
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
self.enable_wandb = enable_wandb
|
||||
|
||||
if self.enable_wandb:
|
||||
self._init_wandb()
|
||||
|
||||
def _init_wandb(self):
|
||||
# Disabled by default to avoid CLI prompts
|
||||
pass
|
||||
|
||||
def save_cnn_checkpoint(self,
|
||||
cnn_model,
|
||||
model_name: str,
|
||||
epoch: int,
|
||||
train_accuracy: float,
|
||||
val_accuracy: float,
|
||||
train_loss: float,
|
||||
val_loss: float,
|
||||
training_time_hours: float = None) -> bool:
|
||||
try:
|
||||
performance_metrics = {
|
||||
'accuracy': train_accuracy,
|
||||
'val_accuracy': val_accuracy,
|
||||
'loss': train_loss,
|
||||
'val_loss': val_loss
|
||||
}
|
||||
|
||||
training_metadata = {
|
||||
'epoch': epoch,
|
||||
'training_time_hours': training_time_hours,
|
||||
'total_parameters': self._count_parameters(cnn_model)
|
||||
}
|
||||
|
||||
# W&B disabled
|
||||
|
||||
metadata = save_checkpoint(
|
||||
model=cnn_model,
|
||||
model_name=model_name,
|
||||
model_type='cnn',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f"CNN checkpoint saved: {metadata.checkpoint_id}")
|
||||
return True
|
||||
else:
|
||||
logger.info(f"CNN checkpoint not saved (performance not improved)")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving CNN checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def save_rl_checkpoint(self,
|
||||
rl_agent,
|
||||
model_name: str,
|
||||
episode: int,
|
||||
avg_reward: float,
|
||||
best_reward: float,
|
||||
epsilon: float,
|
||||
total_pnl: float = None) -> bool:
|
||||
try:
|
||||
performance_metrics = {
|
||||
'reward': avg_reward,
|
||||
'best_reward': best_reward
|
||||
}
|
||||
|
||||
if total_pnl is not None:
|
||||
performance_metrics['pnl'] = total_pnl
|
||||
|
||||
training_metadata = {
|
||||
'episode': episode,
|
||||
'epsilon': epsilon,
|
||||
'total_parameters': self._count_parameters(rl_agent)
|
||||
}
|
||||
|
||||
# W&B disabled
|
||||
|
||||
metadata = save_checkpoint(
|
||||
model=rl_agent,
|
||||
model_name=model_name,
|
||||
model_type='rl',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f"RL checkpoint saved: {metadata.checkpoint_id}")
|
||||
return True
|
||||
else:
|
||||
logger.info(f"RL checkpoint not saved (performance not improved)")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving RL checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def load_best_model(self, model_name: str, model_class=None):
|
||||
try:
|
||||
result = load_best_checkpoint(model_name)
|
||||
if not result:
|
||||
logger.warning(f"No checkpoint found for model: {model_name}")
|
||||
return None
|
||||
|
||||
file_path, metadata = result
|
||||
|
||||
checkpoint = torch.load(file_path, map_location='cpu')
|
||||
|
||||
logger.info(f"Loaded best checkpoint for {model_name}:")
|
||||
logger.info(f" Performance score: {metadata.performance_score:.4f}")
|
||||
logger.info(f" Created: {metadata.created_at}")
|
||||
|
||||
if model_class and 'model_state_dict' in checkpoint:
|
||||
model = model_class()
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
return model
|
||||
|
||||
return checkpoint
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best model {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def _count_parameters(self, model) -> int:
|
||||
try:
|
||||
if hasattr(model, 'parameters'):
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
elif hasattr(model, 'policy_net'):
|
||||
policy_params = sum(p.numel() for p in model.policy_net.parameters())
|
||||
target_params = sum(p.numel() for p in model.target_net.parameters()) if hasattr(model, 'target_net') else 0
|
||||
return policy_params + target_params
|
||||
else:
|
||||
return 0
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
_training_integration = None
|
||||
|
||||
def get_training_integration() -> TrainingIntegration:
|
||||
global _training_integration
|
||||
if _training_integration is None:
|
||||
_training_integration = TrainingIntegration()
|
||||
return _training_integration
|
||||
|
||||
# ---------------- Unified Training Manager ----------------
|
||||
|
||||
class UnifiedTrainingManager:
|
||||
"""Single entry point to manage all training in the system.
|
||||
|
||||
Coordinates EnhancedRealtimeTrainingSystem and provides start/stop/status.
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator, data_provider, dashboard=None):
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = data_provider
|
||||
self.dashboard = dashboard
|
||||
self.training_system = None
|
||||
self.started = False
|
||||
|
||||
def initialize(self) -> bool:
|
||||
try:
|
||||
# Import via project root shim to avoid path issues
|
||||
from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
||||
self.training_system = EnhancedRealtimeTrainingSystem(
|
||||
orchestrator=self.orchestrator,
|
||||
data_provider=self.data_provider,
|
||||
dashboard=self.dashboard
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"UnifiedTrainingManager: failed to initialize training system: {e}")
|
||||
self.training_system = None
|
||||
return False
|
||||
|
||||
def start(self) -> bool:
|
||||
try:
|
||||
if self.training_system is None:
|
||||
if not self.initialize():
|
||||
return False
|
||||
self.training_system.start_training()
|
||||
self.started = True
|
||||
logger.info("UnifiedTrainingManager: training started")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"UnifiedTrainingManager: error starting training: {e}")
|
||||
return False
|
||||
|
||||
def stop(self) -> bool:
|
||||
try:
|
||||
if self.training_system and self.started:
|
||||
self.training_system.stop_training()
|
||||
self.started = False
|
||||
logger.info("UnifiedTrainingManager: training stopped")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"UnifiedTrainingManager: error stopping training: {e}")
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
try:
|
||||
if self.training_system and hasattr(self.training_system, 'get_training_stats'):
|
||||
return self.training_system.get_training_stats()
|
||||
return {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
_unified_training_manager = None
|
||||
|
||||
def get_unified_training_manager(orchestrator=None, data_provider=None, dashboard=None) -> UnifiedTrainingManager:
|
||||
global _unified_training_manager
|
||||
if _unified_training_manager is None:
|
||||
if orchestrator is None or data_provider is None:
|
||||
raise ValueError("orchestrator and data_provider are required for first-time initialization")
|
||||
_unified_training_manager = UnifiedTrainingManager(orchestrator, data_provider, dashboard)
|
||||
return _unified_training_manager
|
Reference in New Issue
Block a user