731 lines
29 KiB
Python
731 lines
29 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Hybrid Training Script - Combining Supervised and Reinforcement Learning
|
|
|
|
This script provides a hybrid approach that:
|
|
1. Performs supervised learning on market data using CNN models
|
|
2. Uses reinforcement learning to optimize trading strategies
|
|
3. Only uses real market data (never synthetic)
|
|
|
|
The script enables both approaches to complement each other:
|
|
- CNN model learns patterns from historical data (supervised)
|
|
- RL agent optimizes actual trading decisions (reinforcement)
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
import argparse
|
|
import numpy as np
|
|
import torch
|
|
import time
|
|
import json
|
|
import asyncio
|
|
import signal
|
|
import threading
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
import matplotlib.pyplot as plt
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
# Add project root to path if needed
|
|
project_root = os.path.dirname(os.path.abspath(__file__))
|
|
if project_root not in sys.path:
|
|
sys.path.append(project_root)
|
|
|
|
# Import configurations
|
|
import train_config
|
|
|
|
# Import key components
|
|
from NN.models.cnn_model_pytorch import CNNModelPyTorch
|
|
from NN.models.dqn_agent import DQNAgent
|
|
from dataprovider_realtime import MultiTimeframeDataInterface, RealTimeChart
|
|
from NN.utils.signal_interpreter import SignalInterpreter
|
|
|
|
# Global variables for graceful shutdown
|
|
running = True
|
|
training_stats = {
|
|
"supervised": {
|
|
"epochs_completed": 0,
|
|
"best_val_pnl": -float('inf'),
|
|
"best_epoch": 0,
|
|
"best_win_rate": 0
|
|
},
|
|
"reinforcement": {
|
|
"episodes_completed": 0,
|
|
"best_reward": -float('inf'),
|
|
"best_episode": 0,
|
|
"best_win_rate": 0
|
|
},
|
|
"hybrid": {
|
|
"iterations_completed": 0,
|
|
"best_combined_score": -float('inf'),
|
|
"training_started": datetime.now().isoformat(),
|
|
"last_update": datetime.now().isoformat()
|
|
}
|
|
}
|
|
|
|
# Configure signal handler for graceful shutdown
|
|
def signal_handler(sig, frame):
|
|
global running
|
|
logging.info("Received interrupt signal. Finishing current training cycle and saving models...")
|
|
running = False
|
|
|
|
# Register signal handler
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
class HybridModel:
|
|
"""
|
|
Hybrid model that combines supervised CNN learning with RL-based decision optimization
|
|
"""
|
|
def __init__(self, config):
|
|
self.config = config
|
|
self.device = torch.device(config['hardware']['device'])
|
|
self.supervised_model = None
|
|
self.rl_agent = None
|
|
self.data_interface = None
|
|
self.signal_interpreter = None
|
|
self.chart = None
|
|
|
|
# Training stats
|
|
self.tensorboard_writer = None
|
|
self.iter_count = 0
|
|
self.supervised_epochs = 0
|
|
self.rl_episodes = 0
|
|
|
|
# Initialize logging
|
|
self.logger = logging.getLogger('hybrid_model')
|
|
|
|
# Paths
|
|
self.models_dir = Path(config['paths']['models_dir'])
|
|
self.models_dir.mkdir(exist_ok=True, parents=True)
|
|
|
|
def initialize(self):
|
|
"""Initialize all components of the hybrid model"""
|
|
# Set up TensorBoard
|
|
log_dir = Path(self.config['paths']['tensorboard_dir']) / f"hybrid_{int(time.time())}"
|
|
self.tensorboard_writer = SummaryWriter(log_dir=str(log_dir))
|
|
self.logger.info(f"TensorBoard initialized at {log_dir}")
|
|
|
|
# Initialize data interface
|
|
symbol = self.config['market_data']['symbol']
|
|
timeframes = self.config['market_data']['timeframes']
|
|
window_size = self.config['market_data']['window_size']
|
|
|
|
self.logger.info(f"Initializing data interface for {symbol} with timeframes {timeframes}")
|
|
self.data_interface = MultiTimeframeDataInterface(
|
|
symbol=symbol,
|
|
timeframes=timeframes
|
|
)
|
|
|
|
# Initialize supervised model (CNN)
|
|
self._initialize_supervised_model(window_size)
|
|
|
|
# Initialize RL agent
|
|
self._initialize_rl_agent(window_size)
|
|
|
|
# Initialize signal interpreter
|
|
self.signal_interpreter = SignalInterpreter(config={
|
|
'buy_threshold': 0.65,
|
|
'sell_threshold': 0.65,
|
|
'hold_threshold': 0.75,
|
|
'trend_filter_enabled': True,
|
|
'volume_filter_enabled': True
|
|
})
|
|
|
|
# Initialize chart if visualization is enabled
|
|
if self.config.get('visualization', {}).get('enabled', False):
|
|
self._initialize_chart()
|
|
|
|
return True
|
|
|
|
def _initialize_supervised_model(self, window_size):
|
|
"""Initialize the supervised CNN model"""
|
|
try:
|
|
# Get data shape information
|
|
X_train_dict, y_train, X_val_dict, y_val, _, _ = self.data_interface.prepare_training_data(
|
|
window_size=window_size,
|
|
refresh=True
|
|
)
|
|
|
|
if X_train_dict is None or y_train is None:
|
|
raise ValueError("Failed to load training data")
|
|
|
|
# Get reference timeframe (lowest timeframe)
|
|
reference_tf = min(
|
|
self.config['market_data']['timeframes'],
|
|
key=lambda x: self.data_interface.timeframe_to_seconds.get(x, 3600)
|
|
)
|
|
|
|
# Get feature count from the data
|
|
num_features = X_train_dict[reference_tf].shape[2]
|
|
|
|
# Initialize model
|
|
self.logger.info(f"Initializing CNN model with {num_features} features")
|
|
|
|
self.supervised_model = CNNModelPyTorch(
|
|
window_size=window_size,
|
|
num_features=num_features,
|
|
output_size=3, # BUY/HOLD/SELL
|
|
timeframes=self.config['market_data']['timeframes']
|
|
)
|
|
|
|
# Load existing model if available
|
|
model_path = self.models_dir / "supervised_model_best.pt"
|
|
if model_path.exists():
|
|
self.logger.info(f"Loading existing CNN model from {model_path}")
|
|
self.supervised_model.load(str(model_path))
|
|
self.logger.info("CNN model loaded successfully")
|
|
else:
|
|
self.logger.info("No existing CNN model found. Starting with a new model.")
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error initializing supervised model: {str(e)}")
|
|
import traceback
|
|
self.logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
def _initialize_rl_agent(self, window_size):
|
|
"""Initialize the RL agent"""
|
|
try:
|
|
# Get data for RL training
|
|
X_train_dict, _, _, _, _, _ = self.data_interface.prepare_training_data(
|
|
window_size=window_size,
|
|
refresh=True
|
|
)
|
|
|
|
if X_train_dict is None:
|
|
raise ValueError("Failed to load training data for RL agent")
|
|
|
|
# Get reference timeframe features
|
|
reference_tf = min(
|
|
self.config['market_data']['timeframes'],
|
|
key=lambda x: self.data_interface.timeframe_to_seconds.get(x, 3600)
|
|
)
|
|
|
|
# Calculate state size - this is more complex for RL
|
|
# For simplicity, we'll use the CNN's feature representation + position info
|
|
state_size = window_size * X_train_dict[reference_tf].shape[2] + 3 # +3 for position, equity, unrealized_pnl
|
|
|
|
# Initialize RL agent
|
|
self.logger.info(f"Initializing RL agent with state size {state_size}")
|
|
|
|
self.rl_agent = DQNAgent(
|
|
state_size=state_size,
|
|
n_actions=3, # BUY/HOLD/SELL
|
|
epsilon=1.0,
|
|
epsilon_decay=0.995,
|
|
epsilon_min=0.01,
|
|
learning_rate=self.config['training']['learning_rate'],
|
|
gamma=0.99,
|
|
buffer_size=10000,
|
|
batch_size=self.config['training']['batch_size'],
|
|
device=self.device
|
|
)
|
|
|
|
# Load existing agent if available
|
|
agent_path = self.models_dir / "rl_agent_best.pth"
|
|
if agent_path.exists():
|
|
self.logger.info(f"Loading existing RL agent from {agent_path}")
|
|
self.rl_agent.load(str(agent_path))
|
|
self.logger.info("RL agent loaded successfully")
|
|
else:
|
|
self.logger.info("No existing RL agent found. Starting with a new agent.")
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error initializing RL agent: {str(e)}")
|
|
import traceback
|
|
self.logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
def _initialize_chart(self):
|
|
"""Initialize the RealTimeChart for visualization"""
|
|
try:
|
|
from dataprovider_realtime import RealTimeChart
|
|
|
|
symbol = self.config['market_data']['symbol']
|
|
self.logger.info(f"Initializing RealTimeChart for {symbol}")
|
|
|
|
self.chart = RealTimeChart(symbol=symbol)
|
|
|
|
# TODO: Start chart server in a background thread
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error initializing chart: {str(e)}")
|
|
self.chart = None
|
|
|
|
async def train_hybrid(self, iterations=10, sv_epochs_per_iter=5, rl_episodes_per_iter=2):
|
|
"""
|
|
Main hybrid training loop
|
|
|
|
Args:
|
|
iterations: Number of hybrid iterations to run
|
|
sv_epochs_per_iter: Number of supervised epochs per iteration
|
|
rl_episodes_per_iter: Number of RL episodes per iteration
|
|
|
|
Returns:
|
|
dict: Training statistics
|
|
"""
|
|
self.logger.info(f"Starting hybrid training with {iterations} iterations")
|
|
self.logger.info(f"Each iteration includes {sv_epochs_per_iter} supervised epochs and {rl_episodes_per_iter} RL episodes")
|
|
|
|
# Training loop
|
|
for iteration in range(iterations):
|
|
if not running:
|
|
self.logger.info("Training stopped by user")
|
|
break
|
|
|
|
self.logger.info(f"Iteration {iteration+1}/{iterations}")
|
|
self.iter_count += 1
|
|
|
|
# 1. Supervised learning phase
|
|
self.logger.info("Starting supervised learning phase")
|
|
sv_stats = await self.train_supervised(epochs=sv_epochs_per_iter)
|
|
|
|
# 2. Reinforcement learning phase
|
|
self.logger.info("Starting reinforcement learning phase")
|
|
rl_stats = await self.train_reinforcement(episodes=rl_episodes_per_iter)
|
|
|
|
# 3. Update global training stats
|
|
self._update_training_stats(sv_stats, rl_stats)
|
|
|
|
# 4. Save models and stats
|
|
self._save_models_and_stats()
|
|
|
|
# 5. Log to TensorBoard
|
|
if self.tensorboard_writer:
|
|
self._log_to_tensorboard(iteration, sv_stats, rl_stats)
|
|
|
|
self.logger.info("Hybrid training completed")
|
|
return training_stats
|
|
|
|
async def train_supervised(self, epochs=5):
|
|
"""
|
|
Run supervised training for a specified number of epochs
|
|
|
|
Args:
|
|
epochs: Number of epochs to train
|
|
|
|
Returns:
|
|
dict: Training statistics
|
|
"""
|
|
# Get fresh data
|
|
window_size = self.config['market_data']['window_size']
|
|
X_train_dict, y_train, X_val_dict, y_val, train_prices, val_prices = self.data_interface.prepare_training_data(
|
|
window_size=window_size,
|
|
refresh=True
|
|
)
|
|
|
|
if X_train_dict is None or y_train is None:
|
|
self.logger.error("Failed to load training data")
|
|
return {}
|
|
|
|
# Get reference timeframe (lowest timeframe)
|
|
reference_tf = min(
|
|
self.config['market_data']['timeframes'],
|
|
key=lambda x: self.data_interface.timeframe_to_seconds.get(x, 3600)
|
|
)
|
|
|
|
# Calculate future prices for profitability-focused loss function
|
|
train_future_prices = self.data_interface.get_future_prices(train_prices, n_candles=8)
|
|
val_future_prices = self.data_interface.get_future_prices(val_prices, n_candles=8)
|
|
|
|
# For now, we use only the reference timeframe
|
|
X_train = X_train_dict[reference_tf]
|
|
X_val = X_val_dict[reference_tf]
|
|
|
|
# Training stats
|
|
stats = {
|
|
"train_losses": [],
|
|
"val_losses": [],
|
|
"train_accuracies": [],
|
|
"val_accuracies": [],
|
|
"train_pnls": [],
|
|
"val_pnls": [],
|
|
"best_val_pnl": -float('inf'),
|
|
"best_epoch": -1
|
|
}
|
|
|
|
batch_size = self.config['training']['batch_size']
|
|
|
|
# Training loop
|
|
for epoch in range(epochs):
|
|
if not running:
|
|
break
|
|
|
|
epoch_start = time.time()
|
|
|
|
# Train one epoch
|
|
train_action_loss, train_price_loss, train_acc = self.supervised_model.train_epoch(
|
|
X_train, y_train, train_future_prices, batch_size
|
|
)
|
|
|
|
# Evaluate
|
|
val_action_loss, val_price_loss, val_acc = self.supervised_model.evaluate(
|
|
X_val, y_val, val_future_prices
|
|
)
|
|
|
|
# Get predictions for PnL calculation
|
|
train_action_probs, _ = self.supervised_model.predict(X_train)
|
|
val_action_probs, _ = self.supervised_model.predict(X_val)
|
|
|
|
# Convert probabilities to actions
|
|
train_preds = np.argmax(train_action_probs, axis=1)
|
|
val_preds = np.argmax(val_action_probs, axis=1)
|
|
|
|
# Calculate PnL
|
|
train_pnl, train_win_rate, _ = self.data_interface.calculate_pnl(
|
|
train_preds, train_prices, position_size=1.0
|
|
)
|
|
val_pnl, val_win_rate, _ = self.data_interface.calculate_pnl(
|
|
val_preds, val_prices, position_size=1.0
|
|
)
|
|
|
|
# Update stats
|
|
stats["train_losses"].append(train_action_loss)
|
|
stats["val_losses"].append(val_action_loss)
|
|
stats["train_accuracies"].append(train_acc)
|
|
stats["val_accuracies"].append(val_acc)
|
|
stats["train_pnls"].append(train_pnl)
|
|
stats["val_pnls"].append(val_pnl)
|
|
|
|
# Check if this is the best model
|
|
if val_pnl > stats["best_val_pnl"]:
|
|
stats["best_val_pnl"] = val_pnl
|
|
stats["best_epoch"] = epoch
|
|
stats["best_win_rate"] = val_win_rate
|
|
|
|
# Save the best model
|
|
self.supervised_model.save(str(self.models_dir / "supervised_model_best.pt"))
|
|
|
|
# Log epoch results
|
|
self.logger.info(f"Supervised Epoch {epoch+1}/{epochs}")
|
|
self.logger.info(f" Train Loss: {train_action_loss:.4f}, Accuracy: {train_acc:.4f}, PnL: {train_pnl:.4f}")
|
|
self.logger.info(f" Val Loss: {val_action_loss:.4f}, Accuracy: {val_acc:.4f}, PnL: {val_pnl:.4f}")
|
|
|
|
# Log timing
|
|
epoch_time = time.time() - epoch_start
|
|
self.logger.info(f" Epoch completed in {epoch_time:.2f} seconds")
|
|
|
|
# Update global epoch counter
|
|
self.supervised_epochs += 1
|
|
|
|
# Small delay to allow for interruption
|
|
await asyncio.sleep(0.1)
|
|
|
|
return stats
|
|
|
|
async def train_reinforcement(self, episodes=2):
|
|
"""
|
|
Run reinforcement learning for a specified number of episodes
|
|
|
|
Args:
|
|
episodes: Number of episodes to train
|
|
|
|
Returns:
|
|
dict: Training statistics
|
|
"""
|
|
from NN.train_rl import RLTradingEnvironment
|
|
|
|
# Get data for RL environment
|
|
window_size = self.config['market_data']['window_size']
|
|
|
|
# Get all timeframes data
|
|
data_dict = self.data_interface.get_multi_timeframe_data(refresh=True)
|
|
|
|
if not data_dict:
|
|
self.logger.error("Failed to fetch data for any timeframe")
|
|
return {}
|
|
|
|
# Extract key timeframes
|
|
timeframes = self.config['market_data']['timeframes']
|
|
|
|
# Extract features from dataframes
|
|
features = {}
|
|
for tf in timeframes:
|
|
if tf in data_dict:
|
|
df = data_dict[tf]
|
|
# Add indicators if not already added
|
|
if 'rsi' not in df.columns:
|
|
df = self.data_interface.add_indicators(df)
|
|
|
|
# Convert to numpy array with close price as the last column
|
|
features[tf] = np.hstack([
|
|
df.drop(['timestamp', 'close'], axis=1).values,
|
|
df['close'].values.reshape(-1, 1)
|
|
])
|
|
|
|
# Ensure we have all needed timeframes
|
|
required_tfs = ['1m', '5m', '15m'] # Most common timeframes used by RL
|
|
for tf in required_tfs:
|
|
if tf not in features and tf in timeframes:
|
|
self.logger.error(f"Missing features for timeframe {tf}")
|
|
return {}
|
|
|
|
# Create environment with our feature data
|
|
env = RLTradingEnvironment(
|
|
features_1m=features.get('1m'),
|
|
features_1h=features.get('1h', features.get('5m')), # Use 5m as fallback
|
|
features_1d=features.get('1d', features.get('15m')) # Use 15m as fallback
|
|
)
|
|
|
|
# Training stats
|
|
stats = {
|
|
"rewards": [],
|
|
"win_rates": [],
|
|
"trades": [],
|
|
"best_reward": -float('inf'),
|
|
"best_episode": -1
|
|
}
|
|
|
|
# RL training loop
|
|
for episode in range(episodes):
|
|
if not running:
|
|
break
|
|
|
|
episode_start = time.time()
|
|
self.logger.info(f"RL Episode {episode+1}/{episodes}")
|
|
|
|
# Reset environment
|
|
state = env.reset()
|
|
total_reward = 0
|
|
trades = 0
|
|
wins = 0
|
|
|
|
# Run one episode
|
|
done = False
|
|
max_steps = 1000
|
|
step = 0
|
|
|
|
while not done and step < max_steps:
|
|
# Use CNN model to enhance state representation if available
|
|
enhanced_state = self._enhance_state_with_cnn(state)
|
|
|
|
# Select action using the RL agent
|
|
action = self.rl_agent.act(enhanced_state)
|
|
|
|
# Take step in environment
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
# Store in replay buffer
|
|
self.rl_agent.remember(enhanced_state, action, reward,
|
|
self._enhance_state_with_cnn(next_state), done)
|
|
|
|
# Update episode statistics
|
|
total_reward += reward
|
|
state = next_state
|
|
step += 1
|
|
|
|
# Track trades and wins
|
|
if action != 2: # Not HOLD
|
|
trades += 1
|
|
if reward > 0:
|
|
wins += 1
|
|
|
|
# Train the agent on a batch of experiences
|
|
if len(self.rl_agent.memory) > self.config['training']['batch_size']:
|
|
self.rl_agent.replay(self.config['training']['batch_size'])
|
|
|
|
# Allow for interruption
|
|
if step % 100 == 0:
|
|
await asyncio.sleep(0.1)
|
|
if not running:
|
|
break
|
|
|
|
# Calculate win rate
|
|
win_rate = wins / max(1, trades)
|
|
|
|
# Update stats
|
|
stats["rewards"].append(total_reward)
|
|
stats["win_rates"].append(win_rate)
|
|
stats["trades"].append(trades)
|
|
|
|
# Check if this is the best agent
|
|
if total_reward > stats["best_reward"]:
|
|
stats["best_reward"] = total_reward
|
|
stats["best_episode"] = episode
|
|
|
|
# Save the best agent
|
|
self.rl_agent.save(str(self.models_dir / "rl_agent_best.pth"))
|
|
|
|
# Log episode results
|
|
self.logger.info(f" Reward: {total_reward:.4f}, Win Rate: {win_rate:.4f}, Trades: {trades}")
|
|
|
|
# Log timing
|
|
episode_time = time.time() - episode_start
|
|
self.logger.info(f" Episode completed in {episode_time:.2f} seconds")
|
|
|
|
# Update global episode counter
|
|
self.rl_episodes += 1
|
|
|
|
# Reduce exploration rate
|
|
self.rl_agent.adjust_epsilon()
|
|
|
|
# Small delay to allow for interruption
|
|
await asyncio.sleep(0.1)
|
|
|
|
return stats
|
|
|
|
def _enhance_state_with_cnn(self, state):
|
|
"""
|
|
Enhance the RL state with CNN feature extraction
|
|
|
|
Args:
|
|
state: The original state from the environment
|
|
|
|
Returns:
|
|
numpy.ndarray: Enhanced state representation
|
|
"""
|
|
# This is a placeholder - in a real implementation, you would:
|
|
# 1. Format the state for the CNN
|
|
# 2. Get the CNN's feature representation
|
|
# 3. Combine with the original state features
|
|
return state
|
|
|
|
def _update_training_stats(self, sv_stats, rl_stats):
|
|
"""Update global training statistics"""
|
|
global training_stats
|
|
|
|
# Update supervised stats
|
|
if sv_stats:
|
|
training_stats["supervised"]["epochs_completed"] = self.supervised_epochs
|
|
if "best_val_pnl" in sv_stats and sv_stats["best_val_pnl"] > training_stats["supervised"]["best_val_pnl"]:
|
|
training_stats["supervised"]["best_val_pnl"] = sv_stats["best_val_pnl"]
|
|
training_stats["supervised"]["best_epoch"] = sv_stats["best_epoch"] + training_stats["supervised"]["epochs_completed"] - len(sv_stats["train_losses"])
|
|
training_stats["supervised"]["best_win_rate"] = sv_stats.get("best_win_rate", 0)
|
|
|
|
# Update reinforcement stats
|
|
if rl_stats:
|
|
training_stats["reinforcement"]["episodes_completed"] = self.rl_episodes
|
|
if "best_reward" in rl_stats and rl_stats["best_reward"] > training_stats["reinforcement"]["best_reward"]:
|
|
training_stats["reinforcement"]["best_reward"] = rl_stats["best_reward"]
|
|
training_stats["reinforcement"]["best_episode"] = rl_stats["best_episode"] + training_stats["reinforcement"]["episodes_completed"] - len(rl_stats["rewards"])
|
|
|
|
# Update hybrid stats
|
|
training_stats["hybrid"]["iterations_completed"] = self.iter_count
|
|
training_stats["hybrid"]["last_update"] = datetime.now().isoformat()
|
|
|
|
# Calculate combined score (simple formula, can be adjusted)
|
|
sv_score = training_stats["supervised"]["best_val_pnl"]
|
|
rl_score = training_stats["reinforcement"]["best_reward"]
|
|
combined_score = sv_score * 0.7 + rl_score * 0.3 # Weight supervised more
|
|
|
|
if combined_score > training_stats["hybrid"]["best_combined_score"]:
|
|
training_stats["hybrid"]["best_combined_score"] = combined_score
|
|
|
|
def _save_models_and_stats(self):
|
|
"""Save models and training statistics"""
|
|
# Save training stats
|
|
try:
|
|
stats_file = self.models_dir / "hybrid_training_stats.json"
|
|
with open(stats_file, 'w') as f:
|
|
json.dump(training_stats, f, indent=2)
|
|
self.logger.info(f"Training statistics saved to {stats_file}")
|
|
except Exception as e:
|
|
self.logger.error(f"Error saving training stats: {str(e)}")
|
|
|
|
# Models are already saved in their respective training functions
|
|
|
|
def _log_to_tensorboard(self, iteration, sv_stats, rl_stats):
|
|
"""Log training metrics to TensorBoard"""
|
|
if not self.tensorboard_writer:
|
|
return
|
|
|
|
# Log supervised metrics
|
|
if sv_stats and "train_losses" in sv_stats:
|
|
for i, loss in enumerate(sv_stats["train_losses"]):
|
|
step = (iteration * len(sv_stats["train_losses"])) + i
|
|
self.tensorboard_writer.add_scalar('supervised/train_loss', loss, step)
|
|
self.tensorboard_writer.add_scalar('supervised/val_loss', sv_stats["val_losses"][i], step)
|
|
self.tensorboard_writer.add_scalar('supervised/train_accuracy', sv_stats["train_accuracies"][i], step)
|
|
self.tensorboard_writer.add_scalar('supervised/val_accuracy', sv_stats["val_accuracies"][i], step)
|
|
self.tensorboard_writer.add_scalar('supervised/train_pnl', sv_stats["train_pnls"][i], step)
|
|
self.tensorboard_writer.add_scalar('supervised/val_pnl', sv_stats["val_pnls"][i], step)
|
|
|
|
# Log reinforcement metrics
|
|
if rl_stats and "rewards" in rl_stats:
|
|
for i, reward in enumerate(rl_stats["rewards"]):
|
|
step = (iteration * len(rl_stats["rewards"])) + i
|
|
self.tensorboard_writer.add_scalar('reinforcement/reward', reward, step)
|
|
self.tensorboard_writer.add_scalar('reinforcement/win_rate', rl_stats["win_rates"][i], step)
|
|
self.tensorboard_writer.add_scalar('reinforcement/trades', rl_stats["trades"][i], step)
|
|
|
|
# Log hybrid metrics
|
|
self.tensorboard_writer.add_scalar('hybrid/iterations', self.iter_count, iteration)
|
|
self.tensorboard_writer.add_scalar('hybrid/combined_score', training_stats["hybrid"]["best_combined_score"], iteration)
|
|
|
|
# Flush to ensure data is written
|
|
self.tensorboard_writer.flush()
|
|
|
|
async def main():
|
|
"""Main entry point for the hybrid training script"""
|
|
parser = argparse.ArgumentParser(description='Hybrid Training Script')
|
|
parser.add_argument('--iterations', type=int, default=10, help='Number of hybrid iterations to run')
|
|
parser.add_argument('--sv-epochs', type=int, default=5, help='Supervised epochs per iteration')
|
|
parser.add_argument('--rl-episodes', type=int, default=2, help='RL episodes per iteration')
|
|
parser.add_argument('--symbol', type=str, default='BTC/USDT', help='Trading symbol')
|
|
parser.add_argument('--timeframes', type=str, nargs='+', default=['1m', '5m', '15m'], help='Timeframes to use')
|
|
parser.add_argument('--window-size', type=int, default=24, help='Window size for models')
|
|
parser.add_argument('--visualize', action='store_true', help='Enable visualization')
|
|
parser.add_argument('--config', type=str, help='Path to custom configuration file')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Load configuration
|
|
if args.config:
|
|
config = train_config.load_config(args.config)
|
|
else:
|
|
# Create custom config from command-line arguments
|
|
custom_config = {
|
|
'market_data': {
|
|
'symbol': args.symbol,
|
|
'timeframes': args.timeframes,
|
|
'window_size': args.window_size
|
|
},
|
|
'visualization': {
|
|
'enabled': args.visualize
|
|
}
|
|
}
|
|
config = train_config.get_config('hybrid', custom_config)
|
|
|
|
# Print startup banner
|
|
print("=" * 80)
|
|
print("HYBRID TRAINING SESSION")
|
|
print("Combining supervised learning (CNN) with reinforcement learning (RL)")
|
|
print(f"Symbol: {config['market_data']['symbol']}")
|
|
print(f"Timeframes: {config['market_data']['timeframes']}")
|
|
print(f"Iterations: {args.iterations} (SV epochs: {args.sv_epochs}, RL episodes: {args.rl_episodes})")
|
|
print("Press Ctrl+C to safely stop training and save the models")
|
|
print("=" * 80)
|
|
|
|
# Initialize the hybrid model
|
|
hybrid_model = HybridModel(config)
|
|
initialized = hybrid_model.initialize()
|
|
|
|
if not initialized:
|
|
print("Failed to initialize hybrid model. Exiting.")
|
|
return 1
|
|
|
|
try:
|
|
# Run training
|
|
await hybrid_model.train_hybrid(
|
|
iterations=args.iterations,
|
|
sv_epochs_per_iter=args.sv_epochs,
|
|
rl_episodes_per_iter=args.rl_episodes
|
|
)
|
|
|
|
print("Training completed successfully.")
|
|
return 0
|
|
|
|
except KeyboardInterrupt:
|
|
print("Training interrupted by user.")
|
|
return 0
|
|
|
|
except Exception as e:
|
|
print(f"Error during training: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return 1
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |