RL training
This commit is contained in:
704
train_with_realtime_ticks.py
Normal file
704
train_with_realtime_ticks.py
Normal file
@@ -0,0 +1,704 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Real-time training with tick data and multiple timeframes for context
|
||||
This script uses streaming tick data for fast adaptation while maintaining higher timeframe context
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import time
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
import signal
|
||||
import threading
|
||||
import asyncio
|
||||
import websockets
|
||||
from collections import deque
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
|
||||
# Add the project root to path
|
||||
sys.path.append(os.path.abspath('.'))
|
||||
|
||||
# Configure logging with timestamp in filename
|
||||
log_dir = "logs"
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"realtime_ticks_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(log_file),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger('realtime_ticks_training')
|
||||
|
||||
# Import the model and data interfaces
|
||||
from NN.models.cnn_model_pytorch import CNNModelPyTorch
|
||||
from NN.utils.data_interface import DataInterface
|
||||
from NN.utils.signal_interpreter import SignalInterpreter
|
||||
|
||||
# Global variables for graceful shutdown
|
||||
running = True
|
||||
training_stats = {
|
||||
"epochs_completed": 0,
|
||||
"best_val_pnl": -float('inf'),
|
||||
"best_epoch": 0,
|
||||
"best_win_rate": 0,
|
||||
"training_started": datetime.now().isoformat(),
|
||||
"last_update": datetime.now().isoformat(),
|
||||
"epochs": [],
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"total_wins": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
}
|
||||
}
|
||||
|
||||
class TickDataProcessor:
|
||||
"""Process and store real-time tick data"""
|
||||
def __init__(self, symbol: str, max_ticks: int = 10000):
|
||||
self.symbol = symbol
|
||||
self.ticks = deque(maxlen=max_ticks)
|
||||
self.candle_cache = {
|
||||
'1s': deque(maxlen=5000),
|
||||
'1m': deque(maxlen=5000),
|
||||
'5m': deque(maxlen=5000),
|
||||
'15m': deque(maxlen=5000)
|
||||
}
|
||||
self.last_tick = None
|
||||
self.ws_url = f"wss://stream.binance.com:9443/ws/{symbol.replace('/', '').lower()}@trade"
|
||||
self.ws = None
|
||||
self.running = False
|
||||
self.data_queue = asyncio.Queue()
|
||||
|
||||
async def start_websocket(self):
|
||||
"""Start WebSocket connection and receive tick data"""
|
||||
while self.running:
|
||||
try:
|
||||
async with websockets.connect(self.ws_url) as ws:
|
||||
self.ws = ws
|
||||
logger.info("WebSocket connected")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
message = await ws.recv()
|
||||
data = json.loads(message)
|
||||
|
||||
if 'e' in data and data['e'] == 'trade':
|
||||
tick = {
|
||||
'timestamp': data['T'],
|
||||
'price': float(data['p']),
|
||||
'volume': float(data['q']),
|
||||
'symbol': self.symbol
|
||||
}
|
||||
self.ticks.append(tick)
|
||||
await self.data_queue.put(tick)
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.warning("WebSocket connection closed")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving WebSocket message: {str(e)}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket connection error: {str(e)}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
def process_tick(self, tick: Dict):
|
||||
"""Process a single tick into candles for all timeframes"""
|
||||
timestamp = tick['timestamp']
|
||||
price = tick['price']
|
||||
volume = tick['volume']
|
||||
|
||||
for timeframe in self.candle_cache.keys():
|
||||
interval = self._get_interval_seconds(timeframe)
|
||||
if interval is None:
|
||||
continue
|
||||
|
||||
# Round timestamp to nearest candle interval
|
||||
candle_ts = int(timestamp // (interval * 1000)) * (interval * 1000)
|
||||
|
||||
# Get or create candle for this timeframe
|
||||
if not self.candle_cache[timeframe]:
|
||||
# First candle for this timeframe
|
||||
candle = {
|
||||
'timestamp': candle_ts,
|
||||
'open': price,
|
||||
'high': price,
|
||||
'low': price,
|
||||
'close': price,
|
||||
'volume': volume
|
||||
}
|
||||
self.candle_cache[timeframe].append(candle)
|
||||
else:
|
||||
# Update existing candle
|
||||
last_candle = self.candle_cache[timeframe][-1]
|
||||
|
||||
if last_candle['timestamp'] == candle_ts:
|
||||
# Update current candle
|
||||
last_candle['high'] = max(last_candle['high'], price)
|
||||
last_candle['low'] = min(last_candle['low'], price)
|
||||
last_candle['close'] = price
|
||||
last_candle['volume'] += volume
|
||||
else:
|
||||
# New candle
|
||||
candle = {
|
||||
'timestamp': candle_ts,
|
||||
'open': price,
|
||||
'high': price,
|
||||
'low': price,
|
||||
'close': price,
|
||||
'volume': volume
|
||||
}
|
||||
self.candle_cache[timeframe].append(candle)
|
||||
|
||||
def _get_interval_seconds(self, timeframe: str) -> Optional[int]:
|
||||
"""Convert timeframe string to seconds"""
|
||||
try:
|
||||
value = int(timeframe[:-1])
|
||||
unit = timeframe[-1]
|
||||
if unit == 's':
|
||||
return value
|
||||
elif unit == 'm':
|
||||
return value * 60
|
||||
elif unit == 'h':
|
||||
return value * 3600
|
||||
elif unit == 'd':
|
||||
return value * 86400
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_candles(self, timeframe: str) -> pd.DataFrame:
|
||||
"""Get candles for a specific timeframe"""
|
||||
if timeframe not in self.candle_cache:
|
||||
return pd.DataFrame()
|
||||
|
||||
candles = list(self.candle_cache[timeframe])
|
||||
if not candles:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(candles)
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||||
return df
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
"""Handle CTRL+C to gracefully exit training"""
|
||||
global running
|
||||
logger.info("Received interrupt signal. Finishing current epoch and saving model...")
|
||||
running = False
|
||||
|
||||
# Register signal handler
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
def save_training_stats(stats, filepath="NN/models/saved/realtime_ticks_training_stats.json"):
|
||||
"""Save training statistics to file"""
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(stats, f, indent=2)
|
||||
|
||||
logger.info(f"Training statistics saved to {filepath}")
|
||||
|
||||
def calculate_pnl_with_fees(predictions, prices, position_size=1.0, fee_rate=0.0002, initial_balance=100.0):
|
||||
"""
|
||||
Calculate PnL including trading fees and track USD balance
|
||||
fee_rate: 0.02% per trade (both entry and exit)
|
||||
initial_balance: Starting balance in USD (default: 100.0)
|
||||
"""
|
||||
trades = []
|
||||
pnl = 0
|
||||
win_count = 0
|
||||
total_trades = 0
|
||||
current_balance = initial_balance
|
||||
balance_history = [initial_balance]
|
||||
|
||||
for i in range(len(predictions)):
|
||||
if predictions[i] == 2: # BUY
|
||||
entry_price = prices[i]
|
||||
# Look ahead for exit
|
||||
for j in range(i + 1, min(i + 8, len(prices))):
|
||||
if predictions[j] == 0: # SELL
|
||||
exit_price = prices[j]
|
||||
# Calculate position size in USD
|
||||
position_usd = current_balance * position_size
|
||||
|
||||
# Calculate raw PnL in USD
|
||||
raw_pnl_usd = position_usd * ((exit_price - entry_price) / entry_price)
|
||||
|
||||
# Calculate fees in USD (both entry and exit)
|
||||
entry_fee_usd = position_usd * fee_rate
|
||||
exit_fee_usd = position_usd * fee_rate
|
||||
total_fees_usd = entry_fee_usd + exit_fee_usd
|
||||
|
||||
# Calculate net PnL in USD after fees
|
||||
net_pnl_usd = raw_pnl_usd - total_fees_usd
|
||||
|
||||
# Update balance
|
||||
current_balance += net_pnl_usd
|
||||
balance_history.append(current_balance)
|
||||
|
||||
trades.append({
|
||||
'entry_idx': i,
|
||||
'exit_idx': j,
|
||||
'entry_price': entry_price,
|
||||
'exit_price': exit_price,
|
||||
'position_size_usd': position_usd,
|
||||
'raw_pnl_usd': raw_pnl_usd,
|
||||
'fees_usd': total_fees_usd,
|
||||
'net_pnl_usd': net_pnl_usd,
|
||||
'balance': current_balance
|
||||
})
|
||||
|
||||
pnl += net_pnl_usd / initial_balance # Convert to percentage
|
||||
if net_pnl_usd > 0:
|
||||
win_count += 1
|
||||
total_trades += 1
|
||||
break
|
||||
|
||||
win_rate = win_count / total_trades if total_trades > 0 else 0
|
||||
final_balance = current_balance
|
||||
total_return = (final_balance - initial_balance) / initial_balance * 100 # Percentage return
|
||||
|
||||
return pnl, win_rate, trades, balance_history, total_return
|
||||
|
||||
def calculate_max_drawdown(balance_history):
|
||||
"""Calculate maximum drawdown from balance history"""
|
||||
if not balance_history:
|
||||
return 0.0
|
||||
|
||||
peak = balance_history[0]
|
||||
max_drawdown = 0.0
|
||||
|
||||
for balance in balance_history:
|
||||
if balance > peak:
|
||||
peak = balance
|
||||
drawdown = (peak - balance) / peak * 100
|
||||
max_drawdown = max(max_drawdown, drawdown)
|
||||
|
||||
return max_drawdown
|
||||
|
||||
async def run_realtime_training():
|
||||
"""
|
||||
Run continuous training with real-time tick data and multiple timeframes
|
||||
"""
|
||||
global running, training_stats
|
||||
|
||||
# Configuration parameters
|
||||
symbol = "BTC/USDT"
|
||||
timeframes = ["1s", "1m", "5m", "15m"] # Include 1s for tick-based training
|
||||
window_size = 24 # Larger window size for capturing more patterns
|
||||
output_size = 3 # BUY/HOLD/SELL
|
||||
batch_size = 64 # Batch size for training
|
||||
|
||||
# Real-time configuration
|
||||
data_refresh_interval = 60 # Refresh data every minute
|
||||
checkpoint_interval = 3600 # Save checkpoint every hour
|
||||
max_training_time = float('inf') # Run indefinitely
|
||||
|
||||
# Initialize TensorBoard writer
|
||||
tensorboard_dir = "runs/realtime_ticks_training"
|
||||
os.makedirs(tensorboard_dir, exist_ok=True)
|
||||
writer = SummaryWriter(tensorboard_dir)
|
||||
|
||||
# Initialize training start time
|
||||
start_time = time.time()
|
||||
last_checkpoint_time = start_time
|
||||
last_data_refresh_time = start_time
|
||||
|
||||
logger.info(f"Starting continuous real-time training with tick data for {symbol}")
|
||||
logger.info(f"Configuration: timeframes={timeframes}, window_size={window_size}, batch_size={batch_size}")
|
||||
logger.info(f"Data will refresh every {data_refresh_interval} seconds")
|
||||
logger.info(f"Checkpoints will be saved every {checkpoint_interval} seconds")
|
||||
logger.info(f"TensorBoard logs will be saved to {tensorboard_dir}")
|
||||
|
||||
try:
|
||||
# Initialize tick data processor
|
||||
tick_processor = TickDataProcessor(symbol)
|
||||
tick_processor.running = True
|
||||
|
||||
# Start WebSocket connection in background
|
||||
websocket_task = asyncio.create_task(tick_processor.start_websocket())
|
||||
|
||||
# Initialize data interface
|
||||
logger.info("Initializing data interface...")
|
||||
data_interface = DataInterface(
|
||||
symbol=symbol,
|
||||
timeframes=timeframes
|
||||
)
|
||||
|
||||
# Initialize model
|
||||
num_features = data_interface.get_feature_count()
|
||||
logger.info(f"Initializing model with {num_features} features")
|
||||
|
||||
model = CNNModelPyTorch(
|
||||
window_size=window_size,
|
||||
num_features=num_features,
|
||||
output_size=output_size,
|
||||
timeframes=timeframes
|
||||
)
|
||||
|
||||
# Try to load existing model
|
||||
model_path = "NN/models/saved/optimized_short_term_model_best.pt"
|
||||
try:
|
||||
if os.path.exists(model_path):
|
||||
logger.info(f"Loading existing model from {model_path}")
|
||||
model.load(model_path)
|
||||
logger.info("Model loaded successfully")
|
||||
else:
|
||||
logger.info("No existing model found. Starting with a new model.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {str(e)}")
|
||||
logger.info("Starting with a new model.")
|
||||
|
||||
# Initialize signal interpreter
|
||||
signal_interpreter = SignalInterpreter(config={
|
||||
'buy_threshold': 0.55, # Lower threshold to catch more opportunities
|
||||
'sell_threshold': 0.55, # Lower threshold to catch more opportunities
|
||||
'hold_threshold': 0.65, # Lower threshold to reduce missed trades
|
||||
'trend_filter_enabled': True,
|
||||
'volume_filter_enabled': True,
|
||||
'min_confidence': 0.45 # Minimum confidence to consider a trade
|
||||
})
|
||||
|
||||
# Create checkpoint directory
|
||||
checkpoint_dir = "NN/models/saved/realtime_ticks_checkpoints"
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Track metrics
|
||||
epoch = 0
|
||||
best_val_pnl = -float('inf')
|
||||
best_win_rate = 0
|
||||
best_epoch = 0
|
||||
consecutive_failures = 0
|
||||
max_consecutive_failures = 5
|
||||
|
||||
# Training loop
|
||||
while running:
|
||||
try:
|
||||
epoch += 1
|
||||
epoch_start = time.time()
|
||||
|
||||
logger.info(f"Epoch {epoch} - Starting at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
# Process any new ticks
|
||||
while not tick_processor.data_queue.empty():
|
||||
tick = await tick_processor.data_queue.get()
|
||||
tick_processor.process_tick(tick)
|
||||
|
||||
# Check if we need to refresh data
|
||||
if time.time() - last_data_refresh_time > data_refresh_interval:
|
||||
logger.info("Refreshing training data...")
|
||||
last_data_refresh_time = time.time()
|
||||
|
||||
# Get candles for all timeframes
|
||||
candles_data = {}
|
||||
for timeframe in timeframes:
|
||||
df = tick_processor.get_candles(timeframe)
|
||||
if not df.empty:
|
||||
candles_data[timeframe] = df
|
||||
|
||||
# Prepare training data with multiple timeframes
|
||||
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data(
|
||||
refresh=True,
|
||||
refresh_interval=data_refresh_interval
|
||||
)
|
||||
|
||||
if X_train is None or y_train is None:
|
||||
logger.warning("Failed to prepare training data. Using previous data.")
|
||||
consecutive_failures += 1
|
||||
if consecutive_failures >= max_consecutive_failures:
|
||||
logger.error("Too many consecutive failures. Stopping training.")
|
||||
break
|
||||
await asyncio.sleep(5) # Wait before retrying
|
||||
continue
|
||||
|
||||
consecutive_failures = 0 # Reset failure counter on success
|
||||
logger.info(f"Training data prepared - X shape: {X_train.shape}, y shape: {y_train.shape}")
|
||||
|
||||
# Calculate future prices for profitability-focused loss function
|
||||
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=8)
|
||||
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=8)
|
||||
|
||||
# Train one epoch
|
||||
train_action_loss, train_price_loss, train_acc = model.train_epoch(
|
||||
X_train, y_train, train_future_prices, batch_size
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
val_action_loss, val_price_loss, val_acc = model.evaluate(
|
||||
X_val, y_val, val_future_prices
|
||||
)
|
||||
|
||||
logger.info(f"Epoch {epoch} results:")
|
||||
logger.info(f" Train - Loss: {train_action_loss:.4f}, Accuracy: {train_acc:.4f}")
|
||||
logger.info(f" Valid - Loss: {val_action_loss:.4f}, Accuracy: {val_acc:.4f}")
|
||||
|
||||
# Get predictions for PnL calculation
|
||||
train_action_probs, train_price_preds = model.predict(X_train)
|
||||
val_action_probs, val_price_preds = model.predict(X_val)
|
||||
|
||||
# Convert probabilities to actions
|
||||
train_preds = np.argmax(train_action_probs, axis=1)
|
||||
val_preds = np.argmax(val_action_probs, axis=1)
|
||||
|
||||
# Track signal distribution
|
||||
train_buy_count = np.sum(train_preds == 2)
|
||||
train_sell_count = np.sum(train_preds == 0)
|
||||
train_hold_count = np.sum(train_preds == 1)
|
||||
|
||||
val_buy_count = np.sum(val_preds == 2)
|
||||
val_sell_count = np.sum(val_preds == 0)
|
||||
val_hold_count = np.sum(val_preds == 1)
|
||||
|
||||
signal_dist = {
|
||||
"train": {
|
||||
"BUY": float(train_buy_count / len(train_preds)) if len(train_preds) > 0 else 0,
|
||||
"SELL": float(train_sell_count / len(train_preds)) if len(train_preds) > 0 else 0,
|
||||
"HOLD": float(train_hold_count / len(train_preds)) if len(train_preds) > 0 else 0
|
||||
},
|
||||
"val": {
|
||||
"BUY": float(val_buy_count / len(val_preds)) if len(val_preds) > 0 else 0,
|
||||
"SELL": float(val_sell_count / len(val_preds)) if len(val_preds) > 0 else 0,
|
||||
"HOLD": float(val_hold_count / len(val_preds)) if len(val_preds) > 0 else 0
|
||||
}
|
||||
}
|
||||
|
||||
# Calculate PnL and win rates with different position sizes
|
||||
position_sizes = [0.1, 0.25, 0.5, 1.0, 2.0]
|
||||
best_position_train_pnl = -float('inf')
|
||||
best_position_val_pnl = -float('inf')
|
||||
best_position_train_wr = 0
|
||||
best_position_val_wr = 0
|
||||
best_position_size = 1.0
|
||||
|
||||
for position_size in position_sizes:
|
||||
train_pnl, train_win_rate, train_trades, train_balance_history, train_total_return = calculate_pnl_with_fees(
|
||||
train_preds, train_prices, position_size=position_size
|
||||
)
|
||||
val_pnl, val_win_rate, val_trades, val_balance_history, val_total_return = calculate_pnl_with_fees(
|
||||
val_preds, val_prices, position_size=position_size
|
||||
)
|
||||
|
||||
# Update cumulative PnL and trade statistics
|
||||
training_stats["cumulative_pnl"]["train"] += train_pnl
|
||||
training_stats["cumulative_pnl"]["val"] += val_pnl
|
||||
training_stats["total_trades"]["train"] += len(train_trades)
|
||||
training_stats["total_trades"]["val"] += len(val_trades)
|
||||
training_stats["total_wins"]["train"] += sum(1 for t in train_trades if t['net_pnl_usd'] > 0)
|
||||
training_stats["total_wins"]["val"] += sum(1 for t in val_trades if t['net_pnl_usd'] > 0)
|
||||
|
||||
# Calculate average fees per trade
|
||||
avg_train_fees = np.mean([t['fees_usd'] for t in train_trades]) if train_trades else 0
|
||||
avg_val_fees = np.mean([t['fees_usd'] for t in val_trades]) if val_trades else 0
|
||||
|
||||
# Calculate max drawdown
|
||||
train_drawdown = calculate_max_drawdown(train_balance_history)
|
||||
val_drawdown = calculate_max_drawdown(val_balance_history)
|
||||
|
||||
# Calculate overall win rate
|
||||
overall_train_wr = training_stats["total_wins"]["train"] / training_stats["total_trades"]["train"] if training_stats["total_trades"]["train"] > 0 else 0
|
||||
overall_val_wr = training_stats["total_wins"]["val"] / training_stats["total_trades"]["val"] if training_stats["total_trades"]["val"] > 0 else 0
|
||||
|
||||
logger.info(f" Position Size: {position_size}")
|
||||
logger.info(f" Train - PnL: {train_pnl:.4f}, Win Rate: {train_win_rate:.4f}, Trades: {len(train_trades)}")
|
||||
logger.info(f" Train - Total Return: {train_total_return:.2f}%, Max Drawdown: {train_drawdown:.2f}%")
|
||||
logger.info(f" Train - Avg Fees: ${avg_train_fees:.2f} per trade")
|
||||
logger.info(f" Train - Cumulative PnL: {training_stats['cumulative_pnl']['train']:.4f}, Overall WR: {overall_train_wr:.4f}")
|
||||
logger.info(f" Valid - PnL: {val_pnl:.4f}, Win Rate: {val_win_rate:.4f}, Trades: {len(val_trades)}")
|
||||
logger.info(f" Valid - Total Return: {val_total_return:.2f}%, Max Drawdown: {val_drawdown:.2f}%")
|
||||
logger.info(f" Valid - Avg Fees: ${avg_val_fees:.2f} per trade")
|
||||
logger.info(f" Valid - Cumulative PnL: {training_stats['cumulative_pnl']['val']:.4f}, Overall WR: {overall_val_wr:.4f}")
|
||||
|
||||
# Log to TensorBoard
|
||||
writer.add_scalar(f'Balance/train/position_{position_size}', train_balance_history[-1], epoch)
|
||||
writer.add_scalar(f'Balance/validation/position_{position_size}', val_balance_history[-1], epoch)
|
||||
writer.add_scalar(f'Return/train/position_{position_size}', train_total_return, epoch)
|
||||
writer.add_scalar(f'Return/validation/position_{position_size}', val_total_return, epoch)
|
||||
writer.add_scalar(f'Drawdown/train/position_{position_size}', train_drawdown, epoch)
|
||||
writer.add_scalar(f'Drawdown/validation/position_{position_size}', val_drawdown, epoch)
|
||||
writer.add_scalar(f'CumulativePnL/train/position_{position_size}', training_stats["cumulative_pnl"]["train"], epoch)
|
||||
writer.add_scalar(f'CumulativePnL/validation/position_{position_size}', training_stats["cumulative_pnl"]["val"], epoch)
|
||||
writer.add_scalar(f'OverallWinRate/train/position_{position_size}', overall_train_wr, epoch)
|
||||
writer.add_scalar(f'OverallWinRate/validation/position_{position_size}', overall_val_wr, epoch)
|
||||
|
||||
# Track best position size for this epoch
|
||||
if val_pnl > best_position_val_pnl:
|
||||
best_position_val_pnl = val_pnl
|
||||
best_position_val_wr = val_win_rate
|
||||
best_position_size = position_size
|
||||
|
||||
if train_pnl > best_position_train_pnl:
|
||||
best_position_train_pnl = train_pnl
|
||||
best_position_train_wr = train_win_rate
|
||||
|
||||
# Track best model overall (using position size 1.0 as reference)
|
||||
if val_pnl > best_val_pnl and position_size == 1.0:
|
||||
best_val_pnl = val_pnl
|
||||
best_win_rate = val_win_rate
|
||||
best_epoch = epoch
|
||||
logger.info(f" New best validation PnL: {best_val_pnl:.4f} at epoch {best_epoch}")
|
||||
|
||||
# Save the best model
|
||||
model.save(f"NN/models/saved/optimized_short_term_model_ticks_best")
|
||||
|
||||
# Store epoch metrics with cumulative statistics
|
||||
epoch_metrics = {
|
||||
"epoch": epoch,
|
||||
"train_loss": float(train_action_loss),
|
||||
"val_loss": float(val_action_loss),
|
||||
"train_acc": float(train_acc),
|
||||
"val_acc": float(val_acc),
|
||||
"train_pnl": float(best_position_train_pnl),
|
||||
"val_pnl": float(best_position_val_pnl),
|
||||
"train_win_rate": float(best_position_train_wr),
|
||||
"val_win_rate": float(best_position_val_wr),
|
||||
"best_position_size": float(best_position_size),
|
||||
"signal_distribution": signal_dist,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data_age": int(time.time() - last_data_refresh_time),
|
||||
"cumulative_pnl": {
|
||||
"train": float(training_stats["cumulative_pnl"]["train"]),
|
||||
"val": float(training_stats["cumulative_pnl"]["val"])
|
||||
},
|
||||
"total_trades": {
|
||||
"train": int(training_stats["total_trades"]["train"]),
|
||||
"val": int(training_stats["total_trades"]["val"])
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": float(overall_train_wr),
|
||||
"val": float(overall_val_wr)
|
||||
}
|
||||
}
|
||||
|
||||
# Update training stats
|
||||
training_stats["epochs_completed"] = epoch
|
||||
training_stats["best_val_pnl"] = float(best_val_pnl)
|
||||
training_stats["best_epoch"] = best_epoch
|
||||
training_stats["best_win_rate"] = float(best_win_rate)
|
||||
training_stats["last_update"] = datetime.now().isoformat()
|
||||
training_stats["epochs"].append(epoch_metrics)
|
||||
|
||||
# Check if we need to save checkpoint
|
||||
if time.time() - last_checkpoint_time > checkpoint_interval:
|
||||
logger.info(f"Saving checkpoint at epoch {epoch}")
|
||||
# Save model checkpoint
|
||||
model.save(f"{checkpoint_dir}/checkpoint_epoch_{epoch}")
|
||||
# Save training statistics
|
||||
save_training_stats(training_stats)
|
||||
last_checkpoint_time = time.time()
|
||||
|
||||
# Test trade signal generation with a random sample
|
||||
random_idx = np.random.randint(0, len(X_val))
|
||||
sample_X = X_val[random_idx:random_idx+1]
|
||||
sample_probs, sample_price_pred = model.predict(sample_X)
|
||||
|
||||
# Process with signal interpreter
|
||||
signal = signal_interpreter.interpret_signal(
|
||||
sample_probs[0],
|
||||
float(sample_price_pred[0][0]) if hasattr(sample_price_pred, "__getitem__") else float(sample_price_pred[0]),
|
||||
market_data={'price': float(val_prices[random_idx]) if random_idx < len(val_prices) else 50000.0}
|
||||
)
|
||||
|
||||
logger.info(f" Sample trade signal: {signal['action']} with confidence {signal['confidence']:.4f}")
|
||||
|
||||
# Log trading statistics
|
||||
logger.info(f" Train - Actions: BUY={train_buy_count}, SELL={train_sell_count}, HOLD={train_hold_count}")
|
||||
logger.info(f" Valid - Actions: BUY={val_buy_count}, SELL={val_sell_count}, HOLD={val_hold_count}")
|
||||
|
||||
# Log epoch timing
|
||||
epoch_time = time.time() - epoch_start
|
||||
total_elapsed = time.time() - start_time
|
||||
|
||||
logger.info(f" Epoch completed in {epoch_time:.2f} seconds")
|
||||
logger.info(f" Total training time: {total_elapsed/3600:.2f} hours")
|
||||
|
||||
# Small delay to prevent CPU overload
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during epoch {epoch}: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
consecutive_failures += 1
|
||||
if consecutive_failures >= max_consecutive_failures:
|
||||
logger.error("Too many consecutive failures. Stopping training.")
|
||||
break
|
||||
await asyncio.sleep(5) # Wait before retrying
|
||||
continue
|
||||
|
||||
# Cleanup
|
||||
tick_processor.running = False
|
||||
websocket_task.cancel()
|
||||
|
||||
# Save final model and performance metrics
|
||||
logger.info("Saving final optimized model...")
|
||||
model.save("NN/models/saved/optimized_short_term_model_ticks_final")
|
||||
|
||||
# Save performance metrics to file
|
||||
save_training_stats(training_stats)
|
||||
|
||||
# Generate performance plots
|
||||
try:
|
||||
model.plot_training_history("NN/models/saved/realtime_ticks_training_stats.json")
|
||||
logger.info("Performance plots generated successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating plots: {str(e)}")
|
||||
|
||||
# Calculate total training time
|
||||
total_time = time.time() - start_time
|
||||
hours, remainder = divmod(total_time, 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
|
||||
logger.info(f"Continuous training completed in {int(hours)}h {int(minutes)}m {int(seconds)}s")
|
||||
logger.info(f"Best model performance - Epoch: {best_epoch}, PnL: {best_val_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
|
||||
|
||||
# Close TensorBoard writer
|
||||
writer.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during real-time training: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# Try to save the model and stats in case of error
|
||||
try:
|
||||
if 'model' in locals():
|
||||
model.save("NN/models/saved/optimized_short_term_model_ticks_emergency")
|
||||
logger.info("Emergency model save completed")
|
||||
if 'training_stats' in locals():
|
||||
save_training_stats(training_stats, "NN/models/saved/realtime_ticks_training_stats_emergency.json")
|
||||
if 'writer' in locals():
|
||||
writer.close()
|
||||
except Exception as e2:
|
||||
logger.error(f"Failed to save emergency checkpoint: {str(e2)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Print startup banner
|
||||
print("=" * 80)
|
||||
print("CONTINUOUS REALTIME TICKS TRAINING SESSION")
|
||||
print("This script will continuously train the model using real-time tick data")
|
||||
print("Press Ctrl+C to safely stop training and save the model")
|
||||
print("TensorBoard logs will be saved to runs/realtime_ticks_training")
|
||||
print("To view TensorBoard, run: tensorboard --logdir=runs/realtime_ticks_training")
|
||||
print("=" * 80)
|
||||
|
||||
# Run the async training loop
|
||||
asyncio.run(run_realtime_training())
|
||||
Reference in New Issue
Block a user