data normalizations

This commit is contained in:
Dobromir Popov
2025-09-02 18:51:49 +03:00
parent 1c013f2806
commit 6dcb82c184
4 changed files with 251 additions and 322 deletions

View File

@@ -1,177 +0,0 @@
"""
Improved Reward Function for RL Trading Agent
This module provides a more sophisticated reward function for the RL trading agent
that incorporates realistic trading fees, penalties for excessive trading, and
rewards for successful holding of positions.
"""
import numpy as np
from datetime import datetime, timedelta
from collections import deque
import logging
logger = logging.getLogger(__name__)
class RewardCalculator:
def __init__(self, base_fee_rate=0.001, reward_scaling=10.0, risk_aversion=0.1):
self.base_fee_rate = base_fee_rate
self.reward_scaling = reward_scaling
self.risk_aversion = risk_aversion
self.trade_pnls = []
self.returns = []
self.trade_timestamps = []
self.frequency_threshold = 10 # Trades per minute threshold for penalty
self.max_frequency_penalty = 0.05
def record_pnl(self, pnl):
"""Record P&L for risk adjustment calculations"""
self.trade_pnls.append(pnl)
if len(self.trade_pnls) > 100:
self.trade_pnls.pop(0)
def record_trade(self, action):
"""Record trade action for frequency penalty calculations"""
from time import time
self.trade_timestamps.append(time())
if len(self.trade_timestamps) > 100:
self.trade_timestamps.pop(0)
def _calculate_frequency_penalty(self):
"""Calculate penalty for high-frequency trading"""
if len(self.trade_timestamps) < 2:
return 0.0
time_span = self.trade_timestamps[-1] - self.trade_timestamps[0]
if time_span <= 0:
return 0.0
trades_per_minute = (len(self.trade_timestamps) / time_span) * 60
if trades_per_minute > self.frequency_threshold:
penalty = min(self.max_frequency_penalty, (trades_per_minute - self.frequency_threshold) * 0.001)
return penalty
return 0.0
def _calculate_risk_adjustment(self, reward):
"""Adjust rewards based on risk (simple Sharpe ratio implementation)"""
if len(self.trade_pnls) < 5:
return reward
pnl_array = np.array(self.trade_pnls)
mean_return = np.mean(pnl_array)
std_return = np.std(pnl_array)
if std_return == 0:
return reward
sharpe = mean_return / std_return
adjustment_factor = np.clip(1.0 + 0.5 * sharpe, 0.5, 2.0)
return reward * adjustment_factor
def _calculate_holding_reward(self, position_held_time, price_change):
"""Calculate reward for holding a position"""
base_holding_reward = 0.0005 * (position_held_time / 60.0)
if price_change > 0:
return base_holding_reward * 2
elif price_change < 0:
return base_holding_reward * 0.5
return base_holding_reward
def calculate_basic_reward(self, pnl, confidence):
"""Calculate basic training reward based on P&L and confidence"""
try:
# Reward based on net PnL after fees and confidence alignment
base_reward = pnl
# Stronger penalty for confident wrong decisions
if pnl < 0 and confidence >= 0.6:
confidence_adjustment = -confidence * 3.0
elif pnl > 0 and confidence >= 0.6:
confidence_adjustment = confidence * 1.0
else:
confidence_adjustment = 0.0
final_reward = base_reward + confidence_adjustment
# Reduce tanh compression so small PnL changes are not flattened
normalized_reward = np.tanh(final_reward / 2.5)
logger.debug(f"Basic reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}")
return float(normalized_reward)
except Exception as e:
logger.error(f"Error calculating basic reward: {e}")
return 0.0
def calculate_enhanced_reward(self, action, price_change, position_held_time=0, volatility=None, is_profitable=False, confidence=0.0, predicted_change=0.0, actual_change=0.0, current_pnl=0.0, symbol='UNKNOWN'):
"""Calculate enhanced reward for trading actions"""
fee = self.base_fee_rate
frequency_penalty = self._calculate_frequency_penalty()
if action == 0: # Buy
reward = -fee - frequency_penalty
elif action == 1: # Sell
profit_pct = price_change
net_profit = profit_pct - (fee * 2)
reward = net_profit * self.reward_scaling
reward -= frequency_penalty
self.record_pnl(net_profit)
else: # Hold
if is_profitable:
reward = self._calculate_holding_reward(position_held_time, price_change)
else:
reward = -0.0001
if action in [0, 1] and predicted_change != 0:
if (action == 0 and actual_change > 0) or (action == 1 and actual_change < 0):
reward += abs(actual_change) * 5.0
else:
reward -= abs(predicted_change) * 2.0
reward += current_pnl * 0.1
if volatility is not None:
reward -= abs(volatility) * 100
if self.risk_aversion > 0 and len(self.returns) > 1:
returns_std = np.std(self.returns)
reward -= returns_std * self.risk_aversion
self.record_trade(action)
return reward
def calculate_prediction_reward(self, symbol, predicted_direction, actual_direction, confidence, predicted_change, actual_change, current_pnl=0.0, position_duration=0.0):
"""Calculate reward for prediction accuracy"""
reward = 0.0
if predicted_direction == actual_direction:
reward += 1.0 * confidence
else:
reward -= 0.5
if predicted_direction == actual_direction and abs(predicted_change) > 0.001:
reward += abs(actual_change) * 5.0
if predicted_direction != actual_direction and abs(predicted_change) > 0.001:
reward -= abs(predicted_change) * 2.0
reward += current_pnl * 0.1
# Dynamic adjustment based on recent PnL (loss cutting incentive)
if hasattr(self, 'pnl_history') and symbol in self.pnl_history and self.pnl_history[symbol]:
latest_pnl_entry = self.pnl_history[symbol][-1]
latest_pnl_value = latest_pnl_entry.get('pnl', 0.0) if isinstance(latest_pnl_entry, dict) else 0.0
if latest_pnl_value < 0 and position_duration > 60:
reward -= (abs(latest_pnl_value) * 0.2)
pnl_values = [entry.get('pnl', 0.0) for entry in self.pnl_history[symbol] if isinstance(entry, dict)]
best_pnl = max(pnl_values) if pnl_values else 0.0
if best_pnl < 0.0:
reward -= 0.1
return reward
# Example usage:
if __name__ == "__main__":
# Create calculator instance
reward_calc = RewardCalculator()
# Example reward for a buy action
buy_reward = reward_calc.calculate_enhanced_reward(action=0, price_change=0)
print(f"Buy action reward: {buy_reward:.5f}")
# Record a trade for frequency tracking
reward_calc.record_trade(0)
# Wait a bit and make another trade to test frequency penalty
import time
time.sleep(0.1)
# Example reward for a sell action with profit
sell_reward = reward_calc.calculate_enhanced_reward(action=1, price_change=0.015, position_held_time=60)
print(f"Sell action reward (with profit): {sell_reward:.5f}")
# Example reward for a hold action on profitable position
hold_reward = reward_calc.calculate_enhanced_reward(action=2, price_change=0.01, position_held_time=30, is_profitable=True)
print(f"Hold action reward (profitable): {hold_reward:.5f}")
# Example reward for a hold action on unprofitable position
hold_reward_neg = reward_calc.calculate_enhanced_reward(action=2, price_change=-0.01, position_held_time=30, is_profitable=False)
print(f"Hold action reward (unprofitable): {hold_reward_neg:.5f}")

View File

@@ -1,233 +0,0 @@
#!/usr/bin/env python3
"""
Training Integration for Checkpoint Management
"""
import logging
import torch
from datetime import datetime
from typing import Dict, Any, Optional
from pathlib import Path
from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_best_checkpoint
logger = logging.getLogger(__name__)
class TrainingIntegration:
def __init__(self, enable_wandb: bool = False):
self.checkpoint_manager = get_checkpoint_manager()
self.enable_wandb = enable_wandb
if self.enable_wandb:
self._init_wandb()
def _init_wandb(self):
# Disabled by default to avoid CLI prompts
pass
def save_cnn_checkpoint(self,
cnn_model,
model_name: str,
epoch: int,
train_accuracy: float,
val_accuracy: float,
train_loss: float,
val_loss: float,
training_time_hours: float = None) -> bool:
try:
performance_metrics = {
'accuracy': train_accuracy,
'val_accuracy': val_accuracy,
'loss': train_loss,
'val_loss': val_loss
}
training_metadata = {
'epoch': epoch,
'training_time_hours': training_time_hours,
'total_parameters': self._count_parameters(cnn_model)
}
# W&B disabled
metadata = save_checkpoint(
model=cnn_model,
model_name=model_name,
model_type='cnn',
performance_metrics=performance_metrics,
training_metadata=training_metadata
)
if metadata:
logger.info(f"CNN checkpoint saved: {metadata.checkpoint_id}")
return True
else:
logger.info(f"CNN checkpoint not saved (performance not improved)")
return False
except Exception as e:
logger.error(f"Error saving CNN checkpoint: {e}")
return False
def save_rl_checkpoint(self,
rl_agent,
model_name: str,
episode: int,
avg_reward: float,
best_reward: float,
epsilon: float,
total_pnl: float = None) -> bool:
try:
performance_metrics = {
'reward': avg_reward,
'best_reward': best_reward
}
if total_pnl is not None:
performance_metrics['pnl'] = total_pnl
training_metadata = {
'episode': episode,
'epsilon': epsilon,
'total_parameters': self._count_parameters(rl_agent)
}
# W&B disabled
metadata = save_checkpoint(
model=rl_agent,
model_name=model_name,
model_type='rl',
performance_metrics=performance_metrics,
training_metadata=training_metadata
)
if metadata:
logger.info(f"RL checkpoint saved: {metadata.checkpoint_id}")
return True
else:
logger.info(f"RL checkpoint not saved (performance not improved)")
return False
except Exception as e:
logger.error(f"Error saving RL checkpoint: {e}")
return False
def load_best_model(self, model_name: str, model_class=None):
try:
result = load_best_checkpoint(model_name)
if not result:
logger.warning(f"No checkpoint found for model: {model_name}")
return None
file_path, metadata = result
checkpoint = torch.load(file_path, map_location='cpu')
logger.info(f"Loaded best checkpoint for {model_name}:")
logger.info(f" Performance score: {metadata.performance_score:.4f}")
logger.info(f" Created: {metadata.created_at}")
if model_class and 'model_state_dict' in checkpoint:
model = model_class()
model.load_state_dict(checkpoint['model_state_dict'])
return model
return checkpoint
except Exception as e:
logger.error(f"Error loading best model {model_name}: {e}")
return None
def _count_parameters(self, model) -> int:
try:
if hasattr(model, 'parameters'):
return sum(p.numel() for p in model.parameters())
elif hasattr(model, 'policy_net'):
policy_params = sum(p.numel() for p in model.policy_net.parameters())
target_params = sum(p.numel() for p in model.target_net.parameters()) if hasattr(model, 'target_net') else 0
return policy_params + target_params
else:
return 0
except Exception:
return 0
_training_integration = None
def get_training_integration() -> TrainingIntegration:
global _training_integration
if _training_integration is None:
_training_integration = TrainingIntegration()
return _training_integration
# ---------------- Unified Training Manager ----------------
class UnifiedTrainingManager:
"""Single entry point to manage all training in the system.
Coordinates EnhancedRealtimeTrainingSystem and provides start/stop/status.
"""
def __init__(self, orchestrator, data_provider, dashboard=None):
self.orchestrator = orchestrator
self.data_provider = data_provider
self.dashboard = dashboard
self.training_system = None
self.started = False
def initialize(self) -> bool:
try:
# Import via project root shim to avoid path issues
from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
self.training_system = EnhancedRealtimeTrainingSystem(
orchestrator=self.orchestrator,
data_provider=self.data_provider,
dashboard=self.dashboard
)
return True
except Exception as e:
logger.error(f"UnifiedTrainingManager: failed to initialize training system: {e}")
self.training_system = None
return False
def start(self) -> bool:
try:
if self.training_system is None:
if not self.initialize():
return False
self.training_system.start_training()
self.started = True
logger.info("UnifiedTrainingManager: training started")
return True
except Exception as e:
logger.error(f"UnifiedTrainingManager: error starting training: {e}")
return False
def stop(self) -> bool:
try:
if self.training_system and self.started:
self.training_system.stop_training()
self.started = False
logger.info("UnifiedTrainingManager: training stopped")
return True
except Exception as e:
logger.error(f"UnifiedTrainingManager: error stopping training: {e}")
return False
def get_stats(self) -> Dict[str, Any]:
try:
if self.training_system and hasattr(self.training_system, 'get_training_stats'):
return self.training_system.get_training_stats()
return {}
except Exception:
return {}
_unified_training_manager = None
def get_unified_training_manager(orchestrator=None, data_provider=None, dashboard=None) -> UnifiedTrainingManager:
global _unified_training_manager
if _unified_training_manager is None:
if orchestrator is None or data_provider is None:
raise ValueError("orchestrator and data_provider are required for first-time initialization")
_unified_training_manager = UnifiedTrainingManager(orchestrator, data_provider, dashboard)
return _unified_training_manager