massive clenup

This commit is contained in:
Dobromir Popov
2025-05-24 10:32:00 +03:00
parent 310f3c5bf9
commit b5ad023b16
87 changed files with 1930 additions and 784568 deletions

View File

@ -1,261 +0,0 @@
#!/usr/bin/env python
"""
Example script for the Neural Network Trading System
This shows basic usage patterns for the system components
"""
import os
import sys
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from datetime import datetime
import logging
# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Import components
from NN.utils.data_interface import DataInterface
from NN.models.cnn_model import CNNModel
from NN.models.transformer_model import TransformerModel, MixtureOfExpertsModel
from NN.main import NeuralNetworkOrchestrator
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('example')
def example_data_interface():
"""Show how to use the data interface"""
logger.info("=== Data Interface Example ===")
# Initialize data interface
di = DataInterface(symbol="BTC/USDT", timeframes=['1h', '4h', '1d'])
# Get historical data
df_1h = di.get_historical_data(timeframe='1h', n_candles=100)
if df_1h is not None and not df_1h.empty:
logger.info(f"Retrieved {len(df_1h)} 1-hour candles")
logger.info(f"Most recent candle: {df_1h.iloc[-1]}")
# Prepare data for neural network
X, y, timestamps = di.prepare_nn_input(timeframes=['1h'], n_candles=500, window_size=20)
if X is not None and y is not None:
logger.info(f"Prepared input shape: {X.shape}, target shape: {y.shape}")
# Generate a dataset
dataset = di.generate_training_dataset(
timeframes=['1h', '4h'],
n_candles=1000,
window_size=20
)
if dataset:
logger.info(f"Dataset generated and saved to: {list(dataset.values())}")
return X, y, timestamps if X is not None else (None, None, None)
def example_cnn_model(X=None, y=None):
"""Show how to use the CNN model"""
logger.info("=== CNN Model Example ===")
# If no data provided, create dummy data
if X is None or y is None:
logger.info("Creating dummy data for CNN example")
X = np.random.random((1000, 20, 5)) # 1000 samples, 20 time steps, 5 features
y = np.random.randint(0, 2, size=(1000,)) # Binary labels
# Split data into training and testing sets
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Initialize and build the CNN model
cnn = CNNModel(input_shape=(20, 5), output_size=1, model_dir='NN/models/saved')
cnn.build_model(filters=(32, 64, 128), kernel_sizes=(3, 5, 7), dropout_rate=0.3)
# Train the model (very small number of epochs for this example)
history = cnn.train(
X_train, y_train,
batch_size=32,
epochs=5, # Just a few epochs for the example
validation_split=0.2
)
# Evaluate the model
metrics = cnn.evaluate(X_test, y_test, plot_results=True)
if metrics:
logger.info(f"CNN Evaluation metrics: {metrics}")
# Make a prediction
y_pred, y_proba = cnn.predict(X_test[:1])
logger.info(f"CNN Prediction: {y_pred[0]}, Probability: {y_proba[0]:.4f}")
return cnn
def example_transformer_model(X=None, y=None, cnn_model=None):
"""Show how to use the Transformer model"""
logger.info("=== Transformer Model Example ===")
# If no data provided, create dummy data
if X is None or y is None:
logger.info("Creating dummy data for Transformer example")
X = np.random.random((1000, 20, 5)) # 1000 samples, 20 time steps, 5 features
y = np.random.randint(0, 2, size=(1000,)) # Binary labels
# Generate high-level features (from CNN model or random if no CNN provided)
if cnn_model is not None and hasattr(cnn_model, 'extract_hidden_features'):
# Extract features from CNN model
X_features = cnn_model.extract_hidden_features(X)
logger.info(f"Extracted {X_features.shape[1]} features from CNN model")
else:
# Generate random features
X_features = np.random.random((len(X), 128))
logger.info("Generated random features for Transformer model")
# Split data into training and testing sets
from sklearn.model_selection import train_test_split
X_train, X_test, X_feat_train, X_feat_test, y_train, y_test = train_test_split(
X, X_features, y, test_size=0.2, random_state=42
)
# Initialize and build the Transformer model
transformer = TransformerModel(
ts_input_shape=(20, 5),
feature_input_shape=X_features.shape[1],
output_size=1,
model_dir='NN/models/saved'
)
transformer.build_model(
embed_dim=32,
num_heads=2,
ff_dim=64,
num_transformer_blocks=2,
dropout_rate=0.2
)
# Train the model (very small number of epochs for this example)
history = transformer.train(
X_train, X_feat_train, y_train,
batch_size=32,
epochs=5, # Just a few epochs for the example
validation_split=0.2
)
# Make a prediction
y_pred, y_proba = transformer.predict(X_test[:1], X_feat_test[:1])
logger.info(f"Transformer Prediction: {y_pred[0]}, Probability: {y_proba[0]:.4f}")
return transformer
def example_moe_model(X=None, y=None, cnn_model=None, transformer_model=None):
"""Show how to use the Mixture of Experts model"""
logger.info("=== Mixture of Experts Example ===")
# If no data provided, create dummy data
if X is None or y is None:
logger.info("Creating dummy data for MoE example")
X = np.random.random((1000, 20, 5)) # 1000 samples, 20 time steps, 5 features
y = np.random.randint(0, 2, size=(1000,)) # Binary labels
# If models not provided, create them
if cnn_model is None:
logger.info("Creating a new CNN model for MoE")
cnn_model = CNNModel(input_shape=(20, 5), output_size=1)
cnn_model.build_model()
if transformer_model is None:
logger.info("Creating a new Transformer model for MoE")
transformer_model = TransformerModel(ts_input_shape=(20, 5), feature_input_shape=128, output_size=1)
transformer_model.build_model()
# Initialize MoE model
moe = MixtureOfExpertsModel(output_size=1, model_dir='NN/models/saved')
# Add expert models
moe.add_expert('cnn', cnn_model)
moe.add_expert('transformer', transformer_model)
# Build the MoE model (this is a simplified implementation - in a real scenario
# you would need to handle the interfaces between models more carefully)
moe.build_model(
ts_input_shape=(20, 5),
expert_weights={'cnn': 0.7, 'transformer': 0.3}
)
# In a real implementation, you would train the MoE model here
logger.info("MoE model built - in a real implementation, you would train it here")
return moe
def example_orchestrator():
"""Show how to use the Orchestrator"""
logger.info("=== Orchestrator Example ===")
# Configure the orchestrator
config = {
'symbol': 'BTC/USDT',
'timeframes': ['1h', '4h'],
'window_size': 20,
'n_features': 5,
'output_size': 3, # BUY/HOLD/SELL
'batch_size': 32,
'epochs': 5, # Small number for example
'model_dir': 'NN/models/saved',
'data_dir': 'NN/data'
}
# Initialize the orchestrator
orchestrator = NeuralNetworkOrchestrator(config)
# Prepare training data
X, y, timestamps = orchestrator.prepare_training_data(
timeframes=['1h'],
n_candles=200
)
if X is not None and y is not None:
logger.info(f"Prepared training data: X shape {X.shape}, y shape {y.shape}")
# Train CNN model
logger.info("Training CNN model with orchestrator...")
history = orchestrator.train_cnn_model(X, y, epochs=2) # Very small for example
# Make a prediction
result = orchestrator.run_inference_pipeline(
model_type='cnn',
timeframe='1h'
)
if result:
logger.info(f"Inference result: {result}")
else:
logger.warning("Could not prepare training data - this is expected if no real data is available")
logger.info("The orchestrator would normally handle training and inference")
def main():
"""Run all examples"""
logger.info("Starting Neural Network Trading System Examples")
# Example 1: Data Interface
X, y, timestamps = example_data_interface()
# Example 2: CNN Model
cnn_model = example_cnn_model(X, y)
# Example 3: Transformer Model
transformer_model = example_transformer_model(X, y, cnn_model)
# Example 4: Mixture of Experts
moe_model = example_moe_model(X, y, cnn_model, transformer_model)
# Example 5: Orchestrator
example_orchestrator()
logger.info("Examples completed")
if __name__ == "__main__":
main()

View File

@ -1,244 +0,0 @@
"""
Neural Network Trading System Main Module (Compatibility Layer)
This module serves as a compatibility layer for the realtime.py module.
It re-exports the functionality from realtime_main.py that is needed by realtime.py.
"""
import os
import sys
import logging
from datetime import datetime
import numpy as np
# Configure logging
logger = logging.getLogger('NN')
logger.setLevel(logging.INFO)
# Re-export everything from realtime_main.py
from .realtime_main import (
parse_arguments,
realtime,
train,
predict
)
# Create a class that realtime.py expects
class NeuralNetworkOrchestrator:
"""
Orchestrates the neural network operations.
"""
def __init__(self, config):
"""
Initialize the orchestrator with configuration.
Args:
config (dict): Configuration parameters
"""
self.config = config
self.symbol = config.get('symbol', 'BTC/USDT')
self.timeframes = config.get('timeframes', ['1m', '5m', '1h', '4h'])
self.window_size = config.get('window_size', 20)
self.n_features = config.get('n_features', 5)
self.output_size = config.get('output_size', 3)
self.model_dir = config.get('model_dir', 'NN/models/saved')
self.data_dir = config.get('data_dir', 'NN/data')
self.model = None
self.data_interface = None
# Initialize with default values in case imports fail
self.model_initialized = False
self.data_initialized = False
# Import necessary modules dynamically
try:
from .utils.data_interface import DataInterface
# Initialize data interface
self.data_interface = DataInterface(
symbol=self.symbol,
timeframes=self.timeframes
)
self.data_initialized = True
logger.info(f"Data interface initialized for {self.symbol}")
try:
from .models.cnn_model_pytorch import CNNModelPyTorch as Model
# Initialize model
feature_count = self.data_interface.get_feature_count() if hasattr(self.data_interface, 'get_feature_count') else 5
try:
# First try with expected parameters
self.model = Model(
window_size=self.window_size,
num_features=feature_count,
output_size=self.output_size,
timeframes=self.timeframes
)
except TypeError as e:
logger.warning(f"TypeError in model initialization with num_features: {str(e)}")
# Try alternate parameter naming
try:
self.model = Model(
input_shape=(self.window_size, feature_count),
output_size=self.output_size
)
logger.info("Model initialized with alternate parameters")
except Exception as ex:
logger.error(f"Failed to initialize model with alternate parameters: {str(ex)}")
self.model = DummyModel()
# Try to load the best model
self._load_model()
self.model_initialized = True
logger.info("Model initialized successfully")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
import traceback
logger.error(traceback.format_exc())
self.model = DummyModel()
logger.info(f"NeuralNetworkOrchestrator initialized with config: {config}")
except Exception as e:
logger.error(f"Error initializing NeuralNetworkOrchestrator: {str(e)}")
import traceback
logger.error(traceback.format_exc())
self.model = DummyModel()
def _load_model(self):
"""Load the best trained model from available files"""
try:
model_paths = [
os.path.join(self.model_dir, "dqn_agent_best_policy.pt"),
os.path.join(self.model_dir, "cnn_model_best.pt"),
os.path.join("models/saved", "dqn_agent_best_policy.pt"),
os.path.join("models/saved", "cnn_model_best.pt")
]
for model_path in model_paths:
if os.path.exists(model_path):
try:
self.model.load(model_path)
logger.info(f"Loaded model from {model_path}")
return True
except Exception as e:
logger.warning(f"Failed to load model from {model_path}: {str(e)}")
continue
logger.warning("No trained model found, using dummy model")
self.model = DummyModel()
return False
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
self.model = DummyModel()
return False
def run_inference_pipeline(self, model_type='cnn', timeframe='1h'):
"""
Run the inference pipeline using the trained model.
Args:
model_type (str): Type of model to use (cnn, transformer, etc.)
timeframe (str): Timeframe to use for inference
Returns:
dict: Inference result
"""
try:
# Check if we have a model
if not hasattr(self, 'model') or self.model is None:
logger.warning("No model available, initializing dummy model")
self.model = DummyModel()
# Check if we have a data interface
if not hasattr(self, 'data_interface') or self.data_interface is None:
logger.warning("No data interface available")
# Return a dummy prediction
return self._get_dummy_prediction()
# Prepare input data for the selected timeframe
X, timestamp = self.data_interface.prepare_realtime_input(
timeframe=timeframe,
n_candles=self.window_size + 10, # Extra candles for safety
window_size=self.window_size
)
if X is None:
logger.warning(f"No data available for {self.symbol}")
return self._get_dummy_prediction()
# Get model predictions
action_probs, price_pred = self.model.predict(X)
# Convert predictions to action
action_idx = np.argmax(action_probs) if hasattr(action_probs, 'argmax') else 1 # Default to HOLD
action_names = ['SELL', 'HOLD', 'BUY']
action = action_names[action_idx]
# Format timestamp
if not isinstance(timestamp, str):
try:
if hasattr(timestamp, 'isoformat'): # If it's already a datetime-like object
timestamp = timestamp.isoformat()
else: # If it's a numeric timestamp
timestamp = datetime.fromtimestamp(float(timestamp)/1000).isoformat()
except (TypeError, ValueError):
timestamp = datetime.now().isoformat()
# Return result
result = {
'timestamp': timestamp,
'action': action,
'action_index': int(action_idx),
'probability': float(action_probs[action_idx]) if hasattr(action_probs, '__getitem__') else 0.33,
'probabilities': {name: float(prob) for name, prob in zip(action_names, action_probs)} if hasattr(action_probs, '__iter__') else {'SELL': 0.33, 'HOLD': 0.34, 'BUY': 0.33},
'price_prediction': float(price_pred) if price_pred is not None else None
}
logger.info(f"Inference result: {result}")
return result
except Exception as e:
logger.error(f"Error in inference pipeline: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return self._get_dummy_prediction()
def _get_dummy_prediction(self):
"""Return a dummy prediction when model or data is unavailable"""
action_names = ['SELL', 'HOLD', 'BUY']
action_idx = 1 # Default to HOLD
timestamp = datetime.now().isoformat()
return {
'timestamp': timestamp,
'action': 'HOLD',
'action_index': action_idx,
'probability': 0.8,
'probabilities': {'SELL': 0.1, 'HOLD': 0.8, 'BUY': 0.1},
'price_prediction': None,
'is_dummy': True
}
class DummyModel:
"""Dummy model that returns random predictions"""
def __init__(self):
logger.info("Initializing dummy model")
def predict(self, X):
"""Return random predictions"""
# Generate random probabilities for SELL, HOLD, BUY
action_probs = np.array([0.1, 0.8, 0.1]) # Bias towards HOLD
# Generate a random price prediction (None for now)
price_pred = None
return action_probs, price_pred
def load(self, model_path):
"""Dummy load method"""
logger.info(f"Dummy model pretending to load from {model_path}")
return True

View File

@ -1,287 +0,0 @@
import logging
import threading
import time
from typing import Dict, Any, List, Optional, Callable, Tuple
import os
import numpy as np
import pandas as pd
from .trading_agent import TradingAgent
logger = logging.getLogger(__name__)
class NeuralNetworkOrchestrator:
"""Orchestrator for neural network models and trading operations.
This class coordinates between neural network models and trading agents,
ensuring that signals from the models are properly processed and trades
are executed according to the strategy.
"""
def __init__(self, model, data_interface, chart=None,
symbols: List[str] = None,
timeframes: List[str] = None,
window_size: int = 20,
num_features: int = 5,
output_size: int = 3,
models_dir: str = "NN/models/saved",
data_dir: str = "NN/data",
exchange_config: Dict[str, Any] = None):
"""Initialize the neural network orchestrator.
Args:
model: Neural network model instance
data_interface: Data interface for retrieving market data
chart: Real-time chart for visualization (optional)
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
timeframes: List of timeframes to monitor (e.g., ['1m', '5m', '1h'])
window_size: Window size for model input
num_features: Number of features per datapoint
output_size: Number of output classes (e.g., 3 for BUY/HOLD/SELL)
models_dir: Directory for saved models
data_dir: Directory for data storage
exchange_config: Configuration for trading agent (exchange, API keys, etc.)
"""
self.model = model
self.data_interface = data_interface
self.chart = chart
self.symbols = symbols or ["BTC/USDT"]
self.timeframes = timeframes or ["1m", "5m", "1h", "4h", "1d"]
self.window_size = window_size
self.num_features = num_features
self.output_size = output_size
self.models_dir = models_dir
self.data_dir = data_dir
# Initialize trading agent if configuration provided
self.trading_agent = None
if exchange_config:
self.init_trading_agent(exchange_config)
# Initialize inference state
self.is_running = False
self.inference_thread = None
self.stop_event = threading.Event()
self.last_inference_time = 0
self.inference_interval = int(os.environ.get("NN_INFERENCE_INTERVAL", "60"))
logger.info(f"Initializing NeuralNetworkOrchestrator with:")
logger.info(f"- Symbol: {self.symbols[0]}")
logger.info(f"- Timeframes: {', '.join(self.timeframes)}")
logger.info(f"- Window size: {window_size}")
logger.info(f"- Num features: {num_features}")
logger.info(f"- Output size: {output_size}")
logger.info(f"- Models dir: {models_dir}")
logger.info(f"- Data dir: {data_dir}")
logger.info(f"- Inference interval: {self.inference_interval} seconds")
def init_trading_agent(self, config: Dict[str, Any]):
"""Initialize the trading agent with the given configuration.
Args:
config: Configuration for the trading agent
"""
exchange_name = config.get("exchange", "binance")
api_key = config.get("api_key")
api_secret = config.get("api_secret")
test_mode = config.get("test_mode", True)
trade_symbols = config.get("trade_symbols", self.symbols)
position_size = config.get("position_size", 0.1)
max_trades_per_day = config.get("max_trades_per_day", 5)
trade_cooldown_minutes = config.get("trade_cooldown_minutes", 60)
self.trading_agent = TradingAgent(
exchange_name=exchange_name,
api_key=api_key,
api_secret=api_secret,
test_mode=test_mode,
trade_symbols=trade_symbols,
position_size=position_size,
max_trades_per_day=max_trades_per_day,
trade_cooldown_minutes=trade_cooldown_minutes
)
logger.info(f"Trading agent initialized for {exchange_name} exchange.")
def start_inference(self):
"""Start the inference thread."""
if self.is_running:
logger.warning("Neural network inference is already running.")
return
self.is_running = True
self.stop_event.clear()
# Start inference thread
self.inference_thread = threading.Thread(target=self._inference_loop)
self.inference_thread.daemon = True
self.inference_thread.start()
logger.info(f"Neural network inference thread started with {self.inference_interval}s interval.")
# Start trading agent if available
if self.trading_agent:
self.trading_agent.start(signal_callback=self._on_trade_executed)
def stop_inference(self):
"""Stop the inference thread."""
if not self.is_running:
logger.warning("Neural network inference is not running.")
return
logger.info("Stopping neural network inference...")
self.is_running = False
self.stop_event.set()
if self.inference_thread and self.inference_thread.is_alive():
self.inference_thread.join(timeout=10)
logger.info("Neural network inference stopped.")
# Stop trading agent if available
if self.trading_agent:
self.trading_agent.stop()
def _inference_loop(self):
"""Main inference loop that processes data and generates signals."""
logger.info("Inference loop started.")
try:
while self.is_running and not self.stop_event.is_set():
current_time = time.time()
# Check if we should run inference
if current_time - self.last_inference_time >= self.inference_interval:
try:
# Run inference for all symbols
for symbol in self.symbols:
prediction = self._run_inference(symbol)
if prediction:
self._process_prediction(symbol, prediction)
self.last_inference_time = current_time
except Exception as e:
logger.error(f"Error during inference: {str(e)}")
# Sleep for a short time to prevent CPU hogging
time.sleep(1)
except Exception as e:
logger.error(f"Error in inference loop: {str(e)}")
finally:
logger.info("Inference loop stopped.")
def _run_inference(self, symbol: str) -> Optional[Tuple[np.ndarray, float]]:
"""Run inference for a specific symbol.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
Returns:
tuple: (action probabilities, current price) or None if inference failed
"""
try:
# Get the model timeframe from environment
model_timeframe = os.environ.get("NN_TIMEFRAME", "1h")
if model_timeframe not in self.timeframes:
logger.warning(f"Model timeframe {model_timeframe} not in available timeframes. Using {self.timeframes[0]}.")
model_timeframe = self.timeframes[0]
# Load candles for the model timeframe
logger.info(f"Loading {1000} candles from cache for {symbol} at {model_timeframe} timeframe")
candles = self.data_interface.get_historical_data(
symbol=symbol,
timeframe=model_timeframe,
n_candles=1000
)
if candles is None or len(candles) < self.window_size:
logger.warning(f"Not enough data for {symbol} at {model_timeframe} timeframe. Need at least {self.window_size} candles.")
return None
# Prepare input data
X, timestamp = self.data_interface.prepare_model_input(
data=candles,
window_size=self.window_size,
symbol=symbol
)
if X is None:
logger.warning(f"Failed to prepare model input for {symbol}.")
return None
# Get current price
current_price = candles['close'].iloc[-1]
# Run model inference
action_probs, price_pred = self.model.predict(X)
return action_probs, current_price
except Exception as e:
logger.error(f"Error running inference for {symbol}: {str(e)}")
return None
def _process_prediction(self, symbol: str, prediction: Tuple[np.ndarray, float]):
"""Process a prediction and generate signals.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
prediction: Tuple of (action probabilities, current price)
"""
action_probs, current_price = prediction
# Get the best action (0=SELL, 1=HOLD, 2=BUY)
best_action = np.argmax(action_probs)
best_prob = float(action_probs[best_action])
# Convert to action name
action_names = ["SELL", "HOLD", "BUY"]
action_name = action_names[best_action]
# Log the prediction
logger.info(f"Inference result for {symbol}: Action={action_name}, Probability={best_prob:.2f}, Price={current_price:.2f}")
# Add signal to chart if available
if self.chart:
self.chart.add_nn_signal(symbol=symbol, signal=action_name, confidence=best_prob, timestamp=int(time.time()))
# Process signal with trading agent if available
if self.trading_agent:
self.trading_agent.process_signal(
symbol=symbol,
action=action_name,
confidence=best_prob,
timestamp=int(time.time())
)
def _on_trade_executed(self, trade_record: Dict[str, Any]):
"""Callback for when a trade is executed.
Args:
trade_record: Trade information
"""
if self.chart and trade_record:
# Add trade to chart
self.chart.add_trade(
action=trade_record['action'],
price=trade_record.get('price', 0),
timestamp=trade_record['timestamp'],
pnl=trade_record.get('pnl', 0)
)
logger.info(f"Trade added to chart: {trade_record['action']} at {trade_record.get('price', 0):.2f}")
def get_trading_agent_info(self) -> Dict[str, Any]:
"""Get information about the trading agent.
Returns:
dict: Trading agent information or None if no agent is available
"""
if self.trading_agent:
return {
'exchange_info': self.trading_agent.get_exchange_info(),
'positions': self.trading_agent.get_current_positions(),
'trades': len(self.trading_agent.get_trade_history())
}
return None

View File

@ -1,287 +0,0 @@
#!/usr/bin/env python3
"""
Neural Network Trading System Main Module
This module serves as the main entry point for the NN trading system,
coordinating data flow between different components and implementing
training and inference pipelines.
"""
import os
import sys
import logging
import argparse
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler(os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'))
]
)
logger = logging.getLogger('NN')
# Create logs directory if it doesn't exist
os.makedirs('logs', exist_ok=True)
def parse_arguments():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description='Neural Network Trading System')
parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train',
help='Mode to run (train, predict, realtime)')
parser.add_argument('--symbol', type=str, default='BTC/USDT',
help='Trading pair symbol')
parser.add_argument('--timeframes', type=str, nargs='+', default=['1h', '4h'],
help='Timeframes to use')
parser.add_argument('--window-size', type=int, default=20,
help='Window size for input data')
parser.add_argument('--output-size', type=int, default=3,
help='Output size (1 for binary, 3 for BUY/HOLD/SELL)')
parser.add_argument('--batch-size', type=int, default=32,
help='Batch size for training')
parser.add_argument('--epochs', type=int, default=100,
help='Number of epochs for training')
parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn',
help='Model type to use')
parser.add_argument('--framework', type=str, choices=['tensorflow', 'pytorch'], default='pytorch',
help='Deep learning framework to use')
return parser.parse_args()
def main():
"""Main entry point for the NN trading system"""
# Parse arguments
args = parse_arguments()
logger.info(f"Starting NN Trading System in {args.mode} mode")
logger.info(f"Configuration: Symbol={args.symbol}, Timeframes={args.timeframes}, "
f"Window Size={args.window_size}, Output Size={args.output_size}, "
f"Model Type={args.model_type}, Framework={args.framework}")
# Import the appropriate modules based on the framework
if args.framework == 'pytorch':
try:
import torch
logger.info(f"Using PyTorch {torch.__version__}")
# Import PyTorch-based modules
from NN.utils.data_interface import DataInterface
if args.model_type == 'cnn':
from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model
elif args.model_type == 'transformer':
from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model
elif args.model_type == 'moe':
from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model
else:
logger.error(f"Unknown model type: {args.model_type}")
return
except ImportError as e:
logger.error(f"Failed to import PyTorch modules: {str(e)}")
logger.error("Please make sure PyTorch is installed or use the TensorFlow framework.")
return
elif args.framework == 'tensorflow':
try:
import tensorflow as tf
logger.info(f"Using TensorFlow {tf.__version__}")
# Import TensorFlow-based modules
from NN.utils.data_interface import DataInterface
if args.model_type == 'cnn':
from NN.models.cnn_model import CNNModel as Model
elif args.model_type == 'transformer':
from NN.models.transformer_model import TransformerModel as Model
elif args.model_type == 'moe':
from NN.models.transformer_model import MixtureOfExpertsModel as Model
else:
logger.error(f"Unknown model type: {args.model_type}")
return
except ImportError as e:
logger.error(f"Failed to import TensorFlow modules: {str(e)}")
logger.error("Please make sure TensorFlow is installed or use the PyTorch framework.")
return
else:
logger.error(f"Unknown framework: {args.framework}")
return
# Initialize data interface
try:
logger.info("Initializing data interface...")
data_interface = DataInterface(
symbol=args.symbol,
timeframes=args.timeframes
)
except Exception as e:
logger.error(f"Failed to initialize data interface: {str(e)}")
return
# Initialize model
try:
logger.info(f"Initializing {args.model_type.upper()} model...")
model = Model(
window_size=args.window_size,
num_features=data_interface.get_feature_count(),
output_size=args.output_size,
timeframes=args.timeframes
)
except Exception as e:
logger.error(f"Failed to initialize model: {str(e)}")
return
# Execute the requested mode
if args.mode == 'train':
train(data_interface, model, args)
elif args.mode == 'predict':
predict(data_interface, model, args)
elif args.mode == 'realtime':
realtime(data_interface, model, args)
else:
logger.error(f"Unknown mode: {args.mode}")
return
logger.info("Neural Network Trading System finished successfully")
def train(data_interface, model, args):
"""Enhanced training with performance tracking"""
from torch.utils.tensorboard import SummaryWriter
logger.info("Starting training mode...")
writer = SummaryWriter(log_dir=f"runs/{args.model_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
try:
best_val_acc = 0
for epoch in range(args.epochs):
# Refresh data every few epochs
if epoch % 3 == 0:
X_train, y_train, X_val, y_val = data_interface.prepare_training_data(refresh=True)
else:
X_train, y_train, X_val, y_val = data_interface.prepare_training_data()
# Train for one epoch
train_loss, train_acc = model.train_epoch(
X_train, y_train,
batch_size=args.batch_size
)
# Validate
val_loss, val_acc = model.evaluate(X_val, y_val)
# Log metrics
writer.add_scalar('Loss/Train', train_loss, epoch)
writer.add_scalar('Accuracy/Train', train_acc, epoch)
writer.add_scalar('Loss/Validation', val_loss, epoch)
writer.add_scalar('Accuracy/Validation', val_acc, epoch)
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
model_path = os.path.join(
'models',
f"{args.model_type}_best_{args.symbol.replace('/', '_')}.pt"
)
model.save(model_path)
logger.info(f"New best model saved with val_acc: {val_acc:.2f}")
logger.info(f"Epoch {epoch+1}/{args.epochs} - "
f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f} - "
f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}")
# Save final model
model_path = os.path.join(
'models',
f"{args.model_type}_final_{args.symbol.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
)
model.save(model_path)
logger.info(f"Training Complete - Best Val Accuracy: {best_val_acc:.2f}")
except Exception as e:
logger.error(f"Error in training mode: {str(e)}")
return
def predict(data_interface, model, args):
"""Make predictions using the trained model"""
logger.info("Starting prediction mode...")
try:
# Load the latest model
model_dir = os.path.join('models')
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
if not model_files:
logger.error(f"No saved model found for type {args.model_type}")
return
latest_model = sorted(model_files)[-1]
model_path = os.path.join(model_dir, latest_model)
logger.info(f"Loading model from {model_path}...")
model.load(model_path)
# Prepare prediction data
logger.info("Preparing prediction data...")
X_pred = data_interface.prepare_prediction_data()
# Make predictions
logger.info("Making predictions...")
predictions = model.predict(X_pred)
# Process and display predictions
logger.info("Processing predictions...")
data_interface.process_predictions(predictions)
except Exception as e:
logger.error(f"Error in prediction mode: {str(e)}")
return
def realtime(data_interface, model, args):
"""Run the model in real-time mode"""
logger.info("Starting real-time mode...")
try:
# Import realtime analyzer
from NN.utils.realtime_analyzer import RealtimeAnalyzer
# Load the latest model
model_dir = os.path.join('models')
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
if not model_files:
logger.error(f"No saved model found for type {args.model_type}")
return
latest_model = sorted(model_files)[-1]
model_path = os.path.join(model_dir, latest_model)
logger.info(f"Loading model from {model_path}...")
model.load(model_path)
# Initialize realtime analyzer
logger.info("Initializing real-time analyzer...")
realtime_analyzer = RealtimeAnalyzer(
data_interface=data_interface,
model=model,
symbol=args.symbol,
timeframes=args.timeframes
)
# Start real-time analysis
logger.info("Starting real-time analysis...")
realtime_analyzer.start()
except Exception as e:
logger.error(f"Error in real-time mode: {str(e)}")
return
if __name__ == "__main__":
main()

View File

@ -1,241 +0,0 @@
import logging
import numpy as np
import pandas as pd
import time
from typing import Dict, Any, List, Optional, Tuple
logger = logging.getLogger(__name__)
class RealtimeDataInterface:
"""Interface for retrieving real-time market data for neural network models.
This class serves as a bridge between the RealTimeChart data sources and
the neural network models, providing properly formatted data for model
inference.
"""
def __init__(self, symbols: List[str], chart=None, max_cache_size: int = 5000):
"""Initialize the data interface.
Args:
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
chart: RealTimeChart instance (optional)
max_cache_size: Maximum number of cached candles
"""
self.symbols = symbols
self.chart = chart
self.max_cache_size = max_cache_size
# Initialize data cache
self.ohlcv_cache = {} # timeframe -> symbol -> DataFrame
logger.info(f"Initialized RealtimeDataInterface with symbols: {', '.join(symbols)}")
def get_historical_data(self, symbol: str = None, timeframe: str = '1h',
n_candles: int = 500) -> Optional[pd.DataFrame]:
"""Get historical OHLCV data for a symbol and timeframe.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
timeframe: Time interval (e.g., '1m', '5m', '1h')
n_candles: Number of candles to retrieve
Returns:
DataFrame with OHLCV data or None if not available
"""
if not symbol:
if len(self.symbols) > 0:
symbol = self.symbols[0]
else:
logger.error("No symbol specified and no default symbols available")
return None
if symbol not in self.symbols:
logger.warning(f"Symbol {symbol} not in tracked symbols")
return None
try:
# Get data from chart if available
if self.chart:
candles = self._get_chart_data(symbol, timeframe, n_candles)
if candles is not None and len(candles) > 0:
return candles
# Fallback to default empty DataFrame
logger.warning(f"No historical data available for {symbol} at timeframe {timeframe}")
return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
except Exception as e:
logger.error(f"Error getting historical data for {symbol}: {str(e)}")
return None
def _get_chart_data(self, symbol: str, timeframe: str, n_candles: int) -> Optional[pd.DataFrame]:
"""Get data from the RealTimeChart for the specified symbol and timeframe.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
timeframe: Time interval (e.g., '1m', '5m', '1h')
n_candles: Number of candles to retrieve
Returns:
DataFrame with OHLCV data or None if not available
"""
if not self.chart:
return None
# Get chart data using the _get_chart_data method
try:
# Map to interval seconds
interval_map = {
'1s': 1,
'5s': 5,
'10s': 10,
'15s': 15,
'30s': 30,
'1m': 60,
'3m': 180,
'5m': 300,
'15m': 900,
'30m': 1800,
'1h': 3600,
'2h': 7200,
'4h': 14400,
'6h': 21600,
'8h': 28800,
'12h': 43200,
'1d': 86400,
'3d': 259200,
'1w': 604800
}
# Convert timeframe to seconds
if timeframe in interval_map:
interval_seconds = interval_map[timeframe]
else:
# Try to parse the interval (e.g., '1m' -> 60)
try:
if timeframe.endswith('s'):
interval_seconds = int(timeframe[:-1])
elif timeframe.endswith('m'):
interval_seconds = int(timeframe[:-1]) * 60
elif timeframe.endswith('h'):
interval_seconds = int(timeframe[:-1]) * 3600
elif timeframe.endswith('d'):
interval_seconds = int(timeframe[:-1]) * 86400
elif timeframe.endswith('w'):
interval_seconds = int(timeframe[:-1]) * 604800
else:
interval_seconds = int(timeframe)
except ValueError:
logger.error(f"Could not parse timeframe: {timeframe}")
return None
# Get data from chart
df = self.chart._get_chart_data(interval_seconds)
if df is not None and not df.empty:
# Limit to requested number of candles
if len(df) > n_candles:
df = df.iloc[-n_candles:]
return df
else:
logger.warning(f"No data retrieved from chart for {symbol} at timeframe {timeframe}")
return None
except Exception as e:
logger.error(f"Error getting chart data for {symbol} at {timeframe}: {str(e)}")
return None
def prepare_model_input(self, data: pd.DataFrame, window_size: int = 20,
symbol: str = None) -> Tuple[np.ndarray, Optional[int]]:
"""Prepare model input from OHLCV data.
Args:
data: DataFrame with OHLCV data
window_size: Window size for model input
symbol: Symbol for the data (for logging)
Returns:
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
"""
if data is None or len(data) < window_size:
logger.warning(f"Not enough data to prepare model input for {symbol or 'unknown symbol'}")
return None, None
try:
# Get last window_size candles
recent_data = data.iloc[-window_size:].copy()
# Get timestamp of the most recent candle
timestamp = int(recent_data.iloc[-1]['timestamp']) if 'timestamp' in recent_data.columns else int(time.time())
# Extract OHLCV features and normalize
if 'open' in recent_data.columns and 'high' in recent_data.columns and 'low' in recent_data.columns and 'close' in recent_data.columns and 'volume' in recent_data.columns:
# Normalize price data by the last close price
last_close = recent_data['close'].iloc[-1]
# Avoid division by zero
if last_close == 0:
last_close = 1.0
opens = (recent_data['open'] / last_close).values
highs = (recent_data['high'] / last_close).values
lows = (recent_data['low'] / last_close).values
closes = (recent_data['close'] / last_close).values
# Normalize volume by the max volume in the window
max_volume = recent_data['volume'].max()
if max_volume == 0:
max_volume = 1.0
volumes = (recent_data['volume'] / max_volume).values
# Stack features into a 3D array [batch_size=1, window_size, n_features=5]
X = np.column_stack((opens, highs, lows, closes, volumes))
X = X.reshape(1, window_size, 5)
# Replace any NaN or infinite values
X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=0.0)
return X, timestamp
else:
logger.error(f"Data missing required OHLCV columns for {symbol or 'unknown symbol'}")
return None, None
except Exception as e:
logger.error(f"Error preparing model input for {symbol or 'unknown symbol'}: {str(e)}")
return None, None
def prepare_realtime_input(self, timeframe: str = '1h', n_candles: int = 30,
window_size: int = 20) -> Tuple[np.ndarray, Optional[int]]:
"""Prepare real-time input for the model.
Args:
timeframe: Time interval (e.g., '1m', '5m', '1h')
n_candles: Number of candles to retrieve
window_size: Window size for model input
Returns:
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
"""
# Get data for the main symbol
if len(self.symbols) == 0:
logger.error("No symbols available for real-time input")
return None, None
symbol = self.symbols[0]
try:
# Get historical data
data = self.get_historical_data(symbol, timeframe, n_candles)
if data is None or len(data) < window_size:
logger.warning(f"Not enough data for real-time input. Need at least {window_size} candles.")
return None, None
# Prepare model input
return self.prepare_model_input(data, window_size, symbol)
except Exception as e:
logger.error(f"Error preparing real-time input: {str(e)}")
return None, None

View File

@ -1,507 +0,0 @@
#!/usr/bin/env python3
"""
Neural Network Trading System Main Module - PyTorch Version
This module serves as the main entry point for the NN trading system,
using PyTorch exclusively for all model operations.
"""
import os
import sys
import logging
import argparse
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import time
# Configure logging
logger = logging.getLogger('NN')
logger.setLevel(logging.INFO)
try:
# Create logs directory if it doesn't exist
os.makedirs('logs', exist_ok=True)
# Try setting up file logging
log_file = os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
fh = logging.FileHandler(log_file)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
logger.addHandler(fh)
logger.info(f"Logging to file: {log_file}")
except Exception as e:
logger.warning(f"Failed to setup file logging: {str(e)}. Falling back to console logging only.")
# Always setup console logging
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
def parse_arguments():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description='Neural Network Trading System')
parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train',
help='Mode to run (train, predict, realtime)')
parser.add_argument('--symbol', type=str, default='BTC/USDT',
help='Trading pair symbol')
parser.add_argument('--timeframes', type=str, nargs='+', default=['1s', '1m', '5m', '1h', '4h'],
help='Timeframes to use (include 1s for ticks)')
parser.add_argument('--window-size', type=int, default=20,
help='Window size for input data')
parser.add_argument('--output-size', type=int, default=3,
help='Output size (1 for binary, 3 for BUY/HOLD/SELL)')
parser.add_argument('--batch-size', type=int, default=32,
help='Batch size for training')
parser.add_argument('--epochs', type=int, default=10,
help='Number of epochs for training')
parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn',
help='Model type to use')
return parser.parse_args()
def main():
"""Main entry point for the NN trading system"""
args = parse_arguments()
logger.info(f"Starting NN Trading System in {args.mode} mode")
logger.info(f"Configuration: Symbol={args.symbol}, Timeframes={args.timeframes}")
try:
import torch
from NN.utils.data_interface import DataInterface
# Import appropriate PyTorch model
if args.model_type == 'cnn':
from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model
elif args.model_type == 'transformer':
from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model
elif args.model_type == 'moe':
from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model
else:
logger.error(f"Unknown model type: {args.model_type}")
return
except ImportError as e:
logger.error(f"Failed to import PyTorch modules: {str(e)}")
logger.error("Please make sure PyTorch is installed")
return
# Initialize data interface
try:
data_interface = DataInterface(
symbol=args.symbol,
timeframes=args.timeframes
)
# Verify data interface by fetching initial data
logger.info("Verifying data interface...")
X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True)
if X_sample is None or y_sample is not None:
logger.error("Failed to prepare initial training data")
return
logger.info(f"Data interface verified - X shape: {X_sample.shape}, y shape: {y_sample.shape}")
except Exception as e:
logger.error(f"Failed to initialize data interface: {str(e)}")
return
# Initialize model
try:
# Calculate total number of features across all timeframes
num_features = data_interface.get_feature_count()
logger.info(f"Initializing model with {num_features} features")
model = Model(
window_size=args.window_size,
num_features=num_features,
output_size=args.output_size,
timeframes=args.timeframes
)
# Ensure model is on the correct device
if torch.cuda.is_available():
model.model = model.model.cuda()
logger.info("Model moved to CUDA device")
except Exception as e:
logger.error(f"Failed to initialize model: {str(e)}")
return
# Execute requested mode
if args.mode == 'train':
train(data_interface, model, args)
elif args.mode == 'predict':
predict(data_interface, model, args)
elif args.mode == 'realtime':
realtime(data_interface, model, args)
def train(data_interface, model, args):
"""Enhanced training with performance tracking and retrospective fine-tuning"""
logger.info("Starting training mode...")
writer = SummaryWriter()
try:
best_val_acc = 0
best_val_pnl = float('-inf')
best_win_rate = 0
best_price_mae = float('inf')
logger.info("Verifying data interface...")
X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True)
logger.info(f"Data validation - X shape: {X_sample.shape}, y shape: {y_sample.shape}")
# Calculate refresh intervals based on timeframes
min_timeframe = min(args.timeframes)
refresh_interval = {
'1s': 1,
'1m': 60,
'5m': 300,
'15m': 900,
'1h': 3600,
'4h': 14400,
'1d': 86400
}.get(min_timeframe, 60)
logger.info(f"Using refresh interval of {refresh_interval} seconds based on {min_timeframe} timeframe")
for epoch in range(args.epochs):
# Always refresh for tick data or when using multiple timeframes
refresh = '1s' in args.timeframes or len(args.timeframes) > 1
logger.info(f"\nStarting epoch {epoch+1}/{args.epochs}")
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data(
refresh=refresh,
refresh_interval=refresh_interval
)
logger.info(f"Training data - X shape: {X_train.shape}, y shape: {y_train.shape}")
logger.info(f"Validation data - X shape: {X_val.shape}, y shape: {y_val.shape}")
# Get future prices for retrospective training
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=3)
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=3)
# Train and validate
try:
train_action_loss, train_price_loss, train_acc = model.train_epoch(
X_train, y_train, train_future_prices, args.batch_size
)
val_action_loss, val_price_loss, val_acc = model.evaluate(
X_val, y_val, val_future_prices
)
# Get predictions for PnL calculation
train_action_probs, train_price_preds = model.predict(X_train)
val_action_probs, val_price_preds = model.predict(X_val)
# Convert probabilities to actions for PnL calculation
train_preds = np.argmax(train_action_probs, axis=1)
val_preds = np.argmax(val_action_probs, axis=1)
# Calculate PnL and win rates
try:
if train_preds is not None and train_prices is not None:
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
train_preds, train_prices, position_size=1.0
)
else:
train_pnl, train_win_rate, train_trades = 0, 0, []
if val_preds is not None and val_prices is not None:
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
val_preds, val_prices, position_size=1.0
)
else:
val_pnl, val_win_rate, val_trades = 0, 0, []
except Exception as e:
logger.error(f"Error calculating PnL: {str(e)}")
train_pnl, train_win_rate, val_pnl, val_win_rate = 0, 0, 0, 0
train_trades, val_trades = [], []
# Calculate price prediction error
if train_future_prices is not None and train_price_preds is not None:
# Ensure arrays have the same shape and are numpy arrays
train_future_prices_np = np.array(train_future_prices) if not isinstance(train_future_prices, np.ndarray) else train_future_prices
train_price_preds_np = np.array(train_price_preds) if not isinstance(train_price_preds, np.ndarray) else train_price_preds
if len(train_price_preds_np) > 0 and len(train_future_prices_np) > 0:
min_len = min(len(train_price_preds_np), len(train_future_prices_np))
train_price_mae = np.mean(np.abs(train_price_preds_np[:min_len] - train_future_prices_np[:min_len]))
else:
train_price_mae = float('inf')
else:
train_price_mae = float('inf')
if val_future_prices is not None and val_price_preds is not None:
# Ensure arrays have the same shape and are numpy arrays
val_future_prices_np = np.array(val_future_prices) if not isinstance(val_future_prices, np.ndarray) else val_future_prices
val_price_preds_np = np.array(val_price_preds) if not isinstance(val_price_preds, np.ndarray) else val_price_preds
if len(val_price_preds_np) > 0 and len(val_future_prices_np) > 0:
min_len = min(len(val_price_preds_np), len(val_future_prices_np))
val_price_mae = np.mean(np.abs(val_price_preds_np[:min_len] - val_future_prices_np[:min_len]))
else:
val_price_mae = float('inf')
else:
val_price_mae = float('inf')
# Monitor action distribution
train_actions = np.bincount(np.argmax(train_action_probs, axis=1), minlength=3)
val_actions = np.bincount(np.argmax(val_action_probs, axis=1), minlength=3)
# Log metrics
writer.add_scalar('Loss/action_train', train_action_loss, epoch)
writer.add_scalar('Loss/price_train', train_price_loss, epoch)
writer.add_scalar('Loss/action_val', val_action_loss, epoch)
writer.add_scalar('Loss/price_val', val_price_loss, epoch)
writer.add_scalar('Accuracy/train', train_acc, epoch)
writer.add_scalar('Accuracy/val', val_acc, epoch)
writer.add_scalar('PnL/train', train_pnl, epoch)
writer.add_scalar('PnL/val', val_pnl, epoch)
writer.add_scalar('WinRate/train', train_win_rate, epoch)
writer.add_scalar('WinRate/val', val_win_rate, epoch)
writer.add_scalar('PriceMAE/train', train_price_mae, epoch)
writer.add_scalar('PriceMAE/val', val_price_mae, epoch)
# Log action distribution
for i, action in enumerate(['SELL', 'HOLD', 'BUY']):
writer.add_scalar(f'Actions/train_{action}', train_actions[i], epoch)
writer.add_scalar(f'Actions/val_{action}', val_actions[i], epoch)
# Save best model based on validation metrics
if np.isscalar(val_pnl) and np.isscalar(best_val_pnl) and (val_pnl > best_val_pnl or (np.isclose(val_pnl, best_val_pnl) and val_acc > best_val_acc)):
best_val_pnl = val_pnl
best_val_acc = val_acc
best_win_rate = val_win_rate
best_price_mae = val_price_mae
model.save(f"models/{args.model_type}_best.pt")
logger.info("Saved new best model based on validation metrics")
# Log detailed metrics
logger.info(f"Epoch {epoch+1}/{args.epochs}")
logger.info("Training Metrics:")
logger.info(f" Action Loss: {train_action_loss:.4f}")
logger.info(f" Price Loss: {train_price_loss:.4f}")
logger.info(f" Accuracy: {train_acc:.2f}")
logger.info(f" PnL: {train_pnl:.2%}")
logger.info(f" Win Rate: {train_win_rate:.2%}")
logger.info(f" Price MAE: {train_price_mae:.2f}")
logger.info("Validation Metrics:")
logger.info(f" Action Loss: {val_action_loss:.4f}")
logger.info(f" Price Loss: {val_price_loss:.4f}")
logger.info(f" Accuracy: {val_acc:.2f}")
logger.info(f" PnL: {val_pnl:.2%}")
logger.info(f" Win Rate: {val_win_rate:.2%}")
logger.info(f" Price MAE: {val_price_mae:.2f}")
# Log action distribution
logger.info("Action Distribution:")
for i, action in enumerate(['SELL', 'HOLD', 'BUY']):
logger.info(f" {action}: Train={train_actions[i]}, Val={val_actions[i]}")
# Log trade statistics
logger.info("Trade Statistics:")
logger.info(f" Training trades: {len(train_trades)}")
logger.info(f" Validation trades: {len(val_trades)}")
# Log next candle predictions
if epoch % 10 == 0: # Every 10 epochs
logger.info("\nNext Candle Predictions:")
next_candles = model.predict_next_candles(X_val[-1:], n_candles=3)
for tf in args.timeframes:
if tf in next_candles:
logger.info(f"\n{tf} timeframe predictions:")
for i, pred in enumerate(next_candles[tf]):
action = ['SELL', 'HOLD', 'BUY'][np.argmax(pred)]
confidence = np.max(pred)
logger.info(f" Candle {i+1}: {action} (confidence: {confidence:.2f})")
except Exception as e:
logger.error(f"Error during epoch {epoch+1}: {str(e)}")
continue
# Save final model
model.save(f"models/{args.model_type}_final_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")
logger.info(f"\nTraining complete. Best validation metrics:")
logger.info(f"Accuracy: {best_val_acc:.2f}")
logger.info(f"PnL: {best_val_pnl:.2%}")
logger.info(f"Win Rate: {best_win_rate:.2%}")
logger.info(f"Price MAE: {best_price_mae:.2f}")
except Exception as e:
logger.error(f"Error in training: {str(e)}")
def predict(data_interface, model, args):
"""Make predictions using the trained model"""
logger.info("Starting prediction mode...")
try:
# Load the latest model
model_dir = os.path.join('models')
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
if not model_files:
logger.error(f"No saved model found for type {args.model_type}")
return
latest_model = sorted(model_files)[-1]
model_path = os.path.join(model_dir, latest_model)
logger.info(f"Loading model from {model_path}...")
model.load(model_path)
# Prepare prediction data
logger.info("Preparing prediction data...")
X_pred = data_interface.prepare_prediction_data()
# Make predictions
logger.info("Making predictions...")
predictions = model.predict(X_pred)
# Process and display predictions
logger.info("Processing predictions...")
data_interface.process_predictions(predictions)
except Exception as e:
logger.error(f"Error in prediction mode: {str(e)}")
def realtime(data_interface, model, args, chart=None, symbol=None):
"""Run real-time inference with the trained model"""
logger.info(f"Starting real-time inference mode for {symbol}...")
try:
from NN.utils.realtime_analyzer import RealtimeAnalyzer
# Load the latest model
model_dir = os.path.join('models')
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
if not model_files:
logger.error(f"No saved model found for type {args.model_type}")
return
latest_model = sorted(model_files)[-1]
model_path = os.path.join(model_dir, latest_model)
logger.info(f"Loading model from {model_path}...")
model.load(model_path)
# Initialize realtime analyzer
logger.info("Initializing real-time analyzer...")
realtime_analyzer = RealtimeAnalyzer(
data_interface=data_interface,
model=model,
symbol=args.symbol,
timeframes=args.timeframes
)
# Start real-time analysis
logger.info("Starting real-time analysis...")
realtime_analyzer.start()
# Initialize variables for tracking performance
total_pnl = 0.0
trades = []
current_position = 0.0
last_action = None
last_price = None
# Get the pair index for this symbol
pair_index = args.symbols.index(symbol)
# Only execute trades if this is the main pair (BTC/USDT)
is_main_pair = symbol == "BTC/USDT"
while True:
# Get current market data for all pairs
all_pairs_data = []
for s in args.symbols:
X, timestamp = data_interface.prepare_realtime_input(
timeframe=args.timeframes[0], # Use shortest timeframe
n_candles=args.window_size + 10, # Extra candles for safety
window_size=args.window_size
)
if X is not None:
all_pairs_data.append(X)
else:
logger.warning(f"No data available for {s}")
time.sleep(1)
continue
if not all_pairs_data:
logger.warning("No data available for any pair")
time.sleep(1)
continue
# Stack data from all pairs for model input
X_combined = np.concatenate(all_pairs_data, axis=2)
# Get model predictions
action_probs, price_pred = model.predict(X_combined)
# Get predictions for this specific pair
action = np.argmax(action_probs[pair_index]) # 0=SELL, 1=HOLD, 2=BUY
# Get current price for the main pair
current_price = data_interface.get_historical_data(
timeframe=args.timeframes[0],
n_candles=1
)['close'].iloc[-1]
# Calculate PnL if we have a position (only for main pair)
pnl = 0.0
if is_main_pair and last_action is not None and last_price is not None:
if last_action == 2: # BUY
pnl = (current_price - last_price) / last_price
elif last_action == 0: # SELL
pnl = (last_price - current_price) / last_price
# Update total PnL (only for main pair)
if is_main_pair and pnl != 0:
total_pnl += pnl
# Log the prediction
action_name = "SELL" if action == 0 else "HOLD" if action == 1 else "BUY"
log_msg = f"Time: {timestamp}, Symbol: {symbol}, Action: {action_name}, "
if is_main_pair:
log_msg += f"Price: {current_price:.2f}, PnL: {pnl:.2%}, Total PnL: {total_pnl:.2%}"
else:
log_msg += f"Price: {current_price:.2f} (Context Only)"
logger.info(log_msg)
# Update the chart if provided (only for main pair)
if chart is not None and is_main_pair and action != 1: # Skip HOLD actions
chart.add_trade(
action=action_name,
price=current_price,
timestamp=timestamp,
pnl=pnl
)
# Update tracking variables (only for main pair)
if is_main_pair and action != 1: # If not HOLD
last_action = action
last_price = current_price
# Sleep for a short time
time.sleep(1)
except KeyboardInterrupt:
if is_main_pair:
logger.info(f"Real-time inference stopped by user for {symbol}")
logger.info(f"Final performance for {symbol} - Total PnL: {total_pnl:.2%}")
else:
logger.info(f"Real-time inference stopped by user for {symbol} (Context Only)")
except Exception as e:
logger.error(f"Error in real-time inference for {symbol}: {str(e)}")
raise
if __name__ == "__main__":
main()

View File

@ -1,310 +0,0 @@
import logging
import time
import threading
from typing import Dict, Any, List, Optional, Callable, Tuple, Union
from .exchanges import ExchangeInterface, MEXCInterface, BinanceInterface
logger = logging.getLogger(__name__)
class TradingAgent:
"""Trading agent that executes trades based on neural network signals.
This agent interfaces with different exchanges and executes trades
based on the signals received from the neural network.
"""
def __init__(self,
exchange_name: str = 'binance',
api_key: str = None,
api_secret: str = None,
test_mode: bool = True,
trade_symbols: List[str] = None,
position_size: float = 0.1,
max_trades_per_day: int = 5,
trade_cooldown_minutes: int = 60):
"""Initialize the trading agent.
Args:
exchange_name: Name of the exchange to use ('binance', 'mexc')
api_key: API key for the exchange
api_secret: API secret for the exchange
test_mode: If True, use test/sandbox environment
trade_symbols: List of trading symbols to monitor (e.g., ['BTC/USDT'])
position_size: Size of each position as a fraction of total available balance (0.0-1.0)
max_trades_per_day: Maximum number of trades to execute per day
trade_cooldown_minutes: Minimum time between trades in minutes
"""
self.exchange_name = exchange_name.lower()
self.api_key = api_key
self.api_secret = api_secret
self.test_mode = test_mode
self.trade_symbols = trade_symbols or ['BTC/USDT']
self.position_size = min(max(position_size, 0.01), 1.0) # Ensure between 0.01 and 1.0
self.max_trades_per_day = max(1, max_trades_per_day)
self.trade_cooldown_seconds = max(60, trade_cooldown_minutes * 60)
# Initialize exchange interface
self.exchange = self._create_exchange()
# Trading state
self.active = False
self.current_positions = {} # Symbol -> quantity
self.trades_today = {} # Symbol -> count
self.last_trade_time = {} # Symbol -> timestamp
self.trade_history = [] # List of trade records
# Threading
self.trading_thread = None
self.stop_event = threading.Event()
# Signal callback
self.signal_callback = None
# Connect to exchange
if not self.exchange.connect():
logger.error(f"Failed to connect to {self.exchange_name} exchange. Trading agent disabled.")
else:
logger.info(f"Successfully connected to {self.exchange_name} exchange.")
self._load_current_positions()
def _create_exchange(self) -> ExchangeInterface:
"""Create an exchange interface based on the exchange name."""
if self.exchange_name == 'mexc':
return MEXCInterface(
api_key=self.api_key,
api_secret=self.api_secret,
test_mode=self.test_mode
)
elif self.exchange_name == 'binance':
return BinanceInterface(
api_key=self.api_key,
api_secret=self.api_secret,
test_mode=self.test_mode
)
else:
raise ValueError(f"Unsupported exchange: {self.exchange_name}")
def _load_current_positions(self):
"""Load current positions from the exchange."""
for symbol in self.trade_symbols:
try:
base_asset, quote_asset = symbol.split('/')
balance = self.exchange.get_balance(base_asset)
if balance > 0:
self.current_positions[symbol] = balance
logger.info(f"Loaded existing position for {symbol}: {balance} {base_asset}")
except Exception as e:
logger.error(f"Error loading position for {symbol}: {str(e)}")
def start(self, signal_callback: Callable = None):
"""Start the trading agent.
Args:
signal_callback: Optional callback function to receive trade signals
"""
if self.active:
logger.warning("Trading agent is already running.")
return
self.active = True
self.signal_callback = signal_callback
self.stop_event.clear()
logger.info(f"Starting trading agent for {self.exchange_name} exchange.")
logger.info(f"Trading symbols: {', '.join(self.trade_symbols)}")
logger.info(f"Position size: {self.position_size * 100:.1f}% of available balance")
logger.info(f"Max trades per day: {self.max_trades_per_day}")
logger.info(f"Trade cooldown: {self.trade_cooldown_seconds // 60} minutes")
# Reset trading state
self.trades_today = {symbol: 0 for symbol in self.trade_symbols}
self.last_trade_time = {symbol: 0 for symbol in self.trade_symbols}
# Start trading thread
self.trading_thread = threading.Thread(target=self._trading_loop)
self.trading_thread.daemon = True
self.trading_thread.start()
def stop(self):
"""Stop the trading agent."""
if not self.active:
logger.warning("Trading agent is not running.")
return
logger.info("Stopping trading agent...")
self.active = False
self.stop_event.set()
if self.trading_thread and self.trading_thread.is_alive():
self.trading_thread.join(timeout=10)
logger.info("Trading agent stopped.")
def _trading_loop(self):
"""Main trading loop that monitors positions and executes trades."""
logger.info("Trading loop started.")
try:
while self.active and not self.stop_event.is_set():
# Check positions and update state
for symbol in self.trade_symbols:
try:
base_asset, _ = symbol.split('/')
current_balance = self.exchange.get_balance(base_asset)
# Update position if it has changed
if symbol in self.current_positions:
prev_balance = self.current_positions[symbol]
if abs(current_balance - prev_balance) > 0.001 * prev_balance:
logger.info(f"Position updated for {symbol}: {prev_balance} -> {current_balance} {base_asset}")
self.current_positions[symbol] = current_balance
except Exception as e:
logger.error(f"Error checking position for {symbol}: {str(e)}")
# Sleep for a while
time.sleep(10)
except Exception as e:
logger.error(f"Error in trading loop: {str(e)}")
finally:
logger.info("Trading loop stopped.")
def reset_daily_limits(self):
"""Reset daily trading limits. Call this at the start of each trading day."""
self.trades_today = {symbol: 0 for symbol in self.trade_symbols}
logger.info("Daily trading limits reset.")
def process_signal(self, symbol: str, action: str,
confidence: float = None, timestamp: int = None) -> Optional[Dict[str, Any]]:
"""Process a trading signal and execute a trade if conditions are met.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
action: Trade action ('BUY', 'SELL', 'HOLD')
confidence: Confidence level of the signal (0.0-1.0)
timestamp: Timestamp of the signal (unix time)
Returns:
dict: Trade information if a trade was executed, None otherwise
"""
if not self.active:
logger.warning("Trading agent is not active. Signal ignored.")
return None
if symbol not in self.trade_symbols:
logger.warning(f"Symbol {symbol} is not in the trading symbols list. Signal ignored.")
return None
if action not in ['BUY', 'SELL', 'HOLD']:
logger.warning(f"Invalid action: {action}. Must be 'BUY', 'SELL', or 'HOLD'.")
return None
# Log the signal
confidence_str = f" (confidence: {confidence:.2f})" if confidence is not None else ""
logger.info(f"Received {action} signal for {symbol}{confidence_str}")
# Ignore HOLD signals for trading
if action == 'HOLD':
return None
# Check if we can trade based on limits
current_time = time.time()
# Check max trades per day
if self.trades_today.get(symbol, 0) >= self.max_trades_per_day:
logger.warning(f"Max trades per day reached for {symbol}. Signal ignored.")
return None
# Check trade cooldown
last_trade_time = self.last_trade_time.get(symbol, 0)
if current_time - last_trade_time < self.trade_cooldown_seconds:
cooldown_remaining = self.trade_cooldown_seconds - (current_time - last_trade_time)
logger.warning(f"Trade cooldown active for {symbol}. {cooldown_remaining:.1f} seconds remaining. Signal ignored.")
return None
# Check if the action makes sense based on current position
base_asset, _ = symbol.split('/')
current_position = self.current_positions.get(symbol, 0)
if action == 'BUY' and current_position > 0:
logger.warning(f"Already have a position in {symbol}. BUY signal ignored.")
return None
if action == 'SELL' and current_position <= 0:
logger.warning(f"No position in {symbol} to sell. SELL signal ignored.")
return None
# Execute the trade
try:
trade_result = self.exchange.execute_trade(
symbol=symbol,
action=action,
percent_of_balance=self.position_size
)
if trade_result:
# Update trading state
self.trades_today[symbol] = self.trades_today.get(symbol, 0) + 1
self.last_trade_time[symbol] = current_time
# Create trade record
trade_record = {
'symbol': symbol,
'action': action,
'timestamp': timestamp or int(current_time),
'confidence': confidence,
'order_id': trade_result.get('orderId') if isinstance(trade_result, dict) else None,
'status': 'executed'
}
# Add to trade history
self.trade_history.append(trade_record)
# Call signal callback if provided
if self.signal_callback:
self.signal_callback(trade_record)
logger.info(f"Successfully executed {action} trade for {symbol}")
return trade_record
else:
logger.error(f"Failed to execute {action} trade for {symbol}")
return None
except Exception as e:
logger.error(f"Error executing trade for {symbol}: {str(e)}")
return None
def get_current_positions(self) -> Dict[str, float]:
"""Get current positions.
Returns:
dict: Symbol -> position size
"""
return self.current_positions.copy()
def get_trade_history(self) -> List[Dict[str, Any]]:
"""Get trade history.
Returns:
list: List of trade records
"""
return self.trade_history.copy()
def get_exchange_info(self) -> Dict[str, Any]:
"""Get exchange information.
Returns:
dict: Exchange information
"""
return {
'name': self.exchange_name,
'test_mode': self.test_mode,
'active': self.active,
'trade_symbols': self.trade_symbols,
'position_size': self.position_size,
'max_trades_per_day': self.max_trades_per_day,
'trade_cooldown_seconds': self.trade_cooldown_seconds,
'trades_today': self.trades_today.copy()
}

View File

@ -1,585 +0,0 @@
import os
import sys
import time
import logging
import argparse
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import contextlib
from sklearn.model_selection import train_test_split
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Import our enhanced agent
from NN.models.dqn_agent_enhanced import EnhancedDQNAgent
from NN.utils.data_interface import DataInterface
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('logs/enhanced_training.log')
]
)
logger = logging.getLogger(__name__)
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description='Train enhanced RL trading agent')
parser.add_argument('--episodes', type=int, default=100, help='Number of episodes to train')
parser.add_argument('--max-steps', type=int, default=2000, help='Maximum steps per episode')
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading symbol')
parser.add_argument('--no-gpu', action='store_true', help='Disable GPU usage')
parser.add_argument('--confidence', type=float, default=0.4, help='Confidence threshold')
parser.add_argument('--load-model', type=str, default='', help='Load existing model')
parser.add_argument('--batch-size', type=int, default=128, help='Training batch size')
parser.add_argument('--learning-rate', type=float, default=0.0003, help='Learning rate')
parser.add_argument('--no-pretrain', action='store_true', help='Skip pre-training')
parser.add_argument('--pretrain-epochs', type=int, default=20, help='Number of pre-training epochs')
return parser.parse_args()
def generate_price_prediction_training_data(data_1m, data_1h, data_1d, window_size=20):
"""
Generate labeled training data for price prediction pre-training
Args:
data_1m: 1-minute candle data
data_1h: 1-hour candle data
data_1d: 1-day candle data
window_size: Size of the observation window
Returns:
X, y_immediate, y_midterm, y_longterm, y_values
"""
logger.info("Generating price prediction training data")
# Features to use
ohlcv_columns = ['open', 'high', 'low', 'close', 'volume']
# Create feature sets
X = []
y_immediate = [] # 1m prediction (next 5min)
y_midterm = [] # 1h prediction (next few hours)
y_longterm = [] # 1d prediction (next day)
y_values = [] # % change for each timeframe
# Need enough data for all timeframes
if len(data_1m) < window_size + 5 or len(data_1h) < 2 or len(data_1d) < 2:
logger.error("Not enough data for all timeframes")
return np.array([]), np.array([]), np.array([]), np.array([]), np.array([])
# Generate examples
for i in range(window_size, len(data_1m) - 5):
# Skip if we can't align with higher timeframes
if i % 60 != 0: # Only use minutes that align with hour boundaries
continue
try:
# Get window of 1m data as input
window_1m = data_1m[i-window_size:i][ohlcv_columns].values
# Find corresponding indices in higher timeframes
curr_timestamp = data_1m.index[i]
h_idx = data_1h.index.get_indexer([curr_timestamp], method='nearest')[0]
d_idx = data_1d.index.get_indexer([curr_timestamp], method='nearest')[0]
# Skip if indices are out of bounds
if h_idx < 0 or h_idx >= len(data_1h) - 1 or d_idx < 0 or d_idx >= len(data_1d) - 1:
continue
# Get future prices for label generation
future_5m = data_1m[i+5]['close']
future_1h = data_1h[h_idx+1]['close']
future_1d = data_1d[d_idx+1]['close']
current_price = data_1m[i]['close']
# Calculate % change for each timeframe
change_5m = (future_5m - current_price) / current_price * 100
change_1h = (future_1h - current_price) / current_price * 100
change_1d = (future_1d - current_price) / current_price * 100
# Determine price direction (0=down, 1=sideways, 2=up)
def get_direction(change):
if change < -0.5: # Down if less than -0.5%
return 0
elif change > 0.5: # Up if more than 0.5%
return 2
else: # Sideways if between -0.5% and 0.5%
return 1
direction_5m = get_direction(change_5m)
direction_1h = get_direction(change_1h)
direction_1d = get_direction(change_1d)
# Add to dataset
X.append(window_1m.flatten())
y_immediate.append(direction_5m)
y_midterm.append(direction_1h)
y_longterm.append(direction_1d)
y_values.append([change_5m, change_1h, change_1d, 0]) # Last value reserved
except Exception as e:
logger.warning(f"Error generating training example at index {i}: {str(e)}")
# Convert to numpy arrays
X = np.array(X)
y_immediate = np.array(y_immediate)
y_midterm = np.array(y_midterm)
y_longterm = np.array(y_longterm)
y_values = np.array(y_values)
logger.info(f"Generated {len(X)} training examples")
logger.info(f"Class distribution - Immediate: {np.bincount(y_immediate)}, "
f"Midterm: {np.bincount(y_midterm)}, Long-term: {np.bincount(y_longterm)}")
return X, y_immediate, y_midterm, y_longterm, y_values
def pretrain_price_prediction(agent, data_interface, n_epochs=20, batch_size=128, device=None):
"""
Pre-train the price prediction capabilities of the agent
Args:
agent: EnhancedDQNAgent instance
data_interface: DataInterface instance
n_epochs: Number of pre-training epochs
batch_size: Batch size for pre-training
device: Device to use for pre-training
Returns:
The pre-trained agent
"""
logger.info("Starting price prediction pre-training")
try:
# Ensure we have the necessary timeframes
timeframes_needed = ['1m', '1h', '1d']
for tf in timeframes_needed:
if tf not in data_interface.timeframes:
logger.info(f"Adding timeframe {tf} for pre-training")
# Add timeframe to the list if not present
if tf not in data_interface.timeframes:
data_interface.timeframes.append(tf)
data_interface.dataframes[tf] = None
# Get data for each timeframe
data_1m = data_interface.get_historical_data(timeframe='1m')
data_1h = data_interface.get_historical_data(timeframe='1h')
data_1d = data_interface.get_historical_data(timeframe='1d')
# Generate labeled training data
X, y_immediate, y_midterm, y_longterm, y_values = generate_price_prediction_training_data(
data_1m, data_1h, data_1d, window_size=20
)
if len(X) == 0:
logger.error("No training examples generated. Skipping pre-training.")
return agent
# Split data into training and validation sets
X_train, X_val, y_imm_train, y_imm_val, y_mid_train, y_mid_val, y_long_train, y_long_val, y_val_train, y_val_val = train_test_split(
X, y_immediate, y_midterm, y_longterm, y_values, test_size=0.2, random_state=42
)
# Convert to torch tensors
X_train_tensor = torch.FloatTensor(X_train).to(device)
y_imm_train_tensor = torch.LongTensor(y_imm_train).to(device)
y_mid_train_tensor = torch.LongTensor(y_mid_train).to(device)
y_long_train_tensor = torch.LongTensor(y_long_train).to(device)
y_val_train_tensor = torch.FloatTensor(y_val_train).to(device)
X_val_tensor = torch.FloatTensor(X_val).to(device)
y_imm_val_tensor = torch.LongTensor(y_imm_val).to(device)
y_mid_val_tensor = torch.LongTensor(y_mid_val).to(device)
y_long_val_tensor = torch.LongTensor(y_long_val).to(device)
y_val_val_tensor = torch.FloatTensor(y_val_val).to(device)
# Calculate class weights for imbalanced data
def get_class_weights(labels):
counts = np.bincount(labels)
if len(counts) < 3: # Ensure we have 3 classes
counts = np.append(counts, [0] * (3 - len(counts)))
weights = 1.0 / np.array(counts)
weights = weights / np.sum(weights) # Normalize
return weights
imm_weights = torch.FloatTensor(get_class_weights(y_imm_train)).to(device)
mid_weights = torch.FloatTensor(get_class_weights(y_mid_train)).to(device)
long_weights = torch.FloatTensor(get_class_weights(y_long_train)).to(device)
# Create DataLoader for batch training
train_dataset = TensorDataset(
X_train_tensor, y_imm_train_tensor, y_mid_train_tensor,
y_long_train_tensor, y_val_train_tensor
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Set up loss functions with class weights
imm_criterion = nn.CrossEntropyLoss(weight=imm_weights)
mid_criterion = nn.CrossEntropyLoss(weight=mid_weights)
long_criterion = nn.CrossEntropyLoss(weight=long_weights)
value_criterion = nn.MSELoss()
# Set up optimizer (separate from agent's optimizer)
pretrain_optimizer = torch.optim.Adam(agent.policy_net.parameters(), lr=0.0002)
pretrain_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
pretrain_optimizer, mode='min', factor=0.5, patience=3, verbose=True
)
# Set model to training mode
agent.policy_net.train()
# Training loop
best_val_loss = float('inf')
patience = 5
patience_counter = 0
# Create TensorBoard writer for pre-training
writer = SummaryWriter(log_dir=f'runs/pretrain_{int(time.time())}')
for epoch in range(n_epochs):
# Training phase
train_loss = 0.0
imm_correct, mid_correct, long_correct = 0, 0, 0
total = 0
for X_batch, y_imm_batch, y_mid_batch, y_long_batch, y_val_batch in train_loader:
# Zero gradients
pretrain_optimizer.zero_grad()
# Forward pass
with torch.cuda.amp.autocast() if agent.use_mixed_precision else contextlib.nullcontext():
q_values, _, price_preds, _ = agent.policy_net(X_batch)
# Calculate losses for each prediction head
imm_loss = imm_criterion(price_preds['immediate'], y_imm_batch)
mid_loss = mid_criterion(price_preds['midterm'], y_mid_batch)
long_loss = long_criterion(price_preds['longterm'], y_long_batch)
value_loss = value_criterion(price_preds['values'], y_val_batch)
# Combined loss (weighted by importance)
total_loss = imm_loss + 0.7 * mid_loss + 0.5 * long_loss + 0.3 * value_loss
# Backward pass and optimize
if agent.use_mixed_precision:
agent.scaler.scale(total_loss).backward()
agent.scaler.unscale_(pretrain_optimizer)
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
agent.scaler.step(pretrain_optimizer)
agent.scaler.update()
else:
total_loss.backward()
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
pretrain_optimizer.step()
# Accumulate metrics
train_loss += total_loss.item()
total += X_batch.size(0)
# Calculate accuracy
_, imm_pred = torch.max(price_preds['immediate'], 1)
_, mid_pred = torch.max(price_preds['midterm'], 1)
_, long_pred = torch.max(price_preds['longterm'], 1)
imm_correct += (imm_pred == y_imm_batch).sum().item()
mid_correct += (mid_pred == y_mid_batch).sum().item()
long_correct += (long_pred == y_long_batch).sum().item()
# Calculate epoch metrics
train_loss /= len(train_loader)
imm_acc = imm_correct / total
mid_acc = mid_correct / total
long_acc = long_correct / total
# Validation phase
agent.policy_net.eval()
val_loss = 0.0
imm_val_correct, mid_val_correct, long_val_correct = 0, 0, 0
with torch.no_grad():
# Forward pass on validation data
q_values, _, val_price_preds, _ = agent.policy_net(X_val_tensor)
# Calculate validation losses
val_imm_loss = imm_criterion(val_price_preds['immediate'], y_imm_val_tensor)
val_mid_loss = mid_criterion(val_price_preds['midterm'], y_mid_val_tensor)
val_long_loss = long_criterion(val_price_preds['longterm'], y_long_val_tensor)
val_value_loss = value_criterion(val_price_preds['values'], y_val_val_tensor)
val_total_loss = val_imm_loss + 0.7 * val_mid_loss + 0.5 * val_long_loss + 0.3 * val_value_loss
val_loss = val_total_loss.item()
# Calculate validation accuracy
_, imm_val_pred = torch.max(val_price_preds['immediate'], 1)
_, mid_val_pred = torch.max(val_price_preds['midterm'], 1)
_, long_val_pred = torch.max(val_price_preds['longterm'], 1)
imm_val_correct = (imm_val_pred == y_imm_val_tensor).sum().item()
mid_val_correct = (mid_val_pred == y_mid_val_tensor).sum().item()
long_val_correct = (long_val_pred == y_long_val_tensor).sum().item()
imm_val_acc = imm_val_correct / len(X_val_tensor)
mid_val_acc = mid_val_correct / len(X_val_tensor)
long_val_acc = long_val_correct / len(X_val_tensor)
# Log to TensorBoard
writer.add_scalar('pretrain/train_loss', train_loss, epoch)
writer.add_scalar('pretrain/val_loss', val_loss, epoch)
writer.add_scalar('pretrain/imm_acc', imm_acc, epoch)
writer.add_scalar('pretrain/mid_acc', mid_acc, epoch)
writer.add_scalar('pretrain/long_acc', long_acc, epoch)
writer.add_scalar('pretrain/imm_val_acc', imm_val_acc, epoch)
writer.add_scalar('pretrain/mid_val_acc', mid_val_acc, epoch)
writer.add_scalar('pretrain/long_val_acc', long_val_acc, epoch)
# Learning rate scheduling
pretrain_scheduler.step(val_loss)
# Early stopping check
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
# Copy policy_net weights to target_net
agent.target_net.load_state_dict(agent.policy_net.state_dict())
logger.info(f"Saved best model with validation loss: {val_loss:.4f}")
# Save pre-trained model
agent.save("NN/models/saved/enhanced_dqn_pretrained")
else:
patience_counter += 1
if patience_counter >= patience:
logger.info(f"Early stopping triggered after {epoch+1} epochs")
break
# Log progress
logger.info(f"Epoch {epoch+1}/{n_epochs}: "
f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
f"Imm Acc: {imm_acc:.4f}/{imm_val_acc:.4f}, "
f"Mid Acc: {mid_acc:.4f}/{mid_val_acc:.4f}, "
f"Long Acc: {long_acc:.4f}/{long_val_acc:.4f}")
# Set model back to training mode for next epoch
agent.policy_net.train()
writer.close()
logger.info("Price prediction pre-training complete")
return agent
except Exception as e:
logger.error(f"Error during price prediction pre-training: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return agent
def train_enhanced_rl(args):
"""
Train the enhanced RL agent for trading
Args:
args: Command line arguments
"""
# Setup device
if args.no_gpu:
device = torch.device('cpu')
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
# Set up data interface
data_interface = DataInterface(symbol=args.symbol, timeframes=['1m', '5m', '15m'])
# Fetch historical data for each timeframe
for timeframe in data_interface.timeframes:
df = data_interface.get_historical_data(timeframe=timeframe)
logger.info(f"Using data for {args.symbol} {timeframe} ({len(data_interface.dataframes[timeframe])} candles)")
# Create environment for training
from NN.environments.trading_env import TradingEnvironment
window_size = 20
train_env = TradingEnvironment(
data_interface=data_interface,
initial_balance=10000.0,
transaction_fee=0.0002,
window_size=window_size,
max_position=1.0,
reward_scaling=100.0
)
# Create agent with improved parameters
state_shape = train_env.observation_space.shape
n_actions = train_env.action_space.n
agent = EnhancedDQNAgent(
state_shape=state_shape,
n_actions=n_actions,
learning_rate=args.learning_rate,
gamma=0.95,
epsilon=1.0,
epsilon_min=0.05,
epsilon_decay=0.995,
buffer_size=50000,
batch_size=args.batch_size,
target_update=10,
confidence_threshold=args.confidence,
device=device
)
# Load existing model if specified
if args.load_model:
model_path = args.load_model
if agent.load(model_path):
logger.info(f"Loaded existing model from {model_path}")
else:
logger.error(f"Error loading model from {model_path}")
# Pre-training for price prediction
if not args.no_pretrain and not args.load_model:
logger.info("Starting pre-training phase")
agent = pretrain_price_prediction(
agent=agent,
data_interface=data_interface,
n_epochs=args.pretrain_epochs,
batch_size=args.batch_size,
device=device
)
logger.info("Pre-training completed")
# Setup TensorBoard
writer = SummaryWriter(log_dir=f'runs/enhanced_rl_{int(time.time())}')
# Log hardware info
writer.add_text("hardware/device", str(device), 0)
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
writer.add_text(f"hardware/gpu_{i}", torch.cuda.get_device_name(i), 0)
# Move agent to device
agent.move_models_to_device(device)
# Training loop
logger.info(f"Starting enhanced training for {args.episodes} episodes")
total_rewards = []
episode_losses = []
trade_win_rates = []
best_reward = -np.inf
try:
for episode in range(args.episodes):
# Reset environment for new episode
state = train_env.reset()
total_reward = 0.0
done = False
step = 0
episode_start_time = time.time()
# Track trade statistics
trades = []
wins = 0
losses = 0
# Run episode
while not done and step < args.max_steps:
# Choose action
action, confidence = agent.act(state)
# Take action in environment
next_state, reward, done, info = train_env.step(action)
# Remember experience
agent.remember(state, action, reward, next_state, done)
# Track trade results
if 'trade_result' in info and info['trade_result'] is not None:
trade_result = info['trade_result']
trade_pnl = trade_result['pnl']
trades.append(trade_pnl)
if trade_pnl > 0:
wins += 1
logger.info(f"Profitable trade! {trade_pnl:.2f}% profit, reward: {reward:.4f}")
else:
losses += 1
logger.info(f"Loss trade! {trade_pnl:.2f}% loss, penalty: {reward:.4f}")
# Update state and counters
state = next_state
total_reward += reward
step += 1
# Train agent
loss = agent.replay()
if loss > 0:
episode_losses.append(loss)
# Log training metrics for each episode
episode_time = time.time() - episode_start_time
total_rewards.append(total_reward)
# Calculate win rate
win_rate = wins / max(1, (wins + losses))
trade_win_rates.append(win_rate)
# Log to console and TensorBoard
logger.info(f"Episode {episode}/{args.episodes} - Reward: {total_reward:.4f}, Win Rate: {win_rate:.2f}, "
f"Trades: {len(trades)}, Balance: ${train_env.balance:.2f}, Epsilon: {agent.epsilon:.4f}, "
f"Time: {episode_time:.2f}s")
writer.add_scalar('metrics/reward', total_reward, episode)
writer.add_scalar('metrics/balance', train_env.balance, episode)
writer.add_scalar('metrics/win_rate', win_rate, episode)
writer.add_scalar('metrics/trades', len(trades), episode)
writer.add_scalar('metrics/epsilon', agent.epsilon, episode)
if episode_losses:
avg_loss = sum(episode_losses) / len(episode_losses)
writer.add_scalar('metrics/loss', avg_loss, episode)
# Check if this is the best model so far
if total_reward > best_reward:
best_reward = total_reward
# Save best model
agent.save(f"NN/models/saved/enhanced_dqn_best")
logger.info(f"New best model saved with reward: {best_reward:.4f}")
# Save checkpoint every 10 episodes
if episode % 10 == 0 and episode > 0:
agent.save(f"NN/models/saved/enhanced_dqn_checkpoint")
logger.info(f"Checkpoint saved at episode {episode}")
# Reset episode losses
episode_losses = []
# Final save
agent.save(f"NN/models/saved/enhanced_dqn_final")
logger.info("Enhanced training completed, final model saved")
except KeyboardInterrupt:
logger.info("Training interrupted by user")
except Exception as e:
logger.error(f"Training failed: {str(e)}")
import traceback
logger.error(traceback.format_exc())
finally:
# Close TensorBoard writer
writer.close()
return agent, train_env
if __name__ == "__main__":
# Create logs directory if it doesn't exist
os.makedirs("logs", exist_ok=True)
os.makedirs("NN/models/saved", exist_ok=True)
# Parse arguments
args = parse_args()
# Start training
train_enhanced_rl(args)

View File

@ -1,657 +0,0 @@
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import logging
import time
from datetime import datetime
import os
import sys
import pandas as pd
import gym
import json
import random
import torch.nn as nn
import contextlib
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from NN.utils.data_interface import DataInterface
from NN.utils.trading_env import TradingEnvironment
from NN.models.dqn_agent import DQNAgent
from NN.utils.signal_interpreter import SignalInterpreter
# Configure logging
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('rl_training.log'),
logging.StreamHandler()
]
)
# Set up device for PyTorch (use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Log GPU status
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
gpu_names = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
logger.info(f"Using GPU: {gpu_names}")
# Enable TensorFloat32 for NVIDIA Ampere GPUs for faster training
if hasattr(torch.cuda, 'amp') and torch.cuda.is_bf16_supported():
logger.info("BFloat16 precision is supported - will use for faster training")
else:
logger.warning("GPU not available. Using CPU for training (slower).")
class RLTradingEnvironment(gym.Env):
"""
Reinforcement Learning environment for trading with technical indicators
from multiple timeframes
"""
def __init__(self, features_1m, features_1h, features_1d, window_size=20, trading_fee=0.0025, min_trade_interval=15):
super().__init__()
# Initialize attributes before parent class
self.window_size = window_size
self.num_features = features_1m.shape[1] - 1 # Exclude close price
# Count available timeframes
self.num_timeframes = 3 # We require all timeframes now
self.feature_dim = self.num_features * self.num_timeframes
# Store features from different timeframes
self.features_1m = features_1m
self.features_1h = features_1h
self.features_1d = features_1d
# Trading parameters
self.initial_balance = 1.0
self.trading_fee = trading_fee # Increased from 0.001 to 0.0025 (0.25%)
self.min_trade_interval = min_trade_interval # Minimum steps between trades
# Define action and observation spaces
self.action_space = gym.spaces.Discrete(3) # 0: Buy, 1: Sell, 2: Hold
self.observation_space = gym.spaces.Box(
low=-np.inf,
high=np.inf,
shape=(self.window_size, self.feature_dim),
dtype=np.float32
)
# State variables
self.reset()
# Callback for visualization or external monitoring
self.action_callback = None
def reset(self):
"""Reset the environment to initial state"""
self.balance = self.initial_balance
self.position = 0.0 # Amount of asset held
self.current_step = self.window_size
self.trades = 0
self.wins = 0
self.losses = 0
self.trade_history = []
self.last_trade_step = -self.min_trade_interval # Initialize to allow immediate first trade
# Get initial observation
observation = self._get_observation()
return observation
def _get_observation(self):
"""
Get the current state observation.
Combine features from multiple timeframes, reshaped for the CNN.
"""
# Calculate indices for each timeframe
idx_1m = min(self.current_step, self.features_1m.shape[0] - 1)
idx_1h = idx_1m // 60 # 60 minutes in an hour
idx_1d = idx_1h // 24 # 24 hours in a day
# Cap indices to prevent out of bounds
idx_1h = min(idx_1h, self.features_1h.shape[0] - 1)
idx_1d = min(idx_1d, self.features_1d.shape[0] - 1)
# Extract feature windows from each timeframe
window_1m = self.features_1m[max(0, idx_1m - self.window_size):idx_1m]
# Handle hourly timeframe
start_1h = max(0, idx_1h - self.window_size)
window_1h = self.features_1h[start_1h:idx_1h]
# Handle daily timeframe
start_1d = max(0, idx_1d - self.window_size)
window_1d = self.features_1d[start_1d:idx_1d]
# Pad if needed (for higher timeframes)
if len(window_1m) < self.window_size:
padding = np.zeros((self.window_size - len(window_1m), window_1m.shape[1]))
window_1m = np.vstack([padding, window_1m])
if len(window_1h) < self.window_size:
padding = np.zeros((self.window_size - len(window_1h), window_1h.shape[1]))
window_1h = np.vstack([padding, window_1h])
if len(window_1d) < self.window_size:
padding = np.zeros((self.window_size - len(window_1d), window_1d.shape[1]))
window_1d = np.vstack([padding, window_1d])
# Combine features from all timeframes
combined_features = np.hstack([
window_1m.reshape(self.window_size, -1),
window_1h.reshape(self.window_size, -1),
window_1d.reshape(self.window_size, -1)
])
# Convert to float32 and handle any NaN values
combined_features = np.nan_to_num(combined_features, nan=0.0).astype(np.float32)
return combined_features
def step(self, action):
"""Take an action and return the next state, reward, done flag, and info"""
# Initialize info dictionary for additional data
info = {
'trade_executed': False,
'price_change': 0.0,
'position_change': 0,
'current_price': 0.0,
'next_price': 0.0,
'balance_change': 0.0,
'reward_components': {},
'future_prices': {}
}
# Get the current and next price
current_price = self.features_1m[self.current_step, -1]
# Handle edge case at the end of the data
if self.current_step >= len(self.features_1m) - 1:
next_price = current_price # Use current price as next price
done = True
else:
next_price = self.features_1m[self.current_step + 1, -1]
done = False
# Handle zero or negative price (data error)
if current_price <= 0:
current_price = 0.01 # Set to a small positive number
logger.warning(f"Zero or negative price detected at step {self.current_step}. Setting to 0.01.")
if next_price <= 0:
next_price = current_price # Use current price instead
logger.warning(f"Zero or negative next price detected at step {self.current_step + 1}. Using current price.")
# Calculate price change as percentage
price_change_pct = ((next_price - current_price) / current_price) * 100
# Store prices in info
info['current_price'] = current_price
info['next_price'] = next_price
info['price_change'] = price_change_pct
# Initialize reward components dictionary
reward_components = {
'holding_reward': 0.0,
'action_reward': 0.0,
'profit_reward': 0.0,
'trade_freq_penalty': 0.0
}
# Default small negative reward to discourage inaction
reward = -0.01
reward_components['holding_reward'] = -0.01
# Track previous balance for changes
previous_balance = self.balance
# Execute action (0: Buy, 1: Sell, 2: Hold)
if action == 0: # Buy
if self.position == 0: # Only buy if we don't already have a position
# Calculate how much of the asset we can buy with 100% of balance
self.position = self.balance / current_price
self.balance = 0 # All balance used
# If price goes up after buying, that's good
expected_profit = price_change_pct
# Scale reward based on expected profit
if expected_profit > 0:
# Positive reward for profitable buy decision
action_reward = 0.1 + (expected_profit * 0.05) # Base reward + profit-based bonus
reward_components['action_reward'] = action_reward
reward += action_reward
else:
# Small negative reward for unprofitable buy
action_reward = -0.1 + (expected_profit * 0.03) # Smaller penalty for small losses
reward_components['action_reward'] = action_reward
reward += action_reward
# Check if we've traded too frequently
if len(self.trade_history) > 0:
last_trade_step = self.trade_history[-1]['step']
if self.current_step - last_trade_step < 5: # If less than 5 steps since last trade
freq_penalty = -0.2 # Penalty for trading too frequently
reward += freq_penalty
reward_components['trade_freq_penalty'] = freq_penalty
# Record the trade
self.trade_history.append({
'step': self.current_step,
'action': 'buy',
'price': current_price,
'position': self.position,
'balance': self.balance
})
info['trade_executed'] = True
logger.info(f"Buy at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
elif action == 1: # Sell
if self.position > 0: # Only sell if we have a position
# Calculate sale proceeds
sale_value = self.position * current_price
# Calculate profit or loss percentage from last buy
last_buy_price = None
for trade in reversed(self.trade_history):
if trade['action'] == 'buy':
last_buy_price = trade['price']
break
# If we found the last buy price, calculate profit
if last_buy_price is not None:
profit_pct = ((current_price - last_buy_price) / last_buy_price) * 100
# Highly reward profitable trades
if profit_pct > 0:
# Progressive reward based on profit percentage
profit_reward = min(5.0, profit_pct * 0.2) # Cap at 5.0 to prevent exploitation
reward_components['profit_reward'] = profit_reward
reward += profit_reward
logger.info(f"Profitable trade! {profit_pct:.2f}% profit, reward: {profit_reward:.4f}")
else:
# Penalize losses more heavily based on size of loss
loss_penalty = max(-3.0, profit_pct * 0.15) # Cap at -3.0 to prevent excessive punishment
reward_components['profit_reward'] = loss_penalty
reward += loss_penalty
logger.info(f"Loss trade! {profit_pct:.2f}% loss, penalty: {loss_penalty:.4f}")
# If price goes down after selling, that's good
if price_change_pct < 0:
# Reward for good timing on sell (avoiding future loss)
timing_reward = min(1.0, abs(price_change_pct) * 0.05)
reward_components['action_reward'] = timing_reward
reward += timing_reward
# Check for trading too frequently
if len(self.trade_history) > 0:
last_trade_step = self.trade_history[-1]['step']
if self.current_step - last_trade_step < 5: # If less than 5 steps since last trade
freq_penalty = -0.2 # Penalty for trading too frequently
reward += freq_penalty
reward_components['trade_freq_penalty'] = freq_penalty
# Update balance and position
self.balance = sale_value
position_change = self.position
self.position = 0
# Record the trade
self.trade_history.append({
'step': self.current_step,
'action': 'sell',
'price': current_price,
'position': self.position,
'balance': self.balance
})
info['trade_executed'] = True
info['position_change'] = position_change
logger.info(f"Sell at step {self.current_step}, price: {current_price:.4f}, new balance: {self.balance:.4f}")
elif action == 2: # Hold
# Small reward if holding was a good decision
if self.position > 0 and price_change_pct > 0: # Holding long position during price increase
hold_reward = price_change_pct * 0.01 # Small reward proportional to price increase
reward += hold_reward
reward_components['holding_reward'] = hold_reward
elif self.position == 0 and price_change_pct < 0: # Holding cash during price decrease
hold_reward = abs(price_change_pct) * 0.01 # Small reward for avoiding loss
reward += hold_reward
reward_components['holding_reward'] = hold_reward
# Move to the next step
self.current_step += 1
# Update current portfolio value
if self.position > 0:
self.current_value = self.balance + (self.position * next_price)
else:
self.current_value = self.balance
# Calculate balance change
balance_change = self.current_value - previous_balance
info['balance_change'] = balance_change
# Check if we've reached the end of the data
if self.current_step >= len(self.features_1m) - 1:
done = True
# Final evaluation if we have a position
if self.position > 0:
# Sell remaining position at the final price
final_balance = self.balance + (self.position * next_price)
# Calculate final portfolio value and return
final_return_pct = ((final_balance - self.initial_balance) / self.initial_balance) * 100
# Add big reward/penalty based on overall performance
performance_reward = final_return_pct * 0.1
reward += performance_reward
reward_components['final_performance'] = performance_reward
logger.info(f"Episode ended. Final balance: {final_balance:.4f}, Return: {final_return_pct:.2f}%")
# Get future prices for evaluation (1-hour and 1-day ahead)
info['future_prices'] = {}
# 1-hour future price if hourly data is available
if hasattr(self, 'features_1h') and self.features_1h is not None:
# Find the closest hourly data point
if self.current_step < len(self.features_1m):
current_time = self.current_step # Use as index for simplicity
hourly_idx = min(current_time // 60, len(self.features_1h) - 1) # Assuming 60 minutes per hour
if hourly_idx < len(self.features_1h) - 1:
future_1h_price = self.features_1h[hourly_idx + 1, -1]
info['future_prices']['1h'] = future_1h_price
# 1-day future price if daily data is available
if hasattr(self, 'features_1d') and self.features_1d is not None:
# Find the closest daily data point
if self.current_step < len(self.features_1m):
current_time = self.current_step # Use as index for simplicity
daily_idx = min(current_time // 1440, len(self.features_1d) - 1) # Assuming 1440 minutes per day
if daily_idx < len(self.features_1d) - 1:
future_1d_price = self.features_1d[daily_idx + 1, -1]
info['future_prices']['1d'] = future_1d_price
# Get next observation
next_state = self._get_observation()
# Store reward components in info
info['reward_components'] = reward_components
# Clip reward to prevent extreme values
reward = np.clip(reward, -10.0, 10.0)
return next_state, reward, done, info
def set_action_callback(self, callback):
"""
Set a callback function to be called after each action
Args:
callback: Function with signature (action, price, reward, info)
"""
self.action_callback = callback
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
action_callback=None, episode_callback=None, symbol="BTC/USDT",
pretrain_price_prediction_enabled=False, pretrain_epochs=10):
"""
Train a reinforcement learning agent for trading using ONLY real market data
Args:
env_class: Optional environment class override
num_episodes: Number of episodes to train for
max_steps: Maximum steps per episode
save_path: Path to save the trained model
action_callback: Callback function for monitoring actions
episode_callback: Callback function for monitoring episodes
symbol: Trading symbol to use
pretrain_price_prediction_enabled: DEPRECATED - No longer supported (synthetic data not used)
pretrain_epochs: DEPRECATED - No longer supported (synthetic data not used)
Returns:
tuple: (trained agent, environment)
"""
# Load data for the selected symbol
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m', '1h', '1d'])
try:
# Try to load data for the requested symbol using get_historical_data method
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000)
data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500)
if data_1m is None or data_5m is None or data_15m is None or data_1h is None or data_1d is None:
raise FileNotFoundError("Could not retrieve all required timeframes data for specified symbol")
except Exception as e:
logger.warning(f"Data for {symbol} not available: {str(e)}. Using default cached data.")
# Try to use cached data if available
symbol = "BTC/USDT"
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m', '1h', '1d'])
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000)
data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500)
if data_1m is None or data_5m is None or data_15m is None or data_1h is None or data_1d is None:
logger.error("Failed to retrieve all required timeframes data. Cannot continue training.")
raise ValueError("No data available for training")
# Create features from the data by adding technical indicators and converting to numpy format
if data_1m is not None:
data_1m = data_interface.add_technical_indicators(data_1m)
# Convert to numpy array with close price as the last column
features_1m = np.hstack([
data_1m.drop(['timestamp', 'close'], axis=1).values,
data_1m['close'].values.reshape(-1, 1)
])
else:
features_1m = None
if data_5m is not None:
data_5m = data_interface.add_technical_indicators(data_5m)
# Convert to numpy array with close price as the last column
features_5m = np.hstack([
data_5m.drop(['timestamp', 'close'], axis=1).values,
data_5m['close'].values.reshape(-1, 1)
])
else:
features_5m = None
if data_15m is not None:
data_15m = data_interface.add_technical_indicators(data_15m)
# Convert to numpy array with close price as the last column
features_15m = np.hstack([
data_15m.drop(['timestamp', 'close'], axis=1).values,
data_15m['close'].values.reshape(-1, 1)
])
else:
features_15m = None
if data_1h is not None:
data_1h = data_interface.add_technical_indicators(data_1h)
# Convert to numpy array with close price as the last column
features_1h = np.hstack([
data_1h.drop(['timestamp', 'close'], axis=1).values,
data_1h['close'].values.reshape(-1, 1)
])
else:
features_1h = None
if data_1d is not None:
data_1d = data_interface.add_technical_indicators(data_1d)
# Convert to numpy array with close price as the last column
features_1d = np.hstack([
data_1d.drop(['timestamp', 'close'], axis=1).values,
data_1d['close'].values.reshape(-1, 1)
])
else:
features_1d = None
# Check if we have all the required features
if features_1m is None or features_5m is None or features_15m is None or features_1h is None or features_1d is None:
logger.error("Failed to create features for all timeframes.")
raise ValueError("Could not create features for training")
# Create the environment
if env_class:
# Use provided environment class
env = env_class(features_1m, features_1h, features_1d)
else:
# Use the default environment
env = RLTradingEnvironment(features_1m, features_1h, features_1d)
# Set action callback if provided
if action_callback:
env.set_action_callback(action_callback)
# Get environment properties for agent creation
input_shape = env.observation_space.shape
n_actions = env.action_space.n
# Create the agent
agent = DQNAgent(
state_shape=input_shape,
n_actions=n_actions,
epsilon=1.0,
epsilon_decay=0.995,
epsilon_min=0.01,
learning_rate=0.0001,
gamma=0.99,
buffer_size=10000,
batch_size=64,
device=device # Pass device to agent for GPU usage
)
# Check if model file exists and load it
model_file = f"{save_path}_model.pth"
if os.path.exists(model_file):
try:
agent.load(model_file)
logger.info(f"Loaded existing model from {model_file}")
except Exception as e:
logger.error(f"Error loading model: {e}")
else:
logger.info("No existing model found. Starting with a new model.")
# Remove pre-training code since it used synthetic data
# Pre-training with real data would require a separate implementation
if pretrain_price_prediction_enabled:
logger.warning("Pre-training with synthetic data is no longer supported. Continuing with RL training only.")
# Create TensorBoard writer
writer = SummaryWriter(log_dir=f'runs/dqn_{int(time.time())}')
# Log GPU status to TensorBoard
writer.add_text("hardware/device", str(device), 0)
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
writer.add_text(f"hardware/gpu_{i}", torch.cuda.get_device_name(i), 0)
# Training loop
total_rewards = []
trade_win_rates = []
best_reward = -np.inf
# Move models to the appropriate device if not already there
agent.move_models_to_device(device)
# Enable mixed precision if GPU and feature is available
use_mixed_precision = False
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp'):
logger.info("Enabling mixed precision training")
use_mixed_precision = True
scaler = torch.cuda.amp.GradScaler()
# Define step callback for tensorboard logging and model tracking
def step_callback(action, price, reward, info):
# Pass to external callback if provided
if action_callback:
action_callback(env.current_step, action, price, reward, info)
# Main training loop
logger.info(f"Starting training for {num_episodes} episodes...")
logger.info(f"Starting training on device: {agent.device}")
try:
for episode in range(num_episodes):
state = env.reset()
total_reward = 0
for step in range(max_steps):
# Select action
action = agent.act(state)
# Take action and observe next state and reward
next_state, reward, done, info = env.step(action)
# Store the experience in memory
agent.remember(state, action, reward, next_state, done)
# Update state and reward
state = next_state
total_reward += reward
# Train the agent by sampling from memory
if len(agent.memory) >= agent.batch_size:
loss = agent.replay()
if done or step == max_steps - 1:
break
# Track rewards
total_rewards.append(total_reward)
# Calculate trading metrics
win_rate = env.wins / max(1, env.trades)
trades = env.trades
# Log to TensorBoard
writer.add_scalar('Reward/Episode', total_reward, episode)
writer.add_scalar('Trade/WinRate', win_rate, episode)
writer.add_scalar('Trade/Count', trades, episode)
# Save best model
if total_reward > best_reward and episode > 10:
logger.info(f"New best average reward: {total_reward:.4f}, saving model")
agent.save(save_path)
best_reward = total_reward
# Periodic save every 100 episodes
if episode % 100 == 0 and episode > 0:
agent.save(f"{save_path}_episode_{episode}")
# Call episode callback if provided
if episode_callback:
# Add environment to info dict to use for extrema training
info_with_env = info.copy()
info_with_env['env'] = env
episode_callback(episode, total_reward, info_with_env)
# Final save
logger.info("Training completed, saving final model")
agent.save(f"{save_path}_final")
except Exception as e:
logger.error(f"Training failed: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Close TensorBoard writer
writer.close()
return agent, env
if __name__ == "__main__":
train_rl()