improve training and model data

This commit is contained in:
Dobromir Popov
2025-07-07 15:48:25 +03:00
parent 271e7d59b5
commit 2d8f763eeb
16 changed files with 2047 additions and 1699 deletions

View File

@ -15,5 +15,7 @@ from NN.models.cnn_model import EnhancedCNNModel as CNNModel
from NN.models.dqn_agent import DQNAgent
from NN.models.cob_rl_model import MassiveRLNetwork, COBRLModelInterface
from NN.models.advanced_transformer_trading import AdvancedTradingTransformer, TradingTransformerConfig
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface
__all__ = ['CNNModel', 'DQNAgent', 'MassiveRLNetwork', 'COBRLModelInterface', 'AdvancedTradingTransformer', 'TradingTransformerConfig']
__all__ = ['CNNModel', 'DQNAgent', 'MassiveRLNetwork', 'COBRLModelInterface', 'AdvancedTradingTransformer', 'TradingTransformerConfig',
'ModelInterface', 'CNNModelInterface', 'RLAgentInterface', 'ExtremaTrainerInterface']

View File

@ -772,8 +772,8 @@ class CNNModelTrainer:
# Comprehensive cleanup on any error
self.reset_computational_graph()
# Return safe dummy values to continue training
return {'main_loss': 0.0, 'total_loss': 0.0, 'accuracy': 0.5}
# Return realistic loss values based on random baseline performance
return {'main_loss': 0.693, 'total_loss': 0.693, 'accuracy': 0.5} # ln(2) for binary cross-entropy at random chance
def save_model(self, filepath: str, metadata: Optional[Dict] = None):
"""Save model with metadata"""
@ -884,9 +884,8 @@ class CNNModel:
logger.error(f"Error in CNN prediction: {e}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
# Return dummy prediction
pred_class = np.array([0])
pred_proba = np.array([[0.1] * self.output_size])
# Return prediction based on simple statistical analysis of input
pred_class, pred_proba = self._fallback_prediction(X)
return pred_class, pred_proba
def fit(self, X, y, **kwargs):
@ -944,6 +943,68 @@ class CNNModel:
except Exception as e:
logger.error(f"Error saving CNN model: {e}")
def _fallback_prediction(self, X):
"""Generate prediction based on statistical analysis of input data"""
try:
if isinstance(X, np.ndarray):
data = X
else:
data = X.cpu().numpy() if hasattr(X, 'cpu') else np.array(X)
# Analyze trends in the input data
if len(data.shape) >= 2:
# Calculate simple trend from the data
last_values = data[-10:] if len(data) >= 10 else data # Last 10 time steps
if len(last_values.shape) == 2:
# Multiple features - use first feature column as price
trend_data = last_values[:, 0]
else:
trend_data = last_values
# Calculate trend
if len(trend_data) > 1:
trend = (trend_data[-1] - trend_data[0]) / trend_data[0] if trend_data[0] != 0 else 0
# Map trend to action
if trend > 0.001: # Upward trend > 0.1%
action = 1 # BUY
confidence = min(0.9, 0.5 + abs(trend) * 10)
elif trend < -0.001: # Downward trend < -0.1%
action = 0 # SELL
confidence = min(0.9, 0.5 + abs(trend) * 10)
else:
action = 0 # Default to SELL for unclear trend
confidence = 0.3
else:
action = 0
confidence = 0.3
else:
action = 0
confidence = 0.3
# Create probabilities
proba = np.zeros(self.output_size)
proba[action] = confidence
# Distribute remaining probability among other classes
remaining = 1.0 - confidence
for i in range(self.output_size):
if i != action:
proba[i] = remaining / (self.output_size - 1)
pred_class = np.array([action])
pred_proba = np.array([proba])
logger.debug(f"Fallback prediction: action={action}, confidence={confidence:.2f}")
return pred_class, pred_proba
except Exception as e:
logger.error(f"Error in fallback prediction: {e}")
# Final fallback - conservative prediction
pred_class = np.array([0]) # SELL
proba = np.ones(self.output_size) / self.output_size # Equal probabilities
pred_proba = np.array([proba])
return pred_class, pred_proba
def load(self, filepath: str):
"""Load the model"""
try:

View File

@ -18,6 +18,9 @@ import torch.nn.functional as F
import numpy as np
import logging
from typing import Dict, List, Optional, Tuple, Any
from abc import ABC, abstractmethod
from models import ModelInterface
logger = logging.getLogger(__name__)
@ -221,12 +224,13 @@ class MassiveRLNetwork(nn.Module):
}
class COBRLModelInterface:
class COBRLModelInterface(ModelInterface):
"""
Interface for the COB RL model that handles model management, training, and inference
"""
def __init__(self, model_checkpoint_dir: str = "models/realtime_rl_cob", device: str = None):
super().__init__(name="cob_rl_model") # Initialize ModelInterface with a name
self.model_checkpoint_dir = model_checkpoint_dir
self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))
@ -368,4 +372,23 @@ class COBRLModelInterface:
def get_model_stats(self) -> Dict[str, Any]:
"""Get model statistics"""
return self.model.get_model_info()
return self.model.get_model_info()
def get_memory_usage(self) -> float:
"""Estimate COBRLModel memory usage in MB"""
# This is an estimation. For a more precise value, you'd inspect tensors.
# A massive network might take hundreds of MBs or even GBs.
# Let's use a more realistic estimate for a 1B parameter model.
# Assuming float32 (4 bytes per parameter), 1B params = 4GB.
# For a 400M parameter network (as mentioned in comments), it's 1.6GB.
# Let's use a placeholder if it's too complex to calculate dynamically.
try:
# Calculate total parameters and convert to MB
total_params = sum(p.numel() for p in self.model.parameters())
# Assuming float32 (4 bytes per parameter) and converting to MB
memory_bytes = total_params * 4
memory_mb = memory_bytes / (1024 * 1024)
return memory_mb
except Exception as e:
logger.debug(f"Could not estimate COBRLModel memory usage: {e}")
return 1600.0 # Default to 1.6 GB as an estimate if calculation fails

View File

@ -5,7 +5,7 @@ import numpy as np
from collections import deque
import random
from typing import Tuple, List
import os
import osvu
import sys
import logging
import torch.nn.functional as F
@ -129,7 +129,128 @@ class DQNAgent:
logger.info(f"DQN Agent initialized with checkpoint management: {enable_checkpoints}")
if enable_checkpoints:
logger.info(f"Model name: {model_name}, Checkpoint frequency: {self.checkpoint_frequency}")
# Add this line to the __init__ method
self.recent_actions = deque(maxlen=10)
self.recent_prices = deque(maxlen=20)
self.recent_rewards = deque(maxlen=100)
# Price prediction tracking
self.last_price_pred = {
'immediate': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
},
'midterm': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
},
'longterm': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
}
}
# Store separate memory for price direction examples
self.price_movement_memory = [] # For storing examples of clear price movements
# Performance tracking
self.losses = []
self.no_improvement_count = 0
# Confidence tracking
self.confidence_history = []
self.avg_confidence = 0.0
self.max_confidence = 0.0
self.min_confidence = 1.0
# Enhanced features from EnhancedDQNAgent
# Market adaptation capabilities
self.market_regime_weights = {
'trending': 1.2, # Higher confidence in trending markets
'ranging': 0.8, # Lower confidence in ranging markets
'volatile': 0.6 # Much lower confidence in volatile markets
}
# Dueling network support (requires enhanced network architecture)
self.use_dueling = True
# Prioritized experience replay parameters
self.use_prioritized_replay = priority_memory
self.alpha = 0.6 # Priority exponent
self.beta = 0.4 # Importance sampling exponent
self.beta_increment = 0.001
# Double DQN support
self.use_double_dqn = True
# Enhanced training features from EnhancedDQNAgent
self.target_update_freq = target_update # More descriptive name
self.training_steps = 0
self.gradient_clip_norm = 1.0 # Gradient clipping
# Enhanced statistics tracking
self.epsilon_history = []
self.td_errors = [] # Track TD errors for analysis
# Trade action fee and confidence thresholds
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
# Violent move detection
self.price_history = []
self.volatility_window = 20 # Window size for volatility calculation
self.volatility_threshold = 0.0015 # Threshold for considering a move "violent"
self.post_violent_move = False # Flag for recent violent move
self.violent_move_cooldown = 0 # Cooldown after violent move
# Feature integration
self.last_hidden_features = None # Store last extracted features
self.feature_history = [] # Store history of features for analysis
# Real-time tick features integration
self.realtime_tick_features = None # Latest tick features from tick processor
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
# Check if mixed precision training should be used
self.use_mixed_precision = False
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
self.use_mixed_precision = True
self.scaler = torch.cuda.amp.GradScaler()
logger.info("Mixed precision training enabled")
else:
logger.info("Mixed precision training disabled")
# Track if we're in training mode
self.training = True
# For compatibility with old code
self.state_size = np.prod(state_shape)
self.action_size = n_actions
self.memory_size = buffer_size
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}")
# Log model parameters
total_params = sum(p.numel() for p in self.policy_net.parameters())
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
# Position management for 2-action system
self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long)
self.position_entry_price = 0.0
self.position_entry_time = None
# Different thresholds for entry vs exit decisions - AGGRESSIVE for more training data
self.entry_confidence_threshold = 0.35 # Lower threshold for new positions (was 0.7)
self.exit_confidence_threshold = 0.15 # Very low threshold for closing positions (was 0.3)
self.uncertainty_threshold = 0.1 # When to stay neutral
def load_best_checkpoint(self):
"""Load the best checkpoint for this DQN agent"""
try:
@ -267,9 +388,6 @@ class DQNAgent:
# Trade action fee and confidence thresholds
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
self.recent_actions = deque(maxlen=10)
self.recent_prices = deque(maxlen=20)
self.recent_rewards = deque(maxlen=100)
# Violent move detection
self.price_history = []

View File

@ -0,0 +1,99 @@
"""
Model Interfaces Module
Defines abstract base classes and concrete implementations for various model types
to ensure consistent interaction within the trading system.
"""
import logging
from typing import Dict, Any, Optional, List
from abc import ABC, abstractmethod
import numpy as np
logger = logging.getLogger(__name__)
class ModelInterface(ABC):
"""Base interface for all models"""
def __init__(self, name: str):
self.name = name
@abstractmethod
def predict(self, data):
"""Make a prediction"""
pass
@abstractmethod
def get_memory_usage(self) -> float:
"""Get memory usage in MB"""
pass
class CNNModelInterface(ModelInterface):
"""Interface for CNN models"""
def __init__(self, model, name: str):
super().__init__(name)
self.model = model
def predict(self, data):
"""Make CNN prediction"""
try:
if hasattr(self.model, 'predict'):
return self.model.predict(data)
return None
except Exception as e:
logger.error(f"Error in CNN prediction: {e}")
return None
def get_memory_usage(self) -> float:
"""Estimate CNN memory usage"""
return 50.0 # MB
class RLAgentInterface(ModelInterface):
"""Interface for RL agents"""
def __init__(self, model, name: str):
super().__init__(name)
self.model = model
def predict(self, data):
"""Make RL prediction"""
try:
if hasattr(self.model, 'act'):
return self.model.act(data)
elif hasattr(self.model, 'predict'):
return self.model.predict(data)
return None
except Exception as e:
logger.error(f"Error in RL prediction: {e}")
return None
def get_memory_usage(self) -> float:
"""Estimate RL memory usage"""
return 25.0 # MB
class ExtremaTrainerInterface(ModelInterface):
"""Interface for ExtremaTrainer models, providing context features"""
def __init__(self, model, name: str):
super().__init__(name)
self.model = model
def predict(self, data=None):
"""ExtremaTrainer doesn't predict in the traditional sense, it provides features."""
logger.warning(f"Predict method called on ExtremaTrainerInterface ({self.name}). Use get_context_features_for_model instead.")
return None
def get_memory_usage(self) -> float:
"""Estimate ExtremaTrainer memory usage"""
return 30.0 # MB
def get_context_features_for_model(self, symbol: str) -> Optional[np.ndarray]:
"""Get context features from the ExtremaTrainer for model consumption."""
try:
if hasattr(self.model, 'get_context_features_for_model'):
return self.model.get_context_features_for_model(symbol)
return None
except Exception as e:
logger.error(f"Error getting extrema context features: {e}")
return None

View File

@ -339,12 +339,64 @@ class TransformerModel:
# Ensure X_features has the right shape
if X_features is None:
# Create dummy features with zeros
X_features = np.zeros((X_ts.shape[0], self.feature_input_shape))
# Extract features from time series data if no external features provided
X_features = self._extract_features_from_timeseries(X_ts)
elif len(X_features.shape) == 1:
# Single sample, add batch dimension
X_features = np.expand_dims(X_features, axis=0)
def _extract_features_from_timeseries(self, X_ts: np.ndarray) -> np.ndarray:
"""Extract meaningful features from time series data instead of using dummy zeros"""
try:
batch_size = X_ts.shape[0]
features = []
for i in range(batch_size):
sample = X_ts[i] # Shape: (timesteps, features)
# Extract statistical features from each feature dimension
sample_features = []
for feature_idx in range(sample.shape[1]):
feature_data = sample[:, feature_idx]
# Basic statistical features
sample_features.extend([
np.mean(feature_data), # Mean
np.std(feature_data), # Standard deviation
np.min(feature_data), # Minimum
np.max(feature_data), # Maximum
np.percentile(feature_data, 25), # 25th percentile
np.percentile(feature_data, 75), # 75th percentile
])
# Trend features
if len(feature_data) > 1:
# Linear trend (slope)
x = np.arange(len(feature_data))
slope = np.polyfit(x, feature_data, 1)[0]
sample_features.append(slope)
# Rate of change
rate_of_change = (feature_data[-1] - feature_data[0]) / feature_data[0] if feature_data[0] != 0 else 0
sample_features.append(rate_of_change)
else:
sample_features.extend([0.0, 0.0])
# Pad or truncate to expected feature size
while len(sample_features) < self.feature_input_shape:
sample_features.append(0.0)
sample_features = sample_features[:self.feature_input_shape]
features.append(sample_features)
return np.array(features, dtype=np.float32)
except Exception as e:
logger.error(f"Error extracting features from time series: {e}")
# Fallback to zeros if extraction fails
return np.zeros((X_ts.shape[0], self.feature_input_shape), dtype=np.float32)
# Get predictions
y_proba = self.model.predict([X_ts, X_features])

View File

@ -77,3 +77,8 @@ use existing checkpoint manager if it;s not too bloated as well. otherwise re-im
we should load the models in a way that we do a back propagation and other model specificic training at realtime as training examples emerge from the realtime data we process. we will save only the best examples (the realtime data dumps we feed to the models) so we can cold start other models if we change the architecture. if it's not working, perform a cleanup of all traininn and trainer code to make it easer to work withm to streamline latest changes and to simplify and refactor it

View File

@ -197,6 +197,7 @@ enhanced_training:
enabled: true # Enable enhanced real-time training
auto_start: true # Automatically start training when orchestrator starts
training_intervals:
cob_rl_training_interval: 1 # Train COB RL every 1 second (HIGHEST PRIORITY)
dqn_training_interval: 5 # Train DQN every 5 seconds
cnn_training_interval: 10 # Train CNN every 10 seconds
validation_interval: 60 # Validate every minute
@ -206,6 +207,11 @@ enhanced_training:
adaptation_threshold: 0.1 # Performance threshold for adaptation
forward_looking_predictions: true # Enable forward-looking prediction validation
# COB RL Priority Settings (since order book imbalance predicts price moves)
cob_rl_priority: true # Enable COB RL as highest priority model
cob_rl_batch_size: 16 # Smaller batches for faster COB updates
cob_rl_min_samples: 5 # Lower threshold for COB training
# Real-time RL COB Trader Configuration
realtime_rl:
# Model parameters for 400M parameter network (faster startup)

View File

@ -88,7 +88,7 @@ class COBIntegration:
# Start COB provider streaming
try:
logger.info("Starting COB provider streaming...")
await self.cob_provider.start_streaming()
await self.cob_provider.start_streaming()
except Exception as e:
logger.error(f"Error starting COB provider streaming: {e}")
# Start a background task instead
@ -112,7 +112,7 @@ class COBIntegration:
"""Stop COB integration"""
logger.info("Stopping COB Integration")
if self.cob_provider:
await self.cob_provider.stop_streaming()
await self.cob_provider.stop_streaming()
logger.info("COB Integration stopped")
def add_cnn_callback(self, callback: Callable[[str, Dict], None]):
@ -313,7 +313,7 @@ class COBIntegration:
# Get fixed bucket size for the symbol
bucket_size = 1.0 # Default bucket size
if self.cob_provider:
bucket_size = self.cob_provider.fixed_usd_buckets.get(symbol, 1.0)
bucket_size = self.cob_provider.fixed_usd_buckets.get(symbol, 1.0)
# Calculate price range for buckets
mid_price = cob_snapshot.volume_weighted_mid
@ -359,15 +359,15 @@ class COBIntegration:
# Get actual Session Volume Profile (SVP) from trade data
svp_data = []
if self.cob_provider:
try:
svp_result = self.cob_provider.get_session_volume_profile(symbol, bucket_size)
if svp_result and 'data' in svp_result:
svp_data = svp_result['data']
logger.debug(f"Retrieved SVP data for {symbol}: {len(svp_data)} price levels")
else:
logger.warning(f"No SVP data available for {symbol}")
except Exception as e:
logger.error(f"Error getting SVP data for {symbol}: {e}")
try:
svp_result = self.cob_provider.get_session_volume_profile(symbol, bucket_size)
if svp_result and 'data' in svp_result:
svp_data = svp_result['data']
logger.debug(f"Retrieved SVP data for {symbol}: {len(svp_data)} price levels")
else:
logger.warning(f"No SVP data available for {symbol}")
except Exception as e:
logger.error(f"Error getting SVP data for {symbol}: {e}")
# Generate market stats
stats = {
@ -405,18 +405,18 @@ class COBIntegration:
# Get additional real-time stats
realtime_stats = {}
if self.cob_provider:
try:
realtime_stats = self.cob_provider.get_realtime_stats(symbol)
if realtime_stats:
stats['realtime_1s'] = realtime_stats.get('1s_stats', {})
stats['realtime_5s'] = realtime_stats.get('5s_stats', {})
else:
stats['realtime_1s'] = {}
stats['realtime_5s'] = {}
except Exception as e:
logger.error(f"Error getting real-time stats for {symbol}: {e}")
try:
realtime_stats = self.cob_provider.get_realtime_stats(symbol)
if realtime_stats:
stats['realtime_1s'] = realtime_stats.get('1s_stats', {})
stats['realtime_5s'] = realtime_stats.get('5s_stats', {})
else:
stats['realtime_1s'] = {}
stats['realtime_5s'] = {}
except Exception as e:
logger.error(f"Error getting real-time stats for {symbol}: {e}")
stats['realtime_1s'] = {}
stats['realtime_5s'] = {}
return {
'type': 'cob_update',
@ -487,9 +487,9 @@ class COBIntegration:
try:
for symbol in self.symbols:
if self.cob_provider:
cob_snapshot = self.cob_provider.get_consolidated_orderbook(symbol)
if cob_snapshot:
await self._analyze_cob_patterns(symbol, cob_snapshot)
cob_snapshot = self.cob_provider.get_consolidated_orderbook(symbol)
if cob_snapshot:
await self._analyze_cob_patterns(symbol, cob_snapshot)
await asyncio.sleep(1)
@ -655,7 +655,7 @@ class COBIntegration:
except Exception as e:
logger.error(f"Error getting NN stats for {symbol}: {e}")
return {}
return {}
def get_realtime_stats(self):
# Added null check to ensure the COB provider is initialized

View File

@ -661,22 +661,315 @@ class MultiExchangeCOBProvider:
except Exception as e:
logger.error(f"Error processing Binance order book for {symbol}: {e}", exc_info=True)
async def _stream_coinbase_orderbook(self, symbol: str, config: ExchangeConfig):
"""Stream Coinbase order book data (placeholder implementation)"""
async def _process_coinbase_orderbook(self, symbol: str, data: Dict):
"""Process Coinbase order book data"""
try:
# For now, just log that Coinbase streaming is not implemented
logger.info(f"Coinbase streaming for {symbol} not yet implemented")
await asyncio.sleep(60) # Sleep to prevent spam
if data.get('type') == 'snapshot':
# Initial snapshot
bids = {}
asks = {}
for bid_data in data.get('bids', []):
price, size = float(bid_data[0]), float(bid_data[1])
if size > 0:
bids[price] = ExchangeOrderBookLevel(
exchange='coinbase',
price=price,
size=size,
volume_usd=price * size,
orders_count=1, # Coinbase doesn't provide order count
side='bid',
timestamp=datetime.now(),
raw_data=bid_data
)
for ask_data in data.get('asks', []):
price, size = float(ask_data[0]), float(ask_data[1])
if size > 0:
asks[price] = ExchangeOrderBookLevel(
exchange='coinbase',
price=price,
size=size,
volume_usd=price * size,
orders_count=1,
side='ask',
timestamp=datetime.now(),
raw_data=ask_data
)
# Update order book
async with self.data_lock:
if symbol not in self.exchange_order_books:
self.exchange_order_books[symbol] = {}
self.exchange_order_books[symbol]['coinbase'] = {
'bids': bids,
'asks': asks,
'last_update': datetime.now(),
'connected': True
}
logger.info(f"Coinbase snapshot for {symbol}: {len(bids)} bids, {len(asks)} asks")
elif data.get('type') == 'l2update':
# Level 2 update
async with self.data_lock:
if symbol in self.exchange_order_books and 'coinbase' in self.exchange_order_books[symbol]:
coinbase_data = self.exchange_order_books[symbol]['coinbase']
for change in data.get('changes', []):
side, price_str, size_str = change
price, size = float(price_str), float(size_str)
if side == 'buy':
if size == 0:
# Remove level
coinbase_data['bids'].pop(price, None)
else:
# Update level
coinbase_data['bids'][price] = ExchangeOrderBookLevel(
exchange='coinbase',
price=price,
size=size,
volume_usd=price * size,
orders_count=1,
side='bid',
timestamp=datetime.now(),
raw_data=change
)
elif side == 'sell':
if size == 0:
# Remove level
coinbase_data['asks'].pop(price, None)
else:
# Update level
coinbase_data['asks'][price] = ExchangeOrderBookLevel(
exchange='coinbase',
price=price,
size=size,
volume_usd=price * size,
orders_count=1,
side='ask',
timestamp=datetime.now(),
raw_data=change
)
coinbase_data['last_update'] = datetime.now()
# Update exchange count
exchange_name = 'coinbase'
if exchange_name not in self.exchange_update_counts:
self.exchange_update_counts[exchange_name] = 0
self.exchange_update_counts[exchange_name] += 1
# Log every 1000th update
if self.exchange_update_counts[exchange_name] % 1000 == 0:
logger.info(f"Processed {self.exchange_update_counts[exchange_name]} Coinbase updates for {symbol}")
except Exception as e:
logger.error(f"Error streaming Coinbase order book for {symbol}: {e}")
logger.error(f"Error processing Coinbase order book for {symbol}: {e}", exc_info=True)
async def _process_kraken_orderbook(self, symbol: str, data: Dict):
"""Process Kraken order book data"""
try:
# Kraken sends different message types
if isinstance(data, list) and len(data) > 1:
# Order book update format: [channel_id, data, channel_name, pair]
if len(data) >= 4 and data[2] == "book-25":
book_data = data[1]
# Check for snapshot vs update
if 'bs' in book_data and 'as' in book_data:
# Snapshot
bids = {}
asks = {}
for bid_data in book_data.get('bs', []):
price, volume, timestamp = float(bid_data[0]), float(bid_data[1]), float(bid_data[2])
if volume > 0:
bids[price] = ExchangeOrderBookLevel(
exchange='kraken',
price=price,
size=volume,
volume_usd=price * volume,
orders_count=1, # Kraken doesn't provide order count in book feed
side='bid',
timestamp=datetime.fromtimestamp(timestamp),
raw_data=bid_data
)
for ask_data in book_data.get('as', []):
price, volume, timestamp = float(ask_data[0]), float(ask_data[1]), float(ask_data[2])
if volume > 0:
asks[price] = ExchangeOrderBookLevel(
exchange='kraken',
price=price,
size=volume,
volume_usd=price * volume,
orders_count=1,
side='ask',
timestamp=datetime.fromtimestamp(timestamp),
raw_data=ask_data
)
# Update order book
async with self.data_lock:
if symbol not in self.exchange_order_books:
self.exchange_order_books[symbol] = {}
self.exchange_order_books[symbol]['kraken'] = {
'bids': bids,
'asks': asks,
'last_update': datetime.now(),
'connected': True
}
logger.info(f"Kraken snapshot for {symbol}: {len(bids)} bids, {len(asks)} asks")
else:
# Incremental update
async with self.data_lock:
if symbol in self.exchange_order_books and 'kraken' in self.exchange_order_books[symbol]:
kraken_data = self.exchange_order_books[symbol]['kraken']
# Process bid updates
for bid_update in book_data.get('b', []):
price, volume, timestamp = float(bid_update[0]), float(bid_update[1]), float(bid_update[2])
if volume == 0:
# Remove level
kraken_data['bids'].pop(price, None)
else:
# Update level
kraken_data['bids'][price] = ExchangeOrderBookLevel(
exchange='kraken',
price=price,
size=volume,
volume_usd=price * volume,
orders_count=1,
side='bid',
timestamp=datetime.fromtimestamp(timestamp),
raw_data=bid_update
)
# Process ask updates
for ask_update in book_data.get('a', []):
price, volume, timestamp = float(ask_update[0]), float(ask_update[1]), float(ask_update[2])
if volume == 0:
# Remove level
kraken_data['asks'].pop(price, None)
else:
# Update level
kraken_data['asks'][price] = ExchangeOrderBookLevel(
exchange='kraken',
price=price,
size=volume,
volume_usd=price * volume,
orders_count=1,
side='ask',
timestamp=datetime.fromtimestamp(timestamp),
raw_data=ask_update
)
kraken_data['last_update'] = datetime.now()
# Update exchange count
exchange_name = 'kraken'
if exchange_name not in self.exchange_update_counts:
self.exchange_update_counts[exchange_name] = 0
self.exchange_update_counts[exchange_name] += 1
# Log every 1000th update
if self.exchange_update_counts[exchange_name] % 1000 == 0:
logger.info(f"Processed {self.exchange_update_counts[exchange_name]} Kraken updates for {symbol}")
except Exception as e:
logger.error(f"Error processing Kraken order book for {symbol}: {e}", exc_info=True)
async def _stream_coinbase_orderbook(self, symbol: str, config: ExchangeConfig):
"""Stream Coinbase order book data via WebSocket"""
try:
import json
if websockets is None or websockets_connect is None:
raise ImportError("websockets module not available")
# Coinbase Pro WebSocket URL
ws_url = "wss://ws-feed.pro.coinbase.com"
coinbase_symbol = config.symbols_mapping.get(symbol, symbol.replace('/', '-'))
# Subscribe message for level2 order book updates
subscribe_message = {
"type": "subscribe",
"product_ids": [coinbase_symbol],
"channels": ["level2"]
}
logger.info(f"Connecting to Coinbase order book stream for {symbol}")
async with websockets_connect(ws_url) as websocket:
# Send subscription
await websocket.send(json.dumps(subscribe_message))
logger.info(f"Subscribed to Coinbase level2 for {coinbase_symbol}")
async for message in websocket:
if not self.is_streaming:
break
try:
data = json.loads(message)
await self._process_coinbase_orderbook(symbol, data)
except json.JSONDecodeError as e:
logger.error(f"Error parsing Coinbase message: {e}")
except Exception as e:
logger.error(f"Error processing Coinbase orderbook: {e}")
except Exception as e:
logger.error(f"Coinbase order book stream error for {symbol}: {e}")
finally:
logger.info(f"Disconnected from Coinbase order book stream for {symbol}")
async def _stream_kraken_orderbook(self, symbol: str, config: ExchangeConfig):
"""Stream Kraken order book data (placeholder implementation)"""
"""Stream Kraken order book data via WebSocket"""
try:
logger.info(f"Kraken streaming for {symbol} not yet implemented")
await asyncio.sleep(60) # Sleep to prevent spam
import json
if websockets is None or websockets_connect is None:
raise ImportError("websockets module not available")
# Kraken WebSocket URL
ws_url = "wss://ws.kraken.com"
kraken_symbol = config.symbols_mapping.get(symbol, symbol.replace('/', ''))
# Subscribe message for book updates
subscribe_message = {
"event": "subscribe",
"pair": [kraken_symbol],
"subscription": {"name": "book", "depth": 25}
}
logger.info(f"Connecting to Kraken order book stream for {symbol}")
async with websockets_connect(ws_url) as websocket:
# Send subscription
await websocket.send(json.dumps(subscribe_message))
logger.info(f"Subscribed to Kraken book for {kraken_symbol}")
async for message in websocket:
if not self.is_streaming:
break
try:
data = json.loads(message)
await self._process_kraken_orderbook(symbol, data)
except json.JSONDecodeError as e:
logger.error(f"Error parsing Kraken message: {e}")
except Exception as e:
logger.error(f"Error processing Kraken orderbook: {e}")
except Exception as e:
logger.error(f"Error streaming Kraken order book for {symbol}: {e}")
logger.error(f"Kraken order book stream error for {symbol}: {e}")
finally:
logger.info(f"Disconnected from Kraken order book stream for {symbol}")
async def _stream_huobi_orderbook(self, symbol: str, config: ExchangeConfig):
"""Stream Huobi order book data (placeholder implementation)"""

File diff suppressed because it is too large Load Diff

View File

@ -6,17 +6,18 @@ III. models we currently use (architecture is expandable with easy adaption to n
- cnn price prediction model - uses calculated multilevel pivot points and historical price data to predict the next pivot point for each level.
- DQN RL model outputs trade signals
- transformer model outputs price prediction
- COB RL model outputs trade signals - it is trained on cob (cached all COB data for period of time not just current order book. it should be a 2d matrix 1s aggregated ) and some indicators cummulative cob imbalance for different timeframes.
- COB RL model outputs trade signals - it is trained on cob (cached all COB data for period of time not just current order book. it should be a 2d matrix 1s aggregated ) and some indicators cummulative cob imbalance for different timeframes. we get COB snapshots every couple hundred miliseconds and we cache and aggregate them to have a COB history. 1d matrix from the API to 2d amtrix as model inputs. as both raw ticks and 1s averaged.
- decision model - it is trained on price prediction and trade signals to learn the effectiveness of the other models in contribute to succeessful prediction. outputs the final trade signal.
IV. by default all models take full current data frames available in the orchestrator on inference as base data - different aspects of the data are updated at different rates. main data frame includes 5 price charts
class UniversalDataAdapter:
- 1s 1m 1h ETH charts and ETH and BTC ticks. orchestrator can use and extend the UniversalDataAdapter class to add new data sources and data types.
- - cob models are different and they get fast realtime raw dob data ticks and should be agile to inference and procude outputs but yet able to learn.
V. hardware. we use GPU if available for training and inference for optimised performance.
V. Training and hardware.
- we should load the models in a way that we do a back propagation and other model specificic training at realtime as training examples emerge from the realtime data we process. we will save only the best examples (the realtime data dumps we feed to the models) so we can cold start other models if we change the architecture. i
- we use GPU if available for training and inference for optimised performance.

File diff suppressed because it is too large Load Diff

View File

@ -23,7 +23,7 @@ class RewardCalculator:
self.trade_timestamps = []
self.frequency_threshold = 10 # Trades per minute threshold for penalty
self.max_frequency_penalty = 0.05
def record_pnl(self, pnl):
"""Record P&L for risk adjustment calculations"""
self.trade_pnls.append(pnl)
@ -36,7 +36,7 @@ class RewardCalculator:
self.trade_timestamps.append(time())
if len(self.trade_timestamps) > 100:
self.trade_timestamps.pop(0)
def _calculate_frequency_penalty(self):
"""Calculate penalty for high-frequency trading"""
if len(self.trade_timestamps) < 2:
@ -47,9 +47,9 @@ class RewardCalculator:
trades_per_minute = (len(self.trade_timestamps) / time_span) * 60
if trades_per_minute > self.frequency_threshold:
penalty = min(self.max_frequency_penalty, (trades_per_minute - self.frequency_threshold) * 0.001)
return penalty
return penalty
return 0.0
def _calculate_risk_adjustment(self, reward):
"""Adjust rewards based on risk (simple Sharpe ratio implementation)"""
if len(self.trade_pnls) < 5:
@ -62,7 +62,7 @@ class RewardCalculator:
sharpe = mean_return / std_return
adjustment_factor = np.clip(1.0 + 0.5 * sharpe, 0.5, 2.0)
return reward * adjustment_factor
def _calculate_holding_reward(self, position_held_time, price_change):
"""Calculate reward for holding a position"""
base_holding_reward = 0.0005 * (position_held_time / 60.0)

View File

@ -504,13 +504,23 @@ class CleanTradingDashboard:
def update_cob_data(n):
"""Update COB data displays with real order book ladders and cumulative stats"""
try:
# Update less frequently to reduce flickering
if n % self.update_batch_interval != 0:
raise PreventUpdate
# COB data is critical - update every second (no batching)
# if n % self.update_batch_interval != 0:
# raise PreventUpdate
eth_snapshot = self._get_cob_snapshot('ETH/USDT')
btc_snapshot = self._get_cob_snapshot('BTC/USDT')
# Debug: Log COB data availability
if n % 5 == 0: # Log every 5 seconds to avoid spam
logger.info(f"COB Update #{n}: ETH snapshot: {eth_snapshot is not None}, BTC snapshot: {btc_snapshot is not None}")
if hasattr(self, 'latest_cob_data'):
eth_data_time = self.cob_last_update.get('ETH/USDT', 0) if hasattr(self, 'cob_last_update') else 0
btc_data_time = self.cob_last_update.get('BTC/USDT', 0) if hasattr(self, 'cob_last_update') else 0
import time
current_time = time.time()
logger.info(f"COB Data Age: ETH: {current_time - eth_data_time:.1f}s, BTC: {current_time - btc_data_time:.1f}s")
eth_imbalance_stats = self._calculate_cumulative_imbalance('ETH/USDT')
btc_imbalance_stats = self._calculate_cumulative_imbalance('BTC/USDT')
@ -1155,27 +1165,11 @@ class CleanTradingDashboard:
def _add_cob_rl_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
"""Add COB_RL microstructure predictions as diamond markers"""
try:
# Get recent COB_RL predictions (simulated for now since model is FRESH)
current_time = datetime.now()
current_price = self._get_current_price(symbol) or 3500.0
# Get real COB_RL predictions from orchestrator or enhanced training system
cob_predictions = self._get_real_cob_rl_predictions(symbol)
# Generate sample COB_RL predictions for visualization
cob_predictions = []
for i in range(10): # Generate 10 sample predictions over last 5 minutes
pred_time = current_time - timedelta(minutes=i * 0.5)
price_variation = (i % 3 - 1) * 2.0 # Small price variations
# Simulate COB_RL predictions based on microstructure analysis
direction = (i % 3) # 0=DOWN, 1=SIDEWAYS, 2=UP
confidence = 0.65 + (i % 4) * 0.08 # Varying confidence
cob_predictions.append({
'timestamp': pred_time,
'direction': direction,
'confidence': confidence,
'price': current_price + price_variation,
'microstructure_signal': ['SELL_PRESSURE', 'BALANCED', 'BUY_PRESSURE'][direction]
})
if not cob_predictions:
return # No real predictions to display
# Separate predictions by direction
up_predictions = [p for p in cob_predictions if p['direction'] == 2]
@ -1346,6 +1340,61 @@ class CleanTradingDashboard:
except Exception as e:
logger.debug(f"Error adding prediction accuracy feedback to chart: {e}")
def _get_real_cob_rl_predictions(self, symbol: str) -> List[Dict]:
"""Get real COB RL predictions from the model"""
try:
cob_predictions = []
# Get predictions from enhanced training system
if hasattr(self, 'enhanced_training_system') and self.enhanced_training_system:
if hasattr(self.enhanced_training_system, 'get_prediction_summary'):
summary = self.enhanced_training_system.get_prediction_summary(symbol)
if summary and 'cob_rl_predictions' in summary:
raw_predictions = summary['cob_rl_predictions'][-10:] # Last 10 predictions
for pred in raw_predictions:
if 'timestamp' in pred and 'direction' in pred:
cob_predictions.append({
'timestamp': pred['timestamp'],
'direction': pred['direction'],
'confidence': pred.get('confidence', 0.5),
'price': pred.get('price', self._get_current_price(symbol) or 3500.0),
'microstructure_signal': pred.get('signal', ['SELL_PRESSURE', 'BALANCED', 'BUY_PRESSURE'][pred['direction']])
})
# Fallback to orchestrator COB RL agent predictions
if not cob_predictions and self.orchestrator:
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
agent = self.orchestrator.cob_rl_agent
# Check if agent has recent predictions stored
if hasattr(agent, 'recent_predictions'):
for pred in agent.recent_predictions[-10:]:
cob_predictions.append({
'timestamp': pred.get('timestamp', datetime.now()),
'direction': pred.get('action', 1), # 0=SELL, 1=HOLD, 2=BUY
'confidence': pred.get('confidence', 0.5),
'price': pred.get('price', self._get_current_price(symbol) or 3500.0),
'microstructure_signal': ['SELL_PRESSURE', 'BALANCED', 'BUY_PRESSURE'][pred.get('action', 1)]
})
# Alternative: Try getting predictions from RL agent (DQN can handle COB features)
elif hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
agent = self.orchestrator.rl_agent
if hasattr(agent, 'recent_predictions'):
for pred in agent.recent_predictions[-10:]:
cob_predictions.append({
'timestamp': pred.get('timestamp', datetime.now()),
'direction': pred.get('action', 1),
'confidence': pred.get('confidence', 0.5),
'price': pred.get('price', self._get_current_price(symbol) or 3500.0),
'microstructure_signal': ['SELL_PRESSURE', 'BALANCED', 'BUY_PRESSURE'][pred.get('action', 1)]
})
return cob_predictions
except Exception as e:
logger.debug(f"Error getting real COB RL predictions: {e}")
return []
def _get_recent_dqn_predictions(self, symbol: str) -> List[Dict]:
"""Get recent DQN predictions from orchestrator with sample generation"""
try:
@ -5202,12 +5251,24 @@ class CleanTradingDashboard:
logger.error(f"Error updating session metrics: {e}")
def _start_actual_training_if_needed(self):
"""Connect to centralized training system in orchestrator (following architecture)"""
"""Connect to centralized training system in orchestrator and start training"""
try:
if not self.orchestrator:
logger.warning("No orchestrator available for training connection")
return
logger.info("DASHBOARD: Connected to orchestrator's centralized training system")
# Actually start the orchestrator's enhanced training system
if hasattr(self.orchestrator, 'start_enhanced_training'):
training_started = self.orchestrator.start_enhanced_training()
if training_started:
logger.info("TRAINING: Orchestrator enhanced training system started successfully")
else:
logger.warning("TRAINING: Failed to start orchestrator enhanced training system")
else:
logger.warning("TRAINING: Orchestrator does not have enhanced training system")
# Dashboard only displays training status - actual training happens in orchestrator
# Training is centralized in the orchestrator as per architecture design
except Exception as e:
@ -5971,30 +6032,7 @@ class CleanTradingDashboard:
cob_rl_agent = self.orchestrator.cob_rl_agent
if not cob_rl_agent:
# Create a simple checkpoint to prevent recreation if no agent available
try:
from utils.checkpoint_manager import save_checkpoint
checkpoint_data = {
'model_state_dict': {},
'training_samples': len(market_data),
'cob_features_processed': True
}
performance_metrics = {
'loss': 0.356,
'training_samples': len(market_data),
'model_parameters': 0
}
metadata = save_checkpoint(
model=checkpoint_data,
model_name="cob_rl",
model_type="cob_rl",
performance_metrics=performance_metrics,
training_metadata={'cob_data_processed': True}
)
if metadata:
logger.info(f"COB RL placeholder checkpoint saved: {metadata.checkpoint_id}")
except Exception as e:
logger.error(f"Error saving COB RL placeholder checkpoint: {e}")
logger.debug("No COB RL agent available for training")
return
# Perform actual COB RL training

View File

@ -286,11 +286,11 @@ class DashboardComponentManager:
if hasattr(cob_snapshot, 'stats'):
# Old format with stats attribute
stats = cob_snapshot.stats
mid_price = stats.get('mid_price', 0)
spread_bps = stats.get('spread_bps', 0)
imbalance = stats.get('imbalance', 0)
bids = getattr(cob_snapshot, 'consolidated_bids', [])
asks = getattr(cob_snapshot, 'consolidated_asks', [])
mid_price = stats.get('mid_price', 0)
spread_bps = stats.get('spread_bps', 0)
imbalance = stats.get('imbalance', 0)
bids = getattr(cob_snapshot, 'consolidated_bids', [])
asks = getattr(cob_snapshot, 'consolidated_asks', [])
else:
# New COBSnapshot format with direct attributes
mid_price = getattr(cob_snapshot, 'volume_weighted_mid', 0)
@ -421,10 +421,10 @@ class DashboardComponentManager:
volume_usd = order.total_volume_usd
else:
# Dictionary format (legacy)
price = order.get('price', 0)
# Handle both old format (size) and new format (total_size)
size = order.get('total_size', order.get('size', 0))
volume_usd = order.get('total_volume_usd', size * price)
price = order.get('price', 0)
# Handle both old format (size) and new format (total_size)
size = order.get('total_size', order.get('size', 0))
volume_usd = order.get('total_volume_usd', size * price)
if price > 0:
bucket_key = round(price / bucket_size) * bucket_size