massive clenup
This commit is contained in:
261
NN/example.py
261
NN/example.py
@ -1,261 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Example script for the Neural Network Trading System
|
||||
This shows basic usage patterns for the system components
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
import matplotlib.pyplot as plt
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Import components
|
||||
from NN.utils.data_interface import DataInterface
|
||||
from NN.models.cnn_model import CNNModel
|
||||
from NN.models.transformer_model import TransformerModel, MixtureOfExpertsModel
|
||||
from NN.main import NeuralNetworkOrchestrator
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
logger = logging.getLogger('example')
|
||||
|
||||
def example_data_interface():
|
||||
"""Show how to use the data interface"""
|
||||
logger.info("=== Data Interface Example ===")
|
||||
|
||||
# Initialize data interface
|
||||
di = DataInterface(symbol="BTC/USDT", timeframes=['1h', '4h', '1d'])
|
||||
|
||||
# Get historical data
|
||||
df_1h = di.get_historical_data(timeframe='1h', n_candles=100)
|
||||
if df_1h is not None and not df_1h.empty:
|
||||
logger.info(f"Retrieved {len(df_1h)} 1-hour candles")
|
||||
logger.info(f"Most recent candle: {df_1h.iloc[-1]}")
|
||||
|
||||
# Prepare data for neural network
|
||||
X, y, timestamps = di.prepare_nn_input(timeframes=['1h'], n_candles=500, window_size=20)
|
||||
if X is not None and y is not None:
|
||||
logger.info(f"Prepared input shape: {X.shape}, target shape: {y.shape}")
|
||||
|
||||
# Generate a dataset
|
||||
dataset = di.generate_training_dataset(
|
||||
timeframes=['1h', '4h'],
|
||||
n_candles=1000,
|
||||
window_size=20
|
||||
)
|
||||
if dataset:
|
||||
logger.info(f"Dataset generated and saved to: {list(dataset.values())}")
|
||||
|
||||
return X, y, timestamps if X is not None else (None, None, None)
|
||||
|
||||
def example_cnn_model(X=None, y=None):
|
||||
"""Show how to use the CNN model"""
|
||||
logger.info("=== CNN Model Example ===")
|
||||
|
||||
# If no data provided, create dummy data
|
||||
if X is None or y is None:
|
||||
logger.info("Creating dummy data for CNN example")
|
||||
X = np.random.random((1000, 20, 5)) # 1000 samples, 20 time steps, 5 features
|
||||
y = np.random.randint(0, 2, size=(1000,)) # Binary labels
|
||||
|
||||
# Split data into training and testing sets
|
||||
from sklearn.model_selection import train_test_split
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
|
||||
# Initialize and build the CNN model
|
||||
cnn = CNNModel(input_shape=(20, 5), output_size=1, model_dir='NN/models/saved')
|
||||
cnn.build_model(filters=(32, 64, 128), kernel_sizes=(3, 5, 7), dropout_rate=0.3)
|
||||
|
||||
# Train the model (very small number of epochs for this example)
|
||||
history = cnn.train(
|
||||
X_train, y_train,
|
||||
batch_size=32,
|
||||
epochs=5, # Just a few epochs for the example
|
||||
validation_split=0.2
|
||||
)
|
||||
|
||||
# Evaluate the model
|
||||
metrics = cnn.evaluate(X_test, y_test, plot_results=True)
|
||||
if metrics:
|
||||
logger.info(f"CNN Evaluation metrics: {metrics}")
|
||||
|
||||
# Make a prediction
|
||||
y_pred, y_proba = cnn.predict(X_test[:1])
|
||||
logger.info(f"CNN Prediction: {y_pred[0]}, Probability: {y_proba[0]:.4f}")
|
||||
|
||||
return cnn
|
||||
|
||||
def example_transformer_model(X=None, y=None, cnn_model=None):
|
||||
"""Show how to use the Transformer model"""
|
||||
logger.info("=== Transformer Model Example ===")
|
||||
|
||||
# If no data provided, create dummy data
|
||||
if X is None or y is None:
|
||||
logger.info("Creating dummy data for Transformer example")
|
||||
X = np.random.random((1000, 20, 5)) # 1000 samples, 20 time steps, 5 features
|
||||
y = np.random.randint(0, 2, size=(1000,)) # Binary labels
|
||||
|
||||
# Generate high-level features (from CNN model or random if no CNN provided)
|
||||
if cnn_model is not None and hasattr(cnn_model, 'extract_hidden_features'):
|
||||
# Extract features from CNN model
|
||||
X_features = cnn_model.extract_hidden_features(X)
|
||||
logger.info(f"Extracted {X_features.shape[1]} features from CNN model")
|
||||
else:
|
||||
# Generate random features
|
||||
X_features = np.random.random((len(X), 128))
|
||||
logger.info("Generated random features for Transformer model")
|
||||
|
||||
# Split data into training and testing sets
|
||||
from sklearn.model_selection import train_test_split
|
||||
X_train, X_test, X_feat_train, X_feat_test, y_train, y_test = train_test_split(
|
||||
X, X_features, y, test_size=0.2, random_state=42
|
||||
)
|
||||
|
||||
# Initialize and build the Transformer model
|
||||
transformer = TransformerModel(
|
||||
ts_input_shape=(20, 5),
|
||||
feature_input_shape=X_features.shape[1],
|
||||
output_size=1,
|
||||
model_dir='NN/models/saved'
|
||||
)
|
||||
transformer.build_model(
|
||||
embed_dim=32,
|
||||
num_heads=2,
|
||||
ff_dim=64,
|
||||
num_transformer_blocks=2,
|
||||
dropout_rate=0.2
|
||||
)
|
||||
|
||||
# Train the model (very small number of epochs for this example)
|
||||
history = transformer.train(
|
||||
X_train, X_feat_train, y_train,
|
||||
batch_size=32,
|
||||
epochs=5, # Just a few epochs for the example
|
||||
validation_split=0.2
|
||||
)
|
||||
|
||||
# Make a prediction
|
||||
y_pred, y_proba = transformer.predict(X_test[:1], X_feat_test[:1])
|
||||
logger.info(f"Transformer Prediction: {y_pred[0]}, Probability: {y_proba[0]:.4f}")
|
||||
|
||||
return transformer
|
||||
|
||||
def example_moe_model(X=None, y=None, cnn_model=None, transformer_model=None):
|
||||
"""Show how to use the Mixture of Experts model"""
|
||||
logger.info("=== Mixture of Experts Example ===")
|
||||
|
||||
# If no data provided, create dummy data
|
||||
if X is None or y is None:
|
||||
logger.info("Creating dummy data for MoE example")
|
||||
X = np.random.random((1000, 20, 5)) # 1000 samples, 20 time steps, 5 features
|
||||
y = np.random.randint(0, 2, size=(1000,)) # Binary labels
|
||||
|
||||
# If models not provided, create them
|
||||
if cnn_model is None:
|
||||
logger.info("Creating a new CNN model for MoE")
|
||||
cnn_model = CNNModel(input_shape=(20, 5), output_size=1)
|
||||
cnn_model.build_model()
|
||||
|
||||
if transformer_model is None:
|
||||
logger.info("Creating a new Transformer model for MoE")
|
||||
transformer_model = TransformerModel(ts_input_shape=(20, 5), feature_input_shape=128, output_size=1)
|
||||
transformer_model.build_model()
|
||||
|
||||
# Initialize MoE model
|
||||
moe = MixtureOfExpertsModel(output_size=1, model_dir='NN/models/saved')
|
||||
|
||||
# Add expert models
|
||||
moe.add_expert('cnn', cnn_model)
|
||||
moe.add_expert('transformer', transformer_model)
|
||||
|
||||
# Build the MoE model (this is a simplified implementation - in a real scenario
|
||||
# you would need to handle the interfaces between models more carefully)
|
||||
moe.build_model(
|
||||
ts_input_shape=(20, 5),
|
||||
expert_weights={'cnn': 0.7, 'transformer': 0.3}
|
||||
)
|
||||
|
||||
# In a real implementation, you would train the MoE model here
|
||||
logger.info("MoE model built - in a real implementation, you would train it here")
|
||||
|
||||
return moe
|
||||
|
||||
def example_orchestrator():
|
||||
"""Show how to use the Orchestrator"""
|
||||
logger.info("=== Orchestrator Example ===")
|
||||
|
||||
# Configure the orchestrator
|
||||
config = {
|
||||
'symbol': 'BTC/USDT',
|
||||
'timeframes': ['1h', '4h'],
|
||||
'window_size': 20,
|
||||
'n_features': 5,
|
||||
'output_size': 3, # BUY/HOLD/SELL
|
||||
'batch_size': 32,
|
||||
'epochs': 5, # Small number for example
|
||||
'model_dir': 'NN/models/saved',
|
||||
'data_dir': 'NN/data'
|
||||
}
|
||||
|
||||
# Initialize the orchestrator
|
||||
orchestrator = NeuralNetworkOrchestrator(config)
|
||||
|
||||
# Prepare training data
|
||||
X, y, timestamps = orchestrator.prepare_training_data(
|
||||
timeframes=['1h'],
|
||||
n_candles=200
|
||||
)
|
||||
|
||||
if X is not None and y is not None:
|
||||
logger.info(f"Prepared training data: X shape {X.shape}, y shape {y.shape}")
|
||||
|
||||
# Train CNN model
|
||||
logger.info("Training CNN model with orchestrator...")
|
||||
history = orchestrator.train_cnn_model(X, y, epochs=2) # Very small for example
|
||||
|
||||
# Make a prediction
|
||||
result = orchestrator.run_inference_pipeline(
|
||||
model_type='cnn',
|
||||
timeframe='1h'
|
||||
)
|
||||
|
||||
if result:
|
||||
logger.info(f"Inference result: {result}")
|
||||
else:
|
||||
logger.warning("Could not prepare training data - this is expected if no real data is available")
|
||||
logger.info("The orchestrator would normally handle training and inference")
|
||||
|
||||
def main():
|
||||
"""Run all examples"""
|
||||
logger.info("Starting Neural Network Trading System Examples")
|
||||
|
||||
# Example 1: Data Interface
|
||||
X, y, timestamps = example_data_interface()
|
||||
|
||||
# Example 2: CNN Model
|
||||
cnn_model = example_cnn_model(X, y)
|
||||
|
||||
# Example 3: Transformer Model
|
||||
transformer_model = example_transformer_model(X, y, cnn_model)
|
||||
|
||||
# Example 4: Mixture of Experts
|
||||
moe_model = example_moe_model(X, y, cnn_model, transformer_model)
|
||||
|
||||
# Example 5: Orchestrator
|
||||
example_orchestrator()
|
||||
|
||||
logger.info("Examples completed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
244
NN/main.py
244
NN/main.py
@ -1,244 +0,0 @@
|
||||
"""
|
||||
Neural Network Trading System Main Module (Compatibility Layer)
|
||||
|
||||
This module serves as a compatibility layer for the realtime.py module.
|
||||
It re-exports the functionality from realtime_main.py that is needed by realtime.py.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger('NN')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Re-export everything from realtime_main.py
|
||||
from .realtime_main import (
|
||||
parse_arguments,
|
||||
realtime,
|
||||
train,
|
||||
predict
|
||||
)
|
||||
|
||||
# Create a class that realtime.py expects
|
||||
class NeuralNetworkOrchestrator:
|
||||
"""
|
||||
Orchestrates the neural network operations.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Initialize the orchestrator with configuration.
|
||||
|
||||
Args:
|
||||
config (dict): Configuration parameters
|
||||
"""
|
||||
self.config = config
|
||||
self.symbol = config.get('symbol', 'BTC/USDT')
|
||||
self.timeframes = config.get('timeframes', ['1m', '5m', '1h', '4h'])
|
||||
self.window_size = config.get('window_size', 20)
|
||||
self.n_features = config.get('n_features', 5)
|
||||
self.output_size = config.get('output_size', 3)
|
||||
self.model_dir = config.get('model_dir', 'NN/models/saved')
|
||||
self.data_dir = config.get('data_dir', 'NN/data')
|
||||
self.model = None
|
||||
self.data_interface = None
|
||||
|
||||
# Initialize with default values in case imports fail
|
||||
self.model_initialized = False
|
||||
self.data_initialized = False
|
||||
|
||||
# Import necessary modules dynamically
|
||||
try:
|
||||
from .utils.data_interface import DataInterface
|
||||
|
||||
# Initialize data interface
|
||||
self.data_interface = DataInterface(
|
||||
symbol=self.symbol,
|
||||
timeframes=self.timeframes
|
||||
)
|
||||
self.data_initialized = True
|
||||
logger.info(f"Data interface initialized for {self.symbol}")
|
||||
|
||||
try:
|
||||
from .models.cnn_model_pytorch import CNNModelPyTorch as Model
|
||||
|
||||
# Initialize model
|
||||
feature_count = self.data_interface.get_feature_count() if hasattr(self.data_interface, 'get_feature_count') else 5
|
||||
try:
|
||||
# First try with expected parameters
|
||||
self.model = Model(
|
||||
window_size=self.window_size,
|
||||
num_features=feature_count,
|
||||
output_size=self.output_size,
|
||||
timeframes=self.timeframes
|
||||
)
|
||||
except TypeError as e:
|
||||
logger.warning(f"TypeError in model initialization with num_features: {str(e)}")
|
||||
# Try alternate parameter naming
|
||||
try:
|
||||
self.model = Model(
|
||||
input_shape=(self.window_size, feature_count),
|
||||
output_size=self.output_size
|
||||
)
|
||||
logger.info("Model initialized with alternate parameters")
|
||||
except Exception as ex:
|
||||
logger.error(f"Failed to initialize model with alternate parameters: {str(ex)}")
|
||||
self.model = DummyModel()
|
||||
|
||||
# Try to load the best model
|
||||
self._load_model()
|
||||
self.model_initialized = True
|
||||
logger.info("Model initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing model: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
self.model = DummyModel()
|
||||
|
||||
logger.info(f"NeuralNetworkOrchestrator initialized with config: {config}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing NeuralNetworkOrchestrator: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
self.model = DummyModel()
|
||||
|
||||
def _load_model(self):
|
||||
"""Load the best trained model from available files"""
|
||||
try:
|
||||
model_paths = [
|
||||
os.path.join(self.model_dir, "dqn_agent_best_policy.pt"),
|
||||
os.path.join(self.model_dir, "cnn_model_best.pt"),
|
||||
os.path.join("models/saved", "dqn_agent_best_policy.pt"),
|
||||
os.path.join("models/saved", "cnn_model_best.pt")
|
||||
]
|
||||
|
||||
for model_path in model_paths:
|
||||
if os.path.exists(model_path):
|
||||
try:
|
||||
self.model.load(model_path)
|
||||
logger.info(f"Loaded model from {model_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load model from {model_path}: {str(e)}")
|
||||
continue
|
||||
|
||||
logger.warning("No trained model found, using dummy model")
|
||||
self.model = DummyModel()
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {str(e)}")
|
||||
self.model = DummyModel()
|
||||
return False
|
||||
|
||||
def run_inference_pipeline(self, model_type='cnn', timeframe='1h'):
|
||||
"""
|
||||
Run the inference pipeline using the trained model.
|
||||
|
||||
Args:
|
||||
model_type (str): Type of model to use (cnn, transformer, etc.)
|
||||
timeframe (str): Timeframe to use for inference
|
||||
|
||||
Returns:
|
||||
dict: Inference result
|
||||
"""
|
||||
try:
|
||||
# Check if we have a model
|
||||
if not hasattr(self, 'model') or self.model is None:
|
||||
logger.warning("No model available, initializing dummy model")
|
||||
self.model = DummyModel()
|
||||
|
||||
# Check if we have a data interface
|
||||
if not hasattr(self, 'data_interface') or self.data_interface is None:
|
||||
logger.warning("No data interface available")
|
||||
# Return a dummy prediction
|
||||
return self._get_dummy_prediction()
|
||||
|
||||
# Prepare input data for the selected timeframe
|
||||
X, timestamp = self.data_interface.prepare_realtime_input(
|
||||
timeframe=timeframe,
|
||||
n_candles=self.window_size + 10, # Extra candles for safety
|
||||
window_size=self.window_size
|
||||
)
|
||||
|
||||
if X is None:
|
||||
logger.warning(f"No data available for {self.symbol}")
|
||||
return self._get_dummy_prediction()
|
||||
|
||||
# Get model predictions
|
||||
action_probs, price_pred = self.model.predict(X)
|
||||
|
||||
# Convert predictions to action
|
||||
action_idx = np.argmax(action_probs) if hasattr(action_probs, 'argmax') else 1 # Default to HOLD
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
action = action_names[action_idx]
|
||||
|
||||
# Format timestamp
|
||||
if not isinstance(timestamp, str):
|
||||
try:
|
||||
if hasattr(timestamp, 'isoformat'): # If it's already a datetime-like object
|
||||
timestamp = timestamp.isoformat()
|
||||
else: # If it's a numeric timestamp
|
||||
timestamp = datetime.fromtimestamp(float(timestamp)/1000).isoformat()
|
||||
except (TypeError, ValueError):
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
# Return result
|
||||
result = {
|
||||
'timestamp': timestamp,
|
||||
'action': action,
|
||||
'action_index': int(action_idx),
|
||||
'probability': float(action_probs[action_idx]) if hasattr(action_probs, '__getitem__') else 0.33,
|
||||
'probabilities': {name: float(prob) for name, prob in zip(action_names, action_probs)} if hasattr(action_probs, '__iter__') else {'SELL': 0.33, 'HOLD': 0.34, 'BUY': 0.33},
|
||||
'price_prediction': float(price_pred) if price_pred is not None else None
|
||||
}
|
||||
|
||||
logger.info(f"Inference result: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in inference pipeline: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return self._get_dummy_prediction()
|
||||
|
||||
def _get_dummy_prediction(self):
|
||||
"""Return a dummy prediction when model or data is unavailable"""
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
action_idx = 1 # Default to HOLD
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
return {
|
||||
'timestamp': timestamp,
|
||||
'action': 'HOLD',
|
||||
'action_index': action_idx,
|
||||
'probability': 0.8,
|
||||
'probabilities': {'SELL': 0.1, 'HOLD': 0.8, 'BUY': 0.1},
|
||||
'price_prediction': None,
|
||||
'is_dummy': True
|
||||
}
|
||||
|
||||
|
||||
class DummyModel:
|
||||
"""Dummy model that returns random predictions"""
|
||||
|
||||
def __init__(self):
|
||||
logger.info("Initializing dummy model")
|
||||
|
||||
def predict(self, X):
|
||||
"""Return random predictions"""
|
||||
# Generate random probabilities for SELL, HOLD, BUY
|
||||
action_probs = np.array([0.1, 0.8, 0.1]) # Bias towards HOLD
|
||||
|
||||
# Generate a random price prediction (None for now)
|
||||
price_pred = None
|
||||
|
||||
return action_probs, price_pred
|
||||
|
||||
def load(self, model_path):
|
||||
"""Dummy load method"""
|
||||
logger.info(f"Dummy model pretending to load from {model_path}")
|
||||
return True
|
@ -1,287 +0,0 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, Callable, Tuple
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .trading_agent import TradingAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class NeuralNetworkOrchestrator:
|
||||
"""Orchestrator for neural network models and trading operations.
|
||||
|
||||
This class coordinates between neural network models and trading agents,
|
||||
ensuring that signals from the models are properly processed and trades
|
||||
are executed according to the strategy.
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_interface, chart=None,
|
||||
symbols: List[str] = None,
|
||||
timeframes: List[str] = None,
|
||||
window_size: int = 20,
|
||||
num_features: int = 5,
|
||||
output_size: int = 3,
|
||||
models_dir: str = "NN/models/saved",
|
||||
data_dir: str = "NN/data",
|
||||
exchange_config: Dict[str, Any] = None):
|
||||
"""Initialize the neural network orchestrator.
|
||||
|
||||
Args:
|
||||
model: Neural network model instance
|
||||
data_interface: Data interface for retrieving market data
|
||||
chart: Real-time chart for visualization (optional)
|
||||
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
|
||||
timeframes: List of timeframes to monitor (e.g., ['1m', '5m', '1h'])
|
||||
window_size: Window size for model input
|
||||
num_features: Number of features per datapoint
|
||||
output_size: Number of output classes (e.g., 3 for BUY/HOLD/SELL)
|
||||
models_dir: Directory for saved models
|
||||
data_dir: Directory for data storage
|
||||
exchange_config: Configuration for trading agent (exchange, API keys, etc.)
|
||||
"""
|
||||
self.model = model
|
||||
self.data_interface = data_interface
|
||||
self.chart = chart
|
||||
|
||||
self.symbols = symbols or ["BTC/USDT"]
|
||||
self.timeframes = timeframes or ["1m", "5m", "1h", "4h", "1d"]
|
||||
self.window_size = window_size
|
||||
self.num_features = num_features
|
||||
self.output_size = output_size
|
||||
self.models_dir = models_dir
|
||||
self.data_dir = data_dir
|
||||
|
||||
# Initialize trading agent if configuration provided
|
||||
self.trading_agent = None
|
||||
if exchange_config:
|
||||
self.init_trading_agent(exchange_config)
|
||||
|
||||
# Initialize inference state
|
||||
self.is_running = False
|
||||
self.inference_thread = None
|
||||
self.stop_event = threading.Event()
|
||||
self.last_inference_time = 0
|
||||
self.inference_interval = int(os.environ.get("NN_INFERENCE_INTERVAL", "60"))
|
||||
|
||||
logger.info(f"Initializing NeuralNetworkOrchestrator with:")
|
||||
logger.info(f"- Symbol: {self.symbols[0]}")
|
||||
logger.info(f"- Timeframes: {', '.join(self.timeframes)}")
|
||||
logger.info(f"- Window size: {window_size}")
|
||||
logger.info(f"- Num features: {num_features}")
|
||||
logger.info(f"- Output size: {output_size}")
|
||||
logger.info(f"- Models dir: {models_dir}")
|
||||
logger.info(f"- Data dir: {data_dir}")
|
||||
logger.info(f"- Inference interval: {self.inference_interval} seconds")
|
||||
|
||||
def init_trading_agent(self, config: Dict[str, Any]):
|
||||
"""Initialize the trading agent with the given configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration for the trading agent
|
||||
"""
|
||||
exchange_name = config.get("exchange", "binance")
|
||||
api_key = config.get("api_key")
|
||||
api_secret = config.get("api_secret")
|
||||
test_mode = config.get("test_mode", True)
|
||||
trade_symbols = config.get("trade_symbols", self.symbols)
|
||||
position_size = config.get("position_size", 0.1)
|
||||
max_trades_per_day = config.get("max_trades_per_day", 5)
|
||||
trade_cooldown_minutes = config.get("trade_cooldown_minutes", 60)
|
||||
|
||||
self.trading_agent = TradingAgent(
|
||||
exchange_name=exchange_name,
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
test_mode=test_mode,
|
||||
trade_symbols=trade_symbols,
|
||||
position_size=position_size,
|
||||
max_trades_per_day=max_trades_per_day,
|
||||
trade_cooldown_minutes=trade_cooldown_minutes
|
||||
)
|
||||
|
||||
logger.info(f"Trading agent initialized for {exchange_name} exchange.")
|
||||
|
||||
def start_inference(self):
|
||||
"""Start the inference thread."""
|
||||
if self.is_running:
|
||||
logger.warning("Neural network inference is already running.")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.stop_event.clear()
|
||||
|
||||
# Start inference thread
|
||||
self.inference_thread = threading.Thread(target=self._inference_loop)
|
||||
self.inference_thread.daemon = True
|
||||
self.inference_thread.start()
|
||||
|
||||
logger.info(f"Neural network inference thread started with {self.inference_interval}s interval.")
|
||||
|
||||
# Start trading agent if available
|
||||
if self.trading_agent:
|
||||
self.trading_agent.start(signal_callback=self._on_trade_executed)
|
||||
|
||||
def stop_inference(self):
|
||||
"""Stop the inference thread."""
|
||||
if not self.is_running:
|
||||
logger.warning("Neural network inference is not running.")
|
||||
return
|
||||
|
||||
logger.info("Stopping neural network inference...")
|
||||
self.is_running = False
|
||||
self.stop_event.set()
|
||||
|
||||
if self.inference_thread and self.inference_thread.is_alive():
|
||||
self.inference_thread.join(timeout=10)
|
||||
|
||||
logger.info("Neural network inference stopped.")
|
||||
|
||||
# Stop trading agent if available
|
||||
if self.trading_agent:
|
||||
self.trading_agent.stop()
|
||||
|
||||
def _inference_loop(self):
|
||||
"""Main inference loop that processes data and generates signals."""
|
||||
logger.info("Inference loop started.")
|
||||
|
||||
try:
|
||||
while self.is_running and not self.stop_event.is_set():
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we should run inference
|
||||
if current_time - self.last_inference_time >= self.inference_interval:
|
||||
try:
|
||||
# Run inference for all symbols
|
||||
for symbol in self.symbols:
|
||||
prediction = self._run_inference(symbol)
|
||||
if prediction:
|
||||
self._process_prediction(symbol, prediction)
|
||||
|
||||
self.last_inference_time = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference: {str(e)}")
|
||||
|
||||
# Sleep for a short time to prevent CPU hogging
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in inference loop: {str(e)}")
|
||||
finally:
|
||||
logger.info("Inference loop stopped.")
|
||||
|
||||
def _run_inference(self, symbol: str) -> Optional[Tuple[np.ndarray, float]]:
|
||||
"""Run inference for a specific symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
|
||||
Returns:
|
||||
tuple: (action probabilities, current price) or None if inference failed
|
||||
"""
|
||||
try:
|
||||
# Get the model timeframe from environment
|
||||
model_timeframe = os.environ.get("NN_TIMEFRAME", "1h")
|
||||
if model_timeframe not in self.timeframes:
|
||||
logger.warning(f"Model timeframe {model_timeframe} not in available timeframes. Using {self.timeframes[0]}.")
|
||||
model_timeframe = self.timeframes[0]
|
||||
|
||||
# Load candles for the model timeframe
|
||||
logger.info(f"Loading {1000} candles from cache for {symbol} at {model_timeframe} timeframe")
|
||||
candles = self.data_interface.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=model_timeframe,
|
||||
n_candles=1000
|
||||
)
|
||||
|
||||
if candles is None or len(candles) < self.window_size:
|
||||
logger.warning(f"Not enough data for {symbol} at {model_timeframe} timeframe. Need at least {self.window_size} candles.")
|
||||
return None
|
||||
|
||||
# Prepare input data
|
||||
X, timestamp = self.data_interface.prepare_model_input(
|
||||
data=candles,
|
||||
window_size=self.window_size,
|
||||
symbol=symbol
|
||||
)
|
||||
|
||||
if X is None:
|
||||
logger.warning(f"Failed to prepare model input for {symbol}.")
|
||||
return None
|
||||
|
||||
# Get current price
|
||||
current_price = candles['close'].iloc[-1]
|
||||
|
||||
# Run model inference
|
||||
action_probs, price_pred = self.model.predict(X)
|
||||
|
||||
return action_probs, current_price
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running inference for {symbol}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _process_prediction(self, symbol: str, prediction: Tuple[np.ndarray, float]):
|
||||
"""Process a prediction and generate signals.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
prediction: Tuple of (action probabilities, current price)
|
||||
"""
|
||||
action_probs, current_price = prediction
|
||||
|
||||
# Get the best action (0=SELL, 1=HOLD, 2=BUY)
|
||||
best_action = np.argmax(action_probs)
|
||||
best_prob = float(action_probs[best_action])
|
||||
|
||||
# Convert to action name
|
||||
action_names = ["SELL", "HOLD", "BUY"]
|
||||
action_name = action_names[best_action]
|
||||
|
||||
# Log the prediction
|
||||
logger.info(f"Inference result for {symbol}: Action={action_name}, Probability={best_prob:.2f}, Price={current_price:.2f}")
|
||||
|
||||
# Add signal to chart if available
|
||||
if self.chart:
|
||||
self.chart.add_nn_signal(symbol=symbol, signal=action_name, confidence=best_prob, timestamp=int(time.time()))
|
||||
|
||||
# Process signal with trading agent if available
|
||||
if self.trading_agent:
|
||||
self.trading_agent.process_signal(
|
||||
symbol=symbol,
|
||||
action=action_name,
|
||||
confidence=best_prob,
|
||||
timestamp=int(time.time())
|
||||
)
|
||||
|
||||
def _on_trade_executed(self, trade_record: Dict[str, Any]):
|
||||
"""Callback for when a trade is executed.
|
||||
|
||||
Args:
|
||||
trade_record: Trade information
|
||||
"""
|
||||
if self.chart and trade_record:
|
||||
# Add trade to chart
|
||||
self.chart.add_trade(
|
||||
action=trade_record['action'],
|
||||
price=trade_record.get('price', 0),
|
||||
timestamp=trade_record['timestamp'],
|
||||
pnl=trade_record.get('pnl', 0)
|
||||
)
|
||||
|
||||
logger.info(f"Trade added to chart: {trade_record['action']} at {trade_record.get('price', 0):.2f}")
|
||||
|
||||
def get_trading_agent_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the trading agent.
|
||||
|
||||
Returns:
|
||||
dict: Trading agent information or None if no agent is available
|
||||
"""
|
||||
if self.trading_agent:
|
||||
return {
|
||||
'exchange_info': self.trading_agent.get_exchange_info(),
|
||||
'positions': self.trading_agent.get_current_positions(),
|
||||
'trades': len(self.trading_agent.get_trade_history())
|
||||
}
|
||||
return None
|
@ -1,287 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Neural Network Trading System Main Module
|
||||
|
||||
This module serves as the main entry point for the NN trading system,
|
||||
coordinating data flow between different components and implementing
|
||||
training and inference pipelines.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'))
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger('NN')
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
|
||||
def parse_arguments():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(description='Neural Network Trading System')
|
||||
|
||||
parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train',
|
||||
help='Mode to run (train, predict, realtime)')
|
||||
parser.add_argument('--symbol', type=str, default='BTC/USDT',
|
||||
help='Trading pair symbol')
|
||||
parser.add_argument('--timeframes', type=str, nargs='+', default=['1h', '4h'],
|
||||
help='Timeframes to use')
|
||||
parser.add_argument('--window-size', type=int, default=20,
|
||||
help='Window size for input data')
|
||||
parser.add_argument('--output-size', type=int, default=3,
|
||||
help='Output size (1 for binary, 3 for BUY/HOLD/SELL)')
|
||||
parser.add_argument('--batch-size', type=int, default=32,
|
||||
help='Batch size for training')
|
||||
parser.add_argument('--epochs', type=int, default=100,
|
||||
help='Number of epochs for training')
|
||||
parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn',
|
||||
help='Model type to use')
|
||||
parser.add_argument('--framework', type=str, choices=['tensorflow', 'pytorch'], default='pytorch',
|
||||
help='Deep learning framework to use')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def main():
|
||||
"""Main entry point for the NN trading system"""
|
||||
# Parse arguments
|
||||
args = parse_arguments()
|
||||
|
||||
logger.info(f"Starting NN Trading System in {args.mode} mode")
|
||||
logger.info(f"Configuration: Symbol={args.symbol}, Timeframes={args.timeframes}, "
|
||||
f"Window Size={args.window_size}, Output Size={args.output_size}, "
|
||||
f"Model Type={args.model_type}, Framework={args.framework}")
|
||||
|
||||
# Import the appropriate modules based on the framework
|
||||
if args.framework == 'pytorch':
|
||||
try:
|
||||
import torch
|
||||
logger.info(f"Using PyTorch {torch.__version__}")
|
||||
|
||||
# Import PyTorch-based modules
|
||||
from NN.utils.data_interface import DataInterface
|
||||
|
||||
if args.model_type == 'cnn':
|
||||
from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model
|
||||
elif args.model_type == 'transformer':
|
||||
from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model
|
||||
elif args.model_type == 'moe':
|
||||
from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model
|
||||
else:
|
||||
logger.error(f"Unknown model type: {args.model_type}")
|
||||
return
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import PyTorch modules: {str(e)}")
|
||||
logger.error("Please make sure PyTorch is installed or use the TensorFlow framework.")
|
||||
return
|
||||
|
||||
elif args.framework == 'tensorflow':
|
||||
try:
|
||||
import tensorflow as tf
|
||||
logger.info(f"Using TensorFlow {tf.__version__}")
|
||||
|
||||
# Import TensorFlow-based modules
|
||||
from NN.utils.data_interface import DataInterface
|
||||
|
||||
if args.model_type == 'cnn':
|
||||
from NN.models.cnn_model import CNNModel as Model
|
||||
elif args.model_type == 'transformer':
|
||||
from NN.models.transformer_model import TransformerModel as Model
|
||||
elif args.model_type == 'moe':
|
||||
from NN.models.transformer_model import MixtureOfExpertsModel as Model
|
||||
else:
|
||||
logger.error(f"Unknown model type: {args.model_type}")
|
||||
return
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import TensorFlow modules: {str(e)}")
|
||||
logger.error("Please make sure TensorFlow is installed or use the PyTorch framework.")
|
||||
return
|
||||
else:
|
||||
logger.error(f"Unknown framework: {args.framework}")
|
||||
return
|
||||
|
||||
# Initialize data interface
|
||||
try:
|
||||
logger.info("Initializing data interface...")
|
||||
data_interface = DataInterface(
|
||||
symbol=args.symbol,
|
||||
timeframes=args.timeframes
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize data interface: {str(e)}")
|
||||
return
|
||||
|
||||
# Initialize model
|
||||
try:
|
||||
logger.info(f"Initializing {args.model_type.upper()} model...")
|
||||
model = Model(
|
||||
window_size=args.window_size,
|
||||
num_features=data_interface.get_feature_count(),
|
||||
output_size=args.output_size,
|
||||
timeframes=args.timeframes
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {str(e)}")
|
||||
return
|
||||
|
||||
# Execute the requested mode
|
||||
if args.mode == 'train':
|
||||
train(data_interface, model, args)
|
||||
elif args.mode == 'predict':
|
||||
predict(data_interface, model, args)
|
||||
elif args.mode == 'realtime':
|
||||
realtime(data_interface, model, args)
|
||||
else:
|
||||
logger.error(f"Unknown mode: {args.mode}")
|
||||
return
|
||||
|
||||
logger.info("Neural Network Trading System finished successfully")
|
||||
|
||||
def train(data_interface, model, args):
|
||||
"""Enhanced training with performance tracking"""
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
logger.info("Starting training mode...")
|
||||
writer = SummaryWriter(log_dir=f"runs/{args.model_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
||||
|
||||
try:
|
||||
best_val_acc = 0
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
# Refresh data every few epochs
|
||||
if epoch % 3 == 0:
|
||||
X_train, y_train, X_val, y_val = data_interface.prepare_training_data(refresh=True)
|
||||
else:
|
||||
X_train, y_train, X_val, y_val = data_interface.prepare_training_data()
|
||||
|
||||
# Train for one epoch
|
||||
train_loss, train_acc = model.train_epoch(
|
||||
X_train, y_train,
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
|
||||
# Validate
|
||||
val_loss, val_acc = model.evaluate(X_val, y_val)
|
||||
|
||||
# Log metrics
|
||||
writer.add_scalar('Loss/Train', train_loss, epoch)
|
||||
writer.add_scalar('Accuracy/Train', train_acc, epoch)
|
||||
writer.add_scalar('Loss/Validation', val_loss, epoch)
|
||||
writer.add_scalar('Accuracy/Validation', val_acc, epoch)
|
||||
|
||||
# Save best model
|
||||
if val_acc > best_val_acc:
|
||||
best_val_acc = val_acc
|
||||
model_path = os.path.join(
|
||||
'models',
|
||||
f"{args.model_type}_best_{args.symbol.replace('/', '_')}.pt"
|
||||
)
|
||||
model.save(model_path)
|
||||
logger.info(f"New best model saved with val_acc: {val_acc:.2f}")
|
||||
|
||||
logger.info(f"Epoch {epoch+1}/{args.epochs} - "
|
||||
f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f} - "
|
||||
f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}")
|
||||
|
||||
# Save final model
|
||||
model_path = os.path.join(
|
||||
'models',
|
||||
f"{args.model_type}_final_{args.symbol.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
|
||||
)
|
||||
model.save(model_path)
|
||||
|
||||
logger.info(f"Training Complete - Best Val Accuracy: {best_val_acc:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training mode: {str(e)}")
|
||||
return
|
||||
|
||||
def predict(data_interface, model, args):
|
||||
"""Make predictions using the trained model"""
|
||||
logger.info("Starting prediction mode...")
|
||||
|
||||
try:
|
||||
# Load the latest model
|
||||
model_dir = os.path.join('models')
|
||||
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
|
||||
|
||||
if not model_files:
|
||||
logger.error(f"No saved model found for type {args.model_type}")
|
||||
return
|
||||
|
||||
latest_model = sorted(model_files)[-1]
|
||||
model_path = os.path.join(model_dir, latest_model)
|
||||
|
||||
logger.info(f"Loading model from {model_path}...")
|
||||
model.load(model_path)
|
||||
|
||||
# Prepare prediction data
|
||||
logger.info("Preparing prediction data...")
|
||||
X_pred = data_interface.prepare_prediction_data()
|
||||
|
||||
# Make predictions
|
||||
logger.info("Making predictions...")
|
||||
predictions = model.predict(X_pred)
|
||||
|
||||
# Process and display predictions
|
||||
logger.info("Processing predictions...")
|
||||
data_interface.process_predictions(predictions)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction mode: {str(e)}")
|
||||
return
|
||||
|
||||
def realtime(data_interface, model, args):
|
||||
"""Run the model in real-time mode"""
|
||||
logger.info("Starting real-time mode...")
|
||||
|
||||
try:
|
||||
# Import realtime analyzer
|
||||
from NN.utils.realtime_analyzer import RealtimeAnalyzer
|
||||
|
||||
# Load the latest model
|
||||
model_dir = os.path.join('models')
|
||||
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
|
||||
|
||||
if not model_files:
|
||||
logger.error(f"No saved model found for type {args.model_type}")
|
||||
return
|
||||
|
||||
latest_model = sorted(model_files)[-1]
|
||||
model_path = os.path.join(model_dir, latest_model)
|
||||
|
||||
logger.info(f"Loading model from {model_path}...")
|
||||
model.load(model_path)
|
||||
|
||||
# Initialize realtime analyzer
|
||||
logger.info("Initializing real-time analyzer...")
|
||||
realtime_analyzer = RealtimeAnalyzer(
|
||||
data_interface=data_interface,
|
||||
model=model,
|
||||
symbol=args.symbol,
|
||||
timeframes=args.timeframes
|
||||
)
|
||||
|
||||
# Start real-time analysis
|
||||
logger.info("Starting real-time analysis...")
|
||||
realtime_analyzer.start()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in real-time mode: {str(e)}")
|
||||
return
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,241 +0,0 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RealtimeDataInterface:
|
||||
"""Interface for retrieving real-time market data for neural network models.
|
||||
|
||||
This class serves as a bridge between the RealTimeChart data sources and
|
||||
the neural network models, providing properly formatted data for model
|
||||
inference.
|
||||
"""
|
||||
|
||||
def __init__(self, symbols: List[str], chart=None, max_cache_size: int = 5000):
|
||||
"""Initialize the data interface.
|
||||
|
||||
Args:
|
||||
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
|
||||
chart: RealTimeChart instance (optional)
|
||||
max_cache_size: Maximum number of cached candles
|
||||
"""
|
||||
self.symbols = symbols
|
||||
self.chart = chart
|
||||
self.max_cache_size = max_cache_size
|
||||
|
||||
# Initialize data cache
|
||||
self.ohlcv_cache = {} # timeframe -> symbol -> DataFrame
|
||||
|
||||
logger.info(f"Initialized RealtimeDataInterface with symbols: {', '.join(symbols)}")
|
||||
|
||||
def get_historical_data(self, symbol: str = None, timeframe: str = '1h',
|
||||
n_candles: int = 500) -> Optional[pd.DataFrame]:
|
||||
"""Get historical OHLCV data for a symbol and timeframe.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
timeframe: Time interval (e.g., '1m', '5m', '1h')
|
||||
n_candles: Number of candles to retrieve
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data or None if not available
|
||||
"""
|
||||
if not symbol:
|
||||
if len(self.symbols) > 0:
|
||||
symbol = self.symbols[0]
|
||||
else:
|
||||
logger.error("No symbol specified and no default symbols available")
|
||||
return None
|
||||
|
||||
if symbol not in self.symbols:
|
||||
logger.warning(f"Symbol {symbol} not in tracked symbols")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Get data from chart if available
|
||||
if self.chart:
|
||||
candles = self._get_chart_data(symbol, timeframe, n_candles)
|
||||
if candles is not None and len(candles) > 0:
|
||||
return candles
|
||||
|
||||
# Fallback to default empty DataFrame
|
||||
logger.warning(f"No historical data available for {symbol} at timeframe {timeframe}")
|
||||
return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting historical data for {symbol}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _get_chart_data(self, symbol: str, timeframe: str, n_candles: int) -> Optional[pd.DataFrame]:
|
||||
"""Get data from the RealTimeChart for the specified symbol and timeframe.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
timeframe: Time interval (e.g., '1m', '5m', '1h')
|
||||
n_candles: Number of candles to retrieve
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data or None if not available
|
||||
"""
|
||||
if not self.chart:
|
||||
return None
|
||||
|
||||
# Get chart data using the _get_chart_data method
|
||||
try:
|
||||
# Map to interval seconds
|
||||
interval_map = {
|
||||
'1s': 1,
|
||||
'5s': 5,
|
||||
'10s': 10,
|
||||
'15s': 15,
|
||||
'30s': 30,
|
||||
'1m': 60,
|
||||
'3m': 180,
|
||||
'5m': 300,
|
||||
'15m': 900,
|
||||
'30m': 1800,
|
||||
'1h': 3600,
|
||||
'2h': 7200,
|
||||
'4h': 14400,
|
||||
'6h': 21600,
|
||||
'8h': 28800,
|
||||
'12h': 43200,
|
||||
'1d': 86400,
|
||||
'3d': 259200,
|
||||
'1w': 604800
|
||||
}
|
||||
|
||||
# Convert timeframe to seconds
|
||||
if timeframe in interval_map:
|
||||
interval_seconds = interval_map[timeframe]
|
||||
else:
|
||||
# Try to parse the interval (e.g., '1m' -> 60)
|
||||
try:
|
||||
if timeframe.endswith('s'):
|
||||
interval_seconds = int(timeframe[:-1])
|
||||
elif timeframe.endswith('m'):
|
||||
interval_seconds = int(timeframe[:-1]) * 60
|
||||
elif timeframe.endswith('h'):
|
||||
interval_seconds = int(timeframe[:-1]) * 3600
|
||||
elif timeframe.endswith('d'):
|
||||
interval_seconds = int(timeframe[:-1]) * 86400
|
||||
elif timeframe.endswith('w'):
|
||||
interval_seconds = int(timeframe[:-1]) * 604800
|
||||
else:
|
||||
interval_seconds = int(timeframe)
|
||||
except ValueError:
|
||||
logger.error(f"Could not parse timeframe: {timeframe}")
|
||||
return None
|
||||
|
||||
# Get data from chart
|
||||
df = self.chart._get_chart_data(interval_seconds)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Limit to requested number of candles
|
||||
if len(df) > n_candles:
|
||||
df = df.iloc[-n_candles:]
|
||||
|
||||
return df
|
||||
else:
|
||||
logger.warning(f"No data retrieved from chart for {symbol} at timeframe {timeframe}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting chart data for {symbol} at {timeframe}: {str(e)}")
|
||||
return None
|
||||
|
||||
def prepare_model_input(self, data: pd.DataFrame, window_size: int = 20,
|
||||
symbol: str = None) -> Tuple[np.ndarray, Optional[int]]:
|
||||
"""Prepare model input from OHLCV data.
|
||||
|
||||
Args:
|
||||
data: DataFrame with OHLCV data
|
||||
window_size: Window size for model input
|
||||
symbol: Symbol for the data (for logging)
|
||||
|
||||
Returns:
|
||||
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
|
||||
"""
|
||||
if data is None or len(data) < window_size:
|
||||
logger.warning(f"Not enough data to prepare model input for {symbol or 'unknown symbol'}")
|
||||
return None, None
|
||||
|
||||
try:
|
||||
# Get last window_size candles
|
||||
recent_data = data.iloc[-window_size:].copy()
|
||||
|
||||
# Get timestamp of the most recent candle
|
||||
timestamp = int(recent_data.iloc[-1]['timestamp']) if 'timestamp' in recent_data.columns else int(time.time())
|
||||
|
||||
# Extract OHLCV features and normalize
|
||||
if 'open' in recent_data.columns and 'high' in recent_data.columns and 'low' in recent_data.columns and 'close' in recent_data.columns and 'volume' in recent_data.columns:
|
||||
# Normalize price data by the last close price
|
||||
last_close = recent_data['close'].iloc[-1]
|
||||
|
||||
# Avoid division by zero
|
||||
if last_close == 0:
|
||||
last_close = 1.0
|
||||
|
||||
opens = (recent_data['open'] / last_close).values
|
||||
highs = (recent_data['high'] / last_close).values
|
||||
lows = (recent_data['low'] / last_close).values
|
||||
closes = (recent_data['close'] / last_close).values
|
||||
|
||||
# Normalize volume by the max volume in the window
|
||||
max_volume = recent_data['volume'].max()
|
||||
if max_volume == 0:
|
||||
max_volume = 1.0
|
||||
volumes = (recent_data['volume'] / max_volume).values
|
||||
|
||||
# Stack features into a 3D array [batch_size=1, window_size, n_features=5]
|
||||
X = np.column_stack((opens, highs, lows, closes, volumes))
|
||||
X = X.reshape(1, window_size, 5)
|
||||
|
||||
# Replace any NaN or infinite values
|
||||
X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=0.0)
|
||||
|
||||
return X, timestamp
|
||||
else:
|
||||
logger.error(f"Data missing required OHLCV columns for {symbol or 'unknown symbol'}")
|
||||
return None, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing model input for {symbol or 'unknown symbol'}: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def prepare_realtime_input(self, timeframe: str = '1h', n_candles: int = 30,
|
||||
window_size: int = 20) -> Tuple[np.ndarray, Optional[int]]:
|
||||
"""Prepare real-time input for the model.
|
||||
|
||||
Args:
|
||||
timeframe: Time interval (e.g., '1m', '5m', '1h')
|
||||
n_candles: Number of candles to retrieve
|
||||
window_size: Window size for model input
|
||||
|
||||
Returns:
|
||||
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
|
||||
"""
|
||||
# Get data for the main symbol
|
||||
if len(self.symbols) == 0:
|
||||
logger.error("No symbols available for real-time input")
|
||||
return None, None
|
||||
|
||||
symbol = self.symbols[0]
|
||||
|
||||
try:
|
||||
# Get historical data
|
||||
data = self.get_historical_data(symbol, timeframe, n_candles)
|
||||
|
||||
if data is None or len(data) < window_size:
|
||||
logger.warning(f"Not enough data for real-time input. Need at least {window_size} candles.")
|
||||
return None, None
|
||||
|
||||
# Prepare model input
|
||||
return self.prepare_model_input(data, window_size, symbol)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing real-time input: {str(e)}")
|
||||
return None, None
|
@ -1,507 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Neural Network Trading System Main Module - PyTorch Version
|
||||
|
||||
This module serves as the main entry point for the NN trading system,
|
||||
using PyTorch exclusively for all model operations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger('NN')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
try:
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
|
||||
# Try setting up file logging
|
||||
log_file = os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
|
||||
fh = logging.FileHandler(log_file)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
logger.info(f"Logging to file: {log_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup file logging: {str(e)}. Falling back to console logging only.")
|
||||
|
||||
# Always setup console logging
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
def parse_arguments():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(description='Neural Network Trading System')
|
||||
|
||||
parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train',
|
||||
help='Mode to run (train, predict, realtime)')
|
||||
parser.add_argument('--symbol', type=str, default='BTC/USDT',
|
||||
help='Trading pair symbol')
|
||||
parser.add_argument('--timeframes', type=str, nargs='+', default=['1s', '1m', '5m', '1h', '4h'],
|
||||
help='Timeframes to use (include 1s for ticks)')
|
||||
parser.add_argument('--window-size', type=int, default=20,
|
||||
help='Window size for input data')
|
||||
parser.add_argument('--output-size', type=int, default=3,
|
||||
help='Output size (1 for binary, 3 for BUY/HOLD/SELL)')
|
||||
parser.add_argument('--batch-size', type=int, default=32,
|
||||
help='Batch size for training')
|
||||
parser.add_argument('--epochs', type=int, default=10,
|
||||
help='Number of epochs for training')
|
||||
parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn',
|
||||
help='Model type to use')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def main():
|
||||
"""Main entry point for the NN trading system"""
|
||||
args = parse_arguments()
|
||||
|
||||
logger.info(f"Starting NN Trading System in {args.mode} mode")
|
||||
logger.info(f"Configuration: Symbol={args.symbol}, Timeframes={args.timeframes}")
|
||||
|
||||
try:
|
||||
import torch
|
||||
from NN.utils.data_interface import DataInterface
|
||||
|
||||
# Import appropriate PyTorch model
|
||||
if args.model_type == 'cnn':
|
||||
from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model
|
||||
elif args.model_type == 'transformer':
|
||||
from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model
|
||||
elif args.model_type == 'moe':
|
||||
from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model
|
||||
else:
|
||||
logger.error(f"Unknown model type: {args.model_type}")
|
||||
return
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import PyTorch modules: {str(e)}")
|
||||
logger.error("Please make sure PyTorch is installed")
|
||||
return
|
||||
|
||||
# Initialize data interface
|
||||
try:
|
||||
data_interface = DataInterface(
|
||||
symbol=args.symbol,
|
||||
timeframes=args.timeframes
|
||||
)
|
||||
|
||||
# Verify data interface by fetching initial data
|
||||
logger.info("Verifying data interface...")
|
||||
X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True)
|
||||
if X_sample is None or y_sample is not None:
|
||||
logger.error("Failed to prepare initial training data")
|
||||
return
|
||||
|
||||
logger.info(f"Data interface verified - X shape: {X_sample.shape}, y shape: {y_sample.shape}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize data interface: {str(e)}")
|
||||
return
|
||||
|
||||
# Initialize model
|
||||
try:
|
||||
# Calculate total number of features across all timeframes
|
||||
num_features = data_interface.get_feature_count()
|
||||
logger.info(f"Initializing model with {num_features} features")
|
||||
|
||||
model = Model(
|
||||
window_size=args.window_size,
|
||||
num_features=num_features,
|
||||
output_size=args.output_size,
|
||||
timeframes=args.timeframes
|
||||
)
|
||||
|
||||
# Ensure model is on the correct device
|
||||
if torch.cuda.is_available():
|
||||
model.model = model.model.cuda()
|
||||
logger.info("Model moved to CUDA device")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {str(e)}")
|
||||
return
|
||||
|
||||
# Execute requested mode
|
||||
if args.mode == 'train':
|
||||
train(data_interface, model, args)
|
||||
elif args.mode == 'predict':
|
||||
predict(data_interface, model, args)
|
||||
elif args.mode == 'realtime':
|
||||
realtime(data_interface, model, args)
|
||||
|
||||
def train(data_interface, model, args):
|
||||
"""Enhanced training with performance tracking and retrospective fine-tuning"""
|
||||
logger.info("Starting training mode...")
|
||||
writer = SummaryWriter()
|
||||
|
||||
try:
|
||||
best_val_acc = 0
|
||||
best_val_pnl = float('-inf')
|
||||
best_win_rate = 0
|
||||
best_price_mae = float('inf')
|
||||
|
||||
logger.info("Verifying data interface...")
|
||||
X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True)
|
||||
logger.info(f"Data validation - X shape: {X_sample.shape}, y shape: {y_sample.shape}")
|
||||
|
||||
# Calculate refresh intervals based on timeframes
|
||||
min_timeframe = min(args.timeframes)
|
||||
refresh_interval = {
|
||||
'1s': 1,
|
||||
'1m': 60,
|
||||
'5m': 300,
|
||||
'15m': 900,
|
||||
'1h': 3600,
|
||||
'4h': 14400,
|
||||
'1d': 86400
|
||||
}.get(min_timeframe, 60)
|
||||
|
||||
logger.info(f"Using refresh interval of {refresh_interval} seconds based on {min_timeframe} timeframe")
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
# Always refresh for tick data or when using multiple timeframes
|
||||
refresh = '1s' in args.timeframes or len(args.timeframes) > 1
|
||||
|
||||
logger.info(f"\nStarting epoch {epoch+1}/{args.epochs}")
|
||||
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data(
|
||||
refresh=refresh,
|
||||
refresh_interval=refresh_interval
|
||||
)
|
||||
logger.info(f"Training data - X shape: {X_train.shape}, y shape: {y_train.shape}")
|
||||
logger.info(f"Validation data - X shape: {X_val.shape}, y shape: {y_val.shape}")
|
||||
|
||||
# Get future prices for retrospective training
|
||||
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=3)
|
||||
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=3)
|
||||
|
||||
# Train and validate
|
||||
try:
|
||||
train_action_loss, train_price_loss, train_acc = model.train_epoch(
|
||||
X_train, y_train, train_future_prices, args.batch_size
|
||||
)
|
||||
val_action_loss, val_price_loss, val_acc = model.evaluate(
|
||||
X_val, y_val, val_future_prices
|
||||
)
|
||||
|
||||
# Get predictions for PnL calculation
|
||||
train_action_probs, train_price_preds = model.predict(X_train)
|
||||
val_action_probs, val_price_preds = model.predict(X_val)
|
||||
|
||||
# Convert probabilities to actions for PnL calculation
|
||||
train_preds = np.argmax(train_action_probs, axis=1)
|
||||
val_preds = np.argmax(val_action_probs, axis=1)
|
||||
|
||||
# Calculate PnL and win rates
|
||||
try:
|
||||
if train_preds is not None and train_prices is not None:
|
||||
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
|
||||
train_preds, train_prices, position_size=1.0
|
||||
)
|
||||
else:
|
||||
train_pnl, train_win_rate, train_trades = 0, 0, []
|
||||
|
||||
if val_preds is not None and val_prices is not None:
|
||||
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
|
||||
val_preds, val_prices, position_size=1.0
|
||||
)
|
||||
else:
|
||||
val_pnl, val_win_rate, val_trades = 0, 0, []
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating PnL: {str(e)}")
|
||||
train_pnl, train_win_rate, val_pnl, val_win_rate = 0, 0, 0, 0
|
||||
train_trades, val_trades = [], []
|
||||
|
||||
# Calculate price prediction error
|
||||
if train_future_prices is not None and train_price_preds is not None:
|
||||
# Ensure arrays have the same shape and are numpy arrays
|
||||
train_future_prices_np = np.array(train_future_prices) if not isinstance(train_future_prices, np.ndarray) else train_future_prices
|
||||
train_price_preds_np = np.array(train_price_preds) if not isinstance(train_price_preds, np.ndarray) else train_price_preds
|
||||
|
||||
if len(train_price_preds_np) > 0 and len(train_future_prices_np) > 0:
|
||||
min_len = min(len(train_price_preds_np), len(train_future_prices_np))
|
||||
train_price_mae = np.mean(np.abs(train_price_preds_np[:min_len] - train_future_prices_np[:min_len]))
|
||||
else:
|
||||
train_price_mae = float('inf')
|
||||
else:
|
||||
train_price_mae = float('inf')
|
||||
|
||||
if val_future_prices is not None and val_price_preds is not None:
|
||||
# Ensure arrays have the same shape and are numpy arrays
|
||||
val_future_prices_np = np.array(val_future_prices) if not isinstance(val_future_prices, np.ndarray) else val_future_prices
|
||||
val_price_preds_np = np.array(val_price_preds) if not isinstance(val_price_preds, np.ndarray) else val_price_preds
|
||||
|
||||
if len(val_price_preds_np) > 0 and len(val_future_prices_np) > 0:
|
||||
min_len = min(len(val_price_preds_np), len(val_future_prices_np))
|
||||
val_price_mae = np.mean(np.abs(val_price_preds_np[:min_len] - val_future_prices_np[:min_len]))
|
||||
else:
|
||||
val_price_mae = float('inf')
|
||||
else:
|
||||
val_price_mae = float('inf')
|
||||
|
||||
# Monitor action distribution
|
||||
train_actions = np.bincount(np.argmax(train_action_probs, axis=1), minlength=3)
|
||||
val_actions = np.bincount(np.argmax(val_action_probs, axis=1), minlength=3)
|
||||
|
||||
# Log metrics
|
||||
writer.add_scalar('Loss/action_train', train_action_loss, epoch)
|
||||
writer.add_scalar('Loss/price_train', train_price_loss, epoch)
|
||||
writer.add_scalar('Loss/action_val', val_action_loss, epoch)
|
||||
writer.add_scalar('Loss/price_val', val_price_loss, epoch)
|
||||
writer.add_scalar('Accuracy/train', train_acc, epoch)
|
||||
writer.add_scalar('Accuracy/val', val_acc, epoch)
|
||||
writer.add_scalar('PnL/train', train_pnl, epoch)
|
||||
writer.add_scalar('PnL/val', val_pnl, epoch)
|
||||
writer.add_scalar('WinRate/train', train_win_rate, epoch)
|
||||
writer.add_scalar('WinRate/val', val_win_rate, epoch)
|
||||
writer.add_scalar('PriceMAE/train', train_price_mae, epoch)
|
||||
writer.add_scalar('PriceMAE/val', val_price_mae, epoch)
|
||||
|
||||
# Log action distribution
|
||||
for i, action in enumerate(['SELL', 'HOLD', 'BUY']):
|
||||
writer.add_scalar(f'Actions/train_{action}', train_actions[i], epoch)
|
||||
writer.add_scalar(f'Actions/val_{action}', val_actions[i], epoch)
|
||||
|
||||
# Save best model based on validation metrics
|
||||
if np.isscalar(val_pnl) and np.isscalar(best_val_pnl) and (val_pnl > best_val_pnl or (np.isclose(val_pnl, best_val_pnl) and val_acc > best_val_acc)):
|
||||
best_val_pnl = val_pnl
|
||||
best_val_acc = val_acc
|
||||
best_win_rate = val_win_rate
|
||||
best_price_mae = val_price_mae
|
||||
model.save(f"models/{args.model_type}_best.pt")
|
||||
logger.info("Saved new best model based on validation metrics")
|
||||
|
||||
# Log detailed metrics
|
||||
logger.info(f"Epoch {epoch+1}/{args.epochs}")
|
||||
logger.info("Training Metrics:")
|
||||
logger.info(f" Action Loss: {train_action_loss:.4f}")
|
||||
logger.info(f" Price Loss: {train_price_loss:.4f}")
|
||||
logger.info(f" Accuracy: {train_acc:.2f}")
|
||||
logger.info(f" PnL: {train_pnl:.2%}")
|
||||
logger.info(f" Win Rate: {train_win_rate:.2%}")
|
||||
logger.info(f" Price MAE: {train_price_mae:.2f}")
|
||||
|
||||
logger.info("Validation Metrics:")
|
||||
logger.info(f" Action Loss: {val_action_loss:.4f}")
|
||||
logger.info(f" Price Loss: {val_price_loss:.4f}")
|
||||
logger.info(f" Accuracy: {val_acc:.2f}")
|
||||
logger.info(f" PnL: {val_pnl:.2%}")
|
||||
logger.info(f" Win Rate: {val_win_rate:.2%}")
|
||||
logger.info(f" Price MAE: {val_price_mae:.2f}")
|
||||
|
||||
# Log action distribution
|
||||
logger.info("Action Distribution:")
|
||||
for i, action in enumerate(['SELL', 'HOLD', 'BUY']):
|
||||
logger.info(f" {action}: Train={train_actions[i]}, Val={val_actions[i]}")
|
||||
|
||||
# Log trade statistics
|
||||
logger.info("Trade Statistics:")
|
||||
logger.info(f" Training trades: {len(train_trades)}")
|
||||
logger.info(f" Validation trades: {len(val_trades)}")
|
||||
|
||||
# Log next candle predictions
|
||||
if epoch % 10 == 0: # Every 10 epochs
|
||||
logger.info("\nNext Candle Predictions:")
|
||||
next_candles = model.predict_next_candles(X_val[-1:], n_candles=3)
|
||||
for tf in args.timeframes:
|
||||
if tf in next_candles:
|
||||
logger.info(f"\n{tf} timeframe predictions:")
|
||||
for i, pred in enumerate(next_candles[tf]):
|
||||
action = ['SELL', 'HOLD', 'BUY'][np.argmax(pred)]
|
||||
confidence = np.max(pred)
|
||||
logger.info(f" Candle {i+1}: {action} (confidence: {confidence:.2f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during epoch {epoch+1}: {str(e)}")
|
||||
continue
|
||||
|
||||
# Save final model
|
||||
model.save(f"models/{args.model_type}_final_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")
|
||||
logger.info(f"\nTraining complete. Best validation metrics:")
|
||||
logger.info(f"Accuracy: {best_val_acc:.2f}")
|
||||
logger.info(f"PnL: {best_val_pnl:.2%}")
|
||||
logger.info(f"Win Rate: {best_win_rate:.2%}")
|
||||
logger.info(f"Price MAE: {best_price_mae:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training: {str(e)}")
|
||||
|
||||
def predict(data_interface, model, args):
|
||||
"""Make predictions using the trained model"""
|
||||
logger.info("Starting prediction mode...")
|
||||
|
||||
try:
|
||||
# Load the latest model
|
||||
model_dir = os.path.join('models')
|
||||
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
|
||||
|
||||
if not model_files:
|
||||
logger.error(f"No saved model found for type {args.model_type}")
|
||||
return
|
||||
|
||||
latest_model = sorted(model_files)[-1]
|
||||
model_path = os.path.join(model_dir, latest_model)
|
||||
|
||||
logger.info(f"Loading model from {model_path}...")
|
||||
model.load(model_path)
|
||||
|
||||
# Prepare prediction data
|
||||
logger.info("Preparing prediction data...")
|
||||
X_pred = data_interface.prepare_prediction_data()
|
||||
|
||||
# Make predictions
|
||||
logger.info("Making predictions...")
|
||||
predictions = model.predict(X_pred)
|
||||
|
||||
# Process and display predictions
|
||||
logger.info("Processing predictions...")
|
||||
data_interface.process_predictions(predictions)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction mode: {str(e)}")
|
||||
|
||||
def realtime(data_interface, model, args, chart=None, symbol=None):
|
||||
"""Run real-time inference with the trained model"""
|
||||
logger.info(f"Starting real-time inference mode for {symbol}...")
|
||||
|
||||
try:
|
||||
from NN.utils.realtime_analyzer import RealtimeAnalyzer
|
||||
|
||||
# Load the latest model
|
||||
model_dir = os.path.join('models')
|
||||
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
|
||||
|
||||
if not model_files:
|
||||
logger.error(f"No saved model found for type {args.model_type}")
|
||||
return
|
||||
|
||||
latest_model = sorted(model_files)[-1]
|
||||
model_path = os.path.join(model_dir, latest_model)
|
||||
|
||||
logger.info(f"Loading model from {model_path}...")
|
||||
model.load(model_path)
|
||||
|
||||
# Initialize realtime analyzer
|
||||
logger.info("Initializing real-time analyzer...")
|
||||
realtime_analyzer = RealtimeAnalyzer(
|
||||
data_interface=data_interface,
|
||||
model=model,
|
||||
symbol=args.symbol,
|
||||
timeframes=args.timeframes
|
||||
)
|
||||
|
||||
# Start real-time analysis
|
||||
logger.info("Starting real-time analysis...")
|
||||
realtime_analyzer.start()
|
||||
|
||||
|
||||
|
||||
# Initialize variables for tracking performance
|
||||
total_pnl = 0.0
|
||||
trades = []
|
||||
current_position = 0.0
|
||||
last_action = None
|
||||
last_price = None
|
||||
|
||||
# Get the pair index for this symbol
|
||||
pair_index = args.symbols.index(symbol)
|
||||
|
||||
# Only execute trades if this is the main pair (BTC/USDT)
|
||||
is_main_pair = symbol == "BTC/USDT"
|
||||
|
||||
while True:
|
||||
# Get current market data for all pairs
|
||||
all_pairs_data = []
|
||||
for s in args.symbols:
|
||||
X, timestamp = data_interface.prepare_realtime_input(
|
||||
timeframe=args.timeframes[0], # Use shortest timeframe
|
||||
n_candles=args.window_size + 10, # Extra candles for safety
|
||||
window_size=args.window_size
|
||||
)
|
||||
if X is not None:
|
||||
all_pairs_data.append(X)
|
||||
else:
|
||||
logger.warning(f"No data available for {s}")
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
if not all_pairs_data:
|
||||
logger.warning("No data available for any pair")
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
# Stack data from all pairs for model input
|
||||
X_combined = np.concatenate(all_pairs_data, axis=2)
|
||||
|
||||
# Get model predictions
|
||||
action_probs, price_pred = model.predict(X_combined)
|
||||
|
||||
# Get predictions for this specific pair
|
||||
action = np.argmax(action_probs[pair_index]) # 0=SELL, 1=HOLD, 2=BUY
|
||||
|
||||
# Get current price for the main pair
|
||||
current_price = data_interface.get_historical_data(
|
||||
timeframe=args.timeframes[0],
|
||||
n_candles=1
|
||||
)['close'].iloc[-1]
|
||||
|
||||
# Calculate PnL if we have a position (only for main pair)
|
||||
pnl = 0.0
|
||||
if is_main_pair and last_action is not None and last_price is not None:
|
||||
if last_action == 2: # BUY
|
||||
pnl = (current_price - last_price) / last_price
|
||||
elif last_action == 0: # SELL
|
||||
pnl = (last_price - current_price) / last_price
|
||||
|
||||
# Update total PnL (only for main pair)
|
||||
if is_main_pair and pnl != 0:
|
||||
total_pnl += pnl
|
||||
|
||||
# Log the prediction
|
||||
action_name = "SELL" if action == 0 else "HOLD" if action == 1 else "BUY"
|
||||
log_msg = f"Time: {timestamp}, Symbol: {symbol}, Action: {action_name}, "
|
||||
if is_main_pair:
|
||||
log_msg += f"Price: {current_price:.2f}, PnL: {pnl:.2%}, Total PnL: {total_pnl:.2%}"
|
||||
else:
|
||||
log_msg += f"Price: {current_price:.2f} (Context Only)"
|
||||
logger.info(log_msg)
|
||||
|
||||
# Update the chart if provided (only for main pair)
|
||||
if chart is not None and is_main_pair and action != 1: # Skip HOLD actions
|
||||
chart.add_trade(
|
||||
action=action_name,
|
||||
price=current_price,
|
||||
timestamp=timestamp,
|
||||
pnl=pnl
|
||||
)
|
||||
|
||||
# Update tracking variables (only for main pair)
|
||||
if is_main_pair and action != 1: # If not HOLD
|
||||
last_action = action
|
||||
last_price = current_price
|
||||
|
||||
# Sleep for a short time
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
if is_main_pair:
|
||||
logger.info(f"Real-time inference stopped by user for {symbol}")
|
||||
logger.info(f"Final performance for {symbol} - Total PnL: {total_pnl:.2%}")
|
||||
else:
|
||||
logger.info(f"Real-time inference stopped by user for {symbol} (Context Only)")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in real-time inference for {symbol}: {str(e)}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,310 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, Any, List, Optional, Callable, Tuple, Union
|
||||
|
||||
from .exchanges import ExchangeInterface, MEXCInterface, BinanceInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TradingAgent:
|
||||
"""Trading agent that executes trades based on neural network signals.
|
||||
|
||||
This agent interfaces with different exchanges and executes trades
|
||||
based on the signals received from the neural network.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
exchange_name: str = 'binance',
|
||||
api_key: str = None,
|
||||
api_secret: str = None,
|
||||
test_mode: bool = True,
|
||||
trade_symbols: List[str] = None,
|
||||
position_size: float = 0.1,
|
||||
max_trades_per_day: int = 5,
|
||||
trade_cooldown_minutes: int = 60):
|
||||
"""Initialize the trading agent.
|
||||
|
||||
Args:
|
||||
exchange_name: Name of the exchange to use ('binance', 'mexc')
|
||||
api_key: API key for the exchange
|
||||
api_secret: API secret for the exchange
|
||||
test_mode: If True, use test/sandbox environment
|
||||
trade_symbols: List of trading symbols to monitor (e.g., ['BTC/USDT'])
|
||||
position_size: Size of each position as a fraction of total available balance (0.0-1.0)
|
||||
max_trades_per_day: Maximum number of trades to execute per day
|
||||
trade_cooldown_minutes: Minimum time between trades in minutes
|
||||
"""
|
||||
self.exchange_name = exchange_name.lower()
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.test_mode = test_mode
|
||||
self.trade_symbols = trade_symbols or ['BTC/USDT']
|
||||
self.position_size = min(max(position_size, 0.01), 1.0) # Ensure between 0.01 and 1.0
|
||||
self.max_trades_per_day = max(1, max_trades_per_day)
|
||||
self.trade_cooldown_seconds = max(60, trade_cooldown_minutes * 60)
|
||||
|
||||
# Initialize exchange interface
|
||||
self.exchange = self._create_exchange()
|
||||
|
||||
# Trading state
|
||||
self.active = False
|
||||
self.current_positions = {} # Symbol -> quantity
|
||||
self.trades_today = {} # Symbol -> count
|
||||
self.last_trade_time = {} # Symbol -> timestamp
|
||||
self.trade_history = [] # List of trade records
|
||||
|
||||
# Threading
|
||||
self.trading_thread = None
|
||||
self.stop_event = threading.Event()
|
||||
|
||||
# Signal callback
|
||||
self.signal_callback = None
|
||||
|
||||
# Connect to exchange
|
||||
if not self.exchange.connect():
|
||||
logger.error(f"Failed to connect to {self.exchange_name} exchange. Trading agent disabled.")
|
||||
else:
|
||||
logger.info(f"Successfully connected to {self.exchange_name} exchange.")
|
||||
self._load_current_positions()
|
||||
|
||||
def _create_exchange(self) -> ExchangeInterface:
|
||||
"""Create an exchange interface based on the exchange name."""
|
||||
if self.exchange_name == 'mexc':
|
||||
return MEXCInterface(
|
||||
api_key=self.api_key,
|
||||
api_secret=self.api_secret,
|
||||
test_mode=self.test_mode
|
||||
)
|
||||
elif self.exchange_name == 'binance':
|
||||
return BinanceInterface(
|
||||
api_key=self.api_key,
|
||||
api_secret=self.api_secret,
|
||||
test_mode=self.test_mode
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported exchange: {self.exchange_name}")
|
||||
|
||||
def _load_current_positions(self):
|
||||
"""Load current positions from the exchange."""
|
||||
for symbol in self.trade_symbols:
|
||||
try:
|
||||
base_asset, quote_asset = symbol.split('/')
|
||||
balance = self.exchange.get_balance(base_asset)
|
||||
|
||||
if balance > 0:
|
||||
self.current_positions[symbol] = balance
|
||||
logger.info(f"Loaded existing position for {symbol}: {balance} {base_asset}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading position for {symbol}: {str(e)}")
|
||||
|
||||
def start(self, signal_callback: Callable = None):
|
||||
"""Start the trading agent.
|
||||
|
||||
Args:
|
||||
signal_callback: Optional callback function to receive trade signals
|
||||
"""
|
||||
if self.active:
|
||||
logger.warning("Trading agent is already running.")
|
||||
return
|
||||
|
||||
self.active = True
|
||||
self.signal_callback = signal_callback
|
||||
self.stop_event.clear()
|
||||
|
||||
logger.info(f"Starting trading agent for {self.exchange_name} exchange.")
|
||||
logger.info(f"Trading symbols: {', '.join(self.trade_symbols)}")
|
||||
logger.info(f"Position size: {self.position_size * 100:.1f}% of available balance")
|
||||
logger.info(f"Max trades per day: {self.max_trades_per_day}")
|
||||
logger.info(f"Trade cooldown: {self.trade_cooldown_seconds // 60} minutes")
|
||||
|
||||
# Reset trading state
|
||||
self.trades_today = {symbol: 0 for symbol in self.trade_symbols}
|
||||
self.last_trade_time = {symbol: 0 for symbol in self.trade_symbols}
|
||||
|
||||
# Start trading thread
|
||||
self.trading_thread = threading.Thread(target=self._trading_loop)
|
||||
self.trading_thread.daemon = True
|
||||
self.trading_thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop the trading agent."""
|
||||
if not self.active:
|
||||
logger.warning("Trading agent is not running.")
|
||||
return
|
||||
|
||||
logger.info("Stopping trading agent...")
|
||||
self.active = False
|
||||
self.stop_event.set()
|
||||
|
||||
if self.trading_thread and self.trading_thread.is_alive():
|
||||
self.trading_thread.join(timeout=10)
|
||||
|
||||
logger.info("Trading agent stopped.")
|
||||
|
||||
def _trading_loop(self):
|
||||
"""Main trading loop that monitors positions and executes trades."""
|
||||
logger.info("Trading loop started.")
|
||||
|
||||
try:
|
||||
while self.active and not self.stop_event.is_set():
|
||||
# Check positions and update state
|
||||
for symbol in self.trade_symbols:
|
||||
try:
|
||||
base_asset, _ = symbol.split('/')
|
||||
current_balance = self.exchange.get_balance(base_asset)
|
||||
|
||||
# Update position if it has changed
|
||||
if symbol in self.current_positions:
|
||||
prev_balance = self.current_positions[symbol]
|
||||
if abs(current_balance - prev_balance) > 0.001 * prev_balance:
|
||||
logger.info(f"Position updated for {symbol}: {prev_balance} -> {current_balance} {base_asset}")
|
||||
|
||||
self.current_positions[symbol] = current_balance
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking position for {symbol}: {str(e)}")
|
||||
|
||||
# Sleep for a while
|
||||
time.sleep(10)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in trading loop: {str(e)}")
|
||||
finally:
|
||||
logger.info("Trading loop stopped.")
|
||||
|
||||
def reset_daily_limits(self):
|
||||
"""Reset daily trading limits. Call this at the start of each trading day."""
|
||||
self.trades_today = {symbol: 0 for symbol in self.trade_symbols}
|
||||
logger.info("Daily trading limits reset.")
|
||||
|
||||
def process_signal(self, symbol: str, action: str,
|
||||
confidence: float = None, timestamp: int = None) -> Optional[Dict[str, Any]]:
|
||||
"""Process a trading signal and execute a trade if conditions are met.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
action: Trade action ('BUY', 'SELL', 'HOLD')
|
||||
confidence: Confidence level of the signal (0.0-1.0)
|
||||
timestamp: Timestamp of the signal (unix time)
|
||||
|
||||
Returns:
|
||||
dict: Trade information if a trade was executed, None otherwise
|
||||
"""
|
||||
if not self.active:
|
||||
logger.warning("Trading agent is not active. Signal ignored.")
|
||||
return None
|
||||
|
||||
if symbol not in self.trade_symbols:
|
||||
logger.warning(f"Symbol {symbol} is not in the trading symbols list. Signal ignored.")
|
||||
return None
|
||||
|
||||
if action not in ['BUY', 'SELL', 'HOLD']:
|
||||
logger.warning(f"Invalid action: {action}. Must be 'BUY', 'SELL', or 'HOLD'.")
|
||||
return None
|
||||
|
||||
# Log the signal
|
||||
confidence_str = f" (confidence: {confidence:.2f})" if confidence is not None else ""
|
||||
logger.info(f"Received {action} signal for {symbol}{confidence_str}")
|
||||
|
||||
# Ignore HOLD signals for trading
|
||||
if action == 'HOLD':
|
||||
return None
|
||||
|
||||
# Check if we can trade based on limits
|
||||
current_time = time.time()
|
||||
|
||||
# Check max trades per day
|
||||
if self.trades_today.get(symbol, 0) >= self.max_trades_per_day:
|
||||
logger.warning(f"Max trades per day reached for {symbol}. Signal ignored.")
|
||||
return None
|
||||
|
||||
# Check trade cooldown
|
||||
last_trade_time = self.last_trade_time.get(symbol, 0)
|
||||
if current_time - last_trade_time < self.trade_cooldown_seconds:
|
||||
cooldown_remaining = self.trade_cooldown_seconds - (current_time - last_trade_time)
|
||||
logger.warning(f"Trade cooldown active for {symbol}. {cooldown_remaining:.1f} seconds remaining. Signal ignored.")
|
||||
return None
|
||||
|
||||
# Check if the action makes sense based on current position
|
||||
base_asset, _ = symbol.split('/')
|
||||
current_position = self.current_positions.get(symbol, 0)
|
||||
|
||||
if action == 'BUY' and current_position > 0:
|
||||
logger.warning(f"Already have a position in {symbol}. BUY signal ignored.")
|
||||
return None
|
||||
|
||||
if action == 'SELL' and current_position <= 0:
|
||||
logger.warning(f"No position in {symbol} to sell. SELL signal ignored.")
|
||||
return None
|
||||
|
||||
# Execute the trade
|
||||
try:
|
||||
trade_result = self.exchange.execute_trade(
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
percent_of_balance=self.position_size
|
||||
)
|
||||
|
||||
if trade_result:
|
||||
# Update trading state
|
||||
self.trades_today[symbol] = self.trades_today.get(symbol, 0) + 1
|
||||
self.last_trade_time[symbol] = current_time
|
||||
|
||||
# Create trade record
|
||||
trade_record = {
|
||||
'symbol': symbol,
|
||||
'action': action,
|
||||
'timestamp': timestamp or int(current_time),
|
||||
'confidence': confidence,
|
||||
'order_id': trade_result.get('orderId') if isinstance(trade_result, dict) else None,
|
||||
'status': 'executed'
|
||||
}
|
||||
|
||||
# Add to trade history
|
||||
self.trade_history.append(trade_record)
|
||||
|
||||
# Call signal callback if provided
|
||||
if self.signal_callback:
|
||||
self.signal_callback(trade_record)
|
||||
|
||||
logger.info(f"Successfully executed {action} trade for {symbol}")
|
||||
return trade_record
|
||||
else:
|
||||
logger.error(f"Failed to execute {action} trade for {symbol}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing trade for {symbol}: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_current_positions(self) -> Dict[str, float]:
|
||||
"""Get current positions.
|
||||
|
||||
Returns:
|
||||
dict: Symbol -> position size
|
||||
"""
|
||||
return self.current_positions.copy()
|
||||
|
||||
def get_trade_history(self) -> List[Dict[str, Any]]:
|
||||
"""Get trade history.
|
||||
|
||||
Returns:
|
||||
list: List of trade records
|
||||
"""
|
||||
return self.trade_history.copy()
|
||||
|
||||
def get_exchange_info(self) -> Dict[str, Any]:
|
||||
"""Get exchange information.
|
||||
|
||||
Returns:
|
||||
dict: Exchange information
|
||||
"""
|
||||
return {
|
||||
'name': self.exchange_name,
|
||||
'test_mode': self.test_mode,
|
||||
'active': self.active,
|
||||
'trade_symbols': self.trade_symbols,
|
||||
'position_size': self.position_size,
|
||||
'max_trades_per_day': self.max_trades_per_day,
|
||||
'trade_cooldown_seconds': self.trade_cooldown_seconds,
|
||||
'trades_today': self.trades_today.copy()
|
||||
}
|
@ -1,585 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
import contextlib
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Import our enhanced agent
|
||||
from NN.models.dqn_agent_enhanced import EnhancedDQNAgent
|
||||
from NN.utils.data_interface import DataInterface
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler('logs/enhanced_training.log')
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(description='Train enhanced RL trading agent')
|
||||
parser.add_argument('--episodes', type=int, default=100, help='Number of episodes to train')
|
||||
parser.add_argument('--max-steps', type=int, default=2000, help='Maximum steps per episode')
|
||||
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading symbol')
|
||||
parser.add_argument('--no-gpu', action='store_true', help='Disable GPU usage')
|
||||
parser.add_argument('--confidence', type=float, default=0.4, help='Confidence threshold')
|
||||
parser.add_argument('--load-model', type=str, default='', help='Load existing model')
|
||||
parser.add_argument('--batch-size', type=int, default=128, help='Training batch size')
|
||||
parser.add_argument('--learning-rate', type=float, default=0.0003, help='Learning rate')
|
||||
parser.add_argument('--no-pretrain', action='store_true', help='Skip pre-training')
|
||||
parser.add_argument('--pretrain-epochs', type=int, default=20, help='Number of pre-training epochs')
|
||||
return parser.parse_args()
|
||||
|
||||
def generate_price_prediction_training_data(data_1m, data_1h, data_1d, window_size=20):
|
||||
"""
|
||||
Generate labeled training data for price prediction pre-training
|
||||
|
||||
Args:
|
||||
data_1m: 1-minute candle data
|
||||
data_1h: 1-hour candle data
|
||||
data_1d: 1-day candle data
|
||||
window_size: Size of the observation window
|
||||
|
||||
Returns:
|
||||
X, y_immediate, y_midterm, y_longterm, y_values
|
||||
"""
|
||||
logger.info("Generating price prediction training data")
|
||||
|
||||
# Features to use
|
||||
ohlcv_columns = ['open', 'high', 'low', 'close', 'volume']
|
||||
|
||||
# Create feature sets
|
||||
X = []
|
||||
y_immediate = [] # 1m prediction (next 5min)
|
||||
y_midterm = [] # 1h prediction (next few hours)
|
||||
y_longterm = [] # 1d prediction (next day)
|
||||
y_values = [] # % change for each timeframe
|
||||
|
||||
# Need enough data for all timeframes
|
||||
if len(data_1m) < window_size + 5 or len(data_1h) < 2 or len(data_1d) < 2:
|
||||
logger.error("Not enough data for all timeframes")
|
||||
return np.array([]), np.array([]), np.array([]), np.array([]), np.array([])
|
||||
|
||||
# Generate examples
|
||||
for i in range(window_size, len(data_1m) - 5):
|
||||
# Skip if we can't align with higher timeframes
|
||||
if i % 60 != 0: # Only use minutes that align with hour boundaries
|
||||
continue
|
||||
|
||||
try:
|
||||
# Get window of 1m data as input
|
||||
window_1m = data_1m[i-window_size:i][ohlcv_columns].values
|
||||
|
||||
# Find corresponding indices in higher timeframes
|
||||
curr_timestamp = data_1m.index[i]
|
||||
h_idx = data_1h.index.get_indexer([curr_timestamp], method='nearest')[0]
|
||||
d_idx = data_1d.index.get_indexer([curr_timestamp], method='nearest')[0]
|
||||
|
||||
# Skip if indices are out of bounds
|
||||
if h_idx < 0 or h_idx >= len(data_1h) - 1 or d_idx < 0 or d_idx >= len(data_1d) - 1:
|
||||
continue
|
||||
|
||||
# Get future prices for label generation
|
||||
future_5m = data_1m[i+5]['close']
|
||||
future_1h = data_1h[h_idx+1]['close']
|
||||
future_1d = data_1d[d_idx+1]['close']
|
||||
|
||||
current_price = data_1m[i]['close']
|
||||
|
||||
# Calculate % change for each timeframe
|
||||
change_5m = (future_5m - current_price) / current_price * 100
|
||||
change_1h = (future_1h - current_price) / current_price * 100
|
||||
change_1d = (future_1d - current_price) / current_price * 100
|
||||
|
||||
# Determine price direction (0=down, 1=sideways, 2=up)
|
||||
def get_direction(change):
|
||||
if change < -0.5: # Down if less than -0.5%
|
||||
return 0
|
||||
elif change > 0.5: # Up if more than 0.5%
|
||||
return 2
|
||||
else: # Sideways if between -0.5% and 0.5%
|
||||
return 1
|
||||
|
||||
direction_5m = get_direction(change_5m)
|
||||
direction_1h = get_direction(change_1h)
|
||||
direction_1d = get_direction(change_1d)
|
||||
|
||||
# Add to dataset
|
||||
X.append(window_1m.flatten())
|
||||
y_immediate.append(direction_5m)
|
||||
y_midterm.append(direction_1h)
|
||||
y_longterm.append(direction_1d)
|
||||
y_values.append([change_5m, change_1h, change_1d, 0]) # Last value reserved
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error generating training example at index {i}: {str(e)}")
|
||||
|
||||
# Convert to numpy arrays
|
||||
X = np.array(X)
|
||||
y_immediate = np.array(y_immediate)
|
||||
y_midterm = np.array(y_midterm)
|
||||
y_longterm = np.array(y_longterm)
|
||||
y_values = np.array(y_values)
|
||||
|
||||
logger.info(f"Generated {len(X)} training examples")
|
||||
logger.info(f"Class distribution - Immediate: {np.bincount(y_immediate)}, "
|
||||
f"Midterm: {np.bincount(y_midterm)}, Long-term: {np.bincount(y_longterm)}")
|
||||
|
||||
return X, y_immediate, y_midterm, y_longterm, y_values
|
||||
|
||||
def pretrain_price_prediction(agent, data_interface, n_epochs=20, batch_size=128, device=None):
|
||||
"""
|
||||
Pre-train the price prediction capabilities of the agent
|
||||
|
||||
Args:
|
||||
agent: EnhancedDQNAgent instance
|
||||
data_interface: DataInterface instance
|
||||
n_epochs: Number of pre-training epochs
|
||||
batch_size: Batch size for pre-training
|
||||
device: Device to use for pre-training
|
||||
|
||||
Returns:
|
||||
The pre-trained agent
|
||||
"""
|
||||
logger.info("Starting price prediction pre-training")
|
||||
|
||||
try:
|
||||
# Ensure we have the necessary timeframes
|
||||
timeframes_needed = ['1m', '1h', '1d']
|
||||
for tf in timeframes_needed:
|
||||
if tf not in data_interface.timeframes:
|
||||
logger.info(f"Adding timeframe {tf} for pre-training")
|
||||
# Add timeframe to the list if not present
|
||||
if tf not in data_interface.timeframes:
|
||||
data_interface.timeframes.append(tf)
|
||||
data_interface.dataframes[tf] = None
|
||||
|
||||
# Get data for each timeframe
|
||||
data_1m = data_interface.get_historical_data(timeframe='1m')
|
||||
data_1h = data_interface.get_historical_data(timeframe='1h')
|
||||
data_1d = data_interface.get_historical_data(timeframe='1d')
|
||||
|
||||
# Generate labeled training data
|
||||
X, y_immediate, y_midterm, y_longterm, y_values = generate_price_prediction_training_data(
|
||||
data_1m, data_1h, data_1d, window_size=20
|
||||
)
|
||||
|
||||
if len(X) == 0:
|
||||
logger.error("No training examples generated. Skipping pre-training.")
|
||||
return agent
|
||||
|
||||
# Split data into training and validation sets
|
||||
X_train, X_val, y_imm_train, y_imm_val, y_mid_train, y_mid_val, y_long_train, y_long_val, y_val_train, y_val_val = train_test_split(
|
||||
X, y_immediate, y_midterm, y_longterm, y_values, test_size=0.2, random_state=42
|
||||
)
|
||||
|
||||
# Convert to torch tensors
|
||||
X_train_tensor = torch.FloatTensor(X_train).to(device)
|
||||
y_imm_train_tensor = torch.LongTensor(y_imm_train).to(device)
|
||||
y_mid_train_tensor = torch.LongTensor(y_mid_train).to(device)
|
||||
y_long_train_tensor = torch.LongTensor(y_long_train).to(device)
|
||||
y_val_train_tensor = torch.FloatTensor(y_val_train).to(device)
|
||||
|
||||
X_val_tensor = torch.FloatTensor(X_val).to(device)
|
||||
y_imm_val_tensor = torch.LongTensor(y_imm_val).to(device)
|
||||
y_mid_val_tensor = torch.LongTensor(y_mid_val).to(device)
|
||||
y_long_val_tensor = torch.LongTensor(y_long_val).to(device)
|
||||
y_val_val_tensor = torch.FloatTensor(y_val_val).to(device)
|
||||
|
||||
# Calculate class weights for imbalanced data
|
||||
def get_class_weights(labels):
|
||||
counts = np.bincount(labels)
|
||||
if len(counts) < 3: # Ensure we have 3 classes
|
||||
counts = np.append(counts, [0] * (3 - len(counts)))
|
||||
weights = 1.0 / np.array(counts)
|
||||
weights = weights / np.sum(weights) # Normalize
|
||||
return weights
|
||||
|
||||
imm_weights = torch.FloatTensor(get_class_weights(y_imm_train)).to(device)
|
||||
mid_weights = torch.FloatTensor(get_class_weights(y_mid_train)).to(device)
|
||||
long_weights = torch.FloatTensor(get_class_weights(y_long_train)).to(device)
|
||||
|
||||
# Create DataLoader for batch training
|
||||
train_dataset = TensorDataset(
|
||||
X_train_tensor, y_imm_train_tensor, y_mid_train_tensor,
|
||||
y_long_train_tensor, y_val_train_tensor
|
||||
)
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# Set up loss functions with class weights
|
||||
imm_criterion = nn.CrossEntropyLoss(weight=imm_weights)
|
||||
mid_criterion = nn.CrossEntropyLoss(weight=mid_weights)
|
||||
long_criterion = nn.CrossEntropyLoss(weight=long_weights)
|
||||
value_criterion = nn.MSELoss()
|
||||
|
||||
# Set up optimizer (separate from agent's optimizer)
|
||||
pretrain_optimizer = torch.optim.Adam(agent.policy_net.parameters(), lr=0.0002)
|
||||
pretrain_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
pretrain_optimizer, mode='min', factor=0.5, patience=3, verbose=True
|
||||
)
|
||||
|
||||
# Set model to training mode
|
||||
agent.policy_net.train()
|
||||
|
||||
# Training loop
|
||||
best_val_loss = float('inf')
|
||||
patience = 5
|
||||
patience_counter = 0
|
||||
|
||||
# Create TensorBoard writer for pre-training
|
||||
writer = SummaryWriter(log_dir=f'runs/pretrain_{int(time.time())}')
|
||||
|
||||
for epoch in range(n_epochs):
|
||||
# Training phase
|
||||
train_loss = 0.0
|
||||
imm_correct, mid_correct, long_correct = 0, 0, 0
|
||||
total = 0
|
||||
|
||||
for X_batch, y_imm_batch, y_mid_batch, y_long_batch, y_val_batch in train_loader:
|
||||
# Zero gradients
|
||||
pretrain_optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
with torch.cuda.amp.autocast() if agent.use_mixed_precision else contextlib.nullcontext():
|
||||
q_values, _, price_preds, _ = agent.policy_net(X_batch)
|
||||
|
||||
# Calculate losses for each prediction head
|
||||
imm_loss = imm_criterion(price_preds['immediate'], y_imm_batch)
|
||||
mid_loss = mid_criterion(price_preds['midterm'], y_mid_batch)
|
||||
long_loss = long_criterion(price_preds['longterm'], y_long_batch)
|
||||
value_loss = value_criterion(price_preds['values'], y_val_batch)
|
||||
|
||||
# Combined loss (weighted by importance)
|
||||
total_loss = imm_loss + 0.7 * mid_loss + 0.5 * long_loss + 0.3 * value_loss
|
||||
|
||||
# Backward pass and optimize
|
||||
if agent.use_mixed_precision:
|
||||
agent.scaler.scale(total_loss).backward()
|
||||
agent.scaler.unscale_(pretrain_optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
|
||||
agent.scaler.step(pretrain_optimizer)
|
||||
agent.scaler.update()
|
||||
else:
|
||||
total_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
|
||||
pretrain_optimizer.step()
|
||||
|
||||
# Accumulate metrics
|
||||
train_loss += total_loss.item()
|
||||
total += X_batch.size(0)
|
||||
|
||||
# Calculate accuracy
|
||||
_, imm_pred = torch.max(price_preds['immediate'], 1)
|
||||
_, mid_pred = torch.max(price_preds['midterm'], 1)
|
||||
_, long_pred = torch.max(price_preds['longterm'], 1)
|
||||
|
||||
imm_correct += (imm_pred == y_imm_batch).sum().item()
|
||||
mid_correct += (mid_pred == y_mid_batch).sum().item()
|
||||
long_correct += (long_pred == y_long_batch).sum().item()
|
||||
|
||||
# Calculate epoch metrics
|
||||
train_loss /= len(train_loader)
|
||||
imm_acc = imm_correct / total
|
||||
mid_acc = mid_correct / total
|
||||
long_acc = long_correct / total
|
||||
|
||||
# Validation phase
|
||||
agent.policy_net.eval()
|
||||
val_loss = 0.0
|
||||
imm_val_correct, mid_val_correct, long_val_correct = 0, 0, 0
|
||||
|
||||
with torch.no_grad():
|
||||
# Forward pass on validation data
|
||||
q_values, _, val_price_preds, _ = agent.policy_net(X_val_tensor)
|
||||
|
||||
# Calculate validation losses
|
||||
val_imm_loss = imm_criterion(val_price_preds['immediate'], y_imm_val_tensor)
|
||||
val_mid_loss = mid_criterion(val_price_preds['midterm'], y_mid_val_tensor)
|
||||
val_long_loss = long_criterion(val_price_preds['longterm'], y_long_val_tensor)
|
||||
val_value_loss = value_criterion(val_price_preds['values'], y_val_val_tensor)
|
||||
|
||||
val_total_loss = val_imm_loss + 0.7 * val_mid_loss + 0.5 * val_long_loss + 0.3 * val_value_loss
|
||||
val_loss = val_total_loss.item()
|
||||
|
||||
# Calculate validation accuracy
|
||||
_, imm_val_pred = torch.max(val_price_preds['immediate'], 1)
|
||||
_, mid_val_pred = torch.max(val_price_preds['midterm'], 1)
|
||||
_, long_val_pred = torch.max(val_price_preds['longterm'], 1)
|
||||
|
||||
imm_val_correct = (imm_val_pred == y_imm_val_tensor).sum().item()
|
||||
mid_val_correct = (mid_val_pred == y_mid_val_tensor).sum().item()
|
||||
long_val_correct = (long_val_pred == y_long_val_tensor).sum().item()
|
||||
|
||||
imm_val_acc = imm_val_correct / len(X_val_tensor)
|
||||
mid_val_acc = mid_val_correct / len(X_val_tensor)
|
||||
long_val_acc = long_val_correct / len(X_val_tensor)
|
||||
|
||||
# Log to TensorBoard
|
||||
writer.add_scalar('pretrain/train_loss', train_loss, epoch)
|
||||
writer.add_scalar('pretrain/val_loss', val_loss, epoch)
|
||||
writer.add_scalar('pretrain/imm_acc', imm_acc, epoch)
|
||||
writer.add_scalar('pretrain/mid_acc', mid_acc, epoch)
|
||||
writer.add_scalar('pretrain/long_acc', long_acc, epoch)
|
||||
writer.add_scalar('pretrain/imm_val_acc', imm_val_acc, epoch)
|
||||
writer.add_scalar('pretrain/mid_val_acc', mid_val_acc, epoch)
|
||||
writer.add_scalar('pretrain/long_val_acc', long_val_acc, epoch)
|
||||
|
||||
# Learning rate scheduling
|
||||
pretrain_scheduler.step(val_loss)
|
||||
|
||||
# Early stopping check
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
patience_counter = 0
|
||||
# Copy policy_net weights to target_net
|
||||
agent.target_net.load_state_dict(agent.policy_net.state_dict())
|
||||
logger.info(f"Saved best model with validation loss: {val_loss:.4f}")
|
||||
# Save pre-trained model
|
||||
agent.save("NN/models/saved/enhanced_dqn_pretrained")
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= patience:
|
||||
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
||||
break
|
||||
|
||||
# Log progress
|
||||
logger.info(f"Epoch {epoch+1}/{n_epochs}: "
|
||||
f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
|
||||
f"Imm Acc: {imm_acc:.4f}/{imm_val_acc:.4f}, "
|
||||
f"Mid Acc: {mid_acc:.4f}/{mid_val_acc:.4f}, "
|
||||
f"Long Acc: {long_acc:.4f}/{long_val_acc:.4f}")
|
||||
|
||||
# Set model back to training mode for next epoch
|
||||
agent.policy_net.train()
|
||||
|
||||
writer.close()
|
||||
logger.info("Price prediction pre-training complete")
|
||||
return agent
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during price prediction pre-training: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return agent
|
||||
|
||||
def train_enhanced_rl(args):
|
||||
"""
|
||||
Train the enhanced RL agent for trading
|
||||
|
||||
Args:
|
||||
args: Command line arguments
|
||||
"""
|
||||
# Setup device
|
||||
if args.no_gpu:
|
||||
device = torch.device('cpu')
|
||||
else:
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Set up data interface
|
||||
data_interface = DataInterface(symbol=args.symbol, timeframes=['1m', '5m', '15m'])
|
||||
|
||||
# Fetch historical data for each timeframe
|
||||
for timeframe in data_interface.timeframes:
|
||||
df = data_interface.get_historical_data(timeframe=timeframe)
|
||||
logger.info(f"Using data for {args.symbol} {timeframe} ({len(data_interface.dataframes[timeframe])} candles)")
|
||||
|
||||
# Create environment for training
|
||||
from NN.environments.trading_env import TradingEnvironment
|
||||
window_size = 20
|
||||
train_env = TradingEnvironment(
|
||||
data_interface=data_interface,
|
||||
initial_balance=10000.0,
|
||||
transaction_fee=0.0002,
|
||||
window_size=window_size,
|
||||
max_position=1.0,
|
||||
reward_scaling=100.0
|
||||
)
|
||||
|
||||
# Create agent with improved parameters
|
||||
state_shape = train_env.observation_space.shape
|
||||
n_actions = train_env.action_space.n
|
||||
|
||||
agent = EnhancedDQNAgent(
|
||||
state_shape=state_shape,
|
||||
n_actions=n_actions,
|
||||
learning_rate=args.learning_rate,
|
||||
gamma=0.95,
|
||||
epsilon=1.0,
|
||||
epsilon_min=0.05,
|
||||
epsilon_decay=0.995,
|
||||
buffer_size=50000,
|
||||
batch_size=args.batch_size,
|
||||
target_update=10,
|
||||
confidence_threshold=args.confidence,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Load existing model if specified
|
||||
if args.load_model:
|
||||
model_path = args.load_model
|
||||
if agent.load(model_path):
|
||||
logger.info(f"Loaded existing model from {model_path}")
|
||||
else:
|
||||
logger.error(f"Error loading model from {model_path}")
|
||||
|
||||
# Pre-training for price prediction
|
||||
if not args.no_pretrain and not args.load_model:
|
||||
logger.info("Starting pre-training phase")
|
||||
agent = pretrain_price_prediction(
|
||||
agent=agent,
|
||||
data_interface=data_interface,
|
||||
n_epochs=args.pretrain_epochs,
|
||||
batch_size=args.batch_size,
|
||||
device=device
|
||||
)
|
||||
logger.info("Pre-training completed")
|
||||
|
||||
# Setup TensorBoard
|
||||
writer = SummaryWriter(log_dir=f'runs/enhanced_rl_{int(time.time())}')
|
||||
|
||||
# Log hardware info
|
||||
writer.add_text("hardware/device", str(device), 0)
|
||||
if torch.cuda.is_available():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
writer.add_text(f"hardware/gpu_{i}", torch.cuda.get_device_name(i), 0)
|
||||
|
||||
# Move agent to device
|
||||
agent.move_models_to_device(device)
|
||||
|
||||
# Training loop
|
||||
logger.info(f"Starting enhanced training for {args.episodes} episodes")
|
||||
|
||||
total_rewards = []
|
||||
episode_losses = []
|
||||
trade_win_rates = []
|
||||
best_reward = -np.inf
|
||||
|
||||
try:
|
||||
for episode in range(args.episodes):
|
||||
# Reset environment for new episode
|
||||
state = train_env.reset()
|
||||
total_reward = 0.0
|
||||
done = False
|
||||
step = 0
|
||||
episode_start_time = time.time()
|
||||
|
||||
# Track trade statistics
|
||||
trades = []
|
||||
wins = 0
|
||||
losses = 0
|
||||
|
||||
# Run episode
|
||||
while not done and step < args.max_steps:
|
||||
# Choose action
|
||||
action, confidence = agent.act(state)
|
||||
|
||||
# Take action in environment
|
||||
next_state, reward, done, info = train_env.step(action)
|
||||
|
||||
# Remember experience
|
||||
agent.remember(state, action, reward, next_state, done)
|
||||
|
||||
# Track trade results
|
||||
if 'trade_result' in info and info['trade_result'] is not None:
|
||||
trade_result = info['trade_result']
|
||||
trade_pnl = trade_result['pnl']
|
||||
trades.append(trade_pnl)
|
||||
|
||||
if trade_pnl > 0:
|
||||
wins += 1
|
||||
logger.info(f"Profitable trade! {trade_pnl:.2f}% profit, reward: {reward:.4f}")
|
||||
else:
|
||||
losses += 1
|
||||
logger.info(f"Loss trade! {trade_pnl:.2f}% loss, penalty: {reward:.4f}")
|
||||
|
||||
# Update state and counters
|
||||
state = next_state
|
||||
total_reward += reward
|
||||
step += 1
|
||||
|
||||
# Train agent
|
||||
loss = agent.replay()
|
||||
if loss > 0:
|
||||
episode_losses.append(loss)
|
||||
|
||||
# Log training metrics for each episode
|
||||
episode_time = time.time() - episode_start_time
|
||||
total_rewards.append(total_reward)
|
||||
|
||||
# Calculate win rate
|
||||
win_rate = wins / max(1, (wins + losses))
|
||||
trade_win_rates.append(win_rate)
|
||||
|
||||
# Log to console and TensorBoard
|
||||
logger.info(f"Episode {episode}/{args.episodes} - Reward: {total_reward:.4f}, Win Rate: {win_rate:.2f}, "
|
||||
f"Trades: {len(trades)}, Balance: ${train_env.balance:.2f}, Epsilon: {agent.epsilon:.4f}, "
|
||||
f"Time: {episode_time:.2f}s")
|
||||
|
||||
writer.add_scalar('metrics/reward', total_reward, episode)
|
||||
writer.add_scalar('metrics/balance', train_env.balance, episode)
|
||||
writer.add_scalar('metrics/win_rate', win_rate, episode)
|
||||
writer.add_scalar('metrics/trades', len(trades), episode)
|
||||
writer.add_scalar('metrics/epsilon', agent.epsilon, episode)
|
||||
|
||||
if episode_losses:
|
||||
avg_loss = sum(episode_losses) / len(episode_losses)
|
||||
writer.add_scalar('metrics/loss', avg_loss, episode)
|
||||
|
||||
# Check if this is the best model so far
|
||||
if total_reward > best_reward:
|
||||
best_reward = total_reward
|
||||
# Save best model
|
||||
agent.save(f"NN/models/saved/enhanced_dqn_best")
|
||||
logger.info(f"New best model saved with reward: {best_reward:.4f}")
|
||||
|
||||
# Save checkpoint every 10 episodes
|
||||
if episode % 10 == 0 and episode > 0:
|
||||
agent.save(f"NN/models/saved/enhanced_dqn_checkpoint")
|
||||
logger.info(f"Checkpoint saved at episode {episode}")
|
||||
|
||||
# Reset episode losses
|
||||
episode_losses = []
|
||||
|
||||
# Final save
|
||||
agent.save(f"NN/models/saved/enhanced_dqn_final")
|
||||
logger.info("Enhanced training completed, final model saved")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
# Close TensorBoard writer
|
||||
writer.close()
|
||||
|
||||
return agent, train_env
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
os.makedirs("NN/models/saved", exist_ok=True)
|
||||
|
||||
# Parse arguments
|
||||
args = parse_args()
|
||||
|
||||
# Start training
|
||||
train_enhanced_rl(args)
|
657
NN/train_rl.py
657
NN/train_rl.py
@ -1,657 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
import os
|
||||
import sys
|
||||
import pandas as pd
|
||||
import gym
|
||||
import json
|
||||
import random
|
||||
import torch.nn as nn
|
||||
import contextlib
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from NN.utils.data_interface import DataInterface
|
||||
from NN.utils.trading_env import TradingEnvironment
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
from NN.utils.signal_interpreter import SignalInterpreter
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('rl_training.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
|
||||
# Set up device for PyTorch (use GPU if available)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Log GPU status
|
||||
if torch.cuda.is_available():
|
||||
gpu_count = torch.cuda.device_count()
|
||||
gpu_names = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
|
||||
logger.info(f"Using GPU: {gpu_names}")
|
||||
|
||||
# Enable TensorFloat32 for NVIDIA Ampere GPUs for faster training
|
||||
if hasattr(torch.cuda, 'amp') and torch.cuda.is_bf16_supported():
|
||||
logger.info("BFloat16 precision is supported - will use for faster training")
|
||||
else:
|
||||
logger.warning("GPU not available. Using CPU for training (slower).")
|
||||
|
||||
class RLTradingEnvironment(gym.Env):
|
||||
"""
|
||||
Reinforcement Learning environment for trading with technical indicators
|
||||
from multiple timeframes
|
||||
"""
|
||||
def __init__(self, features_1m, features_1h, features_1d, window_size=20, trading_fee=0.0025, min_trade_interval=15):
|
||||
super().__init__()
|
||||
|
||||
# Initialize attributes before parent class
|
||||
self.window_size = window_size
|
||||
self.num_features = features_1m.shape[1] - 1 # Exclude close price
|
||||
|
||||
# Count available timeframes
|
||||
self.num_timeframes = 3 # We require all timeframes now
|
||||
self.feature_dim = self.num_features * self.num_timeframes
|
||||
|
||||
# Store features from different timeframes
|
||||
self.features_1m = features_1m
|
||||
self.features_1h = features_1h
|
||||
self.features_1d = features_1d
|
||||
|
||||
# Trading parameters
|
||||
self.initial_balance = 1.0
|
||||
self.trading_fee = trading_fee # Increased from 0.001 to 0.0025 (0.25%)
|
||||
self.min_trade_interval = min_trade_interval # Minimum steps between trades
|
||||
|
||||
# Define action and observation spaces
|
||||
self.action_space = gym.spaces.Discrete(3) # 0: Buy, 1: Sell, 2: Hold
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=-np.inf,
|
||||
high=np.inf,
|
||||
shape=(self.window_size, self.feature_dim),
|
||||
dtype=np.float32
|
||||
)
|
||||
|
||||
# State variables
|
||||
self.reset()
|
||||
|
||||
# Callback for visualization or external monitoring
|
||||
self.action_callback = None
|
||||
|
||||
def reset(self):
|
||||
"""Reset the environment to initial state"""
|
||||
self.balance = self.initial_balance
|
||||
self.position = 0.0 # Amount of asset held
|
||||
self.current_step = self.window_size
|
||||
self.trades = 0
|
||||
self.wins = 0
|
||||
self.losses = 0
|
||||
self.trade_history = []
|
||||
self.last_trade_step = -self.min_trade_interval # Initialize to allow immediate first trade
|
||||
|
||||
# Get initial observation
|
||||
observation = self._get_observation()
|
||||
return observation
|
||||
|
||||
def _get_observation(self):
|
||||
"""
|
||||
Get the current state observation.
|
||||
Combine features from multiple timeframes, reshaped for the CNN.
|
||||
"""
|
||||
# Calculate indices for each timeframe
|
||||
idx_1m = min(self.current_step, self.features_1m.shape[0] - 1)
|
||||
idx_1h = idx_1m // 60 # 60 minutes in an hour
|
||||
idx_1d = idx_1h // 24 # 24 hours in a day
|
||||
|
||||
# Cap indices to prevent out of bounds
|
||||
idx_1h = min(idx_1h, self.features_1h.shape[0] - 1)
|
||||
idx_1d = min(idx_1d, self.features_1d.shape[0] - 1)
|
||||
|
||||
# Extract feature windows from each timeframe
|
||||
window_1m = self.features_1m[max(0, idx_1m - self.window_size):idx_1m]
|
||||
|
||||
# Handle hourly timeframe
|
||||
start_1h = max(0, idx_1h - self.window_size)
|
||||
window_1h = self.features_1h[start_1h:idx_1h]
|
||||
|
||||
# Handle daily timeframe
|
||||
start_1d = max(0, idx_1d - self.window_size)
|
||||
window_1d = self.features_1d[start_1d:idx_1d]
|
||||
|
||||
# Pad if needed (for higher timeframes)
|
||||
if len(window_1m) < self.window_size:
|
||||
padding = np.zeros((self.window_size - len(window_1m), window_1m.shape[1]))
|
||||
window_1m = np.vstack([padding, window_1m])
|
||||
|
||||
if len(window_1h) < self.window_size:
|
||||
padding = np.zeros((self.window_size - len(window_1h), window_1h.shape[1]))
|
||||
window_1h = np.vstack([padding, window_1h])
|
||||
|
||||
if len(window_1d) < self.window_size:
|
||||
padding = np.zeros((self.window_size - len(window_1d), window_1d.shape[1]))
|
||||
window_1d = np.vstack([padding, window_1d])
|
||||
|
||||
# Combine features from all timeframes
|
||||
combined_features = np.hstack([
|
||||
window_1m.reshape(self.window_size, -1),
|
||||
window_1h.reshape(self.window_size, -1),
|
||||
window_1d.reshape(self.window_size, -1)
|
||||
])
|
||||
|
||||
# Convert to float32 and handle any NaN values
|
||||
combined_features = np.nan_to_num(combined_features, nan=0.0).astype(np.float32)
|
||||
|
||||
return combined_features
|
||||
|
||||
def step(self, action):
|
||||
"""Take an action and return the next state, reward, done flag, and info"""
|
||||
# Initialize info dictionary for additional data
|
||||
info = {
|
||||
'trade_executed': False,
|
||||
'price_change': 0.0,
|
||||
'position_change': 0,
|
||||
'current_price': 0.0,
|
||||
'next_price': 0.0,
|
||||
'balance_change': 0.0,
|
||||
'reward_components': {},
|
||||
'future_prices': {}
|
||||
}
|
||||
|
||||
# Get the current and next price
|
||||
current_price = self.features_1m[self.current_step, -1]
|
||||
|
||||
# Handle edge case at the end of the data
|
||||
if self.current_step >= len(self.features_1m) - 1:
|
||||
next_price = current_price # Use current price as next price
|
||||
done = True
|
||||
else:
|
||||
next_price = self.features_1m[self.current_step + 1, -1]
|
||||
done = False
|
||||
|
||||
# Handle zero or negative price (data error)
|
||||
if current_price <= 0:
|
||||
current_price = 0.01 # Set to a small positive number
|
||||
logger.warning(f"Zero or negative price detected at step {self.current_step}. Setting to 0.01.")
|
||||
|
||||
if next_price <= 0:
|
||||
next_price = current_price # Use current price instead
|
||||
logger.warning(f"Zero or negative next price detected at step {self.current_step + 1}. Using current price.")
|
||||
|
||||
# Calculate price change as percentage
|
||||
price_change_pct = ((next_price - current_price) / current_price) * 100
|
||||
|
||||
# Store prices in info
|
||||
info['current_price'] = current_price
|
||||
info['next_price'] = next_price
|
||||
info['price_change'] = price_change_pct
|
||||
|
||||
# Initialize reward components dictionary
|
||||
reward_components = {
|
||||
'holding_reward': 0.0,
|
||||
'action_reward': 0.0,
|
||||
'profit_reward': 0.0,
|
||||
'trade_freq_penalty': 0.0
|
||||
}
|
||||
|
||||
# Default small negative reward to discourage inaction
|
||||
reward = -0.01
|
||||
reward_components['holding_reward'] = -0.01
|
||||
|
||||
# Track previous balance for changes
|
||||
previous_balance = self.balance
|
||||
|
||||
# Execute action (0: Buy, 1: Sell, 2: Hold)
|
||||
if action == 0: # Buy
|
||||
if self.position == 0: # Only buy if we don't already have a position
|
||||
# Calculate how much of the asset we can buy with 100% of balance
|
||||
self.position = self.balance / current_price
|
||||
self.balance = 0 # All balance used
|
||||
|
||||
# If price goes up after buying, that's good
|
||||
expected_profit = price_change_pct
|
||||
# Scale reward based on expected profit
|
||||
if expected_profit > 0:
|
||||
# Positive reward for profitable buy decision
|
||||
action_reward = 0.1 + (expected_profit * 0.05) # Base reward + profit-based bonus
|
||||
reward_components['action_reward'] = action_reward
|
||||
reward += action_reward
|
||||
else:
|
||||
# Small negative reward for unprofitable buy
|
||||
action_reward = -0.1 + (expected_profit * 0.03) # Smaller penalty for small losses
|
||||
reward_components['action_reward'] = action_reward
|
||||
reward += action_reward
|
||||
|
||||
# Check if we've traded too frequently
|
||||
if len(self.trade_history) > 0:
|
||||
last_trade_step = self.trade_history[-1]['step']
|
||||
if self.current_step - last_trade_step < 5: # If less than 5 steps since last trade
|
||||
freq_penalty = -0.2 # Penalty for trading too frequently
|
||||
reward += freq_penalty
|
||||
reward_components['trade_freq_penalty'] = freq_penalty
|
||||
|
||||
# Record the trade
|
||||
self.trade_history.append({
|
||||
'step': self.current_step,
|
||||
'action': 'buy',
|
||||
'price': current_price,
|
||||
'position': self.position,
|
||||
'balance': self.balance
|
||||
})
|
||||
|
||||
info['trade_executed'] = True
|
||||
logger.info(f"Buy at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
|
||||
|
||||
elif action == 1: # Sell
|
||||
if self.position > 0: # Only sell if we have a position
|
||||
# Calculate sale proceeds
|
||||
sale_value = self.position * current_price
|
||||
|
||||
# Calculate profit or loss percentage from last buy
|
||||
last_buy_price = None
|
||||
for trade in reversed(self.trade_history):
|
||||
if trade['action'] == 'buy':
|
||||
last_buy_price = trade['price']
|
||||
break
|
||||
|
||||
# If we found the last buy price, calculate profit
|
||||
if last_buy_price is not None:
|
||||
profit_pct = ((current_price - last_buy_price) / last_buy_price) * 100
|
||||
|
||||
# Highly reward profitable trades
|
||||
if profit_pct > 0:
|
||||
# Progressive reward based on profit percentage
|
||||
profit_reward = min(5.0, profit_pct * 0.2) # Cap at 5.0 to prevent exploitation
|
||||
reward_components['profit_reward'] = profit_reward
|
||||
reward += profit_reward
|
||||
logger.info(f"Profitable trade! {profit_pct:.2f}% profit, reward: {profit_reward:.4f}")
|
||||
else:
|
||||
# Penalize losses more heavily based on size of loss
|
||||
loss_penalty = max(-3.0, profit_pct * 0.15) # Cap at -3.0 to prevent excessive punishment
|
||||
reward_components['profit_reward'] = loss_penalty
|
||||
reward += loss_penalty
|
||||
logger.info(f"Loss trade! {profit_pct:.2f}% loss, penalty: {loss_penalty:.4f}")
|
||||
|
||||
# If price goes down after selling, that's good
|
||||
if price_change_pct < 0:
|
||||
# Reward for good timing on sell (avoiding future loss)
|
||||
timing_reward = min(1.0, abs(price_change_pct) * 0.05)
|
||||
reward_components['action_reward'] = timing_reward
|
||||
reward += timing_reward
|
||||
|
||||
# Check for trading too frequently
|
||||
if len(self.trade_history) > 0:
|
||||
last_trade_step = self.trade_history[-1]['step']
|
||||
if self.current_step - last_trade_step < 5: # If less than 5 steps since last trade
|
||||
freq_penalty = -0.2 # Penalty for trading too frequently
|
||||
reward += freq_penalty
|
||||
reward_components['trade_freq_penalty'] = freq_penalty
|
||||
|
||||
# Update balance and position
|
||||
self.balance = sale_value
|
||||
position_change = self.position
|
||||
self.position = 0
|
||||
|
||||
# Record the trade
|
||||
self.trade_history.append({
|
||||
'step': self.current_step,
|
||||
'action': 'sell',
|
||||
'price': current_price,
|
||||
'position': self.position,
|
||||
'balance': self.balance
|
||||
})
|
||||
|
||||
info['trade_executed'] = True
|
||||
info['position_change'] = position_change
|
||||
logger.info(f"Sell at step {self.current_step}, price: {current_price:.4f}, new balance: {self.balance:.4f}")
|
||||
|
||||
elif action == 2: # Hold
|
||||
# Small reward if holding was a good decision
|
||||
if self.position > 0 and price_change_pct > 0: # Holding long position during price increase
|
||||
hold_reward = price_change_pct * 0.01 # Small reward proportional to price increase
|
||||
reward += hold_reward
|
||||
reward_components['holding_reward'] = hold_reward
|
||||
elif self.position == 0 and price_change_pct < 0: # Holding cash during price decrease
|
||||
hold_reward = abs(price_change_pct) * 0.01 # Small reward for avoiding loss
|
||||
reward += hold_reward
|
||||
reward_components['holding_reward'] = hold_reward
|
||||
|
||||
# Move to the next step
|
||||
self.current_step += 1
|
||||
|
||||
# Update current portfolio value
|
||||
if self.position > 0:
|
||||
self.current_value = self.balance + (self.position * next_price)
|
||||
else:
|
||||
self.current_value = self.balance
|
||||
|
||||
# Calculate balance change
|
||||
balance_change = self.current_value - previous_balance
|
||||
info['balance_change'] = balance_change
|
||||
|
||||
# Check if we've reached the end of the data
|
||||
if self.current_step >= len(self.features_1m) - 1:
|
||||
done = True
|
||||
|
||||
# Final evaluation if we have a position
|
||||
if self.position > 0:
|
||||
# Sell remaining position at the final price
|
||||
final_balance = self.balance + (self.position * next_price)
|
||||
|
||||
# Calculate final portfolio value and return
|
||||
final_return_pct = ((final_balance - self.initial_balance) / self.initial_balance) * 100
|
||||
|
||||
# Add big reward/penalty based on overall performance
|
||||
performance_reward = final_return_pct * 0.1
|
||||
reward += performance_reward
|
||||
reward_components['final_performance'] = performance_reward
|
||||
|
||||
logger.info(f"Episode ended. Final balance: {final_balance:.4f}, Return: {final_return_pct:.2f}%")
|
||||
|
||||
# Get future prices for evaluation (1-hour and 1-day ahead)
|
||||
info['future_prices'] = {}
|
||||
|
||||
# 1-hour future price if hourly data is available
|
||||
if hasattr(self, 'features_1h') and self.features_1h is not None:
|
||||
# Find the closest hourly data point
|
||||
if self.current_step < len(self.features_1m):
|
||||
current_time = self.current_step # Use as index for simplicity
|
||||
hourly_idx = min(current_time // 60, len(self.features_1h) - 1) # Assuming 60 minutes per hour
|
||||
if hourly_idx < len(self.features_1h) - 1:
|
||||
future_1h_price = self.features_1h[hourly_idx + 1, -1]
|
||||
info['future_prices']['1h'] = future_1h_price
|
||||
|
||||
# 1-day future price if daily data is available
|
||||
if hasattr(self, 'features_1d') and self.features_1d is not None:
|
||||
# Find the closest daily data point
|
||||
if self.current_step < len(self.features_1m):
|
||||
current_time = self.current_step # Use as index for simplicity
|
||||
daily_idx = min(current_time // 1440, len(self.features_1d) - 1) # Assuming 1440 minutes per day
|
||||
if daily_idx < len(self.features_1d) - 1:
|
||||
future_1d_price = self.features_1d[daily_idx + 1, -1]
|
||||
info['future_prices']['1d'] = future_1d_price
|
||||
|
||||
# Get next observation
|
||||
next_state = self._get_observation()
|
||||
|
||||
# Store reward components in info
|
||||
info['reward_components'] = reward_components
|
||||
|
||||
# Clip reward to prevent extreme values
|
||||
reward = np.clip(reward, -10.0, 10.0)
|
||||
|
||||
return next_state, reward, done, info
|
||||
|
||||
def set_action_callback(self, callback):
|
||||
"""
|
||||
Set a callback function to be called after each action
|
||||
|
||||
Args:
|
||||
callback: Function with signature (action, price, reward, info)
|
||||
"""
|
||||
self.action_callback = callback
|
||||
|
||||
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
|
||||
action_callback=None, episode_callback=None, symbol="BTC/USDT",
|
||||
pretrain_price_prediction_enabled=False, pretrain_epochs=10):
|
||||
"""
|
||||
Train a reinforcement learning agent for trading using ONLY real market data
|
||||
|
||||
Args:
|
||||
env_class: Optional environment class override
|
||||
num_episodes: Number of episodes to train for
|
||||
max_steps: Maximum steps per episode
|
||||
save_path: Path to save the trained model
|
||||
action_callback: Callback function for monitoring actions
|
||||
episode_callback: Callback function for monitoring episodes
|
||||
symbol: Trading symbol to use
|
||||
pretrain_price_prediction_enabled: DEPRECATED - No longer supported (synthetic data not used)
|
||||
pretrain_epochs: DEPRECATED - No longer supported (synthetic data not used)
|
||||
|
||||
Returns:
|
||||
tuple: (trained agent, environment)
|
||||
"""
|
||||
# Load data for the selected symbol
|
||||
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m', '1h', '1d'])
|
||||
|
||||
try:
|
||||
# Try to load data for the requested symbol using get_historical_data method
|
||||
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
|
||||
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
|
||||
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
|
||||
data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000)
|
||||
data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500)
|
||||
|
||||
if data_1m is None or data_5m is None or data_15m is None or data_1h is None or data_1d is None:
|
||||
raise FileNotFoundError("Could not retrieve all required timeframes data for specified symbol")
|
||||
except Exception as e:
|
||||
logger.warning(f"Data for {symbol} not available: {str(e)}. Using default cached data.")
|
||||
# Try to use cached data if available
|
||||
symbol = "BTC/USDT"
|
||||
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m', '1h', '1d'])
|
||||
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
|
||||
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
|
||||
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
|
||||
data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000)
|
||||
data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500)
|
||||
|
||||
if data_1m is None or data_5m is None or data_15m is None or data_1h is None or data_1d is None:
|
||||
logger.error("Failed to retrieve all required timeframes data. Cannot continue training.")
|
||||
raise ValueError("No data available for training")
|
||||
|
||||
# Create features from the data by adding technical indicators and converting to numpy format
|
||||
if data_1m is not None:
|
||||
data_1m = data_interface.add_technical_indicators(data_1m)
|
||||
# Convert to numpy array with close price as the last column
|
||||
features_1m = np.hstack([
|
||||
data_1m.drop(['timestamp', 'close'], axis=1).values,
|
||||
data_1m['close'].values.reshape(-1, 1)
|
||||
])
|
||||
else:
|
||||
features_1m = None
|
||||
|
||||
if data_5m is not None:
|
||||
data_5m = data_interface.add_technical_indicators(data_5m)
|
||||
# Convert to numpy array with close price as the last column
|
||||
features_5m = np.hstack([
|
||||
data_5m.drop(['timestamp', 'close'], axis=1).values,
|
||||
data_5m['close'].values.reshape(-1, 1)
|
||||
])
|
||||
else:
|
||||
features_5m = None
|
||||
|
||||
if data_15m is not None:
|
||||
data_15m = data_interface.add_technical_indicators(data_15m)
|
||||
# Convert to numpy array with close price as the last column
|
||||
features_15m = np.hstack([
|
||||
data_15m.drop(['timestamp', 'close'], axis=1).values,
|
||||
data_15m['close'].values.reshape(-1, 1)
|
||||
])
|
||||
else:
|
||||
features_15m = None
|
||||
|
||||
if data_1h is not None:
|
||||
data_1h = data_interface.add_technical_indicators(data_1h)
|
||||
# Convert to numpy array with close price as the last column
|
||||
features_1h = np.hstack([
|
||||
data_1h.drop(['timestamp', 'close'], axis=1).values,
|
||||
data_1h['close'].values.reshape(-1, 1)
|
||||
])
|
||||
else:
|
||||
features_1h = None
|
||||
|
||||
if data_1d is not None:
|
||||
data_1d = data_interface.add_technical_indicators(data_1d)
|
||||
# Convert to numpy array with close price as the last column
|
||||
features_1d = np.hstack([
|
||||
data_1d.drop(['timestamp', 'close'], axis=1).values,
|
||||
data_1d['close'].values.reshape(-1, 1)
|
||||
])
|
||||
else:
|
||||
features_1d = None
|
||||
|
||||
# Check if we have all the required features
|
||||
if features_1m is None or features_5m is None or features_15m is None or features_1h is None or features_1d is None:
|
||||
logger.error("Failed to create features for all timeframes.")
|
||||
raise ValueError("Could not create features for training")
|
||||
|
||||
# Create the environment
|
||||
if env_class:
|
||||
# Use provided environment class
|
||||
env = env_class(features_1m, features_1h, features_1d)
|
||||
else:
|
||||
# Use the default environment
|
||||
env = RLTradingEnvironment(features_1m, features_1h, features_1d)
|
||||
|
||||
# Set action callback if provided
|
||||
if action_callback:
|
||||
env.set_action_callback(action_callback)
|
||||
|
||||
# Get environment properties for agent creation
|
||||
input_shape = env.observation_space.shape
|
||||
n_actions = env.action_space.n
|
||||
|
||||
# Create the agent
|
||||
agent = DQNAgent(
|
||||
state_shape=input_shape,
|
||||
n_actions=n_actions,
|
||||
epsilon=1.0,
|
||||
epsilon_decay=0.995,
|
||||
epsilon_min=0.01,
|
||||
learning_rate=0.0001,
|
||||
gamma=0.99,
|
||||
buffer_size=10000,
|
||||
batch_size=64,
|
||||
device=device # Pass device to agent for GPU usage
|
||||
)
|
||||
|
||||
# Check if model file exists and load it
|
||||
model_file = f"{save_path}_model.pth"
|
||||
if os.path.exists(model_file):
|
||||
try:
|
||||
agent.load(model_file)
|
||||
logger.info(f"Loaded existing model from {model_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {e}")
|
||||
else:
|
||||
logger.info("No existing model found. Starting with a new model.")
|
||||
|
||||
# Remove pre-training code since it used synthetic data
|
||||
# Pre-training with real data would require a separate implementation
|
||||
if pretrain_price_prediction_enabled:
|
||||
logger.warning("Pre-training with synthetic data is no longer supported. Continuing with RL training only.")
|
||||
|
||||
# Create TensorBoard writer
|
||||
writer = SummaryWriter(log_dir=f'runs/dqn_{int(time.time())}')
|
||||
|
||||
# Log GPU status to TensorBoard
|
||||
writer.add_text("hardware/device", str(device), 0)
|
||||
if torch.cuda.is_available():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
writer.add_text(f"hardware/gpu_{i}", torch.cuda.get_device_name(i), 0)
|
||||
|
||||
# Training loop
|
||||
total_rewards = []
|
||||
trade_win_rates = []
|
||||
best_reward = -np.inf
|
||||
|
||||
# Move models to the appropriate device if not already there
|
||||
agent.move_models_to_device(device)
|
||||
|
||||
# Enable mixed precision if GPU and feature is available
|
||||
use_mixed_precision = False
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp'):
|
||||
logger.info("Enabling mixed precision training")
|
||||
use_mixed_precision = True
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
# Define step callback for tensorboard logging and model tracking
|
||||
def step_callback(action, price, reward, info):
|
||||
# Pass to external callback if provided
|
||||
if action_callback:
|
||||
action_callback(env.current_step, action, price, reward, info)
|
||||
|
||||
# Main training loop
|
||||
logger.info(f"Starting training for {num_episodes} episodes...")
|
||||
logger.info(f"Starting training on device: {agent.device}")
|
||||
|
||||
try:
|
||||
for episode in range(num_episodes):
|
||||
state = env.reset()
|
||||
total_reward = 0
|
||||
|
||||
for step in range(max_steps):
|
||||
# Select action
|
||||
action = agent.act(state)
|
||||
|
||||
# Take action and observe next state and reward
|
||||
next_state, reward, done, info = env.step(action)
|
||||
|
||||
# Store the experience in memory
|
||||
agent.remember(state, action, reward, next_state, done)
|
||||
|
||||
# Update state and reward
|
||||
state = next_state
|
||||
total_reward += reward
|
||||
|
||||
# Train the agent by sampling from memory
|
||||
if len(agent.memory) >= agent.batch_size:
|
||||
loss = agent.replay()
|
||||
|
||||
if done or step == max_steps - 1:
|
||||
break
|
||||
|
||||
# Track rewards
|
||||
total_rewards.append(total_reward)
|
||||
|
||||
# Calculate trading metrics
|
||||
win_rate = env.wins / max(1, env.trades)
|
||||
trades = env.trades
|
||||
|
||||
# Log to TensorBoard
|
||||
writer.add_scalar('Reward/Episode', total_reward, episode)
|
||||
writer.add_scalar('Trade/WinRate', win_rate, episode)
|
||||
writer.add_scalar('Trade/Count', trades, episode)
|
||||
|
||||
# Save best model
|
||||
if total_reward > best_reward and episode > 10:
|
||||
logger.info(f"New best average reward: {total_reward:.4f}, saving model")
|
||||
agent.save(save_path)
|
||||
best_reward = total_reward
|
||||
|
||||
# Periodic save every 100 episodes
|
||||
if episode % 100 == 0 and episode > 0:
|
||||
agent.save(f"{save_path}_episode_{episode}")
|
||||
|
||||
# Call episode callback if provided
|
||||
if episode_callback:
|
||||
# Add environment to info dict to use for extrema training
|
||||
info_with_env = info.copy()
|
||||
info_with_env['env'] = env
|
||||
episode_callback(episode, total_reward, info_with_env)
|
||||
|
||||
# Final save
|
||||
logger.info("Training completed, saving final model")
|
||||
agent.save(f"{save_path}_final")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# Close TensorBoard writer
|
||||
writer.close()
|
||||
|
||||
return agent, env
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_rl()
|
Reference in New Issue
Block a user