misc
This commit is contained in:
731
train_hybrid.py
Normal file
731
train_hybrid.py
Normal file
@@ -0,0 +1,731 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Hybrid Training Script - Combining Supervised and Reinforcement Learning
|
||||
|
||||
This script provides a hybrid approach that:
|
||||
1. Performs supervised learning on market data using CNN models
|
||||
2. Uses reinforcement learning to optimize trading strategies
|
||||
3. Only uses real market data (never synthetic)
|
||||
|
||||
The script enables both approaches to complement each other:
|
||||
- CNN model learns patterns from historical data (supervised)
|
||||
- RL agent optimizes actual trading decisions (reinforcement)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
import signal
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
# Add project root to path if needed
|
||||
project_root = os.path.dirname(os.path.abspath(__file__))
|
||||
if project_root not in sys.path:
|
||||
sys.path.append(project_root)
|
||||
|
||||
# Import configurations
|
||||
import train_config
|
||||
|
||||
# Import key components
|
||||
from NN.models.cnn_model_pytorch import CNNModelPyTorch
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
from realtime import MultiTimeframeDataInterface, RealTimeChart
|
||||
from NN.utils.signal_interpreter import SignalInterpreter
|
||||
|
||||
# Global variables for graceful shutdown
|
||||
running = True
|
||||
training_stats = {
|
||||
"supervised": {
|
||||
"epochs_completed": 0,
|
||||
"best_val_pnl": -float('inf'),
|
||||
"best_epoch": 0,
|
||||
"best_win_rate": 0
|
||||
},
|
||||
"reinforcement": {
|
||||
"episodes_completed": 0,
|
||||
"best_reward": -float('inf'),
|
||||
"best_episode": 0,
|
||||
"best_win_rate": 0
|
||||
},
|
||||
"hybrid": {
|
||||
"iterations_completed": 0,
|
||||
"best_combined_score": -float('inf'),
|
||||
"training_started": datetime.now().isoformat(),
|
||||
"last_update": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# Configure signal handler for graceful shutdown
|
||||
def signal_handler(sig, frame):
|
||||
global running
|
||||
logging.info("Received interrupt signal. Finishing current training cycle and saving models...")
|
||||
running = False
|
||||
|
||||
# Register signal handler
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
class HybridModel:
|
||||
"""
|
||||
Hybrid model that combines supervised CNN learning with RL-based decision optimization
|
||||
"""
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.device = torch.device(config['hardware']['device'])
|
||||
self.supervised_model = None
|
||||
self.rl_agent = None
|
||||
self.data_interface = None
|
||||
self.signal_interpreter = None
|
||||
self.chart = None
|
||||
|
||||
# Training stats
|
||||
self.tensorboard_writer = None
|
||||
self.iter_count = 0
|
||||
self.supervised_epochs = 0
|
||||
self.rl_episodes = 0
|
||||
|
||||
# Initialize logging
|
||||
self.logger = logging.getLogger('hybrid_model')
|
||||
|
||||
# Paths
|
||||
self.models_dir = Path(config['paths']['models_dir'])
|
||||
self.models_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all components of the hybrid model"""
|
||||
# Set up TensorBoard
|
||||
log_dir = Path(self.config['paths']['tensorboard_dir']) / f"hybrid_{int(time.time())}"
|
||||
self.tensorboard_writer = SummaryWriter(log_dir=str(log_dir))
|
||||
self.logger.info(f"TensorBoard initialized at {log_dir}")
|
||||
|
||||
# Initialize data interface
|
||||
symbol = self.config['market_data']['symbol']
|
||||
timeframes = self.config['market_data']['timeframes']
|
||||
window_size = self.config['market_data']['window_size']
|
||||
|
||||
self.logger.info(f"Initializing data interface for {symbol} with timeframes {timeframes}")
|
||||
self.data_interface = MultiTimeframeDataInterface(
|
||||
symbol=symbol,
|
||||
timeframes=timeframes
|
||||
)
|
||||
|
||||
# Initialize supervised model (CNN)
|
||||
self._initialize_supervised_model(window_size)
|
||||
|
||||
# Initialize RL agent
|
||||
self._initialize_rl_agent(window_size)
|
||||
|
||||
# Initialize signal interpreter
|
||||
self.signal_interpreter = SignalInterpreter(config={
|
||||
'buy_threshold': 0.65,
|
||||
'sell_threshold': 0.65,
|
||||
'hold_threshold': 0.75,
|
||||
'trend_filter_enabled': True,
|
||||
'volume_filter_enabled': True
|
||||
})
|
||||
|
||||
# Initialize chart if visualization is enabled
|
||||
if self.config.get('visualization', {}).get('enabled', False):
|
||||
self._initialize_chart()
|
||||
|
||||
return True
|
||||
|
||||
def _initialize_supervised_model(self, window_size):
|
||||
"""Initialize the supervised CNN model"""
|
||||
try:
|
||||
# Get data shape information
|
||||
X_train_dict, y_train, X_val_dict, y_val, _, _ = self.data_interface.prepare_training_data(
|
||||
window_size=window_size,
|
||||
refresh=True
|
||||
)
|
||||
|
||||
if X_train_dict is None or y_train is None:
|
||||
raise ValueError("Failed to load training data")
|
||||
|
||||
# Get reference timeframe (lowest timeframe)
|
||||
reference_tf = min(
|
||||
self.config['market_data']['timeframes'],
|
||||
key=lambda x: self.data_interface.timeframe_to_seconds.get(x, 3600)
|
||||
)
|
||||
|
||||
# Get feature count from the data
|
||||
num_features = X_train_dict[reference_tf].shape[2]
|
||||
|
||||
# Initialize model
|
||||
self.logger.info(f"Initializing CNN model with {num_features} features")
|
||||
|
||||
self.supervised_model = CNNModelPyTorch(
|
||||
window_size=window_size,
|
||||
num_features=num_features,
|
||||
output_size=3, # BUY/HOLD/SELL
|
||||
timeframes=self.config['market_data']['timeframes']
|
||||
)
|
||||
|
||||
# Load existing model if available
|
||||
model_path = self.models_dir / "supervised_model_best.pt"
|
||||
if model_path.exists():
|
||||
self.logger.info(f"Loading existing CNN model from {model_path}")
|
||||
self.supervised_model.load(str(model_path))
|
||||
self.logger.info("CNN model loaded successfully")
|
||||
else:
|
||||
self.logger.info("No existing CNN model found. Starting with a new model.")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error initializing supervised model: {str(e)}")
|
||||
import traceback
|
||||
self.logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
def _initialize_rl_agent(self, window_size):
|
||||
"""Initialize the RL agent"""
|
||||
try:
|
||||
# Get data for RL training
|
||||
X_train_dict, _, _, _, _, _ = self.data_interface.prepare_training_data(
|
||||
window_size=window_size,
|
||||
refresh=True
|
||||
)
|
||||
|
||||
if X_train_dict is None:
|
||||
raise ValueError("Failed to load training data for RL agent")
|
||||
|
||||
# Get reference timeframe features
|
||||
reference_tf = min(
|
||||
self.config['market_data']['timeframes'],
|
||||
key=lambda x: self.data_interface.timeframe_to_seconds.get(x, 3600)
|
||||
)
|
||||
|
||||
# Calculate state size - this is more complex for RL
|
||||
# For simplicity, we'll use the CNN's feature representation + position info
|
||||
state_size = window_size * X_train_dict[reference_tf].shape[2] + 3 # +3 for position, equity, unrealized_pnl
|
||||
|
||||
# Initialize RL agent
|
||||
self.logger.info(f"Initializing RL agent with state size {state_size}")
|
||||
|
||||
self.rl_agent = DQNAgent(
|
||||
state_size=state_size,
|
||||
n_actions=3, # BUY/HOLD/SELL
|
||||
epsilon=1.0,
|
||||
epsilon_decay=0.995,
|
||||
epsilon_min=0.01,
|
||||
learning_rate=self.config['training']['learning_rate'],
|
||||
gamma=0.99,
|
||||
buffer_size=10000,
|
||||
batch_size=self.config['training']['batch_size'],
|
||||
device=self.device
|
||||
)
|
||||
|
||||
# Load existing agent if available
|
||||
agent_path = self.models_dir / "rl_agent_best.pth"
|
||||
if agent_path.exists():
|
||||
self.logger.info(f"Loading existing RL agent from {agent_path}")
|
||||
self.rl_agent.load(str(agent_path))
|
||||
self.logger.info("RL agent loaded successfully")
|
||||
else:
|
||||
self.logger.info("No existing RL agent found. Starting with a new agent.")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error initializing RL agent: {str(e)}")
|
||||
import traceback
|
||||
self.logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
def _initialize_chart(self):
|
||||
"""Initialize the RealTimeChart for visualization"""
|
||||
try:
|
||||
from realtime import RealTimeChart
|
||||
|
||||
symbol = self.config['market_data']['symbol']
|
||||
self.logger.info(f"Initializing RealTimeChart for {symbol}")
|
||||
|
||||
self.chart = RealTimeChart(symbol=symbol)
|
||||
|
||||
# TODO: Start chart server in a background thread
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error initializing chart: {str(e)}")
|
||||
self.chart = None
|
||||
|
||||
async def train_hybrid(self, iterations=10, sv_epochs_per_iter=5, rl_episodes_per_iter=2):
|
||||
"""
|
||||
Main hybrid training loop
|
||||
|
||||
Args:
|
||||
iterations: Number of hybrid iterations to run
|
||||
sv_epochs_per_iter: Number of supervised epochs per iteration
|
||||
rl_episodes_per_iter: Number of RL episodes per iteration
|
||||
|
||||
Returns:
|
||||
dict: Training statistics
|
||||
"""
|
||||
self.logger.info(f"Starting hybrid training with {iterations} iterations")
|
||||
self.logger.info(f"Each iteration includes {sv_epochs_per_iter} supervised epochs and {rl_episodes_per_iter} RL episodes")
|
||||
|
||||
# Training loop
|
||||
for iteration in range(iterations):
|
||||
if not running:
|
||||
self.logger.info("Training stopped by user")
|
||||
break
|
||||
|
||||
self.logger.info(f"Iteration {iteration+1}/{iterations}")
|
||||
self.iter_count += 1
|
||||
|
||||
# 1. Supervised learning phase
|
||||
self.logger.info("Starting supervised learning phase")
|
||||
sv_stats = await self.train_supervised(epochs=sv_epochs_per_iter)
|
||||
|
||||
# 2. Reinforcement learning phase
|
||||
self.logger.info("Starting reinforcement learning phase")
|
||||
rl_stats = await self.train_reinforcement(episodes=rl_episodes_per_iter)
|
||||
|
||||
# 3. Update global training stats
|
||||
self._update_training_stats(sv_stats, rl_stats)
|
||||
|
||||
# 4. Save models and stats
|
||||
self._save_models_and_stats()
|
||||
|
||||
# 5. Log to TensorBoard
|
||||
if self.tensorboard_writer:
|
||||
self._log_to_tensorboard(iteration, sv_stats, rl_stats)
|
||||
|
||||
self.logger.info("Hybrid training completed")
|
||||
return training_stats
|
||||
|
||||
async def train_supervised(self, epochs=5):
|
||||
"""
|
||||
Run supervised training for a specified number of epochs
|
||||
|
||||
Args:
|
||||
epochs: Number of epochs to train
|
||||
|
||||
Returns:
|
||||
dict: Training statistics
|
||||
"""
|
||||
# Get fresh data
|
||||
window_size = self.config['market_data']['window_size']
|
||||
X_train_dict, y_train, X_val_dict, y_val, train_prices, val_prices = self.data_interface.prepare_training_data(
|
||||
window_size=window_size,
|
||||
refresh=True
|
||||
)
|
||||
|
||||
if X_train_dict is None or y_train is None:
|
||||
self.logger.error("Failed to load training data")
|
||||
return {}
|
||||
|
||||
# Get reference timeframe (lowest timeframe)
|
||||
reference_tf = min(
|
||||
self.config['market_data']['timeframes'],
|
||||
key=lambda x: self.data_interface.timeframe_to_seconds.get(x, 3600)
|
||||
)
|
||||
|
||||
# Calculate future prices for profitability-focused loss function
|
||||
train_future_prices = self.data_interface.get_future_prices(train_prices, n_candles=8)
|
||||
val_future_prices = self.data_interface.get_future_prices(val_prices, n_candles=8)
|
||||
|
||||
# For now, we use only the reference timeframe
|
||||
X_train = X_train_dict[reference_tf]
|
||||
X_val = X_val_dict[reference_tf]
|
||||
|
||||
# Training stats
|
||||
stats = {
|
||||
"train_losses": [],
|
||||
"val_losses": [],
|
||||
"train_accuracies": [],
|
||||
"val_accuracies": [],
|
||||
"train_pnls": [],
|
||||
"val_pnls": [],
|
||||
"best_val_pnl": -float('inf'),
|
||||
"best_epoch": -1
|
||||
}
|
||||
|
||||
batch_size = self.config['training']['batch_size']
|
||||
|
||||
# Training loop
|
||||
for epoch in range(epochs):
|
||||
if not running:
|
||||
break
|
||||
|
||||
epoch_start = time.time()
|
||||
|
||||
# Train one epoch
|
||||
train_action_loss, train_price_loss, train_acc = self.supervised_model.train_epoch(
|
||||
X_train, y_train, train_future_prices, batch_size
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
val_action_loss, val_price_loss, val_acc = self.supervised_model.evaluate(
|
||||
X_val, y_val, val_future_prices
|
||||
)
|
||||
|
||||
# Get predictions for PnL calculation
|
||||
train_action_probs, _ = self.supervised_model.predict(X_train)
|
||||
val_action_probs, _ = self.supervised_model.predict(X_val)
|
||||
|
||||
# Convert probabilities to actions
|
||||
train_preds = np.argmax(train_action_probs, axis=1)
|
||||
val_preds = np.argmax(val_action_probs, axis=1)
|
||||
|
||||
# Calculate PnL
|
||||
train_pnl, train_win_rate, _ = self.data_interface.calculate_pnl(
|
||||
train_preds, train_prices, position_size=1.0
|
||||
)
|
||||
val_pnl, val_win_rate, _ = self.data_interface.calculate_pnl(
|
||||
val_preds, val_prices, position_size=1.0
|
||||
)
|
||||
|
||||
# Update stats
|
||||
stats["train_losses"].append(train_action_loss)
|
||||
stats["val_losses"].append(val_action_loss)
|
||||
stats["train_accuracies"].append(train_acc)
|
||||
stats["val_accuracies"].append(val_acc)
|
||||
stats["train_pnls"].append(train_pnl)
|
||||
stats["val_pnls"].append(val_pnl)
|
||||
|
||||
# Check if this is the best model
|
||||
if val_pnl > stats["best_val_pnl"]:
|
||||
stats["best_val_pnl"] = val_pnl
|
||||
stats["best_epoch"] = epoch
|
||||
stats["best_win_rate"] = val_win_rate
|
||||
|
||||
# Save the best model
|
||||
self.supervised_model.save(str(self.models_dir / "supervised_model_best.pt"))
|
||||
|
||||
# Log epoch results
|
||||
self.logger.info(f"Supervised Epoch {epoch+1}/{epochs}")
|
||||
self.logger.info(f" Train Loss: {train_action_loss:.4f}, Accuracy: {train_acc:.4f}, PnL: {train_pnl:.4f}")
|
||||
self.logger.info(f" Val Loss: {val_action_loss:.4f}, Accuracy: {val_acc:.4f}, PnL: {val_pnl:.4f}")
|
||||
|
||||
# Log timing
|
||||
epoch_time = time.time() - epoch_start
|
||||
self.logger.info(f" Epoch completed in {epoch_time:.2f} seconds")
|
||||
|
||||
# Update global epoch counter
|
||||
self.supervised_epochs += 1
|
||||
|
||||
# Small delay to allow for interruption
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return stats
|
||||
|
||||
async def train_reinforcement(self, episodes=2):
|
||||
"""
|
||||
Run reinforcement learning for a specified number of episodes
|
||||
|
||||
Args:
|
||||
episodes: Number of episodes to train
|
||||
|
||||
Returns:
|
||||
dict: Training statistics
|
||||
"""
|
||||
from NN.train_rl import RLTradingEnvironment
|
||||
|
||||
# Get data for RL environment
|
||||
window_size = self.config['market_data']['window_size']
|
||||
|
||||
# Get all timeframes data
|
||||
data_dict = self.data_interface.get_multi_timeframe_data(refresh=True)
|
||||
|
||||
if not data_dict:
|
||||
self.logger.error("Failed to fetch data for any timeframe")
|
||||
return {}
|
||||
|
||||
# Extract key timeframes
|
||||
timeframes = self.config['market_data']['timeframes']
|
||||
|
||||
# Extract features from dataframes
|
||||
features = {}
|
||||
for tf in timeframes:
|
||||
if tf in data_dict:
|
||||
df = data_dict[tf]
|
||||
# Add indicators if not already added
|
||||
if 'rsi' not in df.columns:
|
||||
df = self.data_interface.add_indicators(df)
|
||||
|
||||
# Convert to numpy array with close price as the last column
|
||||
features[tf] = np.hstack([
|
||||
df.drop(['timestamp', 'close'], axis=1).values,
|
||||
df['close'].values.reshape(-1, 1)
|
||||
])
|
||||
|
||||
# Ensure we have all needed timeframes
|
||||
required_tfs = ['1m', '5m', '15m'] # Most common timeframes used by RL
|
||||
for tf in required_tfs:
|
||||
if tf not in features and tf in timeframes:
|
||||
self.logger.error(f"Missing features for timeframe {tf}")
|
||||
return {}
|
||||
|
||||
# Create environment with our feature data
|
||||
env = RLTradingEnvironment(
|
||||
features_1m=features.get('1m'),
|
||||
features_1h=features.get('1h', features.get('5m')), # Use 5m as fallback
|
||||
features_1d=features.get('1d', features.get('15m')) # Use 15m as fallback
|
||||
)
|
||||
|
||||
# Training stats
|
||||
stats = {
|
||||
"rewards": [],
|
||||
"win_rates": [],
|
||||
"trades": [],
|
||||
"best_reward": -float('inf'),
|
||||
"best_episode": -1
|
||||
}
|
||||
|
||||
# RL training loop
|
||||
for episode in range(episodes):
|
||||
if not running:
|
||||
break
|
||||
|
||||
episode_start = time.time()
|
||||
self.logger.info(f"RL Episode {episode+1}/{episodes}")
|
||||
|
||||
# Reset environment
|
||||
state = env.reset()
|
||||
total_reward = 0
|
||||
trades = 0
|
||||
wins = 0
|
||||
|
||||
# Run one episode
|
||||
done = False
|
||||
max_steps = 1000
|
||||
step = 0
|
||||
|
||||
while not done and step < max_steps:
|
||||
# Use CNN model to enhance state representation if available
|
||||
enhanced_state = self._enhance_state_with_cnn(state)
|
||||
|
||||
# Select action using the RL agent
|
||||
action = self.rl_agent.act(enhanced_state)
|
||||
|
||||
# Take step in environment
|
||||
next_state, reward, done, info = env.step(action)
|
||||
|
||||
# Store in replay buffer
|
||||
self.rl_agent.remember(enhanced_state, action, reward,
|
||||
self._enhance_state_with_cnn(next_state), done)
|
||||
|
||||
# Update episode statistics
|
||||
total_reward += reward
|
||||
state = next_state
|
||||
step += 1
|
||||
|
||||
# Track trades and wins
|
||||
if action != 2: # Not HOLD
|
||||
trades += 1
|
||||
if reward > 0:
|
||||
wins += 1
|
||||
|
||||
# Train the agent on a batch of experiences
|
||||
if len(self.rl_agent.memory) > self.config['training']['batch_size']:
|
||||
self.rl_agent.replay(self.config['training']['batch_size'])
|
||||
|
||||
# Allow for interruption
|
||||
if step % 100 == 0:
|
||||
await asyncio.sleep(0.1)
|
||||
if not running:
|
||||
break
|
||||
|
||||
# Calculate win rate
|
||||
win_rate = wins / max(1, trades)
|
||||
|
||||
# Update stats
|
||||
stats["rewards"].append(total_reward)
|
||||
stats["win_rates"].append(win_rate)
|
||||
stats["trades"].append(trades)
|
||||
|
||||
# Check if this is the best agent
|
||||
if total_reward > stats["best_reward"]:
|
||||
stats["best_reward"] = total_reward
|
||||
stats["best_episode"] = episode
|
||||
|
||||
# Save the best agent
|
||||
self.rl_agent.save(str(self.models_dir / "rl_agent_best.pth"))
|
||||
|
||||
# Log episode results
|
||||
self.logger.info(f" Reward: {total_reward:.4f}, Win Rate: {win_rate:.4f}, Trades: {trades}")
|
||||
|
||||
# Log timing
|
||||
episode_time = time.time() - episode_start
|
||||
self.logger.info(f" Episode completed in {episode_time:.2f} seconds")
|
||||
|
||||
# Update global episode counter
|
||||
self.rl_episodes += 1
|
||||
|
||||
# Reduce exploration rate
|
||||
self.rl_agent.adjust_epsilon()
|
||||
|
||||
# Small delay to allow for interruption
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return stats
|
||||
|
||||
def _enhance_state_with_cnn(self, state):
|
||||
"""
|
||||
Enhance the RL state with CNN feature extraction
|
||||
|
||||
Args:
|
||||
state: The original state from the environment
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Enhanced state representation
|
||||
"""
|
||||
# This is a placeholder - in a real implementation, you would:
|
||||
# 1. Format the state for the CNN
|
||||
# 2. Get the CNN's feature representation
|
||||
# 3. Combine with the original state features
|
||||
return state
|
||||
|
||||
def _update_training_stats(self, sv_stats, rl_stats):
|
||||
"""Update global training statistics"""
|
||||
global training_stats
|
||||
|
||||
# Update supervised stats
|
||||
if sv_stats:
|
||||
training_stats["supervised"]["epochs_completed"] = self.supervised_epochs
|
||||
if "best_val_pnl" in sv_stats and sv_stats["best_val_pnl"] > training_stats["supervised"]["best_val_pnl"]:
|
||||
training_stats["supervised"]["best_val_pnl"] = sv_stats["best_val_pnl"]
|
||||
training_stats["supervised"]["best_epoch"] = sv_stats["best_epoch"] + training_stats["supervised"]["epochs_completed"] - len(sv_stats["train_losses"])
|
||||
training_stats["supervised"]["best_win_rate"] = sv_stats.get("best_win_rate", 0)
|
||||
|
||||
# Update reinforcement stats
|
||||
if rl_stats:
|
||||
training_stats["reinforcement"]["episodes_completed"] = self.rl_episodes
|
||||
if "best_reward" in rl_stats and rl_stats["best_reward"] > training_stats["reinforcement"]["best_reward"]:
|
||||
training_stats["reinforcement"]["best_reward"] = rl_stats["best_reward"]
|
||||
training_stats["reinforcement"]["best_episode"] = rl_stats["best_episode"] + training_stats["reinforcement"]["episodes_completed"] - len(rl_stats["rewards"])
|
||||
|
||||
# Update hybrid stats
|
||||
training_stats["hybrid"]["iterations_completed"] = self.iter_count
|
||||
training_stats["hybrid"]["last_update"] = datetime.now().isoformat()
|
||||
|
||||
# Calculate combined score (simple formula, can be adjusted)
|
||||
sv_score = training_stats["supervised"]["best_val_pnl"]
|
||||
rl_score = training_stats["reinforcement"]["best_reward"]
|
||||
combined_score = sv_score * 0.7 + rl_score * 0.3 # Weight supervised more
|
||||
|
||||
if combined_score > training_stats["hybrid"]["best_combined_score"]:
|
||||
training_stats["hybrid"]["best_combined_score"] = combined_score
|
||||
|
||||
def _save_models_and_stats(self):
|
||||
"""Save models and training statistics"""
|
||||
# Save training stats
|
||||
try:
|
||||
stats_file = self.models_dir / "hybrid_training_stats.json"
|
||||
with open(stats_file, 'w') as f:
|
||||
json.dump(training_stats, f, indent=2)
|
||||
self.logger.info(f"Training statistics saved to {stats_file}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error saving training stats: {str(e)}")
|
||||
|
||||
# Models are already saved in their respective training functions
|
||||
|
||||
def _log_to_tensorboard(self, iteration, sv_stats, rl_stats):
|
||||
"""Log training metrics to TensorBoard"""
|
||||
if not self.tensorboard_writer:
|
||||
return
|
||||
|
||||
# Log supervised metrics
|
||||
if sv_stats and "train_losses" in sv_stats:
|
||||
for i, loss in enumerate(sv_stats["train_losses"]):
|
||||
step = (iteration * len(sv_stats["train_losses"])) + i
|
||||
self.tensorboard_writer.add_scalar('supervised/train_loss', loss, step)
|
||||
self.tensorboard_writer.add_scalar('supervised/val_loss', sv_stats["val_losses"][i], step)
|
||||
self.tensorboard_writer.add_scalar('supervised/train_accuracy', sv_stats["train_accuracies"][i], step)
|
||||
self.tensorboard_writer.add_scalar('supervised/val_accuracy', sv_stats["val_accuracies"][i], step)
|
||||
self.tensorboard_writer.add_scalar('supervised/train_pnl', sv_stats["train_pnls"][i], step)
|
||||
self.tensorboard_writer.add_scalar('supervised/val_pnl', sv_stats["val_pnls"][i], step)
|
||||
|
||||
# Log reinforcement metrics
|
||||
if rl_stats and "rewards" in rl_stats:
|
||||
for i, reward in enumerate(rl_stats["rewards"]):
|
||||
step = (iteration * len(rl_stats["rewards"])) + i
|
||||
self.tensorboard_writer.add_scalar('reinforcement/reward', reward, step)
|
||||
self.tensorboard_writer.add_scalar('reinforcement/win_rate', rl_stats["win_rates"][i], step)
|
||||
self.tensorboard_writer.add_scalar('reinforcement/trades', rl_stats["trades"][i], step)
|
||||
|
||||
# Log hybrid metrics
|
||||
self.tensorboard_writer.add_scalar('hybrid/iterations', self.iter_count, iteration)
|
||||
self.tensorboard_writer.add_scalar('hybrid/combined_score', training_stats["hybrid"]["best_combined_score"], iteration)
|
||||
|
||||
# Flush to ensure data is written
|
||||
self.tensorboard_writer.flush()
|
||||
|
||||
async def main():
|
||||
"""Main entry point for the hybrid training script"""
|
||||
parser = argparse.ArgumentParser(description='Hybrid Training Script')
|
||||
parser.add_argument('--iterations', type=int, default=10, help='Number of hybrid iterations to run')
|
||||
parser.add_argument('--sv-epochs', type=int, default=5, help='Supervised epochs per iteration')
|
||||
parser.add_argument('--rl-episodes', type=int, default=2, help='RL episodes per iteration')
|
||||
parser.add_argument('--symbol', type=str, default='BTC/USDT', help='Trading symbol')
|
||||
parser.add_argument('--timeframes', type=str, nargs='+', default=['1m', '5m', '15m'], help='Timeframes to use')
|
||||
parser.add_argument('--window-size', type=int, default=24, help='Window size for models')
|
||||
parser.add_argument('--visualize', action='store_true', help='Enable visualization')
|
||||
parser.add_argument('--config', type=str, help='Path to custom configuration file')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load configuration
|
||||
if args.config:
|
||||
config = train_config.load_config(args.config)
|
||||
else:
|
||||
# Create custom config from command-line arguments
|
||||
custom_config = {
|
||||
'market_data': {
|
||||
'symbol': args.symbol,
|
||||
'timeframes': args.timeframes,
|
||||
'window_size': args.window_size
|
||||
},
|
||||
'visualization': {
|
||||
'enabled': args.visualize
|
||||
}
|
||||
}
|
||||
config = train_config.get_config('hybrid', custom_config)
|
||||
|
||||
# Print startup banner
|
||||
print("=" * 80)
|
||||
print("HYBRID TRAINING SESSION")
|
||||
print("Combining supervised learning (CNN) with reinforcement learning (RL)")
|
||||
print(f"Symbol: {config['market_data']['symbol']}")
|
||||
print(f"Timeframes: {config['market_data']['timeframes']}")
|
||||
print(f"Iterations: {args.iterations} (SV epochs: {args.sv_epochs}, RL episodes: {args.rl_episodes})")
|
||||
print("Press Ctrl+C to safely stop training and save the models")
|
||||
print("=" * 80)
|
||||
|
||||
# Initialize the hybrid model
|
||||
hybrid_model = HybridModel(config)
|
||||
initialized = hybrid_model.initialize()
|
||||
|
||||
if not initialized:
|
||||
print("Failed to initialize hybrid model. Exiting.")
|
||||
return 1
|
||||
|
||||
try:
|
||||
# Run training
|
||||
await hybrid_model.train_hybrid(
|
||||
iterations=args.iterations,
|
||||
sv_epochs_per_iter=args.sv_epochs,
|
||||
rl_episodes_per_iter=args.rl_episodes
|
||||
)
|
||||
|
||||
print("Training completed successfully.")
|
||||
return 0
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Training interrupted by user.")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during training: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user