BIG CLEANUP
This commit is contained in:
@@ -1,16 +0,0 @@
|
||||
"""
|
||||
Neural Network Trading System
|
||||
============================
|
||||
|
||||
A comprehensive neural network trading system that uses deep learning models
|
||||
to analyze cryptocurrency price data and generate trading signals.
|
||||
|
||||
The system consists of:
|
||||
1. Data Interface: Connects to realtime trading data
|
||||
2. CNN Model: Deep convolutional neural network for feature extraction
|
||||
3. Transformer Model: Processes high-level features for improved pattern recognition
|
||||
4. MoE: Mixture of Experts model that combines multiple neural networks
|
||||
"""
|
||||
|
||||
__version__ = '0.1.0'
|
||||
__author__ = 'Gogo2 Project'
|
@@ -1,27 +0,0 @@
|
||||
"""
|
||||
Neural Network Models
|
||||
====================
|
||||
|
||||
This package contains the neural network models used in the trading system:
|
||||
- CNN Model: Deep convolutional neural network for feature extraction
|
||||
- DQN Agent: Deep Q-Network for reinforcement learning
|
||||
- COB RL Model: Specialized RL model for order book data
|
||||
- Advanced Transformer: High-performance transformer for trading
|
||||
|
||||
PyTorch implementation only.
|
||||
"""
|
||||
|
||||
# Import core models
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
from NN.models.cob_rl_model import COBRLModelInterface
|
||||
from NN.models.advanced_transformer_trading import AdvancedTradingTransformer, TradingTransformerConfig
|
||||
from NN.models.standardized_cnn import StandardizedCNN # Use the unified CNN model
|
||||
|
||||
# Import model interfaces
|
||||
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface
|
||||
|
||||
# Export the unified StandardizedCNN as CNNModel for compatibility
|
||||
CNNModel = StandardizedCNN
|
||||
|
||||
__all__ = ['CNNModel', 'StandardizedCNN', 'DQNAgent', 'COBRLModelInterface', 'AdvancedTradingTransformer', 'TradingTransformerConfig',
|
||||
'ModelInterface', 'CNNModelInterface', 'RLAgentInterface', 'ExtremaTrainerInterface']
|
@@ -1,201 +0,0 @@
|
||||
# """
|
||||
# Legacy CNN Model Compatibility Layer
|
||||
|
||||
# This module provides compatibility redirects to the unified StandardizedCNN model.
|
||||
# All legacy models (EnhancedCNNModel, CNNModelTrainer, CNNModel) have been retired
|
||||
# in favor of the StandardizedCNN architecture.
|
||||
# """
|
||||
|
||||
# import logging
|
||||
# import warnings
|
||||
# from typing import Tuple, Dict, Any, Optional
|
||||
# import torch
|
||||
# import numpy as np
|
||||
|
||||
# # Import the standardized CNN model
|
||||
# from .standardized_cnn import StandardizedCNN
|
||||
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
# # Compatibility aliases and wrappers
|
||||
# class EnhancedCNNModel:
|
||||
# """Legacy compatibility wrapper - redirects to StandardizedCNN"""
|
||||
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# warnings.warn(
|
||||
# "EnhancedCNNModel is deprecated. Use StandardizedCNN instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# # Create StandardizedCNN with default parameters
|
||||
# self.standardized_cnn = StandardizedCNN()
|
||||
# logger.warning("EnhancedCNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
||||
|
||||
# def __getattr__(self, name):
|
||||
# """Delegate all method calls to StandardizedCNN"""
|
||||
# return getattr(self.standardized_cnn, name)
|
||||
|
||||
|
||||
# class CNNModelTrainer:
|
||||
# """Legacy compatibility wrapper for CNN training"""
|
||||
|
||||
# def __init__(self, model=None, *args, **kwargs):
|
||||
# warnings.warn(
|
||||
# "CNNModelTrainer is deprecated. Use StandardizedCNN.train_step() instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# if isinstance(model, EnhancedCNNModel):
|
||||
# self.model = model.standardized_cnn
|
||||
# else:
|
||||
# self.model = StandardizedCNN()
|
||||
# logger.warning("CNNModelTrainer compatibility wrapper created - please use StandardizedCNN.train_step()")
|
||||
|
||||
# def train_step(self, x, y, *args, **kwargs):
|
||||
# """Legacy train step wrapper"""
|
||||
# try:
|
||||
# # Convert to BaseDataInput format if needed
|
||||
# if hasattr(x, 'get_feature_vector'):
|
||||
# # Already BaseDataInput
|
||||
# base_input = x
|
||||
# else:
|
||||
# # Create mock BaseDataInput for legacy compatibility
|
||||
# from core.data_models import BaseDataInput
|
||||
# base_input = BaseDataInput()
|
||||
# # Set mock feature vector
|
||||
# if isinstance(x, torch.Tensor):
|
||||
# feature_vector = x.flatten().cpu().numpy()
|
||||
# else:
|
||||
# feature_vector = np.array(x).flatten()
|
||||
|
||||
# # Pad or truncate to expected size
|
||||
# expected_size = self.model.expected_feature_dim
|
||||
# if len(feature_vector) < expected_size:
|
||||
# padding = np.zeros(expected_size - len(feature_vector))
|
||||
# feature_vector = np.concatenate([feature_vector, padding])
|
||||
# else:
|
||||
# feature_vector = feature_vector[:expected_size]
|
||||
|
||||
# base_input._feature_vector = feature_vector
|
||||
|
||||
# # Convert target to string format
|
||||
# if isinstance(y, torch.Tensor):
|
||||
# y_val = y.item() if y.numel() == 1 else y.argmax().item()
|
||||
# else:
|
||||
# y_val = int(y) if np.isscalar(y) else int(np.argmax(y))
|
||||
|
||||
# target_map = {0: 'BUY', 1: 'SELL', 2: 'HOLD'}
|
||||
# target = target_map.get(y_val, 'HOLD')
|
||||
|
||||
# # Use StandardizedCNN training
|
||||
# optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
||||
# loss = self.model.train_step([base_input], [target], optimizer)
|
||||
|
||||
# return {'total_loss': loss, 'main_loss': loss, 'accuracy': 0.5}
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Legacy train_step error: {e}")
|
||||
# return {'total_loss': 0.0, 'main_loss': 0.0, 'accuracy': 0.5}
|
||||
|
||||
|
||||
# # class CNNModel:
|
||||
# # """Legacy compatibility wrapper for CNN model interface"""
|
||||
|
||||
# # def __init__(self, input_shape=(900, 50), output_size=3, model_path=None):
|
||||
# # warnings.warn(
|
||||
# # "CNNModel is deprecated. Use StandardizedCNN directly.",
|
||||
# # DeprecationWarning,
|
||||
# # stacklevel=2
|
||||
# # )
|
||||
# # self.input_shape = input_shape
|
||||
# # self.output_size = output_size
|
||||
# # self.standardized_cnn = StandardizedCNN()
|
||||
# # self.trainer = CNNModelTrainer(self.standardized_cnn)
|
||||
# # logger.warning("CNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
||||
|
||||
# # def build_model(self, **kwargs):
|
||||
# # """Legacy build method - no-op for StandardizedCNN"""
|
||||
# # return self
|
||||
|
||||
# # def predict(self, X):
|
||||
# # """Legacy predict method"""
|
||||
# # try:
|
||||
# # # Convert input to BaseDataInput
|
||||
# # from core.data_models import BaseDataInput
|
||||
# # base_input = BaseDataInput()
|
||||
|
||||
# # if isinstance(X, np.ndarray):
|
||||
# # feature_vector = X.flatten()
|
||||
# # else:
|
||||
# # feature_vector = np.array(X).flatten()
|
||||
|
||||
# # # Pad or truncate to expected size
|
||||
# # expected_size = self.standardized_cnn.expected_feature_dim
|
||||
# # if len(feature_vector) < expected_size:
|
||||
# # padding = np.zeros(expected_size - len(feature_vector))
|
||||
# # feature_vector = np.concatenate([feature_vector, padding])
|
||||
# # else:
|
||||
# # feature_vector = feature_vector[:expected_size]
|
||||
|
||||
# # base_input._feature_vector = feature_vector
|
||||
|
||||
# # # Get prediction from StandardizedCNN
|
||||
# # result = self.standardized_cnn.predict_from_base_input(base_input)
|
||||
|
||||
# # # Convert to legacy format
|
||||
# # action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
|
||||
# # pred_class = np.array([action_map.get(result.predictions['action'], 2)])
|
||||
# # pred_proba = np.array([result.predictions['action_probabilities']])
|
||||
|
||||
# # return pred_class, pred_proba
|
||||
|
||||
# # except Exception as e:
|
||||
# # logger.error(f"Legacy predict error: {e}")
|
||||
# # # Return safe defaults
|
||||
# # pred_class = np.array([2]) # HOLD
|
||||
# # pred_proba = np.array([[0.33, 0.33, 0.34]])
|
||||
# # return pred_class, pred_proba
|
||||
|
||||
# # def fit(self, X, y, **kwargs):
|
||||
# # """Legacy fit method"""
|
||||
# # try:
|
||||
# # return self.trainer.train_step(X, y)
|
||||
# # except Exception as e:
|
||||
# # logger.error(f"Legacy fit error: {e}")
|
||||
# # return self
|
||||
|
||||
# # def save(self, filepath: str):
|
||||
# # """Legacy save method"""
|
||||
# # try:
|
||||
# # torch.save(self.standardized_cnn.state_dict(), filepath)
|
||||
# # logger.info(f"StandardizedCNN saved to {filepath}")
|
||||
# # except Exception as e:
|
||||
# # logger.error(f"Error saving model: {e}")
|
||||
|
||||
|
||||
# def create_enhanced_cnn_model(input_size: int = 60,
|
||||
# feature_dim: int = 50,
|
||||
# output_size: int = 3,
|
||||
# base_channels: int = 256,
|
||||
# device: str = 'cuda') -> Tuple[StandardizedCNN, CNNModelTrainer]:
|
||||
# """Legacy compatibility function - returns StandardizedCNN"""
|
||||
# warnings.warn(
|
||||
# "create_enhanced_cnn_model is deprecated. Use StandardizedCNN() directly.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
# model = StandardizedCNN()
|
||||
# trainer = CNNModelTrainer(model)
|
||||
|
||||
# logger.warning("Legacy create_enhanced_cnn_model called - please use StandardizedCNN directly")
|
||||
# return model, trainer
|
||||
|
||||
|
||||
# # Export compatibility symbols
|
||||
# __all__ = [
|
||||
# 'EnhancedCNNModel',
|
||||
# 'CNNModelTrainer',
|
||||
# # 'CNNModel',
|
||||
# 'create_enhanced_cnn_model'
|
||||
# ]
|
@@ -1,821 +0,0 @@
|
||||
"""
|
||||
Transformer Neural Network for timeseries analysis
|
||||
|
||||
This module implements a Transformer model with attention mechanisms for cryptocurrency price analysis.
|
||||
It also includes a Mixture of Experts model that combines predictions from multiple models.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import Model, load_model
|
||||
from tensorflow.keras.layers import (
|
||||
Input, Dense, Dropout, BatchNormalization,
|
||||
Concatenate, Layer, LayerNormalization, MultiHeadAttention,
|
||||
Add, GlobalAveragePooling1D, Conv1D, Reshape
|
||||
)
|
||||
from tensorflow.keras.optimizers import Adam
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
|
||||
import datetime
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TransformerBlock(Layer):
|
||||
"""
|
||||
Transformer block implementation with multi-head attention and feed-forward networks.
|
||||
"""
|
||||
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
|
||||
super(TransformerBlock, self).__init__()
|
||||
self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
|
||||
self.ffn = tf.keras.Sequential([
|
||||
Dense(ff_dim, activation="relu"),
|
||||
Dense(embed_dim),
|
||||
])
|
||||
self.layernorm1 = LayerNormalization(epsilon=1e-6)
|
||||
self.layernorm2 = LayerNormalization(epsilon=1e-6)
|
||||
self.dropout1 = Dropout(rate)
|
||||
self.dropout2 = Dropout(rate)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
attn_output = self.att(inputs, inputs)
|
||||
attn_output = self.dropout1(attn_output, training=training)
|
||||
out1 = self.layernorm1(inputs + attn_output)
|
||||
ffn_output = self.ffn(out1)
|
||||
ffn_output = self.dropout2(ffn_output, training=training)
|
||||
return self.layernorm2(out1 + ffn_output)
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config()
|
||||
config.update({
|
||||
'att': self.att,
|
||||
'ffn': self.ffn,
|
||||
'layernorm1': self.layernorm1,
|
||||
'layernorm2': self.layernorm2,
|
||||
'dropout1': self.dropout1,
|
||||
'dropout2': self.dropout2
|
||||
})
|
||||
return config
|
||||
|
||||
class PositionalEncoding(Layer):
|
||||
"""
|
||||
Positional encoding layer to add position information to input embeddings.
|
||||
"""
|
||||
def __init__(self, position, d_model):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.position = position
|
||||
self.d_model = d_model
|
||||
self.pos_encoding = self.positional_encoding(position, d_model)
|
||||
|
||||
def get_angles(self, position, i, d_model):
|
||||
angles = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(d_model, tf.float32))
|
||||
return position * angles
|
||||
|
||||
def positional_encoding(self, position, d_model):
|
||||
angle_rads = self.get_angles(
|
||||
position=tf.range(position, dtype=tf.float32)[:, tf.newaxis],
|
||||
i=tf.range(d_model, dtype=tf.float32)[tf.newaxis, :],
|
||||
d_model=d_model
|
||||
)
|
||||
|
||||
# Apply sin to even indices in the array
|
||||
sines = tf.math.sin(angle_rads[:, 0::2])
|
||||
|
||||
# Apply cos to odd indices in the array
|
||||
cosines = tf.math.cos(angle_rads[:, 1::2])
|
||||
|
||||
pos_encoding = tf.concat([sines, cosines], axis=-1)
|
||||
pos_encoding = pos_encoding[tf.newaxis, ...]
|
||||
|
||||
return tf.cast(pos_encoding, tf.float32)
|
||||
|
||||
def call(self, inputs):
|
||||
return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config()
|
||||
config.update({
|
||||
'position': self.position,
|
||||
'd_model': self.d_model,
|
||||
'pos_encoding': self.pos_encoding
|
||||
})
|
||||
return config
|
||||
|
||||
class TransformerModel:
|
||||
"""
|
||||
Transformer Neural Network for time series analysis.
|
||||
|
||||
This model uses self-attention mechanisms to capture relationships between
|
||||
different time points in the input data.
|
||||
"""
|
||||
|
||||
def __init__(self, ts_input_shape=(20, 5), feature_input_shape=64, output_size=1, model_dir="NN/models/saved"):
|
||||
"""
|
||||
Initialize the Transformer model.
|
||||
|
||||
Args:
|
||||
ts_input_shape (tuple): Shape of time series input data (sequence_length, features)
|
||||
feature_input_shape (int): Shape of additional feature input (e.g., from CNN)
|
||||
output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell)
|
||||
model_dir (str): Directory to save trained models
|
||||
"""
|
||||
self.ts_input_shape = ts_input_shape
|
||||
self.feature_input_shape = feature_input_shape
|
||||
self.output_size = output_size
|
||||
self.model_dir = model_dir
|
||||
self.model = None
|
||||
self.history = None
|
||||
|
||||
# Create model directory if it doesn't exist
|
||||
os.makedirs(self.model_dir, exist_ok=True)
|
||||
|
||||
logger.info(f"Initialized Transformer model with TS input shape {ts_input_shape}, "
|
||||
f"feature input shape {feature_input_shape}, and output size {output_size}")
|
||||
|
||||
def build_model(self, embed_dim=32, num_heads=4, ff_dim=64, num_transformer_blocks=2, dropout_rate=0.1, learning_rate=0.001):
|
||||
"""
|
||||
Build the Transformer model architecture.
|
||||
|
||||
Args:
|
||||
embed_dim (int): Embedding dimension for transformer
|
||||
num_heads (int): Number of attention heads
|
||||
ff_dim (int): Hidden dimension of the feed forward network
|
||||
num_transformer_blocks (int): Number of transformer blocks
|
||||
dropout_rate (float): Dropout rate for regularization
|
||||
learning_rate (float): Learning rate for Adam optimizer
|
||||
|
||||
Returns:
|
||||
The compiled model
|
||||
"""
|
||||
# Time series input
|
||||
ts_inputs = Input(shape=self.ts_input_shape, name="ts_input")
|
||||
|
||||
# Additional feature input (e.g., from CNN)
|
||||
feature_inputs = Input(shape=(self.feature_input_shape,), name="feature_input")
|
||||
|
||||
# Process time series with transformer
|
||||
# First, project the input to the embedding dimension
|
||||
x = Conv1D(embed_dim, 1, activation="relu")(ts_inputs)
|
||||
|
||||
# Add positional encoding
|
||||
x = PositionalEncoding(self.ts_input_shape[0], embed_dim)(x)
|
||||
|
||||
# Add transformer blocks
|
||||
for _ in range(num_transformer_blocks):
|
||||
x = TransformerBlock(embed_dim, num_heads, ff_dim, dropout_rate)(x)
|
||||
|
||||
# Global pooling to get a single vector representation
|
||||
x = GlobalAveragePooling1D()(x)
|
||||
x = Dropout(dropout_rate)(x)
|
||||
|
||||
# Combine with additional features
|
||||
combined = Concatenate()([x, feature_inputs])
|
||||
|
||||
# Dense layers for final classification/regression
|
||||
x = Dense(64, activation="relu")(combined)
|
||||
x = BatchNormalization()(x)
|
||||
x = Dropout(dropout_rate)(x)
|
||||
|
||||
# Output layer
|
||||
if self.output_size == 1:
|
||||
# Binary classification (up/down)
|
||||
outputs = Dense(1, activation='sigmoid', name='output')(x)
|
||||
loss = 'binary_crossentropy'
|
||||
metrics = ['accuracy']
|
||||
elif self.output_size == 3:
|
||||
# Multi-class classification (buy/hold/sell)
|
||||
outputs = Dense(3, activation='softmax', name='output')(x)
|
||||
loss = 'categorical_crossentropy'
|
||||
metrics = ['accuracy']
|
||||
else:
|
||||
# Regression
|
||||
outputs = Dense(self.output_size, activation='linear', name='output')(x)
|
||||
loss = 'mse'
|
||||
metrics = ['mae']
|
||||
|
||||
# Create and compile model
|
||||
self.model = Model(inputs=[ts_inputs, feature_inputs], outputs=outputs)
|
||||
|
||||
# Compile with Adam optimizer
|
||||
self.model.compile(
|
||||
optimizer=Adam(learning_rate=learning_rate),
|
||||
loss=loss,
|
||||
metrics=metrics
|
||||
)
|
||||
|
||||
# Log model summary
|
||||
self.model.summary(print_fn=lambda x: logger.info(x))
|
||||
|
||||
return self.model
|
||||
|
||||
def train(self, X_ts, X_features, y, batch_size=32, epochs=100, validation_split=0.2,
|
||||
callbacks=None, class_weights=None):
|
||||
"""
|
||||
Train the Transformer model on the provided data.
|
||||
|
||||
Args:
|
||||
X_ts (numpy.ndarray): Time series input features
|
||||
X_features (numpy.ndarray): Additional input features
|
||||
y (numpy.ndarray): Target labels
|
||||
batch_size (int): Batch size
|
||||
epochs (int): Number of epochs
|
||||
validation_split (float): Fraction of data to use for validation
|
||||
callbacks (list): List of Keras callbacks
|
||||
class_weights (dict): Class weights for imbalanced datasets
|
||||
|
||||
Returns:
|
||||
History object containing training metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
self.build_model()
|
||||
|
||||
# Default callbacks if none provided
|
||||
if callbacks is None:
|
||||
# Create a timestamp for model checkpoints
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
callbacks = [
|
||||
EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
restore_best_weights=True
|
||||
),
|
||||
ReduceLROnPlateau(
|
||||
monitor='val_loss',
|
||||
factor=0.5,
|
||||
patience=5,
|
||||
min_lr=1e-6
|
||||
),
|
||||
ModelCheckpoint(
|
||||
filepath=os.path.join(self.model_dir, f"transformer_model_{timestamp}.h5"),
|
||||
monitor='val_loss',
|
||||
save_best_only=True
|
||||
)
|
||||
]
|
||||
|
||||
# Check if y needs to be one-hot encoded for multi-class
|
||||
if self.output_size == 3 and len(y.shape) == 1:
|
||||
y = tf.keras.utils.to_categorical(y, num_classes=3)
|
||||
|
||||
# Train the model
|
||||
logger.info(f"Training Transformer model with {len(X_ts)} samples, batch size {batch_size}, epochs {epochs}")
|
||||
self.history = self.model.fit(
|
||||
[X_ts, X_features], y,
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
validation_split=validation_split,
|
||||
callbacks=callbacks,
|
||||
class_weight=class_weights,
|
||||
verbose=2
|
||||
)
|
||||
|
||||
# Save the trained model
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_path = os.path.join(self.model_dir, f"transformer_model_final_{timestamp}.h5")
|
||||
self.model.save(model_path)
|
||||
logger.info(f"Model saved to {model_path}")
|
||||
|
||||
# Save training history
|
||||
history_path = os.path.join(self.model_dir, f"transformer_model_history_{timestamp}.json")
|
||||
with open(history_path, 'w') as f:
|
||||
# Convert numpy values to Python native types for JSON serialization
|
||||
history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()}
|
||||
json.dump(history_dict, f, indent=2)
|
||||
|
||||
return self.history
|
||||
|
||||
def evaluate(self, X_ts, X_features, y):
|
||||
"""
|
||||
Evaluate the model on test data.
|
||||
|
||||
Args:
|
||||
X_ts (numpy.ndarray): Time series input features
|
||||
X_features (numpy.ndarray): Additional input features
|
||||
y (numpy.ndarray): Target labels
|
||||
|
||||
Returns:
|
||||
dict: Evaluation metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built or trained yet")
|
||||
|
||||
# Convert y to one-hot encoding for multi-class
|
||||
if self.output_size == 3 and len(y.shape) == 1:
|
||||
y = tf.keras.utils.to_categorical(y, num_classes=3)
|
||||
|
||||
# Evaluate model
|
||||
logger.info(f"Evaluating Transformer model on {len(X_ts)} samples")
|
||||
eval_results = self.model.evaluate([X_ts, X_features], y, verbose=0)
|
||||
|
||||
metrics = {}
|
||||
for metric, value in zip(self.model.metrics_names, eval_results):
|
||||
metrics[metric] = value
|
||||
logger.info(f"{metric}: {value:.4f}")
|
||||
|
||||
return metrics
|
||||
|
||||
def predict(self, X_ts, X_features=None):
|
||||
"""
|
||||
Make predictions on new data.
|
||||
|
||||
Args:
|
||||
X_ts (numpy.ndarray): Time series input features
|
||||
X_features (numpy.ndarray): Additional input features
|
||||
|
||||
Returns:
|
||||
tuple: (y_pred, y_proba) where:
|
||||
y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class)
|
||||
y_proba is the class probability
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built or trained yet")
|
||||
|
||||
# Ensure X_ts has the right shape
|
||||
if len(X_ts.shape) == 2:
|
||||
# Single sample, add batch dimension
|
||||
X_ts = np.expand_dims(X_ts, axis=0)
|
||||
|
||||
# Ensure X_features has the right shape
|
||||
if X_features is None:
|
||||
# Extract features from time series data if no external features provided
|
||||
X_features = self._extract_features_from_timeseries(X_ts)
|
||||
elif len(X_features.shape) == 1:
|
||||
# Single sample, add batch dimension
|
||||
X_features = np.expand_dims(X_features, axis=0)
|
||||
|
||||
def _extract_features_from_timeseries(self, X_ts: np.ndarray) -> np.ndarray:
|
||||
"""Extract meaningful features from time series data instead of using dummy zeros"""
|
||||
try:
|
||||
batch_size = X_ts.shape[0]
|
||||
features = []
|
||||
|
||||
for i in range(batch_size):
|
||||
sample = X_ts[i] # Shape: (timesteps, features)
|
||||
|
||||
# Extract statistical features from each feature dimension
|
||||
sample_features = []
|
||||
|
||||
for feature_idx in range(sample.shape[1]):
|
||||
feature_data = sample[:, feature_idx]
|
||||
|
||||
# Basic statistical features
|
||||
sample_features.extend([
|
||||
np.mean(feature_data), # Mean
|
||||
np.std(feature_data), # Standard deviation
|
||||
np.min(feature_data), # Minimum
|
||||
np.max(feature_data), # Maximum
|
||||
np.percentile(feature_data, 25), # 25th percentile
|
||||
np.percentile(feature_data, 75), # 75th percentile
|
||||
])
|
||||
|
||||
# Trend features
|
||||
if len(feature_data) > 1:
|
||||
# Linear trend (slope)
|
||||
x = np.arange(len(feature_data))
|
||||
slope = np.polyfit(x, feature_data, 1)[0]
|
||||
sample_features.append(slope)
|
||||
|
||||
# Rate of change
|
||||
rate_of_change = (feature_data[-1] - feature_data[0]) / feature_data[0] if feature_data[0] != 0 else 0
|
||||
sample_features.append(rate_of_change)
|
||||
else:
|
||||
sample_features.extend([0.0, 0.0])
|
||||
|
||||
# Pad or truncate to expected feature size
|
||||
while len(sample_features) < self.feature_input_shape:
|
||||
sample_features.append(0.0)
|
||||
sample_features = sample_features[:self.feature_input_shape]
|
||||
|
||||
features.append(sample_features)
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting features from time series: {e}")
|
||||
# Fallback to zeros if extraction fails
|
||||
return np.zeros((X_ts.shape[0], self.feature_input_shape), dtype=np.float32)
|
||||
|
||||
# Get predictions
|
||||
y_proba = self.model.predict([X_ts, X_features])
|
||||
|
||||
# Process based on output type
|
||||
if self.output_size == 1:
|
||||
# Binary classification
|
||||
y_pred = (y_proba > 0.5).astype(int).flatten()
|
||||
return y_pred, y_proba.flatten()
|
||||
elif self.output_size == 3:
|
||||
# Multi-class classification
|
||||
y_pred = np.argmax(y_proba, axis=1)
|
||||
return y_pred, y_proba
|
||||
else:
|
||||
# Regression
|
||||
return y_proba, y_proba
|
||||
|
||||
def save(self, filepath=None):
|
||||
"""
|
||||
Save the model to disk.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to save the model
|
||||
|
||||
Returns:
|
||||
str: Path where the model was saved
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built yet")
|
||||
|
||||
if filepath is None:
|
||||
# Create a default filepath with timestamp
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filepath = os.path.join(self.model_dir, f"transformer_model_{timestamp}.h5")
|
||||
|
||||
self.model.save(filepath)
|
||||
logger.info(f"Model saved to {filepath}")
|
||||
return filepath
|
||||
|
||||
def load(self, filepath):
|
||||
"""
|
||||
Load a saved model from disk.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to the saved model
|
||||
|
||||
Returns:
|
||||
The loaded model
|
||||
"""
|
||||
# Register custom layers
|
||||
custom_objects = {
|
||||
'TransformerBlock': TransformerBlock,
|
||||
'PositionalEncoding': PositionalEncoding
|
||||
}
|
||||
|
||||
self.model = load_model(filepath, custom_objects=custom_objects)
|
||||
logger.info(f"Model loaded from {filepath}")
|
||||
return self.model
|
||||
|
||||
def plot_training_history(self):
|
||||
"""
|
||||
Plot training history (loss and metrics).
|
||||
|
||||
Returns:
|
||||
str: Path to the saved plot
|
||||
"""
|
||||
if self.history is None:
|
||||
raise ValueError("Model has not been trained yet")
|
||||
|
||||
plt.figure(figsize=(12, 5))
|
||||
|
||||
# Plot loss
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.plot(self.history.history['loss'], label='Training Loss')
|
||||
if 'val_loss' in self.history.history:
|
||||
plt.plot(self.history.history['val_loss'], label='Validation Loss')
|
||||
plt.title('Model Loss')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
|
||||
# Plot accuracy
|
||||
plt.subplot(1, 2, 2)
|
||||
|
||||
if 'accuracy' in self.history.history:
|
||||
plt.plot(self.history.history['accuracy'], label='Training Accuracy')
|
||||
if 'val_accuracy' in self.history.history:
|
||||
plt.plot(self.history.history['val_accuracy'], label='Validation Accuracy')
|
||||
plt.title('Model Accuracy')
|
||||
plt.ylabel('Accuracy')
|
||||
elif 'mae' in self.history.history:
|
||||
plt.plot(self.history.history['mae'], label='Training MAE')
|
||||
if 'val_mae' in self.history.history:
|
||||
plt.plot(self.history.history['val_mae'], label='Validation MAE')
|
||||
plt.title('Model MAE')
|
||||
plt.ylabel('MAE')
|
||||
|
||||
plt.xlabel('Epoch')
|
||||
plt.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Save figure
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
fig_path = os.path.join(self.model_dir, f"transformer_training_history_{timestamp}.png")
|
||||
plt.savefig(fig_path)
|
||||
plt.close()
|
||||
|
||||
logger.info(f"Training history plot saved to {fig_path}")
|
||||
return fig_path
|
||||
|
||||
|
||||
class MixtureOfExpertsModel:
|
||||
"""
|
||||
Mixture of Experts (MoE) model.
|
||||
|
||||
This model combines predictions from multiple expert models (such as CNN and Transformer)
|
||||
using a weighted ensemble approach.
|
||||
"""
|
||||
|
||||
def __init__(self, output_size=1, model_dir="NN/models/saved"):
|
||||
"""
|
||||
Initialize the MoE model.
|
||||
|
||||
Args:
|
||||
output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell)
|
||||
model_dir (str): Directory to save trained models
|
||||
"""
|
||||
self.output_size = output_size
|
||||
self.model_dir = model_dir
|
||||
self.model = None
|
||||
self.history = None
|
||||
self.experts = {}
|
||||
|
||||
# Create model directory if it doesn't exist
|
||||
os.makedirs(self.model_dir, exist_ok=True)
|
||||
|
||||
logger.info(f"Initialized Mixture of Experts model with output size {output_size}")
|
||||
|
||||
def add_expert(self, name, model):
|
||||
"""
|
||||
Add an expert model to the MoE.
|
||||
|
||||
Args:
|
||||
name (str): Name of the expert model
|
||||
model: The expert model instance
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.experts[name] = model
|
||||
logger.info(f"Added expert model '{name}' to MoE")
|
||||
|
||||
def build_model(self, ts_input_shape=(20, 5), expert_weights=None, learning_rate=0.001):
|
||||
"""
|
||||
Build the MoE model by combining expert models.
|
||||
|
||||
Args:
|
||||
ts_input_shape (tuple): Shape of time series input data
|
||||
expert_weights (dict): Weights for each expert model
|
||||
learning_rate (float): Learning rate for Adam optimizer
|
||||
|
||||
Returns:
|
||||
The compiled model
|
||||
"""
|
||||
# Time series input
|
||||
ts_inputs = Input(shape=ts_input_shape, name="ts_input")
|
||||
|
||||
# Additional feature input (from CNN)
|
||||
feature_inputs = Input(shape=(64,), name="feature_input") # Default size for features
|
||||
|
||||
# Process with each expert model
|
||||
expert_outputs = []
|
||||
expert_names = []
|
||||
|
||||
for name, expert in self.experts.items():
|
||||
# Skip if expert model is not valid or doesn't have a call/predict method
|
||||
if expert is None:
|
||||
logger.warning(f"Expert model '{name}' is None, skipping")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Different handling based on model type
|
||||
if name == 'cnn':
|
||||
# CNN model takes only time series input
|
||||
expert_output = expert(ts_inputs)
|
||||
expert_outputs.append(expert_output)
|
||||
expert_names.append(name)
|
||||
elif name == 'transformer':
|
||||
# Transformer model takes both time series and feature inputs
|
||||
expert_output = expert([ts_inputs, feature_inputs])
|
||||
expert_outputs.append(expert_output)
|
||||
expert_names.append(name)
|
||||
else:
|
||||
logger.warning(f"Unknown expert model type: {name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding expert '{name}': {str(e)}")
|
||||
|
||||
if not expert_outputs:
|
||||
logger.error("No valid expert models found")
|
||||
return None
|
||||
|
||||
# Use expert weighting
|
||||
if expert_weights is None:
|
||||
# Equal weighting
|
||||
weights = [1.0 / len(expert_outputs)] * len(expert_outputs)
|
||||
else:
|
||||
# User-provided weights
|
||||
weights = [expert_weights.get(name, 1.0 / len(expert_outputs)) for name in expert_names]
|
||||
# Normalize weights
|
||||
weights = [w / sum(weights) for w in weights]
|
||||
|
||||
# Combine expert outputs using weighted average
|
||||
if len(expert_outputs) == 1:
|
||||
# Only one expert, use its output directly
|
||||
combined_output = expert_outputs[0]
|
||||
else:
|
||||
# Multiple experts, compute weighted average
|
||||
weighted_outputs = [output * weight for output, weight in zip(expert_outputs, weights)]
|
||||
combined_output = Add()(weighted_outputs)
|
||||
|
||||
# Create the MoE model
|
||||
moe_model = Model(inputs=[ts_inputs, feature_inputs], outputs=combined_output)
|
||||
|
||||
# Compile the model
|
||||
if self.output_size == 1:
|
||||
# Binary classification
|
||||
moe_model.compile(
|
||||
optimizer=Adam(learning_rate=learning_rate),
|
||||
loss='binary_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
elif self.output_size == 3:
|
||||
# Multi-class classification for BUY/HOLD/SELL
|
||||
moe_model.compile(
|
||||
optimizer=Adam(learning_rate=learning_rate),
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
else:
|
||||
# Regression
|
||||
moe_model.compile(
|
||||
optimizer=Adam(learning_rate=learning_rate),
|
||||
loss='mse',
|
||||
metrics=['mae']
|
||||
)
|
||||
|
||||
self.model = moe_model
|
||||
|
||||
# Log model summary
|
||||
self.model.summary(print_fn=lambda x: logger.info(x))
|
||||
|
||||
logger.info(f"Built MoE model with weights: {weights}")
|
||||
return self.model
|
||||
|
||||
def train(self, X_ts, X_features, y, batch_size=32, epochs=100, validation_split=0.2,
|
||||
callbacks=None, class_weights=None):
|
||||
"""
|
||||
Train the MoE model on the provided data.
|
||||
|
||||
Args:
|
||||
X_ts (numpy.ndarray): Time series input features
|
||||
X_features (numpy.ndarray): Additional input features
|
||||
y (numpy.ndarray): Target labels
|
||||
batch_size (int): Batch size
|
||||
epochs (int): Number of epochs
|
||||
validation_split (float): Fraction of data to use for validation
|
||||
callbacks (list): List of Keras callbacks
|
||||
class_weights (dict): Class weights for imbalanced datasets
|
||||
|
||||
Returns:
|
||||
History object containing training metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
logger.error("MoE model has not been built yet")
|
||||
return None
|
||||
|
||||
# Default callbacks if none provided
|
||||
if callbacks is None:
|
||||
# Create a timestamp for model checkpoints
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
callbacks = [
|
||||
EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
restore_best_weights=True
|
||||
),
|
||||
ReduceLROnPlateau(
|
||||
monitor='val_loss',
|
||||
factor=0.5,
|
||||
patience=5,
|
||||
min_lr=1e-6
|
||||
),
|
||||
ModelCheckpoint(
|
||||
filepath=os.path.join(self.model_dir, f"moe_model_{timestamp}.h5"),
|
||||
monitor='val_loss',
|
||||
save_best_only=True
|
||||
)
|
||||
]
|
||||
|
||||
# Check if y needs to be one-hot encoded for multi-class
|
||||
if self.output_size == 3 and len(y.shape) == 1:
|
||||
y = tf.keras.utils.to_categorical(y, num_classes=3)
|
||||
|
||||
# Train the model
|
||||
logger.info(f"Training MoE model with {len(X_ts)} samples, batch size {batch_size}, epochs {epochs}")
|
||||
self.history = self.model.fit(
|
||||
[X_ts, X_features], y,
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
validation_split=validation_split,
|
||||
callbacks=callbacks,
|
||||
class_weight=class_weights,
|
||||
verbose=2
|
||||
)
|
||||
|
||||
# Save the trained model
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_path = os.path.join(self.model_dir, f"moe_model_final_{timestamp}.h5")
|
||||
self.model.save(model_path)
|
||||
logger.info(f"Model saved to {model_path}")
|
||||
|
||||
# Save training history
|
||||
history_path = os.path.join(self.model_dir, f"moe_model_history_{timestamp}.json")
|
||||
with open(history_path, 'w') as f:
|
||||
# Convert numpy values to Python native types for JSON serialization
|
||||
history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()}
|
||||
json.dump(history_dict, f, indent=2)
|
||||
|
||||
return self.history
|
||||
|
||||
def predict(self, X_ts, X_features=None):
|
||||
"""
|
||||
Make predictions on new data.
|
||||
|
||||
Args:
|
||||
X_ts (numpy.ndarray): Time series input features
|
||||
X_features (numpy.ndarray): Additional input features
|
||||
|
||||
Returns:
|
||||
tuple: (y_pred, y_proba) where:
|
||||
y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class)
|
||||
y_proba is the class probability
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built or trained yet")
|
||||
|
||||
# Ensure X_ts has the right shape
|
||||
if len(X_ts.shape) == 2:
|
||||
# Single sample, add batch dimension
|
||||
X_ts = np.expand_dims(X_ts, axis=0)
|
||||
|
||||
# Ensure X_features has the right shape
|
||||
if X_features is None:
|
||||
# Create dummy features with zeros
|
||||
X_features = np.zeros((X_ts.shape[0], 64)) # Default size
|
||||
elif len(X_features.shape) == 1:
|
||||
# Single sample, add batch dimension
|
||||
X_features = np.expand_dims(X_features, axis=0)
|
||||
|
||||
# Get predictions
|
||||
y_proba = self.model.predict([X_ts, X_features])
|
||||
|
||||
# Process based on output type
|
||||
if self.output_size == 1:
|
||||
# Binary classification
|
||||
y_pred = (y_proba > 0.5).astype(int).flatten()
|
||||
return y_pred, y_proba.flatten()
|
||||
elif self.output_size == 3:
|
||||
# Multi-class classification
|
||||
y_pred = np.argmax(y_proba, axis=1)
|
||||
return y_pred, y_proba
|
||||
else:
|
||||
# Regression
|
||||
return y_proba, y_proba
|
||||
|
||||
def save(self, filepath=None):
|
||||
"""
|
||||
Save the model to disk.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to save the model
|
||||
|
||||
Returns:
|
||||
str: Path where the model was saved
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built yet")
|
||||
|
||||
if filepath is None:
|
||||
# Create a default filepath with timestamp
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filepath = os.path.join(self.model_dir, f"moe_model_{timestamp}.h5")
|
||||
|
||||
self.model.save(filepath)
|
||||
logger.info(f"Model saved to {filepath}")
|
||||
return filepath
|
||||
|
||||
def load(self, filepath):
|
||||
"""
|
||||
Load a saved model from disk.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to the saved model
|
||||
|
||||
Returns:
|
||||
The loaded model
|
||||
"""
|
||||
# Register custom layers
|
||||
custom_objects = {
|
||||
'TransformerBlock': TransformerBlock,
|
||||
'PositionalEncoding': PositionalEncoding
|
||||
}
|
||||
|
||||
self.model = load_model(filepath, custom_objects=custom_objects)
|
||||
logger.info(f"Model loaded from {filepath}")
|
||||
return self.model
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# This would be a complete implementation in a real system
|
||||
print("Transformer and MoE models defined, but not implemented here.")
|
@@ -1,88 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Start TensorBoard for monitoring neural network training
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import webbrowser
|
||||
from time import sleep
|
||||
|
||||
def start_tensorboard(logdir="NN/models/saved/logs", port=6006, open_browser=True):
|
||||
"""
|
||||
Start TensorBoard in a subprocess
|
||||
|
||||
Args:
|
||||
logdir: Directory containing TensorBoard logs
|
||||
port: Port to run TensorBoard on
|
||||
open_browser: Whether to open a browser automatically
|
||||
"""
|
||||
# Make sure the log directory exists
|
||||
os.makedirs(logdir, exist_ok=True)
|
||||
|
||||
# Create command
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tensorboard.main",
|
||||
f"--logdir={logdir}",
|
||||
f"--port={port}",
|
||||
"--bind_all"
|
||||
]
|
||||
|
||||
print(f"Starting TensorBoard with logs from {logdir} on port {port}")
|
||||
print(f"Command: {' '.join(cmd)}")
|
||||
|
||||
# Start TensorBoard in a subprocess
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True
|
||||
)
|
||||
|
||||
# Wait for TensorBoard to start up
|
||||
for line in process.stdout:
|
||||
print(line.strip())
|
||||
if "TensorBoard" in line and "http://" in line:
|
||||
# TensorBoard is running, extract the URL
|
||||
url = None
|
||||
for part in line.split():
|
||||
if part.startswith(("http://", "https://")):
|
||||
url = part
|
||||
break
|
||||
|
||||
# Open browser if requested and URL found
|
||||
if open_browser and url:
|
||||
print(f"Opening TensorBoard in browser: {url}")
|
||||
webbrowser.open(url)
|
||||
|
||||
break
|
||||
|
||||
# Return the process for the caller to manage
|
||||
return process
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description="Start TensorBoard for NN training visualization")
|
||||
parser.add_argument("--logdir", default="NN/models/saved/logs", help="Directory containing TensorBoard logs")
|
||||
parser.add_argument("--port", type=int, default=6006, help="Port to run TensorBoard on")
|
||||
parser.add_argument("--no-browser", action="store_true", help="Don't open browser automatically")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Start TensorBoard
|
||||
process = start_tensorboard(args.logdir, args.port, not args.no_browser)
|
||||
|
||||
try:
|
||||
# Keep the script running until Ctrl+C
|
||||
print("TensorBoard is running. Press Ctrl+C to stop.")
|
||||
while True:
|
||||
sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("Stopping TensorBoard...")
|
||||
process.terminate()
|
||||
process.wait()
|
@@ -1,490 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Enhanced RL Training Integration - Comprehensive Fix
|
||||
|
||||
This script addresses the critical RL training audit issues:
|
||||
1. MASSIVE INPUT DATA GAP (99.25% Missing) - Implements full 13,400 feature state
|
||||
2. Disconnected Training Pipeline - Provides proper data flow integration
|
||||
3. Missing Enhanced State Builder - Connects orchestrator to dashboard
|
||||
4. Reward Calculation Issues - Ensures enhanced pivot-based rewards
|
||||
5. Williams Market Structure Integration - Proper feature extraction
|
||||
6. Real-time Data Integration - Live market data to RL
|
||||
|
||||
Usage:
|
||||
python enhanced_rl_training_integration.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import setup_logging, get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.trading_executor import TradingExecutor
|
||||
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
|
||||
from utils.tensorboard_logger import TensorBoardLogger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EnhancedRLTrainingIntegrator:
|
||||
"""
|
||||
Comprehensive RL Training Integrator
|
||||
|
||||
Fixes all audit issues by ensuring proper data flow and feature completeness.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the enhanced RL training integrator"""
|
||||
# Setup logging
|
||||
setup_logging()
|
||||
logger.info("=" * 70)
|
||||
logger.info("ENHANCED RL TRAINING INTEGRATION - COMPREHENSIVE FIX")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Get configuration
|
||||
self.config = get_config()
|
||||
|
||||
# Initialize core components
|
||||
self.data_provider = DataProvider()
|
||||
self.enhanced_orchestrator = None
|
||||
self.trading_executor = TradingExecutor()
|
||||
self.dashboard = None
|
||||
|
||||
# Training metrics
|
||||
self.training_stats = {
|
||||
'total_episodes': 0,
|
||||
'successful_state_builds': 0,
|
||||
'enhanced_reward_calculations': 0,
|
||||
'comprehensive_features_used': 0,
|
||||
'pivot_features_extracted': 0,
|
||||
'cob_features_available': 0
|
||||
}
|
||||
|
||||
# Initialize TensorBoard logger
|
||||
experiment_name = f"enhanced_rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
self.tb_logger = TensorBoardLogger(
|
||||
log_dir="runs",
|
||||
experiment_name=experiment_name,
|
||||
enabled=True
|
||||
)
|
||||
logger.info(f"TensorBoard logging enabled for experiment: {experiment_name}")
|
||||
|
||||
logger.info("Enhanced RL Training Integrator initialized")
|
||||
|
||||
async def start_integration(self):
|
||||
"""Start the comprehensive RL training integration"""
|
||||
try:
|
||||
logger.info("Starting comprehensive RL training integration...")
|
||||
|
||||
# 1. Initialize Enhanced Orchestrator with comprehensive features
|
||||
await self._initialize_enhanced_orchestrator()
|
||||
|
||||
# 2. Create enhanced dashboard with proper connections
|
||||
await self._create_enhanced_dashboard()
|
||||
|
||||
# 3. Verify comprehensive state building
|
||||
await self._verify_comprehensive_state_building()
|
||||
|
||||
# 4. Test enhanced reward calculation
|
||||
await self._test_enhanced_reward_calculation()
|
||||
|
||||
# 5. Validate Williams market structure integration
|
||||
await self._validate_williams_integration()
|
||||
|
||||
# 6. Start live training with comprehensive features
|
||||
await self._start_live_comprehensive_training()
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info("COMPREHENSIVE RL TRAINING INTEGRATION COMPLETE")
|
||||
logger.info("=" * 70)
|
||||
self._log_integration_stats()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in RL training integration: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _initialize_enhanced_orchestrator(self):
|
||||
"""Initialize enhanced orchestrator with comprehensive RL capabilities"""
|
||||
try:
|
||||
logger.info("[STEP 1] Initializing Enhanced Orchestrator...")
|
||||
|
||||
# Create enhanced orchestrator with RL training enabled
|
||||
self.enhanced_orchestrator = EnhancedTradingOrchestrator(
|
||||
data_provider=self.data_provider,
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
enhanced_rl_training=True,
|
||||
model_registry={} # Will be populated as needed
|
||||
)
|
||||
|
||||
# Start COB integration for real-time market microstructure
|
||||
await self.enhanced_orchestrator.start_cob_integration()
|
||||
|
||||
# Start real-time processing
|
||||
await self.enhanced_orchestrator.start_realtime_processing()
|
||||
|
||||
logger.info("[SUCCESS] Enhanced Orchestrator initialized with:")
|
||||
logger.info(" - Comprehensive RL state building: ENABLED")
|
||||
logger.info(" - Enhanced pivot-based rewards: ENABLED")
|
||||
logger.info(" - COB integration: ENABLED")
|
||||
logger.info(" - Williams market structure: ENABLED")
|
||||
logger.info(" - Real-time tick processing: ENABLED")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing enhanced orchestrator: {e}")
|
||||
raise
|
||||
|
||||
async def _create_enhanced_dashboard(self):
|
||||
"""Create dashboard with enhanced orchestrator connections"""
|
||||
try:
|
||||
logger.info("[STEP 2] Creating Enhanced Dashboard...")
|
||||
|
||||
# Create trading dashboard with enhanced orchestrator
|
||||
self.dashboard = TradingDashboard(
|
||||
data_provider=self.data_provider,
|
||||
orchestrator=self.enhanced_orchestrator, # Use enhanced orchestrator
|
||||
trading_executor=self.trading_executor
|
||||
)
|
||||
|
||||
# Verify enhanced connections
|
||||
has_comprehensive_state_builder = hasattr(self.dashboard.orchestrator, 'build_comprehensive_rl_state')
|
||||
has_enhanced_reward_calc = hasattr(self.dashboard.orchestrator, 'calculate_enhanced_pivot_reward')
|
||||
has_symbol_correlation = hasattr(self.dashboard.orchestrator, '_get_symbol_correlation')
|
||||
|
||||
logger.info("[SUCCESS] Enhanced Dashboard created with:")
|
||||
logger.info(f" - Comprehensive state builder: {'AVAILABLE' if has_comprehensive_state_builder else 'MISSING'}")
|
||||
logger.info(f" - Enhanced reward calculation: {'AVAILABLE' if has_enhanced_reward_calc else 'MISSING'}")
|
||||
logger.info(f" - Symbol correlation analysis: {'AVAILABLE' if has_symbol_correlation else 'MISSING'}")
|
||||
|
||||
if not all([has_comprehensive_state_builder, has_enhanced_reward_calc, has_symbol_correlation]):
|
||||
logger.warning("Some enhanced features are missing - this will cause fallbacks to basic training")
|
||||
else:
|
||||
logger.info(" - ALL ENHANCED FEATURES AVAILABLE!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating enhanced dashboard: {e}")
|
||||
raise
|
||||
|
||||
async def _verify_comprehensive_state_building(self):
|
||||
"""Verify that comprehensive RL state building works correctly"""
|
||||
try:
|
||||
logger.info("[STEP 3] Verifying Comprehensive State Building...")
|
||||
|
||||
# Test comprehensive state building for ETH
|
||||
eth_state = self.enhanced_orchestrator.build_comprehensive_rl_state('ETH/USDT')
|
||||
|
||||
if eth_state is not None:
|
||||
logger.info(f"[SUCCESS] ETH comprehensive state built: {len(eth_state)} features")
|
||||
|
||||
# Verify feature count
|
||||
if len(eth_state) == 13400:
|
||||
logger.info(" - PERFECT: Exactly 13,400 features as required!")
|
||||
self.training_stats['comprehensive_features_used'] += 1
|
||||
else:
|
||||
logger.warning(f" - MISMATCH: Expected 13,400 features, got {len(eth_state)}")
|
||||
|
||||
# Analyze feature distribution
|
||||
self._analyze_state_features(eth_state)
|
||||
self.training_stats['successful_state_builds'] += 1
|
||||
|
||||
else:
|
||||
logger.error(" - FAILED: Comprehensive state building returned None")
|
||||
|
||||
# Test for BTC reference
|
||||
btc_state = self.enhanced_orchestrator.build_comprehensive_rl_state('BTC/USDT')
|
||||
if btc_state is not None:
|
||||
logger.info(f"[SUCCESS] BTC reference state built: {len(btc_state)} features")
|
||||
self.training_stats['successful_state_builds'] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying comprehensive state building: {e}")
|
||||
|
||||
def _analyze_state_features(self, state_vector: np.ndarray):
|
||||
"""Analyze the comprehensive state feature distribution"""
|
||||
try:
|
||||
# Calculate feature statistics
|
||||
non_zero_features = np.count_nonzero(state_vector)
|
||||
zero_features = len(state_vector) - non_zero_features
|
||||
feature_mean = np.mean(state_vector)
|
||||
feature_std = np.std(state_vector)
|
||||
feature_min = np.min(state_vector)
|
||||
feature_max = np.max(state_vector)
|
||||
|
||||
logger.info(" - Feature Analysis:")
|
||||
logger.info(f" * Non-zero features: {non_zero_features:,} ({non_zero_features/len(state_vector)*100:.1f}%)")
|
||||
logger.info(f" * Zero features: {zero_features:,} ({zero_features/len(state_vector)*100:.1f}%)")
|
||||
logger.info(f" * Mean: {feature_mean:.6f}")
|
||||
logger.info(f" * Std: {feature_std:.6f}")
|
||||
logger.info(f" * Range: [{feature_min:.6f}, {feature_max:.6f}]")
|
||||
|
||||
# Log feature statistics to TensorBoard
|
||||
step = self.training_stats['total_episodes']
|
||||
self.tb_logger.log_scalars('Features/Distribution', {
|
||||
'non_zero_percentage': non_zero_features/len(state_vector)*100,
|
||||
'mean': feature_mean,
|
||||
'std': feature_std,
|
||||
'min': feature_min,
|
||||
'max': feature_max
|
||||
}, step)
|
||||
|
||||
# Log feature histogram to TensorBoard
|
||||
self.tb_logger.log_histogram('Features/Values', state_vector, step)
|
||||
|
||||
# Check if features are properly distributed
|
||||
if non_zero_features > len(state_vector) * 0.1: # At least 10% non-zero
|
||||
logger.info(" * GOOD: Features are well distributed")
|
||||
else:
|
||||
logger.warning(" * WARNING: Too many zero features - data may be incomplete")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error analyzing state features: {e}")
|
||||
|
||||
async def _test_enhanced_reward_calculation(self):
|
||||
"""Test enhanced pivot-based reward calculation"""
|
||||
try:
|
||||
logger.info("[STEP 4] Testing Enhanced Reward Calculation...")
|
||||
|
||||
# Create mock trade data for testing
|
||||
trade_decision = {
|
||||
'action': 'BUY',
|
||||
'confidence': 0.75,
|
||||
'price': 2500.0,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
trade_outcome = {
|
||||
'net_pnl': 50.0,
|
||||
'exit_price': 2550.0,
|
||||
'duration': timedelta(minutes=15)
|
||||
}
|
||||
|
||||
# Get market data for reward calculation
|
||||
market_data = {
|
||||
'volatility': 0.03,
|
||||
'order_flow_direction': 'bullish',
|
||||
'order_flow_strength': 0.8
|
||||
}
|
||||
|
||||
# Test enhanced reward calculation
|
||||
if hasattr(self.enhanced_orchestrator, 'calculate_enhanced_pivot_reward'):
|
||||
enhanced_reward = self.enhanced_orchestrator.calculate_enhanced_pivot_reward(
|
||||
trade_decision, market_data, trade_outcome
|
||||
)
|
||||
|
||||
logger.info(f"[SUCCESS] Enhanced reward calculated: {enhanced_reward:.3f}")
|
||||
logger.info(" - Enhanced pivot-based reward system: WORKING")
|
||||
self.training_stats['enhanced_reward_calculations'] += 1
|
||||
|
||||
# Log reward metrics to TensorBoard
|
||||
step = self.training_stats['enhanced_reward_calculations']
|
||||
self.tb_logger.log_scalar('Rewards/Enhanced', enhanced_reward, step)
|
||||
|
||||
# Log reward components to TensorBoard
|
||||
self.tb_logger.log_scalars('Rewards/Components', {
|
||||
'pnl_component': trade_outcome['net_pnl'],
|
||||
'confidence': trade_decision['confidence'],
|
||||
'volatility': market_data['volatility'],
|
||||
'order_flow_strength': market_data['order_flow_strength']
|
||||
}, step)
|
||||
|
||||
else:
|
||||
logger.error(" - FAILED: Enhanced reward calculation method not available")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing enhanced reward calculation: {e}")
|
||||
|
||||
async def _validate_williams_integration(self):
|
||||
"""Validate Williams market structure integration"""
|
||||
try:
|
||||
logger.info("[STEP 5] Validating Williams Market Structure Integration...")
|
||||
|
||||
# Test Williams pivot feature extraction
|
||||
try:
|
||||
from training.williams_market_structure import extract_pivot_features, analyze_pivot_context
|
||||
|
||||
# Get test market data
|
||||
df = self.data_provider.get_historical_data('ETH/USDT', '1m', limit=100)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Test pivot feature extraction
|
||||
pivot_features = extract_pivot_features(df)
|
||||
|
||||
if pivot_features is not None:
|
||||
logger.info(f"[SUCCESS] Williams pivot features extracted: {len(pivot_features)} features")
|
||||
self.training_stats['pivot_features_extracted'] += 1
|
||||
|
||||
# Test pivot context analysis
|
||||
market_data = {'ohlcv_data': df}
|
||||
pivot_context = analyze_pivot_context(
|
||||
market_data, datetime.now(), 'BUY'
|
||||
)
|
||||
|
||||
if pivot_context is not None:
|
||||
logger.info("[SUCCESS] Williams pivot context analysis: WORKING")
|
||||
logger.info(f" - Near pivot: {pivot_context.get('near_pivot', False)}")
|
||||
logger.info(f" - Pivot strength: {pivot_context.get('pivot_strength', 0):.3f}")
|
||||
else:
|
||||
logger.warning(" - Williams pivot context analysis returned None")
|
||||
else:
|
||||
logger.warning(" - Williams pivot feature extraction returned None")
|
||||
else:
|
||||
logger.warning(" - No market data available for Williams testing")
|
||||
|
||||
except ImportError:
|
||||
logger.error(" - Williams market structure module not available")
|
||||
except Exception as e:
|
||||
logger.error(f" - Error in Williams integration: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating Williams integration: {e}")
|
||||
|
||||
async def _start_live_comprehensive_training(self):
|
||||
"""Start live training with comprehensive feature integration"""
|
||||
try:
|
||||
logger.info("[STEP 6] Starting Live Comprehensive Training...")
|
||||
|
||||
# Run a few training iterations to verify integration
|
||||
for iteration in range(5):
|
||||
logger.info(f"Training iteration {iteration + 1}/5")
|
||||
|
||||
# Make coordinated decisions using enhanced orchestrator
|
||||
decisions = await self.enhanced_orchestrator.make_coordinated_decisions()
|
||||
|
||||
# Track iteration metrics for TensorBoard
|
||||
iteration_metrics = {
|
||||
'decisions_count': len(decisions),
|
||||
'confidence_avg': 0.0,
|
||||
'state_size_avg': 0.0,
|
||||
'successful_states': 0
|
||||
}
|
||||
|
||||
# Process each decision
|
||||
for symbol, decision in decisions.items():
|
||||
if decision:
|
||||
logger.info(f" {symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
|
||||
|
||||
# Track confidence for TensorBoard
|
||||
iteration_metrics['confidence_avg'] += decision.confidence
|
||||
|
||||
# Build comprehensive state for this decision
|
||||
comprehensive_state = self.enhanced_orchestrator.build_comprehensive_rl_state(symbol)
|
||||
|
||||
if comprehensive_state is not None:
|
||||
state_size = len(comprehensive_state)
|
||||
logger.info(f" - Comprehensive state: {state_size} features")
|
||||
self.training_stats['total_episodes'] += 1
|
||||
|
||||
# Track state size for TensorBoard
|
||||
iteration_metrics['state_size_avg'] += state_size
|
||||
iteration_metrics['successful_states'] += 1
|
||||
|
||||
# Log individual state metrics to TensorBoard
|
||||
self.tb_logger.log_state_metrics(
|
||||
symbol=symbol,
|
||||
state_info={
|
||||
'size': state_size,
|
||||
'quality': 1.0 if state_size == 13400 else 0.8,
|
||||
'feature_counts': {
|
||||
'total': state_size,
|
||||
'non_zero': np.count_nonzero(comprehensive_state)
|
||||
}
|
||||
},
|
||||
step=self.training_stats['total_episodes']
|
||||
)
|
||||
else:
|
||||
logger.warning(f" - Failed to build comprehensive state for {symbol}")
|
||||
|
||||
# Calculate averages for TensorBoard
|
||||
if decisions:
|
||||
iteration_metrics['confidence_avg'] /= len(decisions)
|
||||
|
||||
if iteration_metrics['successful_states'] > 0:
|
||||
iteration_metrics['state_size_avg'] /= iteration_metrics['successful_states']
|
||||
|
||||
# Log iteration metrics to TensorBoard
|
||||
self.tb_logger.log_scalars('Training/Iteration', {
|
||||
'iteration': iteration + 1,
|
||||
'decisions_count': iteration_metrics['decisions_count'],
|
||||
'confidence_avg': iteration_metrics['confidence_avg'],
|
||||
'state_size_avg': iteration_metrics['state_size_avg'],
|
||||
'successful_states': iteration_metrics['successful_states']
|
||||
}, iteration + 1)
|
||||
|
||||
# Wait between iterations
|
||||
await asyncio.sleep(2)
|
||||
|
||||
logger.info("[SUCCESS] Live comprehensive training demonstration complete")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in live comprehensive training: {e}")
|
||||
|
||||
def _log_integration_stats(self):
|
||||
"""Log comprehensive integration statistics"""
|
||||
logger.info("INTEGRATION STATISTICS:")
|
||||
logger.info(f" - Total training episodes: {self.training_stats['total_episodes']}")
|
||||
logger.info(f" - Successful state builds: {self.training_stats['successful_state_builds']}")
|
||||
logger.info(f" - Enhanced reward calculations: {self.training_stats['enhanced_reward_calculations']}")
|
||||
logger.info(f" - Comprehensive features used: {self.training_stats['comprehensive_features_used']}")
|
||||
logger.info(f" - Pivot features extracted: {self.training_stats['pivot_features_extracted']}")
|
||||
|
||||
# Calculate success rates
|
||||
state_success_rate = 0
|
||||
if self.training_stats['total_episodes'] > 0:
|
||||
state_success_rate = self.training_stats['successful_state_builds'] / self.training_stats['total_episodes'] * 100
|
||||
logger.info(f" - State building success rate: {state_success_rate:.1f}%")
|
||||
|
||||
# Log final statistics to TensorBoard
|
||||
self.tb_logger.log_scalars('Integration/Statistics', {
|
||||
'total_episodes': self.training_stats['total_episodes'],
|
||||
'successful_state_builds': self.training_stats['successful_state_builds'],
|
||||
'enhanced_reward_calculations': self.training_stats['enhanced_reward_calculations'],
|
||||
'comprehensive_features_used': self.training_stats['comprehensive_features_used'],
|
||||
'pivot_features_extracted': self.training_stats['pivot_features_extracted'],
|
||||
'state_success_rate': state_success_rate
|
||||
}, 0) # Use step 0 for final summary stats
|
||||
|
||||
# Integration status
|
||||
if self.training_stats['comprehensive_features_used'] > 0:
|
||||
logger.info("STATUS: COMPREHENSIVE RL TRAINING INTEGRATION SUCCESSFUL! ✅")
|
||||
logger.info("The system is now using the full 13,400 feature comprehensive state.")
|
||||
|
||||
# Log success status to TensorBoard
|
||||
self.tb_logger.log_scalar('Integration/Success', 1.0, 0)
|
||||
else:
|
||||
logger.warning("STATUS: Integration partially successful - some fallbacks may occur")
|
||||
|
||||
# Log partial success status to TensorBoard
|
||||
self.tb_logger.log_scalar('Integration/Success', 0.5, 0)
|
||||
|
||||
async def main():
|
||||
"""Main entry point"""
|
||||
try:
|
||||
# Create and run the enhanced RL training integrator
|
||||
integrator = EnhancedRLTrainingIntegrator()
|
||||
await integrator.start_integration()
|
||||
|
||||
logger.info("Enhanced RL training integration completed successfully!")
|
||||
return 0
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Integration interrupted by user")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in integration: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = asyncio.run(main())
|
||||
sys.exit(exit_code)
|
@@ -1,148 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example: Using the Checkpoint Management System
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint, get_checkpoint_manager
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ExampleCNN(nn.Module):
|
||||
def __init__(self, input_channels=5, num_classes=3):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(input_channels, 32, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
|
||||
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(64, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.relu(self.conv1(x))
|
||||
x = torch.relu(self.conv2(x))
|
||||
x = self.pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.fc(x)
|
||||
|
||||
def example_cnn_training():
|
||||
logger.info("=== CNN Training Example ===")
|
||||
|
||||
model = ExampleCNN()
|
||||
training_integration = get_training_integration()
|
||||
|
||||
for epoch in range(5): # Simulate 5 epochs
|
||||
# Simulate training metrics
|
||||
train_loss = 2.0 - (epoch * 0.15) + np.random.normal(0, 0.1)
|
||||
train_acc = 0.3 + (epoch * 0.06) + np.random.normal(0, 0.02)
|
||||
val_loss = train_loss + np.random.normal(0, 0.05)
|
||||
val_acc = train_acc - 0.05 + np.random.normal(0, 0.02)
|
||||
|
||||
# Clamp values to realistic ranges
|
||||
train_acc = max(0.0, min(1.0, train_acc))
|
||||
val_acc = max(0.0, min(1.0, val_acc))
|
||||
train_loss = max(0.1, train_loss)
|
||||
val_loss = max(0.1, val_loss)
|
||||
|
||||
logger.info(f"Epoch {epoch+1}: train_acc={train_acc:.3f}, val_acc={val_acc:.3f}")
|
||||
|
||||
# Save checkpoint
|
||||
saved = training_integration.save_cnn_checkpoint(
|
||||
cnn_model=model,
|
||||
model_name="example_cnn",
|
||||
epoch=epoch + 1,
|
||||
train_accuracy=train_acc,
|
||||
val_accuracy=val_acc,
|
||||
train_loss=train_loss,
|
||||
val_loss=val_loss,
|
||||
training_time_hours=0.1 * (epoch + 1)
|
||||
)
|
||||
|
||||
if saved:
|
||||
logger.info(f" Checkpoint saved for epoch {epoch+1}")
|
||||
else:
|
||||
logger.info(f" Checkpoint not saved (performance not improved)")
|
||||
|
||||
# Load the best checkpoint
|
||||
logger.info("\\nLoading best checkpoint...")
|
||||
best_result = load_best_checkpoint("example_cnn")
|
||||
if best_result:
|
||||
file_path, metadata = best_result
|
||||
logger.info(f"Best checkpoint: {metadata.checkpoint_id}")
|
||||
logger.info(f"Performance score: {metadata.performance_score:.4f}")
|
||||
|
||||
def example_manual_checkpoint():
|
||||
logger.info("\\n=== Manual Checkpoint Example ===")
|
||||
|
||||
model = nn.Linear(10, 3)
|
||||
|
||||
performance_metrics = {
|
||||
'accuracy': 0.85,
|
||||
'val_accuracy': 0.82,
|
||||
'loss': 0.45,
|
||||
'val_loss': 0.48
|
||||
}
|
||||
|
||||
training_metadata = {
|
||||
'epoch': 25,
|
||||
'training_time_hours': 2.5,
|
||||
'total_parameters': sum(p.numel() for p in model.parameters())
|
||||
}
|
||||
|
||||
logger.info("Saving checkpoint manually...")
|
||||
metadata = save_checkpoint(
|
||||
model=model,
|
||||
model_name="example_manual",
|
||||
model_type="cnn",
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata,
|
||||
force_save=True
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f" Manual checkpoint saved: {metadata.checkpoint_id}")
|
||||
logger.info(f" Performance score: {metadata.performance_score:.4f}")
|
||||
|
||||
def show_checkpoint_stats():
|
||||
logger.info("\\n=== Checkpoint Statistics ===")
|
||||
|
||||
checkpoint_manager = get_checkpoint_manager()
|
||||
stats = checkpoint_manager.get_checkpoint_stats()
|
||||
|
||||
logger.info(f"Total models: {stats['total_models']}")
|
||||
logger.info(f"Total checkpoints: {stats['total_checkpoints']}")
|
||||
logger.info(f"Total size: {stats['total_size_mb']:.2f} MB")
|
||||
|
||||
for model_name, model_stats in stats['models'].items():
|
||||
logger.info(f"\\n{model_name}:")
|
||||
logger.info(f" Checkpoints: {model_stats['checkpoint_count']}")
|
||||
logger.info(f" Size: {model_stats['total_size_mb']:.2f} MB")
|
||||
logger.info(f" Best performance: {model_stats['best_performance']:.4f}")
|
||||
|
||||
def main():
|
||||
logger.info(" Checkpoint Management System Examples")
|
||||
logger.info("=" * 50)
|
||||
|
||||
try:
|
||||
example_cnn_training()
|
||||
example_manual_checkpoint()
|
||||
show_checkpoint_stats()
|
||||
|
||||
logger.info("\\n All examples completed successfully!")
|
||||
logger.info("\\nTo use in your training:")
|
||||
logger.info("1. Import: from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint")
|
||||
logger.info("2. Or use: from utils.training_integration import get_training_integration")
|
||||
logger.info("3. Save checkpoints during training with performance metrics")
|
||||
logger.info("4. Load best checkpoints for inference or continued training")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in examples: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -1,517 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive Checkpoint Management Integration
|
||||
|
||||
This script demonstrates how to integrate the checkpoint management system
|
||||
across all training pipelines in the gogo2 project.
|
||||
|
||||
Features:
|
||||
- DQN Agent training with automatic checkpointing
|
||||
- CNN Model training with checkpoint management
|
||||
- ExtremaTrainer with checkpoint persistence
|
||||
- NegativeCaseTrainer with checkpoint integration
|
||||
- Unified training orchestration with checkpoint coordination
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import signal
|
||||
import sys
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('logs/checkpoint_integration.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import get_checkpoint_manager, get_checkpoint_stats
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
# Import training components
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
from NN.models.standardized_cnn import StandardizedCNN
|
||||
from core.extrema_trainer import ExtremaTrainer
|
||||
from core.negative_case_trainer import NegativeCaseTrainer
|
||||
from core.data_provider import DataProvider
|
||||
from core.config import get_config
|
||||
|
||||
class CheckpointIntegratedTrainingSystem:
|
||||
"""Unified training system with comprehensive checkpoint management"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the checkpoint-integrated training system"""
|
||||
self.config = get_config()
|
||||
self.running = False
|
||||
|
||||
# Checkpoint management
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
self.training_integration = get_training_integration()
|
||||
|
||||
# Data provider
|
||||
self.data_provider = DataProvider(
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
timeframes=['1s', '1m', '1h', '1d']
|
||||
)
|
||||
|
||||
# Training components with checkpoint management
|
||||
self.dqn_agent = None
|
||||
self.cnn_trainer = None
|
||||
self.extrema_trainer = None
|
||||
self.negative_case_trainer = None
|
||||
|
||||
# Training statistics
|
||||
self.training_stats = {
|
||||
'start_time': None,
|
||||
'total_training_sessions': 0,
|
||||
'checkpoints_saved': 0,
|
||||
'models_loaded': 0,
|
||||
'best_performances': {}
|
||||
}
|
||||
|
||||
logger.info("Checkpoint-Integrated Training System initialized")
|
||||
|
||||
async def initialize_components(self):
|
||||
"""Initialize all training components with checkpoint management"""
|
||||
try:
|
||||
logger.info("Initializing training components with checkpoint management...")
|
||||
|
||||
# Initialize data provider
|
||||
await self.data_provider.start_real_time_streaming()
|
||||
logger.info("Data provider streaming started")
|
||||
|
||||
# Initialize DQN Agent with checkpoint management
|
||||
logger.info("Initializing DQN Agent with checkpoints...")
|
||||
self.dqn_agent = DQNAgent(
|
||||
state_shape=(100,), # Example state shape
|
||||
n_actions=3,
|
||||
model_name="integrated_dqn_agent",
|
||||
enable_checkpoints=True
|
||||
)
|
||||
logger.info("✅ DQN Agent initialized with checkpoint management")
|
||||
|
||||
# Initialize StandardizedCNN Model with checkpoint management
|
||||
logger.info("Initializing StandardizedCNN Model with checkpoints...")
|
||||
self.cnn_model = StandardizedCNN(model_name="integrated_cnn_model")
|
||||
logger.info("✅ StandardizedCNN Model initialized with checkpoint management")
|
||||
|
||||
# Initialize ExtremaTrainer with checkpoint management
|
||||
logger.info("Initializing ExtremaTrainer with checkpoints...")
|
||||
self.extrema_trainer = ExtremaTrainer(
|
||||
data_provider=self.data_provider,
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
model_name="integrated_extrema_trainer",
|
||||
enable_checkpoints=True
|
||||
)
|
||||
await self.extrema_trainer.initialize_context_data()
|
||||
logger.info("✅ ExtremaTrainer initialized with checkpoint management")
|
||||
|
||||
# Initialize NegativeCaseTrainer with checkpoint management
|
||||
logger.info("Initializing NegativeCaseTrainer with checkpoints...")
|
||||
self.negative_case_trainer = NegativeCaseTrainer(
|
||||
model_name="integrated_negative_case_trainer",
|
||||
enable_checkpoints=True
|
||||
)
|
||||
logger.info("✅ NegativeCaseTrainer initialized with checkpoint management")
|
||||
|
||||
# Load existing checkpoints for all components
|
||||
self.training_stats['models_loaded'] = await self._load_all_checkpoints()
|
||||
|
||||
logger.info("All training components initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing components: {e}")
|
||||
raise
|
||||
|
||||
async def _load_all_checkpoints(self) -> int:
|
||||
"""Load checkpoints for all training components"""
|
||||
loaded_count = 0
|
||||
|
||||
try:
|
||||
# DQN Agent checkpoint loading is handled in __init__
|
||||
if hasattr(self.dqn_agent, 'episode_count') and self.dqn_agent.episode_count > 0:
|
||||
loaded_count += 1
|
||||
logger.info(f"DQN Agent resumed from episode {self.dqn_agent.episode_count}")
|
||||
|
||||
# CNN Trainer checkpoint loading is handled in __init__
|
||||
if hasattr(self.cnn_trainer, 'epoch_count') and self.cnn_trainer.epoch_count > 0:
|
||||
loaded_count += 1
|
||||
logger.info(f"CNN Trainer resumed from epoch {self.cnn_trainer.epoch_count}")
|
||||
|
||||
# ExtremaTrainer checkpoint loading is handled in __init__
|
||||
if hasattr(self.extrema_trainer, 'training_session_count') and self.extrema_trainer.training_session_count > 0:
|
||||
loaded_count += 1
|
||||
logger.info(f"ExtremaTrainer resumed from session {self.extrema_trainer.training_session_count}")
|
||||
|
||||
# NegativeCaseTrainer checkpoint loading is handled in __init__
|
||||
if hasattr(self.negative_case_trainer, 'training_session_count') and self.negative_case_trainer.training_session_count > 0:
|
||||
loaded_count += 1
|
||||
logger.info(f"NegativeCaseTrainer resumed from session {self.negative_case_trainer.training_session_count}")
|
||||
|
||||
return loaded_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoints: {e}")
|
||||
return 0
|
||||
|
||||
async def run_integrated_training_loop(self):
|
||||
"""Run the integrated training loop with checkpoint coordination"""
|
||||
logger.info("Starting integrated training loop with checkpoint management...")
|
||||
|
||||
self.running = True
|
||||
self.training_stats['start_time'] = datetime.now()
|
||||
|
||||
training_cycle = 0
|
||||
|
||||
try:
|
||||
while self.running:
|
||||
training_cycle += 1
|
||||
cycle_start = time.time()
|
||||
|
||||
logger.info(f"=== Training Cycle {training_cycle} ===")
|
||||
|
||||
# DQN Training
|
||||
dqn_results = await self._train_dqn_agent()
|
||||
|
||||
# CNN Training
|
||||
cnn_results = await self._train_cnn_model()
|
||||
|
||||
# Extrema Detection Training
|
||||
extrema_results = await self._train_extrema_detector()
|
||||
|
||||
# Negative Case Training (runs in background)
|
||||
negative_results = await self._process_negative_cases()
|
||||
|
||||
# Coordinate checkpoint saving
|
||||
await self._coordinate_checkpoint_saving(
|
||||
dqn_results, cnn_results, extrema_results, negative_results
|
||||
)
|
||||
|
||||
# Update statistics
|
||||
self.training_stats['total_training_sessions'] += 1
|
||||
|
||||
# Log cycle summary
|
||||
cycle_duration = time.time() - cycle_start
|
||||
logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s")
|
||||
|
||||
# Wait before next cycle
|
||||
await asyncio.sleep(60) # 1-minute cycles
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training loop: {e}")
|
||||
finally:
|
||||
await self.shutdown()
|
||||
|
||||
async def _train_dqn_agent(self) -> Dict[str, Any]:
|
||||
"""Train DQN agent with automatic checkpointing"""
|
||||
try:
|
||||
if not self.dqn_agent:
|
||||
return {'status': 'skipped', 'reason': 'no_agent'}
|
||||
|
||||
# Simulate DQN training episode
|
||||
episode_reward = 0.0
|
||||
|
||||
# Add some training experiences (simulate real training)
|
||||
for _ in range(10): # Simulate 10 training steps
|
||||
state = np.random.randn(100).astype(np.float32)
|
||||
action = np.random.randint(0, 3)
|
||||
reward = np.random.randn() * 0.1
|
||||
next_state = np.random.randn(100).astype(np.float32)
|
||||
done = np.random.random() < 0.1
|
||||
|
||||
self.dqn_agent.remember(state, action, reward, next_state, done)
|
||||
episode_reward += reward
|
||||
|
||||
# Train if enough experiences
|
||||
loss = 0.0
|
||||
if len(self.dqn_agent.memory) >= self.dqn_agent.batch_size:
|
||||
loss = self.dqn_agent.replay()
|
||||
|
||||
# Save checkpoint (automatic based on performance)
|
||||
checkpoint_saved = self.dqn_agent.save_checkpoint(episode_reward)
|
||||
|
||||
if checkpoint_saved:
|
||||
self.training_stats['checkpoints_saved'] += 1
|
||||
|
||||
return {
|
||||
'status': 'completed',
|
||||
'episode_reward': episode_reward,
|
||||
'loss': loss,
|
||||
'checkpoint_saved': checkpoint_saved,
|
||||
'episode': self.dqn_agent.episode_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training DQN agent: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
async def _train_cnn_model(self) -> Dict[str, Any]:
|
||||
"""Train CNN model with automatic checkpointing"""
|
||||
try:
|
||||
if not self.cnn_trainer:
|
||||
return {'status': 'skipped', 'reason': 'no_trainer'}
|
||||
|
||||
# Simulate CNN training step
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
batch_size = 32
|
||||
input_size = 60
|
||||
feature_dim = 50
|
||||
|
||||
# Generate synthetic training data
|
||||
x = torch.randn(batch_size, input_size, feature_dim)
|
||||
y = torch.randint(0, 3, (batch_size,))
|
||||
|
||||
# Training step
|
||||
results = self.cnn_trainer.train_step(x, y)
|
||||
|
||||
# Simulate validation
|
||||
val_x = torch.randn(16, input_size, feature_dim)
|
||||
val_y = torch.randint(0, 3, (16,))
|
||||
val_results = self.cnn_trainer.train_step(val_x, val_y)
|
||||
|
||||
# Save checkpoint (automatic based on performance)
|
||||
checkpoint_saved = self.cnn_trainer.save_checkpoint(
|
||||
train_accuracy=results.get('accuracy', 0.5),
|
||||
val_accuracy=val_results.get('accuracy', 0.5),
|
||||
train_loss=results.get('total_loss', 1.0),
|
||||
val_loss=val_results.get('total_loss', 1.0)
|
||||
)
|
||||
|
||||
if checkpoint_saved:
|
||||
self.training_stats['checkpoints_saved'] += 1
|
||||
|
||||
return {
|
||||
'status': 'completed',
|
||||
'train_accuracy': results.get('accuracy', 0.5),
|
||||
'val_accuracy': val_results.get('accuracy', 0.5),
|
||||
'train_loss': results.get('total_loss', 1.0),
|
||||
'val_loss': val_results.get('total_loss', 1.0),
|
||||
'checkpoint_saved': checkpoint_saved,
|
||||
'epoch': self.cnn_trainer.epoch_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training CNN model: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
async def _train_extrema_detector(self) -> Dict[str, Any]:
|
||||
"""Train extrema detector with automatic checkpointing"""
|
||||
try:
|
||||
if not self.extrema_trainer:
|
||||
return {'status': 'skipped', 'reason': 'no_trainer'}
|
||||
|
||||
# Update context data and detect extrema
|
||||
update_results = self.extrema_trainer.update_context_data()
|
||||
|
||||
# Get training data
|
||||
extrema_data = self.extrema_trainer.get_extrema_training_data(count=10)
|
||||
|
||||
# Simulate training accuracy improvement
|
||||
if extrema_data:
|
||||
self.extrema_trainer.training_stats['total_extrema_detected'] += len(extrema_data)
|
||||
self.extrema_trainer.training_stats['successful_predictions'] += len(extrema_data) // 2
|
||||
self.extrema_trainer.training_stats['failed_predictions'] += len(extrema_data) // 2
|
||||
|
||||
# Save checkpoint (automatic based on performance)
|
||||
checkpoint_saved = self.extrema_trainer.save_checkpoint()
|
||||
|
||||
if checkpoint_saved:
|
||||
self.training_stats['checkpoints_saved'] += 1
|
||||
|
||||
return {
|
||||
'status': 'completed',
|
||||
'extrema_detected': len(extrema_data),
|
||||
'context_updates': sum(1 for success in update_results.values() if success),
|
||||
'checkpoint_saved': checkpoint_saved,
|
||||
'session': self.extrema_trainer.training_session_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training extrema detector: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
async def _process_negative_cases(self) -> Dict[str, Any]:
|
||||
"""Process negative cases with automatic checkpointing"""
|
||||
try:
|
||||
if not self.negative_case_trainer:
|
||||
return {'status': 'skipped', 'reason': 'no_trainer'}
|
||||
|
||||
# Simulate adding a negative case
|
||||
if np.random.random() < 0.1: # 10% chance of negative case
|
||||
trade_info = {
|
||||
'symbol': 'ETH/USDT',
|
||||
'action': 'BUY',
|
||||
'price': 2000.0,
|
||||
'pnl': -50.0, # Loss
|
||||
'value': 1000.0,
|
||||
'confidence': 0.7,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
market_data = {
|
||||
'exit_price': 1950.0,
|
||||
'state_before': {},
|
||||
'state_after': {},
|
||||
'tick_data': [],
|
||||
'technical_indicators': {}
|
||||
}
|
||||
|
||||
case_id = self.negative_case_trainer.add_losing_trade(trade_info, market_data)
|
||||
|
||||
# Simulate loss improvement
|
||||
loss_improvement = np.random.random() * 0.1
|
||||
|
||||
# Save checkpoint (automatic based on performance)
|
||||
checkpoint_saved = self.negative_case_trainer.save_checkpoint(loss_improvement)
|
||||
|
||||
if checkpoint_saved:
|
||||
self.training_stats['checkpoints_saved'] += 1
|
||||
|
||||
return {
|
||||
'status': 'completed',
|
||||
'case_added': case_id,
|
||||
'loss_improvement': loss_improvement,
|
||||
'checkpoint_saved': checkpoint_saved,
|
||||
'session': self.negative_case_trainer.training_session_count
|
||||
}
|
||||
else:
|
||||
return {'status': 'no_cases'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing negative cases: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
async def _coordinate_checkpoint_saving(self, dqn_results: Dict, cnn_results: Dict,
|
||||
extrema_results: Dict, negative_results: Dict):
|
||||
"""Coordinate checkpoint saving across all components"""
|
||||
try:
|
||||
# Count successful checkpoints
|
||||
checkpoints_saved = sum([
|
||||
dqn_results.get('checkpoint_saved', False),
|
||||
cnn_results.get('checkpoint_saved', False),
|
||||
extrema_results.get('checkpoint_saved', False),
|
||||
negative_results.get('checkpoint_saved', False)
|
||||
])
|
||||
|
||||
if checkpoints_saved > 0:
|
||||
logger.info(f"Saved {checkpoints_saved} checkpoints this cycle")
|
||||
|
||||
# Update best performances
|
||||
if 'episode_reward' in dqn_results:
|
||||
current_best = self.training_stats['best_performances'].get('dqn_reward', float('-inf'))
|
||||
if dqn_results['episode_reward'] > current_best:
|
||||
self.training_stats['best_performances']['dqn_reward'] = dqn_results['episode_reward']
|
||||
|
||||
if 'val_accuracy' in cnn_results:
|
||||
current_best = self.training_stats['best_performances'].get('cnn_accuracy', 0.0)
|
||||
if cnn_results['val_accuracy'] > current_best:
|
||||
self.training_stats['best_performances']['cnn_accuracy'] = cnn_results['val_accuracy']
|
||||
|
||||
# Log checkpoint statistics every 10 cycles
|
||||
if self.training_stats['total_training_sessions'] % 10 == 0:
|
||||
await self._log_checkpoint_statistics()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error coordinating checkpoint saving: {e}")
|
||||
|
||||
async def _log_checkpoint_statistics(self):
|
||||
"""Log comprehensive checkpoint statistics"""
|
||||
try:
|
||||
stats = get_checkpoint_stats()
|
||||
|
||||
logger.info("=== Checkpoint Statistics ===")
|
||||
logger.info(f"Total checkpoints: {stats['total_checkpoints']}")
|
||||
logger.info(f"Total size: {stats['total_size_mb']:.2f} MB")
|
||||
logger.info(f"Models managed: {len(stats['models'])}")
|
||||
|
||||
for model_name, model_stats in stats['models'].items():
|
||||
logger.info(f" {model_name}: {model_stats['checkpoint_count']} checkpoints, "
|
||||
f"{model_stats['total_size_mb']:.2f} MB, "
|
||||
f"best: {model_stats['best_performance']:.4f}")
|
||||
|
||||
logger.info(f"Training sessions: {self.training_stats['total_training_sessions']}")
|
||||
logger.info(f"Checkpoints saved: {self.training_stats['checkpoints_saved']}")
|
||||
logger.info(f"Best performances: {self.training_stats['best_performances']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging checkpoint statistics: {e}")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown the training system and save final checkpoints"""
|
||||
logger.info("Shutting down checkpoint-integrated training system...")
|
||||
|
||||
self.running = False
|
||||
|
||||
try:
|
||||
# Force save checkpoints for all components
|
||||
if self.dqn_agent:
|
||||
self.dqn_agent.save_checkpoint(0.0, force_save=True)
|
||||
|
||||
if self.cnn_trainer:
|
||||
self.cnn_trainer.save_checkpoint(0.0, 0.0, 0.0, 0.0, force_save=True)
|
||||
|
||||
if self.extrema_trainer:
|
||||
self.extrema_trainer.save_checkpoint(force_save=True)
|
||||
|
||||
if self.negative_case_trainer:
|
||||
self.negative_case_trainer.save_checkpoint(force_save=True)
|
||||
|
||||
# Final statistics
|
||||
await self._log_checkpoint_statistics()
|
||||
|
||||
logger.info("Checkpoint-integrated training system shutdown complete")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during shutdown: {e}")
|
||||
|
||||
async def main():
|
||||
"""Main function to run the checkpoint-integrated training system"""
|
||||
logger.info("🚀 Starting Checkpoint-Integrated Training System")
|
||||
|
||||
# Create and initialize the training system
|
||||
training_system = CheckpointIntegratedTrainingSystem()
|
||||
|
||||
# Setup signal handlers for graceful shutdown
|
||||
def signal_handler(signum, frame):
|
||||
logger.info("Received shutdown signal")
|
||||
asyncio.create_task(training_system.shutdown())
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
# Initialize components
|
||||
await training_system.initialize_components()
|
||||
|
||||
# Run the integrated training loop
|
||||
await training_system.run_integrated_training_loop()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main: {e}")
|
||||
raise
|
||||
finally:
|
||||
await training_system.shutdown()
|
||||
|
||||
logger.info("✅ Checkpoint management integration complete!")
|
||||
logger.info("All training pipelines now support automatic checkpointing")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Ensure logs directory exists
|
||||
Path("logs").mkdir(exist_ok=True)
|
||||
|
||||
# Run the checkpoint-integrated training system
|
||||
asyncio.run(main())
|
@@ -1,13 +0,0 @@
|
||||
"""
|
||||
Neural Network Utilities
|
||||
======================
|
||||
|
||||
This package contains utility functions and classes used in the neural network trading system:
|
||||
- Data Interface: Connects to realtime trading data and processes it for the neural network models
|
||||
"""
|
||||
|
||||
from .data_interface import DataInterface
|
||||
from .trading_env import TradingEnvironment
|
||||
from .signal_interpreter import SignalInterpreter
|
||||
|
||||
__all__ = ['DataInterface', 'TradingEnvironment', 'SignalInterpreter']
|
@@ -1,123 +0,0 @@
|
||||
"""
|
||||
Enhanced Data Interface with additional NN trading parameters
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from .data_interface import DataInterface
|
||||
|
||||
class MultiDataInterface(DataInterface):
|
||||
"""
|
||||
Enhanced data interface that supports window_size and output_size parameters
|
||||
for neural network trading models.
|
||||
"""
|
||||
|
||||
def __init__(self, symbol: str,
|
||||
timeframes: List[str],
|
||||
window_size: int = 20,
|
||||
output_size: int = 3,
|
||||
data_dir: str = "NN/data"):
|
||||
"""
|
||||
Initialize with window_size and output_size for NN predictions.
|
||||
"""
|
||||
super().__init__(symbol, timeframes, data_dir)
|
||||
self.window_size = window_size
|
||||
self.output_size = output_size
|
||||
self.scalers = {} # Store scalers for each timeframe
|
||||
self.min_window_threshold = 100 # Minimum candles needed for training
|
||||
|
||||
def get_feature_count(self) -> int:
|
||||
"""
|
||||
Get number of features (OHLCV) for NN input.
|
||||
"""
|
||||
return 5 # open, high, low, close, volume
|
||||
|
||||
def prepare_training_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Prepare training data with windowed sequences"""
|
||||
# Get historical data for primary timeframe
|
||||
primary_tf = self.timeframes[0]
|
||||
df = self.get_historical_data(timeframe=primary_tf,
|
||||
n_candles=self.min_window_threshold + 1000)
|
||||
|
||||
if df is None or len(df) < self.min_window_threshold:
|
||||
raise ValueError(f"Insufficient data for training. Need at least {self.min_window_threshold} candles")
|
||||
|
||||
# Prepare OHLCV sequences
|
||||
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values
|
||||
|
||||
# Create sequences and labels
|
||||
X = []
|
||||
y = []
|
||||
|
||||
for i in range(len(ohlcv) - self.window_size - self.output_size):
|
||||
# Input sequence
|
||||
seq = ohlcv[i:i+self.window_size]
|
||||
X.append(seq)
|
||||
|
||||
# Output target (price movement direction)
|
||||
close_prices = ohlcv[i+self.window_size:i+self.window_size+self.output_size, 3] # Close prices
|
||||
price_changes = np.diff(close_prices)
|
||||
|
||||
if self.output_size == 1:
|
||||
# Binary classification (up/down)
|
||||
label = 1 if price_changes[0] > 0 else 0
|
||||
elif self.output_size == 3:
|
||||
# 3-class classification (buy/hold/sell)
|
||||
if price_changes[0] > 0.002: # Significant rise
|
||||
label = 0 # Buy
|
||||
elif price_changes[0] < -0.002: # Significant drop
|
||||
label = 2 # Sell
|
||||
else:
|
||||
label = 1 # Hold
|
||||
else:
|
||||
raise ValueError(f"Unsupported output_size: {self.output_size}")
|
||||
|
||||
y.append(label)
|
||||
|
||||
# Convert to numpy arrays
|
||||
X = np.array(X)
|
||||
y = np.array(y)
|
||||
|
||||
# Split into train/validation (80/20)
|
||||
split_idx = int(0.8 * len(X))
|
||||
X_train, y_train = X[:split_idx], y[:split_idx]
|
||||
X_val, y_val = X[split_idx:], y[split_idx:]
|
||||
|
||||
return X_train, y_train, X_val, y_val
|
||||
|
||||
def prepare_prediction_data(self) -> np.ndarray:
|
||||
"""Prepare most recent window for predictions"""
|
||||
primary_tf = self.timeframes[0]
|
||||
df = self.get_historical_data(timeframe=primary_tf,
|
||||
n_candles=self.window_size,
|
||||
use_cache=False)
|
||||
|
||||
if df is None or len(df) < self.window_size:
|
||||
raise ValueError(f"Need at least {self.window_size} candles for prediction")
|
||||
|
||||
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values[-self.window_size:]
|
||||
return np.array([ohlcv]) # Add batch dimension
|
||||
|
||||
def process_predictions(self, predictions: np.ndarray):
|
||||
"""Convert prediction probabilities to trading signals"""
|
||||
signals = []
|
||||
for pred in predictions:
|
||||
if self.output_size == 1:
|
||||
signal = "BUY" if pred[0] > 0.5 else "SELL"
|
||||
confidence = np.abs(pred[0] - 0.5) * 2 # Convert to 0-1 scale
|
||||
elif self.output_size == 3:
|
||||
action_idx = np.argmax(pred)
|
||||
signal = ["BUY", "HOLD", "SELL"][action_idx]
|
||||
confidence = pred[action_idx]
|
||||
else:
|
||||
signal = "HOLD"
|
||||
confidence = 0.0
|
||||
|
||||
signals.append({
|
||||
'action': signal,
|
||||
'confidence': confidence,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
})
|
||||
|
||||
return signals
|
@@ -1,364 +0,0 @@
|
||||
"""
|
||||
Realtime Analyzer for Neural Network Trading System
|
||||
|
||||
This module implements real-time analysis of market data using trained neural network models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
import os
|
||||
import pandas as pd
|
||||
from collections import deque
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RealtimeAnalyzer:
|
||||
"""
|
||||
Handles real-time analysis of market data using trained neural network models.
|
||||
|
||||
Features:
|
||||
- Connects to real-time data sources (websockets)
|
||||
- Processes tick data into multiple timeframes (1s, 1m, 1h, 1d)
|
||||
- Uses trained models to analyze all timeframes
|
||||
- Generates trading signals
|
||||
- Manages risk and position sizing
|
||||
- Logs all trading decisions
|
||||
"""
|
||||
|
||||
def __init__(self, data_interface, model, symbol="BTC/USDT", timeframes=None):
|
||||
"""
|
||||
Initialize the realtime analyzer.
|
||||
|
||||
Args:
|
||||
data_interface (DataInterface): Preconfigured data interface
|
||||
model: Trained neural network model
|
||||
symbol (str): Trading pair symbol
|
||||
timeframes (list): List of timeframes to monitor (default: ['1s', '1m', '1h', '1d'])
|
||||
"""
|
||||
self.data_interface = data_interface
|
||||
self.model = model
|
||||
self.symbol = symbol
|
||||
self.timeframes = timeframes or ['1s', '1m', '1h', '1d']
|
||||
self.running = False
|
||||
self.data_queue = Queue()
|
||||
self.prediction_interval = 10 # Seconds between predictions
|
||||
self.ws_url = f"wss://stream.binance.com:9443/ws/{symbol.replace('/', '').lower()}@trade"
|
||||
self.ws = None
|
||||
self.tick_storage = deque(maxlen=10000) # Store up to 10,000 ticks
|
||||
self.candle_cache = {
|
||||
'1s': deque(maxlen=5000),
|
||||
'1m': deque(maxlen=5000),
|
||||
'1h': deque(maxlen=5000),
|
||||
'1d': deque(maxlen=5000)
|
||||
}
|
||||
|
||||
logger.info(f"RealtimeAnalyzer initialized for {symbol} with timeframes: {self.timeframes}")
|
||||
|
||||
def start(self):
|
||||
"""Start the realtime analysis process."""
|
||||
if self.running:
|
||||
logger.warning("Realtime analyzer already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
|
||||
# Start WebSocket connection thread
|
||||
self.ws_thread = Thread(target=self._run_websocket, daemon=True)
|
||||
self.ws_thread.start()
|
||||
|
||||
# Start data processing thread
|
||||
self.processing_thread = Thread(target=self._process_data, daemon=True)
|
||||
self.processing_thread.start()
|
||||
|
||||
# Start analysis thread
|
||||
self.analysis_thread = Thread(target=self._analyze_data, daemon=True)
|
||||
self.analysis_thread.start()
|
||||
|
||||
logger.info("Realtime analysis started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the realtime analysis process."""
|
||||
self.running = False
|
||||
if self.ws:
|
||||
asyncio.run(self.ws.close())
|
||||
if hasattr(self, 'ws_thread'):
|
||||
self.ws_thread.join(timeout=1)
|
||||
if hasattr(self, 'processing_thread'):
|
||||
self.processing_thread.join(timeout=1)
|
||||
if hasattr(self, 'analysis_thread'):
|
||||
self.analysis_thread.join(timeout=1)
|
||||
logger.info("Realtime analysis stopped")
|
||||
|
||||
def _run_websocket(self):
|
||||
"""Thread function for running WebSocket connection."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self._connect_websocket())
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Connect to WebSocket and receive data."""
|
||||
while self.running:
|
||||
try:
|
||||
logger.info(f"Connecting to WebSocket: {self.ws_url}")
|
||||
async with websockets.connect(self.ws_url) as ws:
|
||||
self.ws = ws
|
||||
logger.info("WebSocket connected")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
message = await ws.recv()
|
||||
data = json.loads(message)
|
||||
|
||||
if 'e' in data and data['e'] == 'trade':
|
||||
tick = {
|
||||
'timestamp': data['T'],
|
||||
'price': float(data['p']),
|
||||
'volume': float(data['q']),
|
||||
'symbol': self.symbol
|
||||
}
|
||||
self.tick_storage.append(tick)
|
||||
self.data_queue.put(tick)
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.warning("WebSocket connection closed")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving WebSocket message: {str(e)}")
|
||||
time.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket connection error: {str(e)}")
|
||||
time.sleep(5) # Wait before reconnecting
|
||||
|
||||
def _process_data(self):
|
||||
"""Process incoming tick data into candles for all timeframes."""
|
||||
logger.info("Starting data processing thread")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Process any new ticks
|
||||
while not self.data_queue.empty():
|
||||
tick = self.data_queue.get()
|
||||
|
||||
# Convert timestamp to datetime
|
||||
timestamp = datetime.fromtimestamp(tick['timestamp'] / 1000)
|
||||
|
||||
# Process for each timeframe
|
||||
for timeframe in self.timeframes:
|
||||
interval = self._get_interval_seconds(timeframe)
|
||||
if interval is None:
|
||||
continue
|
||||
|
||||
# Round timestamp to nearest candle interval
|
||||
candle_ts = int(tick['timestamp'] // (interval * 1000)) * (interval * 1000)
|
||||
|
||||
# Get or create candle for this timeframe
|
||||
if not self.candle_cache[timeframe]:
|
||||
# First candle for this timeframe
|
||||
candle = {
|
||||
'timestamp': candle_ts,
|
||||
'open': tick['price'],
|
||||
'high': tick['price'],
|
||||
'low': tick['price'],
|
||||
'close': tick['price'],
|
||||
'volume': tick['volume']
|
||||
}
|
||||
self.candle_cache[timeframe].append(candle)
|
||||
else:
|
||||
# Update existing candle
|
||||
last_candle = self.candle_cache[timeframe][-1]
|
||||
|
||||
if last_candle['timestamp'] == candle_ts:
|
||||
# Update current candle
|
||||
last_candle['high'] = max(last_candle['high'], tick['price'])
|
||||
last_candle['low'] = min(last_candle['low'], tick['price'])
|
||||
last_candle['close'] = tick['price']
|
||||
last_candle['volume'] += tick['volume']
|
||||
else:
|
||||
# New candle
|
||||
candle = {
|
||||
'timestamp': candle_ts,
|
||||
'open': tick['price'],
|
||||
'high': tick['price'],
|
||||
'low': tick['price'],
|
||||
'close': tick['price'],
|
||||
'volume': tick['volume']
|
||||
}
|
||||
self.candle_cache[timeframe].append(candle)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in data processing: {str(e)}")
|
||||
time.sleep(1)
|
||||
|
||||
def _get_interval_seconds(self, timeframe):
|
||||
"""Convert timeframe string to seconds."""
|
||||
intervals = {
|
||||
'1s': 1,
|
||||
'1m': 60,
|
||||
'1h': 3600,
|
||||
'1d': 86400
|
||||
}
|
||||
return intervals.get(timeframe)
|
||||
|
||||
def _analyze_data(self):
|
||||
"""Thread function for analyzing data and generating signals."""
|
||||
logger.info("Starting analysis thread")
|
||||
|
||||
last_prediction_time = 0
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# Only make predictions at the specified interval
|
||||
if current_time - last_prediction_time < self.prediction_interval:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Prepare input data from all timeframes
|
||||
input_data = {}
|
||||
valid = True
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
if not self.candle_cache[timeframe]:
|
||||
logger.warning(f"No data available for timeframe {timeframe}")
|
||||
valid = False
|
||||
break
|
||||
|
||||
# Get last N candles for this timeframe
|
||||
candles = list(self.candle_cache[timeframe])[-self.data_interface.window_size:]
|
||||
|
||||
# Convert to numpy array
|
||||
ohlcv = np.array([
|
||||
[c['open'], c['high'], c['low'], c['close'], c['volume']]
|
||||
for c in candles
|
||||
])
|
||||
|
||||
# Normalize data
|
||||
ohlcv_normalized = (ohlcv - ohlcv.mean(axis=0)) / (ohlcv.std(axis=0) + 1e-8)
|
||||
input_data[timeframe] = ohlcv_normalized
|
||||
|
||||
if not valid:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Make prediction using the model
|
||||
try:
|
||||
prediction = self.model.predict(input_data)
|
||||
|
||||
# Get latest timestamp from 1s timeframe
|
||||
latest_ts = self.candle_cache['1s'][-1]['timestamp'] if self.candle_cache['1s'] else int(time.time() * 1000)
|
||||
|
||||
# Process prediction
|
||||
self._process_prediction(
|
||||
prediction=prediction,
|
||||
timeframe='multi',
|
||||
timestamp=latest_ts
|
||||
)
|
||||
|
||||
last_prediction_time = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"Error making prediction: {str(e)}")
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in analysis: {str(e)}")
|
||||
time.sleep(1)
|
||||
|
||||
def _process_prediction(self, prediction, timeframe, timestamp):
|
||||
"""
|
||||
Process model prediction and generate trading signals.
|
||||
|
||||
Args:
|
||||
prediction: Model prediction output
|
||||
timeframe (str): Timeframe the prediction is for ('multi' for combined)
|
||||
timestamp: Timestamp of the prediction (ms)
|
||||
"""
|
||||
# Convert prediction to trading signal
|
||||
signal, confidence = self._prediction_to_signal(prediction)
|
||||
|
||||
# Convert timestamp to datetime
|
||||
try:
|
||||
dt = datetime.fromtimestamp(timestamp / 1000)
|
||||
except:
|
||||
dt = datetime.now()
|
||||
|
||||
# Log the signal with all timeframes
|
||||
logger.info(
|
||||
f"Signal generated - Timeframes: {', '.join(self.timeframes)}, "
|
||||
f"Timestamp: {dt}, "
|
||||
f"Signal: {signal} (Confidence: {confidence:.2f})"
|
||||
)
|
||||
|
||||
# In a real implementation, we would execute trades here
|
||||
# For now, we'll just log the signals
|
||||
|
||||
def _prediction_to_signal(self, prediction):
|
||||
"""
|
||||
Convert model prediction to trading signal and confidence.
|
||||
|
||||
Args:
|
||||
prediction: Model prediction output (can be dict for multi-timeframe)
|
||||
|
||||
Returns:
|
||||
tuple: (signal, confidence) where signal is BUY/SELL/HOLD,
|
||||
confidence is probability (0-1)
|
||||
"""
|
||||
if isinstance(prediction, dict):
|
||||
# Multi-timeframe prediction - combine signals
|
||||
signals = []
|
||||
confidences = []
|
||||
|
||||
for tf, pred in prediction.items():
|
||||
if len(pred.shape) == 1:
|
||||
# Binary classification
|
||||
signal = "BUY" if pred[0] > 0.5 else "SELL"
|
||||
confidence = pred[0] if signal == "BUY" else 1 - pred[0]
|
||||
else:
|
||||
# Multi-class
|
||||
class_idx = np.argmax(pred)
|
||||
signal = ["SELL", "HOLD", "BUY"][class_idx]
|
||||
confidence = pred[class_idx]
|
||||
|
||||
signals.append(signal)
|
||||
confidences.append(confidence)
|
||||
|
||||
# Simple voting system - count BUY/SELL signals
|
||||
buy_count = signals.count("BUY")
|
||||
sell_count = signals.count("SELL")
|
||||
|
||||
if buy_count > sell_count:
|
||||
final_signal = "BUY"
|
||||
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "BUY"])
|
||||
elif sell_count > buy_count:
|
||||
final_signal = "SELL"
|
||||
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "SELL"])
|
||||
else:
|
||||
final_signal = "HOLD"
|
||||
final_confidence = np.mean(confidences)
|
||||
|
||||
return final_signal, final_confidence
|
||||
|
||||
else:
|
||||
# Single prediction
|
||||
if len(prediction.shape) == 1:
|
||||
# Binary classification
|
||||
signal = "BUY" if prediction[0] > 0.5 else "SELL"
|
||||
confidence = prediction[0] if signal == "BUY" else 1 - prediction[0]
|
||||
else:
|
||||
# Multi-class
|
||||
class_idx = np.argmax(prediction)
|
||||
signal = ["SELL", "HOLD", "BUY"][class_idx]
|
||||
confidence = prediction[class_idx]
|
||||
|
||||
return signal, confidence
|
Reference in New Issue
Block a user