Compare commits
4 Commits
5f7032937e
...
aa2a1bf7ee
Author | SHA1 | Date | |
---|---|---|---|
![]() |
aa2a1bf7ee | ||
![]() |
b1ae557843 | ||
![]() |
0b5fa07498 | ||
![]() |
ac4068c168 |
@@ -1,201 +1,201 @@
|
|||||||
"""
|
# """
|
||||||
Legacy CNN Model Compatibility Layer
|
# Legacy CNN Model Compatibility Layer
|
||||||
|
|
||||||
This module provides compatibility redirects to the unified StandardizedCNN model.
|
# This module provides compatibility redirects to the unified StandardizedCNN model.
|
||||||
All legacy models (EnhancedCNNModel, CNNModelTrainer, CNNModel) have been retired
|
# All legacy models (EnhancedCNNModel, CNNModelTrainer, CNNModel) have been retired
|
||||||
in favor of the StandardizedCNN architecture.
|
# in favor of the StandardizedCNN architecture.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
import logging
|
# import logging
|
||||||
import warnings
|
# import warnings
|
||||||
from typing import Tuple, Dict, Any, Optional
|
# from typing import Tuple, Dict, Any, Optional
|
||||||
import torch
|
# import torch
|
||||||
import numpy as np
|
# import numpy as np
|
||||||
|
|
||||||
# Import the standardized CNN model
|
# # Import the standardized CNN model
|
||||||
from .standardized_cnn import StandardizedCNN
|
# from .standardized_cnn import StandardizedCNN
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
# logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Compatibility aliases and wrappers
|
# # Compatibility aliases and wrappers
|
||||||
class EnhancedCNNModel:
|
# class EnhancedCNNModel:
|
||||||
"""Legacy compatibility wrapper - redirects to StandardizedCNN"""
|
# """Legacy compatibility wrapper - redirects to StandardizedCNN"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
# def __init__(self, *args, **kwargs):
|
||||||
warnings.warn(
|
# warnings.warn(
|
||||||
"EnhancedCNNModel is deprecated. Use StandardizedCNN instead.",
|
# "EnhancedCNNModel is deprecated. Use StandardizedCNN instead.",
|
||||||
DeprecationWarning,
|
# DeprecationWarning,
|
||||||
stacklevel=2
|
# stacklevel=2
|
||||||
)
|
# )
|
||||||
# Create StandardizedCNN with default parameters
|
# # Create StandardizedCNN with default parameters
|
||||||
self.standardized_cnn = StandardizedCNN()
|
# self.standardized_cnn = StandardizedCNN()
|
||||||
logger.warning("EnhancedCNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
# logger.warning("EnhancedCNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
||||||
|
|
||||||
def __getattr__(self, name):
|
# def __getattr__(self, name):
|
||||||
"""Delegate all method calls to StandardizedCNN"""
|
# """Delegate all method calls to StandardizedCNN"""
|
||||||
return getattr(self.standardized_cnn, name)
|
# return getattr(self.standardized_cnn, name)
|
||||||
|
|
||||||
|
|
||||||
class CNNModelTrainer:
|
# class CNNModelTrainer:
|
||||||
"""Legacy compatibility wrapper for CNN training"""
|
# """Legacy compatibility wrapper for CNN training"""
|
||||||
|
|
||||||
def __init__(self, model=None, *args, **kwargs):
|
# def __init__(self, model=None, *args, **kwargs):
|
||||||
warnings.warn(
|
# warnings.warn(
|
||||||
"CNNModelTrainer is deprecated. Use StandardizedCNN.train_step() instead.",
|
# "CNNModelTrainer is deprecated. Use StandardizedCNN.train_step() instead.",
|
||||||
DeprecationWarning,
|
# DeprecationWarning,
|
||||||
stacklevel=2
|
# stacklevel=2
|
||||||
)
|
# )
|
||||||
if isinstance(model, EnhancedCNNModel):
|
# if isinstance(model, EnhancedCNNModel):
|
||||||
self.model = model.standardized_cnn
|
# self.model = model.standardized_cnn
|
||||||
else:
|
# else:
|
||||||
self.model = StandardizedCNN()
|
# self.model = StandardizedCNN()
|
||||||
logger.warning("CNNModelTrainer compatibility wrapper created - please use StandardizedCNN.train_step()")
|
# logger.warning("CNNModelTrainer compatibility wrapper created - please use StandardizedCNN.train_step()")
|
||||||
|
|
||||||
def train_step(self, x, y, *args, **kwargs):
|
# def train_step(self, x, y, *args, **kwargs):
|
||||||
"""Legacy train step wrapper"""
|
# """Legacy train step wrapper"""
|
||||||
try:
|
# try:
|
||||||
# Convert to BaseDataInput format if needed
|
# # Convert to BaseDataInput format if needed
|
||||||
if hasattr(x, 'get_feature_vector'):
|
# if hasattr(x, 'get_feature_vector'):
|
||||||
# Already BaseDataInput
|
# # Already BaseDataInput
|
||||||
base_input = x
|
# base_input = x
|
||||||
else:
|
# else:
|
||||||
# Create mock BaseDataInput for legacy compatibility
|
# # Create mock BaseDataInput for legacy compatibility
|
||||||
from core.data_models import BaseDataInput
|
# from core.data_models import BaseDataInput
|
||||||
base_input = BaseDataInput()
|
# base_input = BaseDataInput()
|
||||||
# Set mock feature vector
|
# # Set mock feature vector
|
||||||
if isinstance(x, torch.Tensor):
|
# if isinstance(x, torch.Tensor):
|
||||||
feature_vector = x.flatten().cpu().numpy()
|
# feature_vector = x.flatten().cpu().numpy()
|
||||||
else:
|
# else:
|
||||||
feature_vector = np.array(x).flatten()
|
# feature_vector = np.array(x).flatten()
|
||||||
|
|
||||||
# Pad or truncate to expected size
|
# # Pad or truncate to expected size
|
||||||
expected_size = self.model.expected_feature_dim
|
# expected_size = self.model.expected_feature_dim
|
||||||
if len(feature_vector) < expected_size:
|
# if len(feature_vector) < expected_size:
|
||||||
padding = np.zeros(expected_size - len(feature_vector))
|
# padding = np.zeros(expected_size - len(feature_vector))
|
||||||
feature_vector = np.concatenate([feature_vector, padding])
|
# feature_vector = np.concatenate([feature_vector, padding])
|
||||||
else:
|
# else:
|
||||||
feature_vector = feature_vector[:expected_size]
|
# feature_vector = feature_vector[:expected_size]
|
||||||
|
|
||||||
base_input._feature_vector = feature_vector
|
# base_input._feature_vector = feature_vector
|
||||||
|
|
||||||
# Convert target to string format
|
# # Convert target to string format
|
||||||
if isinstance(y, torch.Tensor):
|
# if isinstance(y, torch.Tensor):
|
||||||
y_val = y.item() if y.numel() == 1 else y.argmax().item()
|
# y_val = y.item() if y.numel() == 1 else y.argmax().item()
|
||||||
else:
|
# else:
|
||||||
y_val = int(y) if np.isscalar(y) else int(np.argmax(y))
|
# y_val = int(y) if np.isscalar(y) else int(np.argmax(y))
|
||||||
|
|
||||||
target_map = {0: 'BUY', 1: 'SELL', 2: 'HOLD'}
|
# target_map = {0: 'BUY', 1: 'SELL', 2: 'HOLD'}
|
||||||
target = target_map.get(y_val, 'HOLD')
|
# target = target_map.get(y_val, 'HOLD')
|
||||||
|
|
||||||
# Use StandardizedCNN training
|
# # Use StandardizedCNN training
|
||||||
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
# optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
||||||
loss = self.model.train_step([base_input], [target], optimizer)
|
# loss = self.model.train_step([base_input], [target], optimizer)
|
||||||
|
|
||||||
return {'total_loss': loss, 'main_loss': loss, 'accuracy': 0.5}
|
# return {'total_loss': loss, 'main_loss': loss, 'accuracy': 0.5}
|
||||||
|
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"Legacy train_step error: {e}")
|
# logger.error(f"Legacy train_step error: {e}")
|
||||||
return {'total_loss': 0.0, 'main_loss': 0.0, 'accuracy': 0.5}
|
# return {'total_loss': 0.0, 'main_loss': 0.0, 'accuracy': 0.5}
|
||||||
|
|
||||||
|
|
||||||
class CNNModel:
|
# # class CNNModel:
|
||||||
"""Legacy compatibility wrapper for CNN model interface"""
|
# # """Legacy compatibility wrapper for CNN model interface"""
|
||||||
|
|
||||||
def __init__(self, input_shape=(900, 50), output_size=3, model_path=None):
|
# # def __init__(self, input_shape=(900, 50), output_size=3, model_path=None):
|
||||||
warnings.warn(
|
# # warnings.warn(
|
||||||
"CNNModel is deprecated. Use StandardizedCNN directly.",
|
# # "CNNModel is deprecated. Use StandardizedCNN directly.",
|
||||||
DeprecationWarning,
|
# # DeprecationWarning,
|
||||||
stacklevel=2
|
# # stacklevel=2
|
||||||
)
|
# # )
|
||||||
self.input_shape = input_shape
|
# # self.input_shape = input_shape
|
||||||
self.output_size = output_size
|
# # self.output_size = output_size
|
||||||
self.standardized_cnn = StandardizedCNN()
|
# # self.standardized_cnn = StandardizedCNN()
|
||||||
self.trainer = CNNModelTrainer(self.standardized_cnn)
|
# # self.trainer = CNNModelTrainer(self.standardized_cnn)
|
||||||
logger.warning("CNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
# # logger.warning("CNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
||||||
|
|
||||||
def build_model(self, **kwargs):
|
# # def build_model(self, **kwargs):
|
||||||
"""Legacy build method - no-op for StandardizedCNN"""
|
# # """Legacy build method - no-op for StandardizedCNN"""
|
||||||
return self
|
# # return self
|
||||||
|
|
||||||
def predict(self, X):
|
# # def predict(self, X):
|
||||||
"""Legacy predict method"""
|
# # """Legacy predict method"""
|
||||||
try:
|
# # try:
|
||||||
# Convert input to BaseDataInput
|
# # # Convert input to BaseDataInput
|
||||||
from core.data_models import BaseDataInput
|
# # from core.data_models import BaseDataInput
|
||||||
base_input = BaseDataInput()
|
# # base_input = BaseDataInput()
|
||||||
|
|
||||||
if isinstance(X, np.ndarray):
|
# # if isinstance(X, np.ndarray):
|
||||||
feature_vector = X.flatten()
|
# # feature_vector = X.flatten()
|
||||||
else:
|
# # else:
|
||||||
feature_vector = np.array(X).flatten()
|
# # feature_vector = np.array(X).flatten()
|
||||||
|
|
||||||
# Pad or truncate to expected size
|
# # # Pad or truncate to expected size
|
||||||
expected_size = self.standardized_cnn.expected_feature_dim
|
# # expected_size = self.standardized_cnn.expected_feature_dim
|
||||||
if len(feature_vector) < expected_size:
|
# # if len(feature_vector) < expected_size:
|
||||||
padding = np.zeros(expected_size - len(feature_vector))
|
# # padding = np.zeros(expected_size - len(feature_vector))
|
||||||
feature_vector = np.concatenate([feature_vector, padding])
|
# # feature_vector = np.concatenate([feature_vector, padding])
|
||||||
else:
|
# # else:
|
||||||
feature_vector = feature_vector[:expected_size]
|
# # feature_vector = feature_vector[:expected_size]
|
||||||
|
|
||||||
base_input._feature_vector = feature_vector
|
# # base_input._feature_vector = feature_vector
|
||||||
|
|
||||||
# Get prediction from StandardizedCNN
|
# # # Get prediction from StandardizedCNN
|
||||||
result = self.standardized_cnn.predict_from_base_input(base_input)
|
# # result = self.standardized_cnn.predict_from_base_input(base_input)
|
||||||
|
|
||||||
# Convert to legacy format
|
# # # Convert to legacy format
|
||||||
action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
|
# # action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
|
||||||
pred_class = np.array([action_map.get(result.predictions['action'], 2)])
|
# # pred_class = np.array([action_map.get(result.predictions['action'], 2)])
|
||||||
pred_proba = np.array([result.predictions['action_probabilities']])
|
# # pred_proba = np.array([result.predictions['action_probabilities']])
|
||||||
|
|
||||||
return pred_class, pred_proba
|
# # return pred_class, pred_proba
|
||||||
|
|
||||||
except Exception as e:
|
# # except Exception as e:
|
||||||
logger.error(f"Legacy predict error: {e}")
|
# # logger.error(f"Legacy predict error: {e}")
|
||||||
# Return safe defaults
|
# # # Return safe defaults
|
||||||
pred_class = np.array([2]) # HOLD
|
# # pred_class = np.array([2]) # HOLD
|
||||||
pred_proba = np.array([[0.33, 0.33, 0.34]])
|
# # pred_proba = np.array([[0.33, 0.33, 0.34]])
|
||||||
return pred_class, pred_proba
|
# # return pred_class, pred_proba
|
||||||
|
|
||||||
def fit(self, X, y, **kwargs):
|
# # def fit(self, X, y, **kwargs):
|
||||||
"""Legacy fit method"""
|
# # """Legacy fit method"""
|
||||||
try:
|
# # try:
|
||||||
return self.trainer.train_step(X, y)
|
# # return self.trainer.train_step(X, y)
|
||||||
except Exception as e:
|
# # except Exception as e:
|
||||||
logger.error(f"Legacy fit error: {e}")
|
# # logger.error(f"Legacy fit error: {e}")
|
||||||
return self
|
# # return self
|
||||||
|
|
||||||
def save(self, filepath: str):
|
# # def save(self, filepath: str):
|
||||||
"""Legacy save method"""
|
# # """Legacy save method"""
|
||||||
try:
|
# # try:
|
||||||
torch.save(self.standardized_cnn.state_dict(), filepath)
|
# # torch.save(self.standardized_cnn.state_dict(), filepath)
|
||||||
logger.info(f"StandardizedCNN saved to {filepath}")
|
# # logger.info(f"StandardizedCNN saved to {filepath}")
|
||||||
except Exception as e:
|
# # except Exception as e:
|
||||||
logger.error(f"Error saving model: {e}")
|
# # logger.error(f"Error saving model: {e}")
|
||||||
|
|
||||||
|
|
||||||
def create_enhanced_cnn_model(input_size: int = 60,
|
# def create_enhanced_cnn_model(input_size: int = 60,
|
||||||
feature_dim: int = 50,
|
# feature_dim: int = 50,
|
||||||
output_size: int = 3,
|
# output_size: int = 3,
|
||||||
base_channels: int = 256,
|
# base_channels: int = 256,
|
||||||
device: str = 'cuda') -> Tuple[StandardizedCNN, CNNModelTrainer]:
|
# device: str = 'cuda') -> Tuple[StandardizedCNN, CNNModelTrainer]:
|
||||||
"""Legacy compatibility function - returns StandardizedCNN"""
|
# """Legacy compatibility function - returns StandardizedCNN"""
|
||||||
warnings.warn(
|
# warnings.warn(
|
||||||
"create_enhanced_cnn_model is deprecated. Use StandardizedCNN() directly.",
|
# "create_enhanced_cnn_model is deprecated. Use StandardizedCNN() directly.",
|
||||||
DeprecationWarning,
|
# DeprecationWarning,
|
||||||
stacklevel=2
|
# stacklevel=2
|
||||||
)
|
# )
|
||||||
|
|
||||||
model = StandardizedCNN()
|
# model = StandardizedCNN()
|
||||||
trainer = CNNModelTrainer(model)
|
# trainer = CNNModelTrainer(model)
|
||||||
|
|
||||||
logger.warning("Legacy create_enhanced_cnn_model called - please use StandardizedCNN directly")
|
# logger.warning("Legacy create_enhanced_cnn_model called - please use StandardizedCNN directly")
|
||||||
return model, trainer
|
# return model, trainer
|
||||||
|
|
||||||
|
|
||||||
# Export compatibility symbols
|
# # Export compatibility symbols
|
||||||
__all__ = [
|
# __all__ = [
|
||||||
'EnhancedCNNModel',
|
# 'EnhancedCNNModel',
|
||||||
'CNNModelTrainer',
|
# 'CNNModelTrainer',
|
||||||
'CNNModel',
|
# # 'CNNModel',
|
||||||
'create_enhanced_cnn_model'
|
# 'create_enhanced_cnn_model'
|
||||||
]
|
# ]
|
||||||
|
@@ -23,11 +23,11 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class DQNNetwork(nn.Module):
|
class DQNNetwork(nn.Module):
|
||||||
"""
|
"""
|
||||||
Massive Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
|
Configurable Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
|
||||||
Handles 7850 input features from multi-timeframe, multi-asset data
|
Handles 7850 input features from multi-timeframe, multi-asset data
|
||||||
TARGET: 50M parameters for enhanced learning capacity
|
Architecture is configurable via config.yaml
|
||||||
"""
|
"""
|
||||||
def __init__(self, input_dim: int, n_actions: int):
|
def __init__(self, input_dim: int, n_actions: int, config: dict = None):
|
||||||
super(DQNNetwork, self).__init__()
|
super(DQNNetwork, self).__init__()
|
||||||
|
|
||||||
# Handle different input dimension formats
|
# Handle different input dimension formats
|
||||||
@@ -41,59 +41,65 @@ class DQNNetwork(nn.Module):
|
|||||||
|
|
||||||
self.n_actions = n_actions
|
self.n_actions = n_actions
|
||||||
|
|
||||||
# MASSIVE network architecture optimized for trading features
|
# Get network architecture from config or use defaults
|
||||||
# Target: ~50M parameters
|
if config and 'network_architecture' in config:
|
||||||
self.feature_extractor = nn.Sequential(
|
arch_config = config['network_architecture']
|
||||||
# Initial feature extraction with massive width
|
feature_layers = arch_config.get('feature_layers', [4096, 3072, 2048, 1536, 1024])
|
||||||
nn.Linear(self.input_size, 8192), # 7850 -> 8192 = ~64M weights
|
regime_head = arch_config.get('regime_head', [512, 256])
|
||||||
nn.LayerNorm(8192),
|
price_direction_head = arch_config.get('price_direction_head', [512, 256])
|
||||||
nn.ReLU(inplace=True),
|
volatility_head = arch_config.get('volatility_head', [512, 128])
|
||||||
nn.Dropout(0.1),
|
value_head = arch_config.get('value_head', [512, 256])
|
||||||
|
advantage_head = arch_config.get('advantage_head', [512, 256])
|
||||||
# Deep feature processing layers
|
dropout_rate = arch_config.get('dropout_rate', 0.1)
|
||||||
nn.Linear(8192, 6144), # 8192 -> 6144 = ~50M weights
|
use_layer_norm = arch_config.get('use_layer_norm', True)
|
||||||
nn.LayerNorm(6144),
|
else:
|
||||||
nn.ReLU(inplace=True),
|
# Default reduced architecture (half the original size)
|
||||||
nn.Dropout(0.1),
|
feature_layers = [4096, 3072, 2048, 1536, 1024]
|
||||||
|
regime_head = [512, 256]
|
||||||
nn.Linear(6144, 4096), # 6144 -> 4096 = ~25M weights
|
price_direction_head = [512, 256]
|
||||||
nn.LayerNorm(4096),
|
volatility_head = [512, 128]
|
||||||
nn.ReLU(inplace=True),
|
value_head = [512, 256]
|
||||||
nn.Dropout(0.1),
|
advantage_head = [512, 256]
|
||||||
|
dropout_rate = 0.1
|
||||||
nn.Linear(4096, 3072), # 4096 -> 3072 = ~12M weights
|
use_layer_norm = True
|
||||||
nn.LayerNorm(3072),
|
|
||||||
nn.ReLU(inplace=True),
|
# Build configurable feature extractor
|
||||||
nn.Dropout(0.1),
|
feature_layers_list = []
|
||||||
|
prev_size = self.input_size
|
||||||
nn.Linear(3072, 2048), # 3072 -> 2048 = ~6M weights
|
|
||||||
nn.LayerNorm(2048),
|
for layer_size in feature_layers:
|
||||||
nn.ReLU(inplace=True),
|
feature_layers_list.append(nn.Linear(prev_size, layer_size))
|
||||||
nn.Dropout(0.1),
|
if use_layer_norm:
|
||||||
)
|
feature_layers_list.append(nn.LayerNorm(layer_size))
|
||||||
|
feature_layers_list.append(nn.ReLU(inplace=True))
|
||||||
|
feature_layers_list.append(nn.Dropout(dropout_rate))
|
||||||
|
prev_size = layer_size
|
||||||
|
|
||||||
|
self.feature_extractor = nn.Sequential(*feature_layers_list)
|
||||||
|
self.feature_size = feature_layers[-1] # Final feature size
|
||||||
|
|
||||||
|
# Build configurable network heads
|
||||||
|
def build_head_layers(input_size, layer_sizes, output_size):
|
||||||
|
layers = []
|
||||||
|
prev_size = input_size
|
||||||
|
for layer_size in layer_sizes:
|
||||||
|
layers.append(nn.Linear(prev_size, layer_size))
|
||||||
|
if use_layer_norm:
|
||||||
|
layers.append(nn.LayerNorm(layer_size))
|
||||||
|
layers.append(nn.ReLU(inplace=True))
|
||||||
|
layers.append(nn.Dropout(dropout_rate))
|
||||||
|
prev_size = layer_size
|
||||||
|
layers.append(nn.Linear(prev_size, output_size))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
# Market regime detection head
|
# Market regime detection head
|
||||||
self.regime_head = nn.Sequential(
|
self.regime_head = build_head_layers(
|
||||||
nn.Linear(2048, 1024),
|
self.feature_size, regime_head, 4 # trending, ranging, volatile, mixed
|
||||||
nn.LayerNorm(1024),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.Linear(1024, 512),
|
|
||||||
nn.LayerNorm(512),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Linear(512, 4) # trending, ranging, volatile, mixed
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Price direction prediction head - outputs direction and confidence
|
# Price direction prediction head - outputs direction and confidence
|
||||||
self.price_direction_head = nn.Sequential(
|
self.price_direction_head = build_head_layers(
|
||||||
nn.Linear(2048, 1024),
|
self.feature_size, price_direction_head, 2 # [direction, confidence]
|
||||||
nn.LayerNorm(1024),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.Linear(1024, 512),
|
|
||||||
nn.LayerNorm(512),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Linear(512, 2) # [direction, confidence]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Direction activation (tanh for -1 to 1)
|
# Direction activation (tanh for -1 to 1)
|
||||||
@@ -102,38 +108,18 @@ class DQNNetwork(nn.Module):
|
|||||||
self.confidence_activation = nn.Sigmoid()
|
self.confidence_activation = nn.Sigmoid()
|
||||||
|
|
||||||
# Volatility prediction head
|
# Volatility prediction head
|
||||||
self.volatility_head = nn.Sequential(
|
self.volatility_head = build_head_layers(
|
||||||
nn.Linear(2048, 1024),
|
self.feature_size, volatility_head, 4 # predicted volatility for 4 timeframes
|
||||||
nn.LayerNorm(1024),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.Linear(1024, 256),
|
|
||||||
nn.LayerNorm(256),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Linear(256, 4) # predicted volatility for 4 timeframes
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Main Q-value head (dueling architecture)
|
# Main Q-value head (dueling architecture)
|
||||||
self.value_head = nn.Sequential(
|
self.value_head = build_head_layers(
|
||||||
nn.Linear(2048, 1024),
|
self.feature_size, value_head, 1 # Single value for dueling architecture
|
||||||
nn.LayerNorm(1024),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.Linear(1024, 512),
|
|
||||||
nn.LayerNorm(512),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Linear(512, 1) # State value
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.advantage_head = nn.Sequential(
|
# Advantage head (dueling architecture)
|
||||||
nn.Linear(2048, 1024),
|
self.advantage_head = build_head_layers(
|
||||||
nn.LayerNorm(1024),
|
self.feature_size, advantage_head, n_actions # Action advantages
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.Linear(1024, 512),
|
|
||||||
nn.LayerNorm(512),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Linear(512, n_actions) # Action advantages
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
@@ -248,7 +234,8 @@ class DQNAgent:
|
|||||||
priority_memory: bool = True,
|
priority_memory: bool = True,
|
||||||
device=None,
|
device=None,
|
||||||
model_name: str = "dqn_agent",
|
model_name: str = "dqn_agent",
|
||||||
enable_checkpoints: bool = True):
|
enable_checkpoints: bool = True,
|
||||||
|
config: dict = None):
|
||||||
|
|
||||||
# Checkpoint management
|
# Checkpoint management
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@@ -292,8 +279,8 @@ class DQNAgent:
|
|||||||
logger.info(f"DQN Agent using device: {self.device}")
|
logger.info(f"DQN Agent using device: {self.device}")
|
||||||
|
|
||||||
# Initialize models with RL-specific network architecture
|
# Initialize models with RL-specific network architecture
|
||||||
self.policy_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
|
self.policy_net = DQNNetwork(self.state_dim, self.n_actions, config).to(self.device)
|
||||||
self.target_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
|
self.target_net = DQNNetwork(self.state_dim, self.n_actions, config).to(self.device)
|
||||||
|
|
||||||
# Ensure models are on the correct device
|
# Ensure models are on the correct device
|
||||||
self.policy_net = self.policy_net.to(self.device)
|
self.policy_net = self.policy_net.to(self.device)
|
||||||
|
185
config.yaml
185
config.yaml
@@ -88,119 +88,14 @@ data:
|
|||||||
market_regime_detection: true
|
market_regime_detection: true
|
||||||
volatility_analysis: true
|
volatility_analysis: true
|
||||||
|
|
||||||
# Enhanced CNN Configuration
|
# Model configurations have been moved to models.yml for better organization
|
||||||
cnn:
|
# See models.yml for all model-specific settings including:
|
||||||
window_size: 20
|
# - CNN configuration
|
||||||
features: ["open", "high", "low", "close", "volume"]
|
# - RL/DQN configuration
|
||||||
timeframes: ["1m", "5m", "15m", "1h", "4h", "1d"]
|
# - Orchestrator settings
|
||||||
hidden_layers: [64, 128, 256]
|
# - Training configuration
|
||||||
dropout: 0.2
|
# - Enhanced training system
|
||||||
learning_rate: 0.001
|
# - Real-time RL COB trader
|
||||||
batch_size: 32
|
|
||||||
epochs: 100
|
|
||||||
confidence_threshold: 0.6
|
|
||||||
early_stopping_patience: 10
|
|
||||||
model_dir: "models/enhanced_cnn" # Ultra-fast scalping weights (500x leverage)
|
|
||||||
timeframe_importance:
|
|
||||||
"1s": 0.60 # Primary scalping signal
|
|
||||||
"1m": 0.20 # Short-term confirmation
|
|
||||||
"1h": 0.15 # Medium-term trend
|
|
||||||
"1d": 0.05 # Long-term direction (minimal)
|
|
||||||
|
|
||||||
# Enhanced RL Agent Configuration
|
|
||||||
rl:
|
|
||||||
state_size: 100 # Will be calculated dynamically based on features
|
|
||||||
action_space: 3 # BUY, HOLD, SELL
|
|
||||||
hidden_size: 256
|
|
||||||
epsilon: 1.0
|
|
||||||
epsilon_decay: 0.995
|
|
||||||
epsilon_min: 0.01
|
|
||||||
learning_rate: 0.0001
|
|
||||||
gamma: 0.99
|
|
||||||
memory_size: 10000
|
|
||||||
batch_size: 64
|
|
||||||
target_update_freq: 1000
|
|
||||||
buffer_size: 10000
|
|
||||||
model_dir: "models/enhanced_rl"
|
|
||||||
# Market regime adaptation
|
|
||||||
market_regime_weights:
|
|
||||||
trending: 1.2 # Higher confidence in trending markets
|
|
||||||
ranging: 0.8 # Lower confidence in ranging markets
|
|
||||||
volatile: 0.6 # Much lower confidence in volatile markets
|
|
||||||
# Prioritized experience replay
|
|
||||||
replay_alpha: 0.6 # Priority exponent
|
|
||||||
replay_beta: 0.4 # Importance sampling exponent
|
|
||||||
|
|
||||||
# Enhanced Orchestrator Settings
|
|
||||||
orchestrator:
|
|
||||||
# Model weights for decision combination
|
|
||||||
cnn_weight: 0.7 # Weight for CNN predictions
|
|
||||||
rl_weight: 0.3 # Weight for RL decisions
|
|
||||||
confidence_threshold: 0.45
|
|
||||||
confidence_threshold_close: 0.35
|
|
||||||
decision_frequency: 30
|
|
||||||
|
|
||||||
# Multi-symbol coordination
|
|
||||||
symbol_correlation_matrix:
|
|
||||||
"ETH/USDT-BTC/USDT": 0.85 # ETH-BTC correlation
|
|
||||||
|
|
||||||
# Perfect move marking
|
|
||||||
perfect_move_threshold: 0.02 # 2% price change to mark as significant
|
|
||||||
perfect_move_buffer_size: 10000
|
|
||||||
|
|
||||||
# RL evaluation settings
|
|
||||||
evaluation_delay: 3600 # Evaluate actions after 1 hour
|
|
||||||
reward_calculation:
|
|
||||||
success_multiplier: 10 # Reward for correct predictions
|
|
||||||
failure_penalty: 5 # Penalty for wrong predictions
|
|
||||||
confidence_scaling: true # Scale rewards by confidence
|
|
||||||
|
|
||||||
# Entry aggressiveness: 0.0 = very conservative (fewer, higher quality trades), 1.0 = very aggressive (more trades)
|
|
||||||
entry_aggressiveness: 0.5
|
|
||||||
# Exit aggressiveness: 0.0 = very conservative (let profits run), 1.0 = very aggressive (quick exits)
|
|
||||||
exit_aggressiveness: 0.5
|
|
||||||
|
|
||||||
# Decision Fusion Configuration
|
|
||||||
decision_fusion:
|
|
||||||
enabled: true # Use neural network decision fusion instead of programmatic
|
|
||||||
mode: "neural" # "neural" or "programmatic"
|
|
||||||
input_size: 128 # Size of input features for decision fusion network
|
|
||||||
hidden_size: 256 # Hidden layer size
|
|
||||||
history_length: 20 # Number of recent decisions to include
|
|
||||||
training_interval: 10 # Train decision fusion every 10 decisions in programmatic mode
|
|
||||||
learning_rate: 0.001 # Learning rate for decision fusion network
|
|
||||||
batch_size: 32 # Training batch size
|
|
||||||
min_samples_for_training: 20 # Lower threshold for faster training in programmatic mode
|
|
||||||
|
|
||||||
# Training Configuration
|
|
||||||
training:
|
|
||||||
learning_rate: 0.001
|
|
||||||
batch_size: 32
|
|
||||||
epochs: 100
|
|
||||||
validation_split: 0.2
|
|
||||||
early_stopping_patience: 10
|
|
||||||
|
|
||||||
# CNN specific training
|
|
||||||
cnn_training_interval: 3600 # Train CNN every hour (was 6 hours)
|
|
||||||
min_perfect_moves: 50 # Reduced from 200 for faster learning
|
|
||||||
|
|
||||||
# RL specific training
|
|
||||||
rl_training_interval: 300 # Train RL every 5 minutes (was 1 hour)
|
|
||||||
min_experiences: 50 # Reduced from 100 for faster learning
|
|
||||||
training_steps_per_cycle: 20 # Increased from 10 for more learning
|
|
||||||
|
|
||||||
model_type: "optimized_short_term"
|
|
||||||
use_realtime: true
|
|
||||||
use_ticks: true
|
|
||||||
checkpoint_dir: "NN/models/saved/realtime_ticks_checkpoints"
|
|
||||||
save_best_model: true
|
|
||||||
save_final_model: false # We only want to keep the best performing model
|
|
||||||
|
|
||||||
# Continuous learning settings
|
|
||||||
continuous_learning: true
|
|
||||||
learning_from_trades: true
|
|
||||||
pattern_recognition: true
|
|
||||||
retrospective_learning: true
|
|
||||||
|
|
||||||
# Universal Trading Configuration (applies to all exchanges)
|
# Universal Trading Configuration (applies to all exchanges)
|
||||||
trading:
|
trading:
|
||||||
@@ -227,69 +122,7 @@ memory:
|
|||||||
model_limit_gb: 4.0 # Per-model memory limit
|
model_limit_gb: 4.0 # Per-model memory limit
|
||||||
cleanup_interval: 1800 # Memory cleanup every 30 minutes
|
cleanup_interval: 1800 # Memory cleanup every 30 minutes
|
||||||
|
|
||||||
# Enhanced Training System Configuration
|
# Enhanced training and real-time RL configurations moved to models.yml
|
||||||
enhanced_training:
|
|
||||||
enabled: true # Enable enhanced real-time training
|
|
||||||
auto_start: true # Automatically start training when orchestrator starts
|
|
||||||
training_intervals:
|
|
||||||
cob_rl_training_interval: 1 # Train COB RL every 1 second (HIGHEST PRIORITY)
|
|
||||||
dqn_training_interval: 5 # Train DQN every 5 seconds
|
|
||||||
cnn_training_interval: 10 # Train CNN every 10 seconds
|
|
||||||
validation_interval: 60 # Validate every minute
|
|
||||||
batch_size: 64 # Training batch size
|
|
||||||
memory_size: 10000 # Experience buffer size
|
|
||||||
min_training_samples: 100 # Minimum samples before training starts
|
|
||||||
adaptation_threshold: 0.1 # Performance threshold for adaptation
|
|
||||||
forward_looking_predictions: true # Enable forward-looking prediction validation
|
|
||||||
|
|
||||||
# COB RL Priority Settings (since order book imbalance predicts price moves)
|
|
||||||
cob_rl_priority: true # Enable COB RL as highest priority model
|
|
||||||
cob_rl_batch_size: 16 # Smaller batches for faster COB updates
|
|
||||||
cob_rl_min_samples: 5 # Lower threshold for COB training
|
|
||||||
|
|
||||||
# Real-time RL COB Trader Configuration
|
|
||||||
realtime_rl:
|
|
||||||
# Model parameters for 400M parameter network (faster startup)
|
|
||||||
model:
|
|
||||||
input_size: 2000 # COB feature dimensions
|
|
||||||
hidden_size: 2048 # Optimized hidden layer size for 400M params
|
|
||||||
num_layers: 8 # Efficient transformer layers for faster training
|
|
||||||
learning_rate: 0.0001 # Higher learning rate for faster convergence
|
|
||||||
weight_decay: 0.00001 # Balanced L2 regularization
|
|
||||||
|
|
||||||
# Inference configuration
|
|
||||||
inference_interval_ms: 200 # Inference every 200ms
|
|
||||||
min_confidence_threshold: 0.7 # Minimum confidence for signal accumulation
|
|
||||||
required_confident_predictions: 3 # Need 3 confident predictions for trade
|
|
||||||
|
|
||||||
# Training configuration
|
|
||||||
training_interval_s: 1.0 # Train every second
|
|
||||||
batch_size: 32 # Training batch size
|
|
||||||
replay_buffer_size: 1000 # Store last 1000 predictions for training
|
|
||||||
|
|
||||||
# Signal accumulation
|
|
||||||
signal_buffer_size: 10 # Buffer size for signal accumulation
|
|
||||||
consensus_threshold: 3 # Need 3 signals in same direction
|
|
||||||
|
|
||||||
# Model checkpointing
|
|
||||||
model_checkpoint_dir: "models/realtime_rl_cob"
|
|
||||||
save_interval_s: 300 # Save models every 5 minutes
|
|
||||||
|
|
||||||
# COB integration
|
|
||||||
symbols: ["BTC/USDT", "ETH/USDT"] # Symbols to trade
|
|
||||||
cob_feature_normalization: "robust" # Feature normalization method
|
|
||||||
|
|
||||||
# Reward engineering for RL
|
|
||||||
reward_structure:
|
|
||||||
correct_direction_base: 1.0 # Base reward for correct prediction
|
|
||||||
confidence_scaling: true # Scale reward by confidence
|
|
||||||
magnitude_bonus: 0.5 # Bonus for predicting magnitude accurately
|
|
||||||
overconfidence_penalty: 1.5 # Penalty multiplier for wrong high-confidence predictions
|
|
||||||
trade_execution_multiplier: 10.0 # Higher weight for actual trade outcomes
|
|
||||||
|
|
||||||
# Performance monitoring
|
|
||||||
statistics_interval_s: 60 # Print stats every minute
|
|
||||||
detailed_logging: true # Enable detailed performance logging
|
|
||||||
|
|
||||||
# Web Dashboard
|
# Web Dashboard
|
||||||
web:
|
web:
|
||||||
|
@@ -24,16 +24,31 @@ class Config:
|
|||||||
self._setup_directories()
|
self._setup_directories()
|
||||||
|
|
||||||
def _load_config(self) -> Dict[str, Any]:
|
def _load_config(self) -> Dict[str, Any]:
|
||||||
"""Load configuration from YAML file"""
|
"""Load configuration from YAML files (config.yaml + models.yml)"""
|
||||||
try:
|
try:
|
||||||
|
# Load main config
|
||||||
if not self.config_path.exists():
|
if not self.config_path.exists():
|
||||||
logger.warning(f"Config file {self.config_path} not found, using defaults")
|
logger.warning(f"Config file {self.config_path} not found, using defaults")
|
||||||
return self._get_default_config()
|
config = self._get_default_config()
|
||||||
|
else:
|
||||||
with open(self.config_path, 'r') as f:
|
with open(self.config_path, 'r') as f:
|
||||||
config = yaml.safe_load(f)
|
config = yaml.safe_load(f)
|
||||||
|
logger.info(f"Loaded main configuration from {self.config_path}")
|
||||||
logger.info(f"Loaded configuration from {self.config_path}")
|
|
||||||
|
# Load models config
|
||||||
|
models_config_path = Path("models.yml")
|
||||||
|
if models_config_path.exists():
|
||||||
|
try:
|
||||||
|
with open(models_config_path, 'r') as f:
|
||||||
|
models_config = yaml.safe_load(f)
|
||||||
|
# Merge models config into main config
|
||||||
|
config.update(models_config)
|
||||||
|
logger.info(f"Loaded models configuration from {models_config_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error loading models.yml: {e}, using main config only")
|
||||||
|
else:
|
||||||
|
logger.info("models.yml not found, using main config only")
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@@ -605,7 +605,9 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
action_size = self.config.rl.get("action_space", 3)
|
action_size = self.config.rl.get("action_space", 3)
|
||||||
self.rl_agent = DQNAgent(
|
self.rl_agent = DQNAgent(
|
||||||
state_shape=actual_state_size, n_actions=action_size
|
state_shape=actual_state_size,
|
||||||
|
n_actions=action_size,
|
||||||
|
config=self.config.rl
|
||||||
)
|
)
|
||||||
self.rl_agent.to(self.device) # Move DQN agent to the determined device
|
self.rl_agent.to(self.device) # Move DQN agent to the determined device
|
||||||
|
|
||||||
@@ -2182,7 +2184,7 @@ class TradingOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Clean up memory periodically
|
# Clean up memory periodically
|
||||||
if len(self.recent_decisions[symbol]) % 200 == 0: # Reduced from 50 to 200
|
if len(self.recent_decisions[symbol]) % 20 == 0: # Reduced from 50 to 20
|
||||||
self.model_registry.cleanup_all_models()
|
self.model_registry.cleanup_all_models()
|
||||||
|
|
||||||
return decision
|
return decision
|
||||||
@@ -2196,55 +2198,108 @@ class TradingOrchestrator:
|
|||||||
):
|
):
|
||||||
"""Add training samples to models based on current predictions and market conditions"""
|
"""Add training samples to models based on current predictions and market conditions"""
|
||||||
try:
|
try:
|
||||||
if not hasattr(self, "cnn_adapter") or not self.cnn_adapter:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get recent price data to evaluate if predictions would be correct
|
# Get recent price data to evaluate if predictions would be correct
|
||||||
recent_prices = self.data_provider.get_recent_prices(symbol, limit=10)
|
# Use available methods from data provider
|
||||||
if not recent_prices or len(recent_prices) < 2:
|
try:
|
||||||
return
|
# Try to get recent prices using get_price_at_index
|
||||||
|
recent_prices = []
|
||||||
|
for i in range(10):
|
||||||
|
price = self.data_provider.get_price_at_index(symbol, i, '1m')
|
||||||
|
if price is not None:
|
||||||
|
recent_prices.append(price)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if len(recent_prices) < 2:
|
||||||
|
# Fallback: use current price and a small assumed change
|
||||||
|
price_change_pct = 0.1 # Assume small positive change
|
||||||
|
else:
|
||||||
|
# Calculate recent price change
|
||||||
|
price_change_pct = (
|
||||||
|
(current_price - recent_prices[-2]) / recent_prices[-2] * 100
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not get recent prices for {symbol}: {e}")
|
||||||
|
# Fallback: use current price and a small assumed change
|
||||||
|
price_change_pct = 0.1 # Assume small positive change
|
||||||
|
|
||||||
# Calculate recent price change
|
# Get current position P&L for sophisticated reward calculation
|
||||||
price_change_pct = (
|
current_position_pnl = self._get_current_position_pnl(symbol)
|
||||||
(current_price - recent_prices[-2]) / recent_prices[-2] * 100
|
has_position = self._has_open_position(symbol)
|
||||||
)
|
|
||||||
|
|
||||||
# Add training samples for CNN predictions
|
# Add training samples for CNN predictions using sophisticated reward system
|
||||||
for prediction in predictions:
|
for prediction in predictions:
|
||||||
if "cnn" in prediction.model_name.lower():
|
if "cnn" in prediction.model_name.lower():
|
||||||
# Determine reward based on prediction accuracy
|
# Calculate sophisticated reward using the new PnL penalty/reward system
|
||||||
reward = 0.0
|
sophisticated_reward, was_correct = self._calculate_sophisticated_reward(
|
||||||
|
predicted_action=prediction.action,
|
||||||
if prediction.action == "BUY" and price_change_pct > 0.1:
|
prediction_confidence=prediction.confidence,
|
||||||
reward = min(
|
price_change_pct=price_change_pct,
|
||||||
price_change_pct * 0.1, 1.0
|
time_diff_minutes=1.0, # Assume 1 minute for now
|
||||||
) # Positive reward for correct BUY
|
has_price_prediction=False,
|
||||||
elif prediction.action == "SELL" and price_change_pct < -0.1:
|
symbol=symbol,
|
||||||
reward = min(
|
has_position=has_position,
|
||||||
abs(price_change_pct) * 0.1, 1.0
|
current_position_pnl=current_position_pnl
|
||||||
) # Positive reward for correct SELL
|
|
||||||
elif prediction.action == "HOLD" and abs(price_change_pct) < 0.1:
|
|
||||||
reward = 0.1 # Small positive reward for correct HOLD
|
|
||||||
else:
|
|
||||||
reward = -0.05 # Small negative reward for incorrect prediction
|
|
||||||
|
|
||||||
# Add training sample
|
|
||||||
self.cnn_adapter.add_training_sample(
|
|
||||||
symbol, prediction.action, reward
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Added CNN training sample: {prediction.action}, reward={reward:.3f}, price_change={price_change_pct:.2f}%"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trigger training if we have enough samples
|
# Create training record for the new training system
|
||||||
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
|
training_record = {
|
||||||
training_results = self.cnn_adapter.train(epochs=1)
|
"symbol": symbol,
|
||||||
logger.info(
|
"model_name": prediction.model_name,
|
||||||
f"CNN training completed: loss={training_results.get('loss', 0):.4f}, accuracy={training_results.get('accuracy', 0):.4f}"
|
"action": prediction.action,
|
||||||
)
|
"confidence": prediction.confidence,
|
||||||
|
"timestamp": prediction.timestamp,
|
||||||
|
"current_price": current_price,
|
||||||
|
"price_change_pct": price_change_pct,
|
||||||
|
"was_correct": was_correct,
|
||||||
|
"sophisticated_reward": sophisticated_reward,
|
||||||
|
"current_position_pnl": current_position_pnl,
|
||||||
|
"has_position": has_position
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use the new training system instead of old cnn_adapter
|
||||||
|
if hasattr(self, "cnn_model") and self.cnn_model:
|
||||||
|
# Train CNN model directly using the new system
|
||||||
|
training_success = await self._train_cnn_model(
|
||||||
|
model=self.cnn_model,
|
||||||
|
model_name=prediction.model_name,
|
||||||
|
record=training_record,
|
||||||
|
prediction={"action": prediction.action, "confidence": prediction.confidence},
|
||||||
|
reward=sophisticated_reward
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_success:
|
||||||
|
logger.debug(
|
||||||
|
f"CNN training completed: action={prediction.action}, reward={sophisticated_reward:.3f}, "
|
||||||
|
f"price_change={price_change_pct:.2f}%, was_correct={was_correct}, "
|
||||||
|
f"position_pnl={current_position_pnl:.2f}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"CNN training failed for {prediction.model_name}")
|
||||||
|
|
||||||
|
# Also try training through model registry if available
|
||||||
|
elif self.model_registry and prediction.model_name in self.model_registry.models:
|
||||||
|
model = self.model_registry.models[prediction.model_name]
|
||||||
|
training_success = await self._train_cnn_model(
|
||||||
|
model=model,
|
||||||
|
model_name=prediction.model_name,
|
||||||
|
record=training_record,
|
||||||
|
prediction={"action": prediction.action, "confidence": prediction.confidence},
|
||||||
|
reward=sophisticated_reward
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_success:
|
||||||
|
logger.debug(
|
||||||
|
f"CNN training via registry completed: {prediction.model_name}, "
|
||||||
|
f"reward={sophisticated_reward:.3f}, was_correct={was_correct}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"CNN training via registry failed for {prediction.model_name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error adding training samples from predictions: {e}")
|
logger.error(f"Error adding training samples from predictions: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
||||||
"""Get predictions from all registered models with input data storage"""
|
"""Get predictions from all registered models with input data storage"""
|
||||||
|
@@ -9,15 +9,21 @@
|
|||||||
"training_enabled": true
|
"training_enabled": true
|
||||||
},
|
},
|
||||||
"cob_rl": {
|
"cob_rl": {
|
||||||
"inference_enabled": true,
|
"inference_enabled": false,
|
||||||
"training_enabled": true
|
"training_enabled": true
|
||||||
},
|
},
|
||||||
"decision_fusion": {
|
"decision_fusion": {
|
||||||
"inference_enabled": false,
|
"inference_enabled": false,
|
||||||
"training_enabled": false
|
"training_enabled": false
|
||||||
|
},
|
||||||
|
"transformer": {
|
||||||
|
"inference_enabled": false,
|
||||||
|
"training_enabled": true
|
||||||
|
},
|
||||||
|
"dqn_agent": {
|
||||||
|
"inference_enabled": false,
|
||||||
|
"training_enabled": true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
},
|
},
|
||||||
"timestamp": "2025-07-29T15:55:43.690404"
|
"timestamp": "2025-07-29T19:17:32.971226"
|
||||||
}
|
}
|
198
models.yml
Normal file
198
models.yml
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
# Model Configurations
|
||||||
|
# This file contains all model-specific configurations to keep the main config.yaml clean
|
||||||
|
|
||||||
|
# Enhanced CNN Configuration (cnn model do not use yml config. do not change this)
|
||||||
|
# cnn:
|
||||||
|
# window_size: 20
|
||||||
|
# features: ["open", "high", "low", "close", "volume"]
|
||||||
|
# timeframes: ["1s", "1m", "1h", "1d"]
|
||||||
|
# hidden_layers: [64, 128, 256]
|
||||||
|
# dropout: 0.2
|
||||||
|
# learning_rate: 0.001
|
||||||
|
# batch_size: 32
|
||||||
|
# epochs: 100
|
||||||
|
# confidence_threshold: 0.6
|
||||||
|
# early_stopping_patience: 10
|
||||||
|
# model_dir: "models/enhanced_cnn" # Ultra-fast scalping weights (500x leverage)
|
||||||
|
# timeframe_importance:
|
||||||
|
# "1s": 0.60 # Primary scalping signal
|
||||||
|
# "1m": 0.20 # Short-term confirmation
|
||||||
|
# "1h": 0.15 # Medium-term trend
|
||||||
|
# "1d": 0.05 # Long-term direction (minimal)
|
||||||
|
|
||||||
|
# Enhanced RL Agent Configuration
|
||||||
|
rl:
|
||||||
|
state_size: 100 # Will be calculated dynamically based on features
|
||||||
|
action_space: 3 # BUY, HOLD, SELL
|
||||||
|
hidden_size: 256
|
||||||
|
epsilon: 1.0
|
||||||
|
epsilon_decay: 0.995
|
||||||
|
epsilon_min: 0.01
|
||||||
|
learning_rate: 0.0001
|
||||||
|
gamma: 0.99
|
||||||
|
memory_size: 10000
|
||||||
|
batch_size: 64
|
||||||
|
target_update_freq: 1000
|
||||||
|
buffer_size: 10000
|
||||||
|
model_dir: "models/enhanced_rl"
|
||||||
|
|
||||||
|
# DQN Network Architecture Configuration
|
||||||
|
network_architecture:
|
||||||
|
# Feature extractor layers (reduced by half from original)
|
||||||
|
feature_layers: [4096, 3072, 2048, 1536, 1024] # Reduced from [8192, 6144, 4096, 3072, 2048]
|
||||||
|
# Market regime detection head
|
||||||
|
regime_head: [512, 256] # Reduced from [1024, 512]
|
||||||
|
# Price direction prediction head
|
||||||
|
price_direction_head: [512, 256] # Reduced from [1024, 512]
|
||||||
|
# Volatility prediction head
|
||||||
|
volatility_head: [512, 128] # Reduced from [1024, 256]
|
||||||
|
# Main Q-value head (dueling architecture)
|
||||||
|
value_head: [512, 256] # Reduced from [1024, 512]
|
||||||
|
advantage_head: [512, 256] # Reduced from [1024, 512]
|
||||||
|
# Dropout rate
|
||||||
|
dropout_rate: 0.1
|
||||||
|
# Layer normalization
|
||||||
|
use_layer_norm: true
|
||||||
|
|
||||||
|
# Market regime adaptation
|
||||||
|
market_regime_weights:
|
||||||
|
trending: 1.2 # Higher confidence in trending markets
|
||||||
|
ranging: 0.8 # Lower confidence in ranging markets
|
||||||
|
volatile: 0.6 # Much lower confidence in volatile markets
|
||||||
|
# Prioritized experience replay
|
||||||
|
replay_alpha: 0.6 # Priority exponent
|
||||||
|
replay_beta: 0.4 # Importance sampling exponent
|
||||||
|
|
||||||
|
# Real-time RL COB Trader Configuration
|
||||||
|
realtime_rl:
|
||||||
|
# Model parameters for 400M parameter network (faster startup)
|
||||||
|
model:
|
||||||
|
input_size: 2000 # COB feature dimensions
|
||||||
|
hidden_size: 2048 # Optimized hidden layer size for 400M params
|
||||||
|
num_layers: 8 # Efficient transformer layers for faster training
|
||||||
|
learning_rate: 0.0001 # Higher learning rate for faster convergence
|
||||||
|
weight_decay: 0.00001 # Balanced L2 regularization
|
||||||
|
|
||||||
|
# Inference configuration
|
||||||
|
inference_interval_ms: 200 # Inference every 200ms
|
||||||
|
min_confidence_threshold: 0.7 # Minimum confidence for signal accumulation
|
||||||
|
required_confident_predictions: 3 # Need 3 confident predictions for trade
|
||||||
|
|
||||||
|
# Training configuration
|
||||||
|
training_interval_s: 1.0 # Train every second
|
||||||
|
batch_size: 32 # Training batch size
|
||||||
|
replay_buffer_size: 1000 # Store last 1000 predictions for training
|
||||||
|
|
||||||
|
# Signal accumulation
|
||||||
|
signal_buffer_size: 10 # Buffer size for signal accumulation
|
||||||
|
consensus_threshold: 3 # Need 3 signals in same direction
|
||||||
|
|
||||||
|
# Model checkpointing
|
||||||
|
model_checkpoint_dir: "models/realtime_rl_cob"
|
||||||
|
save_interval_s: 300 # Save models every 5 minutes
|
||||||
|
|
||||||
|
# COB integration
|
||||||
|
symbols: ["BTC/USDT", "ETH/USDT"] # Symbols to trade
|
||||||
|
cob_feature_normalization: "robust" # Feature normalization method
|
||||||
|
|
||||||
|
# Reward engineering for RL
|
||||||
|
reward_structure:
|
||||||
|
correct_direction_base: 1.0 # Base reward for correct prediction
|
||||||
|
confidence_scaling: true # Scale reward by confidence
|
||||||
|
magnitude_bonus: 0.5 # Bonus for predicting magnitude accurately
|
||||||
|
overconfidence_penalty: 1.5 # Penalty multiplier for wrong high-confidence predictions
|
||||||
|
trade_execution_multiplier: 10.0 # Higher weight for actual trade outcomes
|
||||||
|
|
||||||
|
# Performance monitoring
|
||||||
|
statistics_interval_s: 60 # Print stats every minute
|
||||||
|
detailed_logging: true # Enable detailed performance logging
|
||||||
|
|
||||||
|
# Enhanced Orchestrator Settings
|
||||||
|
orchestrator:
|
||||||
|
# Model weights for decision combination
|
||||||
|
cnn_weight: 0.7 # Weight for CNN predictions
|
||||||
|
rl_weight: 0.3 # Weight for RL decisions
|
||||||
|
confidence_threshold: 0.45
|
||||||
|
confidence_threshold_close: 0.35
|
||||||
|
decision_frequency: 30
|
||||||
|
|
||||||
|
# Multi-symbol coordination
|
||||||
|
symbol_correlation_matrix:
|
||||||
|
"ETH/USDT-BTC/USDT": 0.85 # ETH-BTC correlation
|
||||||
|
|
||||||
|
# Perfect move marking
|
||||||
|
perfect_move_threshold: 0.02 # 2% price change to mark as significant
|
||||||
|
perfect_move_buffer_size: 10000
|
||||||
|
|
||||||
|
# RL evaluation settings
|
||||||
|
evaluation_delay: 3600 # Evaluate actions after 1 hour
|
||||||
|
reward_calculation:
|
||||||
|
success_multiplier: 10 # Reward for correct predictions
|
||||||
|
failure_penalty: 5 # Penalty for wrong predictions
|
||||||
|
confidence_scaling: true # Scale rewards by confidence
|
||||||
|
|
||||||
|
# Entry aggressiveness: 0.0 = very conservative (fewer, higher quality trades), 1.0 = very aggressive (more trades)
|
||||||
|
entry_aggressiveness: 0.5
|
||||||
|
# Exit aggressiveness: 0.0 = very conservative (let profits run), 1.0 = very aggressive (quick exits)
|
||||||
|
exit_aggressiveness: 0.5
|
||||||
|
|
||||||
|
# Decision Fusion Configuration
|
||||||
|
decision_fusion:
|
||||||
|
enabled: true # Use neural network decision fusion instead of programmatic
|
||||||
|
mode: "neural" # "neural" or "programmatic"
|
||||||
|
input_size: 128 # Size of input features for decision fusion network
|
||||||
|
hidden_size: 256 # Hidden layer size
|
||||||
|
history_length: 20 # Number of recent decisions to include
|
||||||
|
training_interval: 10 # Train decision fusion every 10 decisions in programmatic mode
|
||||||
|
learning_rate: 0.001 # Learning rate for decision fusion network
|
||||||
|
batch_size: 32 # Training batch size
|
||||||
|
min_samples_for_training: 20 # Lower threshold for faster training in programmatic mode
|
||||||
|
|
||||||
|
# Training Configuration
|
||||||
|
training:
|
||||||
|
learning_rate: 0.001
|
||||||
|
batch_size: 32
|
||||||
|
epochs: 100
|
||||||
|
validation_split: 0.2
|
||||||
|
early_stopping_patience: 10
|
||||||
|
|
||||||
|
# CNN specific training
|
||||||
|
cnn_training_interval: 3600 # Train CNN every hour (was 6 hours)
|
||||||
|
min_perfect_moves: 50 # Reduced from 200 for faster learning
|
||||||
|
|
||||||
|
# RL specific training
|
||||||
|
rl_training_interval: 300 # Train RL every 5 minutes (was 1 hour)
|
||||||
|
min_experiences: 50 # Reduced from 100 for faster learning
|
||||||
|
training_steps_per_cycle: 20 # Increased from 10 for more learning
|
||||||
|
|
||||||
|
model_type: "optimized_short_term"
|
||||||
|
use_realtime: true
|
||||||
|
use_ticks: true
|
||||||
|
checkpoint_dir: "NN/models/saved/realtime_ticks_checkpoints"
|
||||||
|
save_best_model: true
|
||||||
|
save_final_model: false # We only want to keep the best performing model
|
||||||
|
|
||||||
|
# Continuous learning settings
|
||||||
|
continuous_learning: true
|
||||||
|
adaptive_learning_rate: true
|
||||||
|
performance_threshold: 0.6
|
||||||
|
|
||||||
|
# Enhanced Training System Configuration
|
||||||
|
enhanced_training:
|
||||||
|
enabled: true # Enable enhanced real-time training
|
||||||
|
auto_start: true # Automatically start training when orchestrator starts
|
||||||
|
training_intervals:
|
||||||
|
cob_rl_training_interval: 1 # Train COB RL every 1 second (HIGHEST PRIORITY)
|
||||||
|
dqn_training_interval: 5 # Train DQN every 5 seconds
|
||||||
|
cnn_training_interval: 10 # Train CNN every 10 seconds
|
||||||
|
validation_interval: 60 # Validate every minute
|
||||||
|
batch_size: 64 # Training batch size
|
||||||
|
memory_size: 10000 # Experience buffer size
|
||||||
|
min_training_samples: 100 # Minimum samples before training starts
|
||||||
|
adaptation_threshold: 0.1 # Performance threshold for adaptation
|
||||||
|
forward_looking_predictions: true # Enable forward-looking prediction validation
|
||||||
|
|
||||||
|
# COB RL Priority Settings (since order book imbalance predicts price moves)
|
||||||
|
cob_rl_priority: true # Enable COB RL as highest priority model
|
||||||
|
cob_rl_batch_size: 16 # Smaller batches for faster COB updates
|
||||||
|
cob_rl_min_samples: 5 # Lower threshold for COB training
|
@@ -328,6 +328,7 @@ class CleanTradingDashboard:
|
|||||||
'https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css',
|
'https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css',
|
||||||
'https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css'
|
'https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css'
|
||||||
])
|
])
|
||||||
|
#, suppress_callback_exceptions=True)
|
||||||
|
|
||||||
# Suppress Dash development mode logging
|
# Suppress Dash development mode logging
|
||||||
self.app.enable_dev_tools(debug=False, dev_tools_silence_routes_logging=True)
|
self.app.enable_dev_tools(debug=False, dev_tools_silence_routes_logging=True)
|
||||||
@@ -864,14 +865,32 @@ class CleanTradingDashboard:
|
|||||||
available_models[model_name] = {'name': model_name, 'type': 'unknown'}
|
available_models[model_name] = {'name': model_name, 'type': 'unknown'}
|
||||||
logger.debug(f"Found {len(toggle_models)} models in toggle states")
|
logger.debug(f"Found {len(toggle_models)} models in toggle states")
|
||||||
|
|
||||||
|
# Apply model name mapping to match orchestrator's internal mapping
|
||||||
|
# This ensures component IDs match what the orchestrator expects
|
||||||
|
mapped_models = {}
|
||||||
|
model_mapping = {
|
||||||
|
'dqn_agent': 'dqn',
|
||||||
|
'enhanced_cnn': 'cnn',
|
||||||
|
'extrema_trainer': 'extrema_trainer',
|
||||||
|
'decision': 'decision_fusion',
|
||||||
|
'cob_rl': 'cob_rl',
|
||||||
|
'transformer': 'transformer'
|
||||||
|
}
|
||||||
|
|
||||||
|
for model_name, model_info in available_models.items():
|
||||||
|
# Use mapped name if available, otherwise use original name
|
||||||
|
mapped_name = model_mapping.get(model_name, model_name)
|
||||||
|
mapped_models[mapped_name] = model_info
|
||||||
|
logger.debug(f"Mapped model name: {model_name} -> {mapped_name}")
|
||||||
|
|
||||||
# Fallback: Add known models if none found
|
# Fallback: Add known models if none found
|
||||||
if not available_models:
|
if not mapped_models:
|
||||||
fallback_models = ['dqn', 'cnn', 'cob_rl', 'decision_fusion', 'transformer']
|
fallback_models = ['dqn', 'cnn', 'cob_rl', 'decision_fusion', 'transformer']
|
||||||
for model_name in fallback_models:
|
for model_name in fallback_models:
|
||||||
available_models[model_name] = {'name': model_name, 'type': 'fallback'}
|
mapped_models[model_name] = {'name': model_name, 'type': 'fallback'}
|
||||||
logger.warning(f"Using fallback models: {fallback_models}")
|
logger.warning(f"Using fallback models: {fallback_models}")
|
||||||
|
|
||||||
return available_models
|
return mapped_models
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting available models: {e}")
|
logger.error(f"Error getting available models: {e}")
|
||||||
@@ -916,13 +935,25 @@ class CleanTradingDashboard:
|
|||||||
enabled = bool(value and len(value) > 0) # Convert list to boolean
|
enabled = bool(value and len(value) > 0) # Convert list to boolean
|
||||||
|
|
||||||
if self.orchestrator:
|
if self.orchestrator:
|
||||||
|
# Map component model name back to orchestrator's expected model name
|
||||||
|
reverse_mapping = {
|
||||||
|
'dqn': 'dqn_agent',
|
||||||
|
'cnn': 'enhanced_cnn',
|
||||||
|
'decision_fusion': 'decision',
|
||||||
|
'extrema_trainer': 'extrema_trainer',
|
||||||
|
'cob_rl': 'cob_rl',
|
||||||
|
'transformer': 'transformer'
|
||||||
|
}
|
||||||
|
|
||||||
|
orchestrator_model_name = reverse_mapping.get(model_name, model_name)
|
||||||
|
|
||||||
# Update orchestrator toggle state
|
# Update orchestrator toggle state
|
||||||
if toggle_type == 'inference':
|
if toggle_type == 'inference':
|
||||||
self.orchestrator.set_model_toggle_state(model_name, inference_enabled=enabled)
|
self.orchestrator.set_model_toggle_state(orchestrator_model_name, inference_enabled=enabled)
|
||||||
elif toggle_type == 'training':
|
elif toggle_type == 'training':
|
||||||
self.orchestrator.set_model_toggle_state(model_name, training_enabled=enabled)
|
self.orchestrator.set_model_toggle_state(orchestrator_model_name, training_enabled=enabled)
|
||||||
|
|
||||||
logger.info(f"Model {model_name} {toggle_type} toggle: {enabled}")
|
logger.info(f"Model {model_name} ({orchestrator_model_name}) {toggle_type} toggle: {enabled}")
|
||||||
|
|
||||||
# Update dashboard state variables for backward compatibility
|
# Update dashboard state variables for backward compatibility
|
||||||
self._update_dashboard_state_variable(model_name, toggle_type, enabled)
|
self._update_dashboard_state_variable(model_name, toggle_type, enabled)
|
||||||
@@ -1333,18 +1364,25 @@ class CleanTradingDashboard:
|
|||||||
error_msg = html.P(f"COB Error: {str(e)}", className="text-danger small")
|
error_msg = html.P(f"COB Error: {str(e)}", className="text-danger small")
|
||||||
return error_msg, error_msg
|
return error_msg, error_msg
|
||||||
|
|
||||||
|
# Original training metrics callback - temporarily disabled for testing
|
||||||
|
# @self.app.callback(
|
||||||
|
# Output('training-metrics', 'children'),
|
||||||
@self.app.callback(
|
@self.app.callback(
|
||||||
Output('training-metrics', 'children'),
|
Output('training-metrics', 'children'),
|
||||||
[Input('slow-interval-component', 'n_intervals')] # OPTIMIZED: Move to 10s interval
|
[Input('slow-interval-component', 'n_intervals'),
|
||||||
|
Input('fast-interval-component', 'n_intervals'), # Add fast interval for testing
|
||||||
|
Input('refresh-training-metrics-btn', 'n_clicks')] # Add manual refresh button
|
||||||
)
|
)
|
||||||
def update_training_metrics(n):
|
def update_training_metrics(slow_intervals, fast_intervals, n_clicks):
|
||||||
"""Update training metrics"""
|
"""Update training metrics"""
|
||||||
|
logger.info(f"update_training_metrics callback triggered with slow_intervals={slow_intervals}, fast_intervals={fast_intervals}, n_clicks={n_clicks}")
|
||||||
try:
|
try:
|
||||||
# Get toggle states from orchestrator
|
# Get toggle states from orchestrator
|
||||||
toggle_states = {}
|
toggle_states = {}
|
||||||
if self.orchestrator:
|
if self.orchestrator:
|
||||||
# Get all available models dynamically
|
# Get all available models dynamically
|
||||||
available_models = self._get_available_models()
|
available_models = self._get_available_models()
|
||||||
|
logger.info(f"Available models: {list(available_models.keys())}")
|
||||||
for model_name in available_models.keys():
|
for model_name in available_models.keys():
|
||||||
toggle_states[model_name] = self.orchestrator.get_model_toggle_state(model_name)
|
toggle_states[model_name] = self.orchestrator.get_model_toggle_state(model_name)
|
||||||
else:
|
else:
|
||||||
@@ -1354,24 +1392,48 @@ class CleanTradingDashboard:
|
|||||||
toggle_states[model_name] = state
|
toggle_states[model_name] = state
|
||||||
# Now using slow-interval-component (10s) - no batching needed
|
# Now using slow-interval-component (10s) - no batching needed
|
||||||
|
|
||||||
|
logger.info(f"Getting training metrics with toggle_states: {toggle_states}")
|
||||||
metrics_data = self._get_training_metrics(toggle_states)
|
metrics_data = self._get_training_metrics(toggle_states)
|
||||||
logger.debug(f"update_training_metrics callback: got metrics_data type={type(metrics_data)}")
|
logger.info(f"update_training_metrics callback: got metrics_data type={type(metrics_data)}")
|
||||||
if metrics_data and isinstance(metrics_data, dict):
|
if metrics_data and isinstance(metrics_data, dict):
|
||||||
logger.debug(f"Metrics data keys: {list(metrics_data.keys())}")
|
logger.info(f"Metrics data keys: {list(metrics_data.keys())}")
|
||||||
if 'loaded_models' in metrics_data:
|
if 'loaded_models' in metrics_data:
|
||||||
logger.debug(f"Loaded models count: {len(metrics_data['loaded_models'])}")
|
logger.info(f"Loaded models count: {len(metrics_data['loaded_models'])}")
|
||||||
logger.debug(f"Loaded model names: {list(metrics_data['loaded_models'].keys())}")
|
logger.info(f"Loaded model names: {list(metrics_data['loaded_models'].keys())}")
|
||||||
else:
|
else:
|
||||||
logger.warning("No 'loaded_models' key in metrics_data!")
|
logger.warning("No 'loaded_models' key in metrics_data!")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Invalid metrics_data: {metrics_data}")
|
logger.warning(f"Invalid metrics_data: {metrics_data}")
|
||||||
return self.component_manager.format_training_metrics(metrics_data)
|
|
||||||
|
logger.info("Formatting training metrics...")
|
||||||
|
formatted_metrics = self.component_manager.format_training_metrics(metrics_data)
|
||||||
|
logger.info(f"Formatted metrics type: {type(formatted_metrics)}, length: {len(formatted_metrics) if isinstance(formatted_metrics, list) else 'N/A'}")
|
||||||
|
return formatted_metrics
|
||||||
except PreventUpdate:
|
except PreventUpdate:
|
||||||
|
logger.info("PreventUpdate raised in training metrics callback")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error updating training metrics: {e}")
|
logger.error(f"Error updating training metrics: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
return [html.P(f"Error: {str(e)}", className="text-danger")]
|
return [html.P(f"Error: {str(e)}", className="text-danger")]
|
||||||
|
|
||||||
|
# Test callback for training metrics (commented out - using real callback now)
|
||||||
|
# @self.app.callback(
|
||||||
|
# Output('training-metrics', 'children'),
|
||||||
|
# [Input('refresh-training-metrics-btn', 'n_clicks')],
|
||||||
|
# prevent_initial_call=False
|
||||||
|
# )
|
||||||
|
# def test_training_metrics_callback(n_clicks):
|
||||||
|
# """Test callback for training metrics"""
|
||||||
|
# logger.info(f"test_training_metrics_callback triggered with n_clicks={n_clicks}")
|
||||||
|
# try:
|
||||||
|
# # Return a simple test message
|
||||||
|
# return [html.P("Training metrics test - callback is working!", className="text-success")]
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"Error in test callback: {e}")
|
||||||
|
# return [html.P(f"Error: {str(e)}", className="text-danger")]
|
||||||
|
|
||||||
# Manual trading buttons
|
# Manual trading buttons
|
||||||
@self.app.callback(
|
@self.app.callback(
|
||||||
Output('manual-buy-btn', 'children'),
|
Output('manual-buy-btn', 'children'),
|
||||||
@@ -3651,7 +3713,17 @@ class CleanTradingDashboard:
|
|||||||
available_models = self._get_available_models()
|
available_models = self._get_available_models()
|
||||||
toggle_states = {}
|
toggle_states = {}
|
||||||
for model_name in available_models.keys():
|
for model_name in available_models.keys():
|
||||||
toggle_states[model_name] = self.orchestrator.get_model_toggle_state(model_name)
|
# Map component model name to orchestrator model name for getting toggle state
|
||||||
|
reverse_mapping = {
|
||||||
|
'dqn': 'dqn_agent',
|
||||||
|
'cnn': 'enhanced_cnn',
|
||||||
|
'decision_fusion': 'decision',
|
||||||
|
'extrema_trainer': 'extrema_trainer',
|
||||||
|
'cob_rl': 'cob_rl',
|
||||||
|
'transformer': 'transformer'
|
||||||
|
}
|
||||||
|
orchestrator_model_name = reverse_mapping.get(model_name, model_name)
|
||||||
|
toggle_states[model_name] = self.orchestrator.get_model_toggle_state(orchestrator_model_name)
|
||||||
else:
|
else:
|
||||||
# Fallback to default states for known models
|
# Fallback to default states for known models
|
||||||
toggle_states = {
|
toggle_states = {
|
||||||
@@ -3711,8 +3783,19 @@ class CleanTradingDashboard:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if self.orchestrator:
|
if self.orchestrator:
|
||||||
|
# Map component model name to orchestrator model name for getting statistics
|
||||||
|
reverse_mapping = {
|
||||||
|
'dqn': 'dqn_agent',
|
||||||
|
'cnn': 'enhanced_cnn',
|
||||||
|
'decision_fusion': 'decision',
|
||||||
|
'extrema_trainer': 'extrema_trainer',
|
||||||
|
'cob_rl': 'cob_rl',
|
||||||
|
'transformer': 'transformer'
|
||||||
|
}
|
||||||
|
orchestrator_model_name = reverse_mapping.get(model_name, model_name)
|
||||||
|
|
||||||
# Use the new model statistics system
|
# Use the new model statistics system
|
||||||
model_stats = self.orchestrator.get_model_statistics(model_name.lower())
|
model_stats = self.orchestrator.get_model_statistics(orchestrator_model_name)
|
||||||
if model_stats:
|
if model_stats:
|
||||||
# Last inference time
|
# Last inference time
|
||||||
timing['last_inference'] = model_stats.last_inference_time
|
timing['last_inference'] = model_stats.last_inference_time
|
||||||
@@ -3755,7 +3838,7 @@ class CleanTradingDashboard:
|
|||||||
dqn_prediction_count = len(self.recent_decisions) if signal_generation_active else 0
|
dqn_prediction_count = len(self.recent_decisions) if signal_generation_active else 0
|
||||||
|
|
||||||
# Get latest DQN prediction from orchestrator statistics
|
# Get latest DQN prediction from orchestrator statistics
|
||||||
dqn_stats = orchestrator_stats.get('dqn_agent')
|
dqn_stats = orchestrator_stats.get('dqn_agent') # Use orchestrator's internal name
|
||||||
if dqn_stats and dqn_stats.predictions_history:
|
if dqn_stats and dqn_stats.predictions_history:
|
||||||
# Get the most recent prediction
|
# Get the most recent prediction
|
||||||
latest_pred = list(dqn_stats.predictions_history)[-1]
|
latest_pred = list(dqn_stats.predictions_history)[-1]
|
||||||
@@ -3786,8 +3869,8 @@ class CleanTradingDashboard:
|
|||||||
last_confidence = 0.68
|
last_confidence = 0.68
|
||||||
last_timestamp = datetime.now().strftime('%H:%M:%S')
|
last_timestamp = datetime.now().strftime('%H:%M:%S')
|
||||||
|
|
||||||
# Get real DQN statistics from orchestrator (try both old and new names)
|
# Get real DQN statistics from orchestrator (use orchestrator's internal name)
|
||||||
dqn_stats = orchestrator_stats.get('dqn_agent') or orchestrator_stats.get('dqn')
|
dqn_stats = orchestrator_stats.get('dqn_agent')
|
||||||
dqn_current_loss = dqn_stats.current_loss if dqn_stats else None
|
dqn_current_loss = dqn_stats.current_loss if dqn_stats else None
|
||||||
dqn_best_loss = dqn_stats.best_loss if dqn_stats else None
|
dqn_best_loss = dqn_stats.best_loss if dqn_stats else None
|
||||||
dqn_accuracy = dqn_stats.accuracy if dqn_stats else None
|
dqn_accuracy = dqn_stats.accuracy if dqn_stats else None
|
||||||
@@ -3867,8 +3950,8 @@ class CleanTradingDashboard:
|
|||||||
cnn_state = model_states.get('cnn', {})
|
cnn_state = model_states.get('cnn', {})
|
||||||
cnn_timing = get_model_timing_info('CNN')
|
cnn_timing = get_model_timing_info('CNN')
|
||||||
|
|
||||||
# Get real CNN statistics from orchestrator (try both old and new names)
|
# Get real CNN statistics from orchestrator (use orchestrator's internal name)
|
||||||
cnn_stats = orchestrator_stats.get('enhanced_cnn') or orchestrator_stats.get('cnn')
|
cnn_stats = orchestrator_stats.get('enhanced_cnn')
|
||||||
cnn_active = cnn_stats is not None
|
cnn_active = cnn_stats is not None
|
||||||
|
|
||||||
# Get latest CNN prediction from orchestrator statistics
|
# Get latest CNN prediction from orchestrator statistics
|
||||||
@@ -4095,7 +4178,10 @@ class CleanTradingDashboard:
|
|||||||
# 4. COB RL Model Status - using orchestrator SSOT
|
# 4. COB RL Model Status - using orchestrator SSOT
|
||||||
cob_state = model_states.get('cob_rl', {})
|
cob_state = model_states.get('cob_rl', {})
|
||||||
cob_timing = get_model_timing_info('COB_RL')
|
cob_timing = get_model_timing_info('COB_RL')
|
||||||
cob_active = True
|
|
||||||
|
# Get real COB RL statistics from orchestrator (use orchestrator's internal name)
|
||||||
|
cob_stats = orchestrator_stats.get('cob_rl')
|
||||||
|
cob_active = cob_stats is not None
|
||||||
cob_predictions_count = len(self.recent_decisions) * 2
|
cob_predictions_count = len(self.recent_decisions) * 2
|
||||||
|
|
||||||
# Get COB RL toggle states
|
# Get COB RL toggle states
|
||||||
@@ -4154,10 +4240,8 @@ class CleanTradingDashboard:
|
|||||||
decision_inference_enabled = decision_toggle_state.get("inference_enabled", True)
|
decision_inference_enabled = decision_toggle_state.get("inference_enabled", True)
|
||||||
decision_training_enabled = decision_toggle_state.get("training_enabled", True)
|
decision_training_enabled = decision_toggle_state.get("training_enabled", True)
|
||||||
|
|
||||||
# Get real decision fusion statistics from orchestrator
|
# Get real decision fusion statistics from orchestrator (use orchestrator's internal name)
|
||||||
decision_stats = None
|
decision_stats = orchestrator_stats.get('decision')
|
||||||
if self.orchestrator and hasattr(self.orchestrator, 'model_statistics'):
|
|
||||||
decision_stats = self.orchestrator.model_statistics.get('decision_fusion')
|
|
||||||
|
|
||||||
# Get real last prediction
|
# Get real last prediction
|
||||||
last_prediction = 'HOLD'
|
last_prediction = 'HOLD'
|
||||||
|
@@ -140,7 +140,8 @@ class DashboardComponentManager:
|
|||||||
# Create table headers
|
# Create table headers
|
||||||
headers = html.Thead([
|
headers = html.Thead([
|
||||||
html.Tr([
|
html.Tr([
|
||||||
html.Th("Time", className="small"),
|
html.Th("Entry Time", className="small"),
|
||||||
|
html.Th("Exit Time", className="small"),
|
||||||
html.Th("Side", className="small"),
|
html.Th("Side", className="small"),
|
||||||
html.Th("Size", className="small"),
|
html.Th("Size", className="small"),
|
||||||
html.Th("Entry", className="small"),
|
html.Th("Entry", className="small"),
|
||||||
@@ -158,6 +159,7 @@ class DashboardComponentManager:
|
|||||||
if hasattr(trade, 'entry_time'):
|
if hasattr(trade, 'entry_time'):
|
||||||
# This is a trade object
|
# This is a trade object
|
||||||
entry_time = getattr(trade, 'entry_time', 'Unknown')
|
entry_time = getattr(trade, 'entry_time', 'Unknown')
|
||||||
|
exit_time = getattr(trade, 'exit_time', 'Unknown')
|
||||||
side = getattr(trade, 'side', 'UNKNOWN')
|
side = getattr(trade, 'side', 'UNKNOWN')
|
||||||
size = getattr(trade, 'size', 0)
|
size = getattr(trade, 'size', 0)
|
||||||
entry_price = getattr(trade, 'entry_price', 0)
|
entry_price = getattr(trade, 'entry_price', 0)
|
||||||
@@ -168,6 +170,7 @@ class DashboardComponentManager:
|
|||||||
else:
|
else:
|
||||||
# This is a dictionary format
|
# This is a dictionary format
|
||||||
entry_time = trade.get('entry_time', 'Unknown')
|
entry_time = trade.get('entry_time', 'Unknown')
|
||||||
|
exit_time = trade.get('exit_time', 'Unknown')
|
||||||
side = trade.get('side', 'UNKNOWN')
|
side = trade.get('side', 'UNKNOWN')
|
||||||
size = trade.get('quantity', trade.get('size', 0)) # Try 'quantity' first, then 'size'
|
size = trade.get('quantity', trade.get('size', 0)) # Try 'quantity' first, then 'size'
|
||||||
entry_price = trade.get('entry_price', 0)
|
entry_price = trade.get('entry_price', 0)
|
||||||
@@ -176,11 +179,17 @@ class DashboardComponentManager:
|
|||||||
fees = trade.get('fees', 0)
|
fees = trade.get('fees', 0)
|
||||||
hold_time_seconds = trade.get('hold_time_seconds', 0.0)
|
hold_time_seconds = trade.get('hold_time_seconds', 0.0)
|
||||||
|
|
||||||
# Format time
|
# Format entry time
|
||||||
if isinstance(entry_time, datetime):
|
if isinstance(entry_time, datetime):
|
||||||
time_str = entry_time.strftime('%H:%M:%S')
|
entry_time_str = entry_time.strftime('%H:%M:%S')
|
||||||
else:
|
else:
|
||||||
time_str = str(entry_time)
|
entry_time_str = str(entry_time)
|
||||||
|
|
||||||
|
# Format exit time
|
||||||
|
if isinstance(exit_time, datetime):
|
||||||
|
exit_time_str = exit_time.strftime('%H:%M:%S')
|
||||||
|
else:
|
||||||
|
exit_time_str = str(exit_time)
|
||||||
|
|
||||||
# Determine P&L color
|
# Determine P&L color
|
||||||
pnl_class = "text-success" if pnl >= 0 else "text-danger"
|
pnl_class = "text-success" if pnl >= 0 else "text-danger"
|
||||||
@@ -197,7 +206,8 @@ class DashboardComponentManager:
|
|||||||
net_pnl = pnl - fees
|
net_pnl = pnl - fees
|
||||||
|
|
||||||
row = html.Tr([
|
row = html.Tr([
|
||||||
html.Td(time_str, className="small"),
|
html.Td(entry_time_str, className="small"),
|
||||||
|
html.Td(exit_time_str, className="small"),
|
||||||
html.Td(side, className=f"small {side_class}"),
|
html.Td(side, className=f"small {side_class}"),
|
||||||
html.Td(f"${position_size_usd:.2f}", className="small"), # Show size in USD
|
html.Td(f"${position_size_usd:.2f}", className="small"), # Show size in USD
|
||||||
html.Td(f"${entry_price:.2f}", className="small"),
|
html.Td(f"${entry_price:.2f}", className="small"),
|
||||||
@@ -714,11 +724,11 @@ class DashboardComponentManager:
|
|||||||
"""Format training metrics for display - Enhanced with loaded models"""
|
"""Format training metrics for display - Enhanced with loaded models"""
|
||||||
try:
|
try:
|
||||||
# DEBUG: Log what we're receiving
|
# DEBUG: Log what we're receiving
|
||||||
logger.debug(f"format_training_metrics received: {type(metrics_data)}")
|
logger.info(f"format_training_metrics received: {type(metrics_data)}")
|
||||||
if metrics_data:
|
if metrics_data:
|
||||||
logger.debug(f"Metrics keys: {list(metrics_data.keys()) if isinstance(metrics_data, dict) else 'Not a dict'}")
|
logger.info(f"Metrics keys: {list(metrics_data.keys()) if isinstance(metrics_data, dict) else 'Not a dict'}")
|
||||||
if isinstance(metrics_data, dict) and 'loaded_models' in metrics_data:
|
if isinstance(metrics_data, dict) and 'loaded_models' in metrics_data:
|
||||||
logger.debug(f"Loaded models: {list(metrics_data['loaded_models'].keys())}")
|
logger.info(f"Loaded models: {list(metrics_data['loaded_models'].keys())}")
|
||||||
|
|
||||||
if not metrics_data or 'error' in metrics_data:
|
if not metrics_data or 'error' in metrics_data:
|
||||||
logger.warning(f"No training data or error in metrics_data: {metrics_data}")
|
logger.warning(f"No training data or error in metrics_data: {metrics_data}")
|
||||||
@@ -772,6 +782,7 @@ class DashboardComponentManager:
|
|||||||
checkpoint_status = "LOADED" if model_info.get('checkpoint_loaded', False) else "FRESH"
|
checkpoint_status = "LOADED" if model_info.get('checkpoint_loaded', False) else "FRESH"
|
||||||
|
|
||||||
# Model card
|
# Model card
|
||||||
|
logger.info(f"Creating model card for {model_name} with toggles: inference={model_info.get('inference_enabled', True)}, training={model_info.get('training_enabled', True)}")
|
||||||
model_card = html.Div([
|
model_card = html.Div([
|
||||||
# Header with model name and toggle
|
# Header with model name and toggle
|
||||||
html.Div([
|
html.Div([
|
||||||
@@ -1043,10 +1054,15 @@ class DashboardComponentManager:
|
|||||||
html.Span(f"{enhanced_stats['recent_validation_score']:.3f}", className="text-primary small fw-bold")
|
html.Span(f"{enhanced_stats['recent_validation_score']:.3f}", className="text-primary small fw-bold")
|
||||||
], className="mb-1"))
|
], className="mb-1"))
|
||||||
|
|
||||||
|
logger.info(f"format_training_metrics returning {len(content)} components")
|
||||||
|
for i, component in enumerate(content[:3]): # Log first 3 components
|
||||||
|
logger.info(f" Component {i}: {type(component)}")
|
||||||
return content
|
return content
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error formatting training metrics: {e}")
|
logger.error(f"Error formatting training metrics: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
return [html.P(f"Error: {str(e)}", className="text-danger small")]
|
return [html.P(f"Error: {str(e)}", className="text-danger small")]
|
||||||
|
|
||||||
def _format_cnn_pivot_prediction(self, model_info):
|
def _format_cnn_pivot_prediction(self, model_info):
|
||||||
|
@@ -17,11 +17,32 @@ class DashboardLayoutManager:
|
|||||||
|
|
||||||
def create_main_layout(self):
|
def create_main_layout(self):
|
||||||
"""Create the main dashboard layout"""
|
"""Create the main dashboard layout"""
|
||||||
return html.Div([
|
try:
|
||||||
self._create_header(),
|
print("Creating main layout...")
|
||||||
self._create_interval_component(),
|
header = self._create_header()
|
||||||
self._create_main_content()
|
print("Header created")
|
||||||
], className="container-fluid")
|
interval_component = self._create_interval_component()
|
||||||
|
print("Interval component created")
|
||||||
|
main_content = self._create_main_content()
|
||||||
|
print("Main content created")
|
||||||
|
|
||||||
|
layout = html.Div([
|
||||||
|
header,
|
||||||
|
interval_component,
|
||||||
|
main_content
|
||||||
|
], className="container-fluid")
|
||||||
|
|
||||||
|
print("Main layout created successfully")
|
||||||
|
return layout
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error creating main layout: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
# Return a simple error layout
|
||||||
|
return html.Div([
|
||||||
|
html.H1("Dashboard Error", className="text-danger"),
|
||||||
|
html.P(f"Error creating layout: {str(e)}", className="text-danger")
|
||||||
|
])
|
||||||
|
|
||||||
def _create_header(self):
|
def _create_header(self):
|
||||||
"""Create the dashboard header"""
|
"""Create the dashboard header"""
|
||||||
@@ -52,7 +73,15 @@ class DashboardLayoutManager:
|
|||||||
dcc.Interval(
|
dcc.Interval(
|
||||||
id='slow-interval-component',
|
id='slow-interval-component',
|
||||||
interval=10000, # Update every 10 seconds (0.1 Hz) - OPTIMIZED
|
interval=10000, # Update every 10 seconds (0.1 Hz) - OPTIMIZED
|
||||||
n_intervals=0
|
n_intervals=0,
|
||||||
|
disabled=False
|
||||||
|
),
|
||||||
|
# Fast interval for testing (5 seconds)
|
||||||
|
dcc.Interval(
|
||||||
|
id='fast-interval-component',
|
||||||
|
interval=5000, # Update every 5 seconds for testing
|
||||||
|
n_intervals=0,
|
||||||
|
disabled=False
|
||||||
),
|
),
|
||||||
# WebSocket-based updates for high-frequency data (no interval needed)
|
# WebSocket-based updates for high-frequency data (no interval needed)
|
||||||
html.Div(id='websocket-updates-container', style={'display': 'none'})
|
html.Div(id='websocket-updates-container', style={'display': 'none'})
|
||||||
@@ -357,10 +386,16 @@ class DashboardLayoutManager:
|
|||||||
html.Div([
|
html.Div([
|
||||||
html.Div([
|
html.Div([
|
||||||
html.Div([
|
html.Div([
|
||||||
html.H6([
|
html.Div([
|
||||||
html.I(className="fas fa-brain me-2"),
|
html.H6([
|
||||||
"Models & Training Progress",
|
html.I(className="fas fa-brain me-2"),
|
||||||
], className="card-title mb-2"),
|
"Models & Training Progress",
|
||||||
|
], className="card-title mb-2"),
|
||||||
|
html.Button([
|
||||||
|
html.I(className="fas fa-sync-alt me-1"),
|
||||||
|
"Refresh"
|
||||||
|
], id="refresh-training-metrics-btn", className="btn btn-sm btn-outline-primary")
|
||||||
|
], className="d-flex justify-content-between align-items-center mb-2"),
|
||||||
html.Div(
|
html.Div(
|
||||||
id="training-metrics",
|
id="training-metrics",
|
||||||
style={"height": "300px", "overflowY": "auto"},
|
style={"height": "300px", "overflowY": "auto"},
|
||||||
|
Reference in New Issue
Block a user