gogo2/train_with_realtime_ticks.py
Dobromir Popov 73c5ecb0d2 enhancements
2025-04-01 13:46:53 +03:00

704 lines
31 KiB
Python

#!/usr/bin/env python
"""
Real-time training with tick data and multiple timeframes for context
This script uses streaming tick data for fast adaptation while maintaining higher timeframe context
"""
import os
import sys
import logging
import numpy as np
import torch
import time
import json
from datetime import datetime, timedelta
import signal
import threading
import asyncio
import websockets
from collections import deque
import pandas as pd
from typing import Dict, List, Optional
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import io
# Add the project root to path
sys.path.append(os.path.abspath('.'))
# Configure logging with timestamp in filename
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"realtime_ticks_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
logger = logging.getLogger('realtime_ticks_training')
# Import the model and data interfaces
from NN.models.cnn_model_pytorch import CNNModelPyTorch
from NN.utils.data_interface import DataInterface
from NN.utils.signal_interpreter import SignalInterpreter
# Global variables for graceful shutdown
running = True
training_stats = {
"epochs_completed": 0,
"best_val_pnl": -float('inf'),
"best_epoch": 0,
"best_win_rate": 0,
"training_started": datetime.now().isoformat(),
"last_update": datetime.now().isoformat(),
"epochs": [],
"cumulative_pnl": {
"train": 0.0,
"val": 0.0
},
"total_trades": {
"train": 0,
"val": 0
},
"total_wins": {
"train": 0,
"val": 0
}
}
class TickDataProcessor:
"""Process and store real-time tick data"""
def __init__(self, symbol: str, max_ticks: int = 10000):
self.symbol = symbol
self.ticks = deque(maxlen=max_ticks)
self.candle_cache = {
'1s': deque(maxlen=5000),
'1m': deque(maxlen=5000),
'5m': deque(maxlen=5000),
'15m': deque(maxlen=5000)
}
self.last_tick = None
self.ws_url = f"wss://stream.binance.com:9443/ws/{symbol.replace('/', '').lower()}@trade"
self.ws = None
self.running = False
self.data_queue = asyncio.Queue()
async def start_websocket(self):
"""Start WebSocket connection and receive tick data"""
while self.running:
try:
async with websockets.connect(self.ws_url) as ws:
self.ws = ws
logger.info("WebSocket connected")
while self.running:
try:
message = await ws.recv()
data = json.loads(message)
if 'e' in data and data['e'] == 'trade':
tick = {
'timestamp': data['T'],
'price': float(data['p']),
'volume': float(data['q']),
'symbol': self.symbol
}
self.ticks.append(tick)
await self.data_queue.put(tick)
except websockets.exceptions.ConnectionClosed:
logger.warning("WebSocket connection closed")
break
except Exception as e:
logger.error(f"Error receiving WebSocket message: {str(e)}")
await asyncio.sleep(1)
except Exception as e:
logger.error(f"WebSocket connection error: {str(e)}")
await asyncio.sleep(5)
def process_tick(self, tick: Dict):
"""Process a single tick into candles for all timeframes"""
timestamp = tick['timestamp']
price = tick['price']
volume = tick['volume']
for timeframe in self.candle_cache.keys():
interval = self._get_interval_seconds(timeframe)
if interval is None:
continue
# Round timestamp to nearest candle interval
candle_ts = int(timestamp // (interval * 1000)) * (interval * 1000)
# Get or create candle for this timeframe
if not self.candle_cache[timeframe]:
# First candle for this timeframe
candle = {
'timestamp': candle_ts,
'open': price,
'high': price,
'low': price,
'close': price,
'volume': volume
}
self.candle_cache[timeframe].append(candle)
else:
# Update existing candle
last_candle = self.candle_cache[timeframe][-1]
if last_candle['timestamp'] == candle_ts:
# Update current candle
last_candle['high'] = max(last_candle['high'], price)
last_candle['low'] = min(last_candle['low'], price)
last_candle['close'] = price
last_candle['volume'] += volume
else:
# New candle
candle = {
'timestamp': candle_ts,
'open': price,
'high': price,
'low': price,
'close': price,
'volume': volume
}
self.candle_cache[timeframe].append(candle)
def _get_interval_seconds(self, timeframe: str) -> Optional[int]:
"""Convert timeframe string to seconds"""
try:
value = int(timeframe[:-1])
unit = timeframe[-1]
if unit == 's':
return value
elif unit == 'm':
return value * 60
elif unit == 'h':
return value * 3600
elif unit == 'd':
return value * 86400
return None
except:
return None
def get_candles(self, timeframe: str) -> pd.DataFrame:
"""Get candles for a specific timeframe"""
if timeframe not in self.candle_cache:
return pd.DataFrame()
candles = list(self.candle_cache[timeframe])
if not candles:
return pd.DataFrame()
df = pd.DataFrame(candles)
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
return df
def signal_handler(sig, frame):
"""Handle CTRL+C to gracefully exit training"""
global running
logger.info("Received interrupt signal. Finishing current epoch and saving model...")
running = False
# Register signal handler
signal.signal(signal.SIGINT, signal_handler)
def save_training_stats(stats, filepath="NN/models/saved/realtime_ticks_training_stats.json"):
"""Save training statistics to file"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, 'w') as f:
json.dump(stats, f, indent=2)
logger.info(f"Training statistics saved to {filepath}")
def calculate_pnl_with_fees(predictions, prices, position_size=1.0, fee_rate=0.0002, initial_balance=100.0):
"""
Calculate PnL including trading fees and track USD balance
fee_rate: 0.02% per trade (both entry and exit)
initial_balance: Starting balance in USD (default: 100.0)
"""
trades = []
pnl = 0
win_count = 0
total_trades = 0
current_balance = initial_balance
balance_history = [initial_balance]
for i in range(len(predictions)):
if predictions[i] == 2: # BUY
entry_price = prices[i]
# Look ahead for exit
for j in range(i + 1, min(i + 8, len(prices))):
if predictions[j] == 0: # SELL
exit_price = prices[j]
# Calculate position size in USD
position_usd = current_balance * position_size
# Calculate raw PnL in USD
raw_pnl_usd = position_usd * ((exit_price - entry_price) / entry_price)
# Calculate fees in USD (both entry and exit)
entry_fee_usd = position_usd * fee_rate
exit_fee_usd = position_usd * fee_rate
total_fees_usd = entry_fee_usd + exit_fee_usd
# Calculate net PnL in USD after fees
net_pnl_usd = raw_pnl_usd - total_fees_usd
# Update balance
current_balance += net_pnl_usd
balance_history.append(current_balance)
trades.append({
'entry_idx': i,
'exit_idx': j,
'entry_price': entry_price,
'exit_price': exit_price,
'position_size_usd': position_usd,
'raw_pnl_usd': raw_pnl_usd,
'fees_usd': total_fees_usd,
'net_pnl_usd': net_pnl_usd,
'balance': current_balance
})
pnl += net_pnl_usd / initial_balance # Convert to percentage
if net_pnl_usd > 0:
win_count += 1
total_trades += 1
break
win_rate = win_count / total_trades if total_trades > 0 else 0
final_balance = current_balance
total_return = (final_balance - initial_balance) / initial_balance * 100 # Percentage return
return pnl, win_rate, trades, balance_history, total_return
def calculate_max_drawdown(balance_history):
"""Calculate maximum drawdown from balance history"""
if not balance_history:
return 0.0
peak = balance_history[0]
max_drawdown = 0.0
for balance in balance_history:
if balance > peak:
peak = balance
drawdown = (peak - balance) / peak * 100
max_drawdown = max(max_drawdown, drawdown)
return max_drawdown
async def run_realtime_training():
"""
Run continuous training with real-time tick data and multiple timeframes
"""
global running, training_stats
# Configuration parameters
symbol = "BTC/USDT"
timeframes = ["1s", "1m", "5m", "15m"] # Include 1s for tick-based training
window_size = 24 # Larger window size for capturing more patterns
output_size = 3 # BUY/HOLD/SELL
batch_size = 64 # Batch size for training
# Real-time configuration
data_refresh_interval = 60 # Refresh data every minute
checkpoint_interval = 3600 # Save checkpoint every hour
max_training_time = float('inf') # Run indefinitely
# Initialize TensorBoard writer
tensorboard_dir = "runs/realtime_ticks_training"
os.makedirs(tensorboard_dir, exist_ok=True)
writer = SummaryWriter(tensorboard_dir)
# Initialize training start time
start_time = time.time()
last_checkpoint_time = start_time
last_data_refresh_time = start_time
logger.info(f"Starting continuous real-time training with tick data for {symbol}")
logger.info(f"Configuration: timeframes={timeframes}, window_size={window_size}, batch_size={batch_size}")
logger.info(f"Data will refresh every {data_refresh_interval} seconds")
logger.info(f"Checkpoints will be saved every {checkpoint_interval} seconds")
logger.info(f"TensorBoard logs will be saved to {tensorboard_dir}")
try:
# Initialize tick data processor
tick_processor = TickDataProcessor(symbol)
tick_processor.running = True
# Start WebSocket connection in background
websocket_task = asyncio.create_task(tick_processor.start_websocket())
# Initialize data interface
logger.info("Initializing data interface...")
data_interface = DataInterface(
symbol=symbol,
timeframes=timeframes
)
# Initialize model
num_features = data_interface.get_feature_count()
logger.info(f"Initializing model with {num_features} features")
model = CNNModelPyTorch(
window_size=window_size,
timeframes=timeframes,
output_size=output_size,
num_pairs=1 # Single trading pair
)
# Try to load existing model
model_path = "NN/models/saved/optimized_short_term_model_best.pt"
try:
if os.path.exists(model_path):
logger.info(f"Loading existing model from {model_path}")
model.load(model_path)
logger.info("Model loaded successfully")
else:
logger.info("No existing model found. Starting with a new model.")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
logger.info("Starting with a new model.")
# Initialize signal interpreter
signal_interpreter = SignalInterpreter(config={
'buy_threshold': 0.55, # Lower threshold to catch more opportunities
'sell_threshold': 0.55, # Lower threshold to catch more opportunities
'hold_threshold': 0.65, # Lower threshold to reduce missed trades
'trend_filter_enabled': True,
'volume_filter_enabled': True,
'min_confidence': 0.45 # Minimum confidence to consider a trade
})
# Create checkpoint directory
checkpoint_dir = "NN/models/saved/realtime_ticks_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
# Track metrics
epoch = 0
best_val_pnl = -float('inf')
best_win_rate = 0
best_epoch = 0
consecutive_failures = 0
max_consecutive_failures = 5
# Training loop
while running:
try:
epoch += 1
epoch_start = time.time()
logger.info(f"Epoch {epoch} - Starting at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
# Process any new ticks
while not tick_processor.data_queue.empty():
tick = await tick_processor.data_queue.get()
tick_processor.process_tick(tick)
# Check if we need to refresh data
if time.time() - last_data_refresh_time > data_refresh_interval:
logger.info("Refreshing training data...")
last_data_refresh_time = time.time()
# Get candles for all timeframes
candles_data = {}
for timeframe in timeframes:
df = tick_processor.get_candles(timeframe)
if not df.empty:
candles_data[timeframe] = df
# Prepare training data with multiple timeframes
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data(
refresh=True,
refresh_interval=data_refresh_interval
)
if X_train is None or y_train is None:
logger.warning("Failed to prepare training data. Using previous data.")
consecutive_failures += 1
if consecutive_failures >= max_consecutive_failures:
logger.error("Too many consecutive failures. Stopping training.")
break
await asyncio.sleep(5) # Wait before retrying
continue
consecutive_failures = 0 # Reset failure counter on success
logger.info(f"Training data prepared - X shape: {X_train.shape}, y shape: {y_train.shape}")
# Calculate future prices for profitability-focused loss function
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=8)
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=8)
# Train one epoch
train_action_loss, train_price_loss, train_acc = model.train_epoch(
X_train, y_train, train_future_prices, batch_size
)
# Evaluate
val_action_loss, val_price_loss, val_acc = model.evaluate(
X_val, y_val, val_future_prices
)
logger.info(f"Epoch {epoch} results:")
logger.info(f" Train - Loss: {train_action_loss:.4f}, Accuracy: {train_acc:.4f}")
logger.info(f" Valid - Loss: {val_action_loss:.4f}, Accuracy: {val_acc:.4f}")
# Get predictions for PnL calculation
train_action_probs, train_price_preds = model.predict(X_train)
val_action_probs, val_price_preds = model.predict(X_val)
# Convert probabilities to actions
train_preds = np.argmax(train_action_probs, axis=1)
val_preds = np.argmax(val_action_probs, axis=1)
# Track signal distribution
train_buy_count = np.sum(train_preds == 2)
train_sell_count = np.sum(train_preds == 0)
train_hold_count = np.sum(train_preds == 1)
val_buy_count = np.sum(val_preds == 2)
val_sell_count = np.sum(val_preds == 0)
val_hold_count = np.sum(val_preds == 1)
signal_dist = {
"train": {
"BUY": float(train_buy_count / len(train_preds)) if len(train_preds) > 0 else 0,
"SELL": float(train_sell_count / len(train_preds)) if len(train_preds) > 0 else 0,
"HOLD": float(train_hold_count / len(train_preds)) if len(train_preds) > 0 else 0
},
"val": {
"BUY": float(val_buy_count / len(val_preds)) if len(val_preds) > 0 else 0,
"SELL": float(val_sell_count / len(val_preds)) if len(val_preds) > 0 else 0,
"HOLD": float(val_hold_count / len(val_preds)) if len(val_preds) > 0 else 0
}
}
# Calculate PnL and win rates with different position sizes
position_sizes = [0.1, 0.25, 0.5, 1.0, 2.0]
best_position_train_pnl = -float('inf')
best_position_val_pnl = -float('inf')
best_position_train_wr = 0
best_position_val_wr = 0
best_position_size = 1.0
for position_size in position_sizes:
train_pnl, train_win_rate, train_trades, train_balance_history, train_total_return = calculate_pnl_with_fees(
train_preds, train_prices, position_size=position_size
)
val_pnl, val_win_rate, val_trades, val_balance_history, val_total_return = calculate_pnl_with_fees(
val_preds, val_prices, position_size=position_size
)
# Update cumulative PnL and trade statistics
training_stats["cumulative_pnl"]["train"] += train_pnl
training_stats["cumulative_pnl"]["val"] += val_pnl
training_stats["total_trades"]["train"] += len(train_trades)
training_stats["total_trades"]["val"] += len(val_trades)
training_stats["total_wins"]["train"] += sum(1 for t in train_trades if t['net_pnl_usd'] > 0)
training_stats["total_wins"]["val"] += sum(1 for t in val_trades if t['net_pnl_usd'] > 0)
# Calculate average fees per trade
avg_train_fees = np.mean([t['fees_usd'] for t in train_trades]) if train_trades else 0
avg_val_fees = np.mean([t['fees_usd'] for t in val_trades]) if val_trades else 0
# Calculate max drawdown
train_drawdown = calculate_max_drawdown(train_balance_history)
val_drawdown = calculate_max_drawdown(val_balance_history)
# Calculate overall win rate
overall_train_wr = training_stats["total_wins"]["train"] / training_stats["total_trades"]["train"] if training_stats["total_trades"]["train"] > 0 else 0
overall_val_wr = training_stats["total_wins"]["val"] / training_stats["total_trades"]["val"] if training_stats["total_trades"]["val"] > 0 else 0
logger.info(f" Position Size: {position_size}")
logger.info(f" Train - PnL: {train_pnl:.4f}, Win Rate: {train_win_rate:.4f}, Trades: {len(train_trades)}")
logger.info(f" Train - Total Return: {train_total_return:.2f}%, Max Drawdown: {train_drawdown:.2f}%")
logger.info(f" Train - Avg Fees: ${avg_train_fees:.2f} per trade")
logger.info(f" Train - Cumulative PnL: {training_stats['cumulative_pnl']['train']:.4f}, Overall WR: {overall_train_wr:.4f}")
logger.info(f" Valid - PnL: {val_pnl:.4f}, Win Rate: {val_win_rate:.4f}, Trades: {len(val_trades)}")
logger.info(f" Valid - Total Return: {val_total_return:.2f}%, Max Drawdown: {val_drawdown:.2f}%")
logger.info(f" Valid - Avg Fees: ${avg_val_fees:.2f} per trade")
logger.info(f" Valid - Cumulative PnL: {training_stats['cumulative_pnl']['val']:.4f}, Overall WR: {overall_val_wr:.4f}")
# Log to TensorBoard
writer.add_scalar(f'Balance/train/position_{position_size}', train_balance_history[-1], epoch)
writer.add_scalar(f'Balance/validation/position_{position_size}', val_balance_history[-1], epoch)
writer.add_scalar(f'Return/train/position_{position_size}', train_total_return, epoch)
writer.add_scalar(f'Return/validation/position_{position_size}', val_total_return, epoch)
writer.add_scalar(f'Drawdown/train/position_{position_size}', train_drawdown, epoch)
writer.add_scalar(f'Drawdown/validation/position_{position_size}', val_drawdown, epoch)
writer.add_scalar(f'CumulativePnL/train/position_{position_size}', training_stats["cumulative_pnl"]["train"], epoch)
writer.add_scalar(f'CumulativePnL/validation/position_{position_size}', training_stats["cumulative_pnl"]["val"], epoch)
writer.add_scalar(f'OverallWinRate/train/position_{position_size}', overall_train_wr, epoch)
writer.add_scalar(f'OverallWinRate/validation/position_{position_size}', overall_val_wr, epoch)
# Track best position size for this epoch
if val_pnl > best_position_val_pnl:
best_position_val_pnl = val_pnl
best_position_val_wr = val_win_rate
best_position_size = position_size
if train_pnl > best_position_train_pnl:
best_position_train_pnl = train_pnl
best_position_train_wr = train_win_rate
# Track best model overall (using position size 1.0 as reference)
if val_pnl > best_val_pnl and position_size == 1.0:
best_val_pnl = val_pnl
best_win_rate = val_win_rate
best_epoch = epoch
logger.info(f" New best validation PnL: {best_val_pnl:.4f} at epoch {best_epoch}")
# Save the best model
model.save(f"NN/models/saved/optimized_short_term_model_ticks_best")
# Store epoch metrics with cumulative statistics
epoch_metrics = {
"epoch": epoch,
"train_loss": float(train_action_loss),
"val_loss": float(val_action_loss),
"train_acc": float(train_acc),
"val_acc": float(val_acc),
"train_pnl": float(best_position_train_pnl),
"val_pnl": float(best_position_val_pnl),
"train_win_rate": float(best_position_train_wr),
"val_win_rate": float(best_position_val_wr),
"best_position_size": float(best_position_size),
"signal_distribution": signal_dist,
"timestamp": datetime.now().isoformat(),
"data_age": int(time.time() - last_data_refresh_time),
"cumulative_pnl": {
"train": float(training_stats["cumulative_pnl"]["train"]),
"val": float(training_stats["cumulative_pnl"]["val"])
},
"total_trades": {
"train": int(training_stats["total_trades"]["train"]),
"val": int(training_stats["total_trades"]["val"])
},
"overall_win_rate": {
"train": float(overall_train_wr),
"val": float(overall_val_wr)
}
}
# Update training stats
training_stats["epochs_completed"] = epoch
training_stats["best_val_pnl"] = float(best_val_pnl)
training_stats["best_epoch"] = best_epoch
training_stats["best_win_rate"] = float(best_win_rate)
training_stats["last_update"] = datetime.now().isoformat()
training_stats["epochs"].append(epoch_metrics)
# Check if we need to save checkpoint
if time.time() - last_checkpoint_time > checkpoint_interval:
logger.info(f"Saving checkpoint at epoch {epoch}")
# Save model checkpoint
model.save(f"{checkpoint_dir}/checkpoint_epoch_{epoch}")
# Save training statistics
save_training_stats(training_stats)
last_checkpoint_time = time.time()
# Test trade signal generation with a random sample
random_idx = np.random.randint(0, len(X_val))
sample_X = X_val[random_idx:random_idx+1]
sample_probs, sample_price_pred = model.predict(sample_X)
# Process with signal interpreter
signal = signal_interpreter.interpret_signal(
sample_probs[0],
float(sample_price_pred[0][0]) if hasattr(sample_price_pred, "__getitem__") else float(sample_price_pred[0]),
market_data={'price': float(val_prices[random_idx]) if random_idx < len(val_prices) else 50000.0}
)
logger.info(f" Sample trade signal: {signal['action']} with confidence {signal['confidence']:.4f}")
# Log trading statistics
logger.info(f" Train - Actions: BUY={train_buy_count}, SELL={train_sell_count}, HOLD={train_hold_count}")
logger.info(f" Valid - Actions: BUY={val_buy_count}, SELL={val_sell_count}, HOLD={val_hold_count}")
# Log epoch timing
epoch_time = time.time() - epoch_start
total_elapsed = time.time() - start_time
logger.info(f" Epoch completed in {epoch_time:.2f} seconds")
logger.info(f" Total training time: {total_elapsed/3600:.2f} hours")
# Small delay to prevent CPU overload
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Error during epoch {epoch}: {str(e)}")
import traceback
logger.error(traceback.format_exc())
consecutive_failures += 1
if consecutive_failures >= max_consecutive_failures:
logger.error("Too many consecutive failures. Stopping training.")
break
await asyncio.sleep(5) # Wait before retrying
continue
# Cleanup
tick_processor.running = False
websocket_task.cancel()
# Save final model and performance metrics
logger.info("Saving final optimized model...")
model.save("NN/models/saved/optimized_short_term_model_ticks_final")
# Save performance metrics to file
save_training_stats(training_stats)
# Generate performance plots
try:
model.plot_training_history("NN/models/saved/realtime_ticks_training_stats.json")
logger.info("Performance plots generated successfully")
except Exception as e:
logger.error(f"Error generating plots: {str(e)}")
# Calculate total training time
total_time = time.time() - start_time
hours, remainder = divmod(total_time, 3600)
minutes, seconds = divmod(remainder, 60)
logger.info(f"Continuous training completed in {int(hours)}h {int(minutes)}m {int(seconds)}s")
logger.info(f"Best model performance - Epoch: {best_epoch}, PnL: {best_val_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
# Close TensorBoard writer
writer.close()
except Exception as e:
logger.error(f"Error during real-time training: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Try to save the model and stats in case of error
try:
if 'model' in locals():
model.save("NN/models/saved/optimized_short_term_model_ticks_emergency")
logger.info("Emergency model save completed")
if 'training_stats' in locals():
save_training_stats(training_stats, "NN/models/saved/realtime_ticks_training_stats_emergency.json")
if 'writer' in locals():
writer.close()
except Exception as e2:
logger.error(f"Failed to save emergency checkpoint: {str(e2)}")
if __name__ == "__main__":
# Print startup banner
print("=" * 80)
print("CONTINUOUS REALTIME TICKS TRAINING SESSION")
print("This script will continuously train the model using real-time tick data")
print("Press Ctrl+C to safely stop training and save the model")
print("TensorBoard logs will be saved to runs/realtime_ticks_training")
print("To view TensorBoard, run: tensorboard --logdir=runs/realtime_ticks_training")
print("=" * 80)
# Run the async training loop
asyncio.run(run_realtime_training())