3 Commits

Author SHA1 Message Date
d17af5ca4b inference data storage 2025-07-24 15:31:57 +03:00
fa07265a16 wip training 2025-07-24 15:27:32 +03:00
b3edd21f1b cnn training stats on dash 2025-07-24 14:28:28 +03:00
9 changed files with 767 additions and 29 deletions

View File

@ -140,7 +140,7 @@ Training:
### 4. Orchestrator
The Orchestrator serves as the central coordination hub of the multi-modal trading system, responsible for data subscription management, model inference coordination, output storage, and training pipeline orchestration.
The Orchestrator serves as the central coordination hub of the multi-modal trading system, responsible for data subscription management, model inference coordination, output storage, training pipeline orchestration, and inference-training feedback loop management.
#### Key Classes and Interfaces
@ -245,6 +245,47 @@ The Orchestrator coordinates training for all models by managing the prediction-
- checkpoint manager has capability to ensure only top 5 to 10 best checkpoints are stored for each model deleting the least performant ones. it stores metadata along the CPs to decide the performance
- we automatically load the best CP at startup if we have stored ones
##### 5. Inference Data Validation and Storage
The Orchestrator implements comprehensive inference data validation and persistent storage:
**Input Data Validation**:
- Validates complete OHLCV dataframes for all required timeframes before inference
- Checks input data dimensions against model requirements
- Logs missing components and prevents prediction on incomplete data
- Raises validation errors with specific details about expected vs actual dimensions
**Inference History Storage**:
- Stores complete input data packages with each prediction in persistent storage
- Includes timestamp, symbol, input features, prediction outputs, confidence scores, and model internal states
- Maintains compressed storage to minimize footprint while preserving accessibility
- Implements efficient query mechanisms by symbol, timeframe, and date range
**Storage Management**:
- Applies configurable retention policies to manage storage limits
- Archives or removes oldest entries when limits are reached
- Prioritizes keeping most recent and valuable training examples during storage pressure
- Provides data completeness metrics and validation results in logs
##### 6. Inference-Training Feedback Loop
The Orchestrator manages the continuous learning cycle through inference-training feedback:
**Prediction Outcome Evaluation**:
- Evaluates prediction accuracy against actual price movements after sufficient time has passed
- Creates training examples using stored inference data paired with actual market outcomes
- Feeds prediction-result pairs back to respective models for learning
**Adaptive Learning Signals**:
- Provides positive reinforcement signals for accurate predictions
- Delivers corrective training signals for inaccurate predictions to help models learn from mistakes
- Retrieves last inference data for each model to compare predictions against actual outcomes
**Continuous Improvement Tracking**:
- Tracks and reports accuracy improvements or degradations over time
- Monitors model learning progress through the feedback loop
- Alerts administrators when data flow issues are detected with specific error details and remediation suggestions
##### 5. Decision Making and Trading Actions
Beyond coordination, the Orchestrator makes final trading decisions:

View File

@ -131,3 +131,45 @@ The Multi-Modal Trading System is an advanced algorithmic trading platform that
6. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all trading executors.
7. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all risk management components.
8. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all dashboard components.
### Requirement 9: Model Inference Data Validation and Storage
**User Story:** As a trading system developer, I want to ensure that all model predictions include complete input data validation and persistent storage, so that I can verify models receive correct inputs and track their performance over time.
#### Acceptance Criteria
1. WHEN a model makes a prediction THEN the system SHALL validate that the input data contains complete OHLCV dataframes for all required timeframes
2. WHEN input data is incomplete THEN the system SHALL log the missing components and SHALL NOT proceed with prediction
3. WHEN input validation passes THEN the system SHALL store the complete input data package with the prediction in persistent storage
4. IF input data dimensions are incorrect THEN the system SHALL raise a validation error with specific details about expected vs actual dimensions
5. WHEN a model completes inference THEN the system SHALL store the complete input data, model outputs, confidence scores, and metadata in a persistent inference history
6. WHEN storing inference data THEN the system SHALL include timestamp, symbol, input features, prediction outputs, and model internal states
7. IF inference history storage fails THEN the system SHALL log the error and continue operation without breaking the prediction flow
### Requirement 10: Inference-Training Feedback Loop
**User Story:** As a machine learning engineer, I want the system to automatically train models using their previous inference data compared to actual market outcomes, so that models continuously improve their accuracy through real-world feedback.
#### Acceptance Criteria
1. WHEN sufficient time has passed after a prediction THEN the system SHALL evaluate the prediction accuracy against actual price movements
2. WHEN a prediction outcome is determined THEN the system SHALL create a training example using the stored inference data and actual outcome
3. WHEN training examples are created THEN the system SHALL feed them back to the respective models for learning
4. IF the prediction was accurate THEN the system SHALL reinforce the model's decision pathway through positive training signals
5. IF the prediction was inaccurate THEN the system SHALL provide corrective training signals to help the model learn from mistakes
6. WHEN the system needs training data THEN it SHALL retrieve the last inference data for each model to compare predictions against actual market outcomes
7. WHEN models are trained on inference feedback THEN the system SHALL track and report accuracy improvements or degradations over time
### Requirement 11: Inference History Management and Monitoring
**User Story:** As a system administrator, I want comprehensive logging and monitoring of the inference-training feedback loop with configurable retention policies, so that I can track model learning progress and manage storage efficiently.
#### Acceptance Criteria
1. WHEN inference data is stored THEN the system SHALL log the storage operation with data completeness metrics and validation results
2. WHEN training occurs based on previous inference THEN the system SHALL log the training outcome and model performance changes
3. WHEN the system detects data flow issues THEN it SHALL alert administrators with specific error details and suggested remediation
4. WHEN inference history reaches configured limits THEN the system SHALL archive or remove oldest entries based on retention policy
5. WHEN storing inference data THEN the system SHALL compress data to minimize storage footprint while maintaining accessibility
6. WHEN retrieving historical inference data THEN the system SHALL provide efficient query mechanisms by symbol, timeframe, and date range
7. IF storage space is critically low THEN the system SHALL prioritize keeping the most recent and most valuable training examples

View File

@ -135,6 +135,9 @@
- Add thread-safe access to multi-rate data streams
- _Requirements: 4.1, 1.6, 8.5_
- [ ] 4.2. Implement model inference coordination
- Create ModelInferenceCoordinator class
- Trigger model inference based on data availability and requirements
@ -176,6 +179,85 @@
- Provide model performance monitoring and alerting
- _Requirements: 4.6, 8.2, 8.3_
## Model Inference Data Validation and Storage
- [x] 5. Implement comprehensive inference data validation system
- Create InferenceDataValidator class for input validation
- Validate complete OHLCV dataframes for all required timeframes
- Check input data dimensions against model requirements
- Log missing components and prevent prediction on incomplete data
- _Requirements: 9.1, 9.2, 9.3, 9.4_
- [ ] 5.1. Implement input data validation for all models
- Create validation methods for CNN, RL, and future model inputs
- Validate OHLCV data completeness (300 frames for 1s, 1m, 1h, 1d)
- Validate COB data structure (±20 buckets, MA calculations)
- Raise specific validation errors with expected vs actual dimensions
- Ensure validation occurs before any model inference
- _Requirements: 9.1, 9.4_
- [ ] 5.2. Implement persistent inference history storage
- Create InferenceHistoryStore class for persistent storage
- Store complete input data packages with each prediction
- Include timestamp, symbol, input features, prediction outputs, confidence scores
- Store model internal states for cross-model feeding
- Implement compressed storage to minimize footprint
- _Requirements: 9.5, 9.6_
- [ ] 5.3. Implement inference history query and retrieval system
- Create efficient query mechanisms by symbol, timeframe, and date range
- Implement data retrieval for training pipeline consumption
- Add data completeness metrics and validation results in storage
- Handle storage failures gracefully without breaking prediction flow
- _Requirements: 9.7, 11.6_
## Inference-Training Feedback Loop Implementation
- [ ] 6. Implement prediction outcome evaluation system
- Create PredictionOutcomeEvaluator class
- Evaluate prediction accuracy against actual price movements
- Create training examples using stored inference data and actual outcomes
- Feed prediction-result pairs back to respective models
- _Requirements: 10.1, 10.2, 10.3_
- [ ] 6.1. Implement adaptive learning signal generation
- Create positive reinforcement signals for accurate predictions
- Generate corrective training signals for inaccurate predictions
- Retrieve last inference data for each model for outcome comparison
- Implement model-specific learning signal formats
- _Requirements: 10.4, 10.5, 10.6_
- [ ] 6.2. Implement continuous improvement tracking
- Track and report accuracy improvements/degradations over time
- Monitor model learning progress through feedback loop
- Create performance metrics for inference-training effectiveness
- Generate alerts for learning regression or stagnation
- _Requirements: 10.7_
## Inference History Management and Monitoring
- [ ] 7. Implement comprehensive inference logging and monitoring
- Create InferenceMonitor class for logging and alerting
- Log inference data storage operations with completeness metrics
- Log training outcomes and model performance changes
- Alert administrators on data flow issues with specific error details
- _Requirements: 11.1, 11.2, 11.3_
- [ ] 7.1. Implement configurable retention policies
- Create RetentionPolicyManager class
- Archive or remove oldest entries when limits are reached
- Prioritize keeping most recent and valuable training examples
- Implement storage space monitoring and alerts
- _Requirements: 11.4, 11.7_
- [ ] 7.2. Implement efficient historical data management
- Compress inference data to minimize storage footprint
- Maintain accessibility for training and analysis
- Implement efficient query mechanisms for historical analysis
- Add data archival and restoration capabilities
- _Requirements: 11.5, 11.6_
## Trading Executor Implementation
- [ ] 5. Design and implement the trading executor

View File

@ -6,6 +6,14 @@ system:
log_level: "INFO" # DEBUG, INFO, WARNING, ERROR
session_timeout: 3600 # Session timeout in seconds
# Cold Start Mode Configuration
cold_start:
enabled: true # Enable cold start mode logic
inference_interval: 0.5 # Inference interval (seconds) during cold start
training_interval: 2 # Training interval (seconds) during cold start
heavy_adjustments: true # Allow more aggressive parameter/training adjustments
log_cold_start: true # Log when in cold start mode
# Exchange Configuration
exchanges:
primary: "bybit" # Primary exchange: mexc, deribit, binance, bybit

View File

@ -25,8 +25,8 @@ import math
from collections import defaultdict
from .multi_exchange_cob_provider import MultiExchangeCOBProvider, COBSnapshot, ConsolidatedOrderBookLevel
from .data_provider import DataProvider, MarketTick
from .enhanced_cob_websocket import EnhancedCOBWebSocket
# Import DataProvider and MarketTick only when needed to avoid circular import
logger = logging.getLogger(__name__)
@ -35,7 +35,7 @@ class COBIntegration:
Integration layer for Multi-Exchange COB data with gogo2 trading system
"""
def __init__(self, data_provider: Optional[DataProvider] = None, symbols: Optional[List[str]] = None):
def __init__(self, data_provider: Optional['DataProvider'] = None, symbols: Optional[List[str]] = None):
"""
Initialize COB Integration

View File

@ -124,6 +124,15 @@ class Config:
'epochs': 100,
'validation_split': 0.2,
'early_stopping_patience': 10
},
'cold_start': {
'enabled': True,
'min_ticks': 100,
'min_candles': 100,
'inference_interval': 0.5,
'training_interval': 2,
'heavy_adjustments': True,
'log_cold_start': True
}
}
@ -210,6 +219,19 @@ class Config:
'early_stopping_patience': self._config.get('training', {}).get('early_stopping_patience', 10)
}
@property
def cold_start(self) -> Dict[str, Any]:
"""Get cold start mode settings"""
return self._config.get('cold_start', {
'enabled': True,
'min_ticks': 100,
'min_candles': 100,
'inference_interval': 0.5,
'training_interval': 2,
'heavy_adjustments': True,
'log_cold_start': True
})
def get(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key with optional default"""
return self._config.get(key, default)

View File

@ -30,12 +30,14 @@ from dataclasses import dataclass, field
import ta
from threading import Thread, Lock
from collections import deque
import math
from .config import get_config
from .tick_aggregator import RealTimeTickAggregator, RawTick, OHLCVBar
from .cnn_monitor import log_cnn_prediction
from .williams_market_structure import WilliamsMarketStructure, PivotPoint, TrendLevel
from .enhanced_cob_websocket import EnhancedCOBWebSocket, get_enhanced_cob_websocket
from .cob_integration import COBIntegration
logger = logging.getLogger(__name__)
@ -203,6 +205,7 @@ class DataProvider:
self._load_all_pivot_bounds()
# Centralized data collection for models and dashboard
self.cob_integration = COBIntegration(data_provider=self, symbols=self.symbols)
self.cob_data_cache: Dict[str, deque] = {} # COB data for models
self.training_data_cache: Dict[str, deque] = {} # Training data for models
self.model_data_subscribers: Dict[str, List[Callable]] = {} # Model-specific data callbacks
@ -229,6 +232,11 @@ class DataProvider:
self.training_data_collection_active = False
self.training_data_thread = None
# Price-level bucketing
self.bucketed_cob_data: Dict[str, Dict] = {}
self.bucket_sizes = [1, 10] # $1 and $10 buckets
self.bucketed_cob_callbacks: Dict[int, List[Callable]] = {size: [] for size in self.bucket_sizes}
logger.info(f"DataProvider initialized for symbols: {self.symbols}")
logger.info(f"Timeframes: {self.timeframes}")
logger.info("Centralized data distribution enabled")
@ -242,6 +250,21 @@ class DataProvider:
self.retry_delay = 60 # 1 minute retry delay for 451 errors
self.max_retries = 3
# Start COB integration
self.start_cob_integration()
def start_cob_integration(self):
"""Starts the COB integration in a background thread."""
cob_thread = Thread(target=self._run_cob_integration, daemon=True)
cob_thread.start()
def _run_cob_integration(self):
"""Runs the asyncio event loop for COB integration."""
try:
asyncio.run(self.cob_integration.start())
except Exception as e:
logger.error(f"Error running COB integration: {e}")
def _ensure_datetime_index(self, df: pd.DataFrame) -> pd.DataFrame:
"""Ensure dataframe has proper datetime index"""
if df is None or df.empty:
@ -1397,30 +1420,14 @@ class DataProvider:
logger.warning(f"Error saving cache for {symbol} {timeframe}: {e}")
async def start_real_time_streaming(self):
"""Start real-time data streaming using Enhanced COB WebSocket"""
"""Start real-time data streaming using COBIntegration"""
if self.is_streaming:
logger.warning("Real-time streaming already active")
return
self.is_streaming = True
logger.info("Starting Enhanced COB WebSocket streaming")
try:
# Initialize Enhanced COB WebSocket
self.enhanced_cob_websocket = await get_enhanced_cob_websocket(
symbols=self.symbols,
dashboard_callback=self._on_websocket_status_update
)
# Add COB data callback
self.enhanced_cob_websocket.add_cob_callback(self._on_enhanced_cob_data)
logger.info("Enhanced COB WebSocket streaming started successfully")
except Exception as e:
logger.error(f"Error starting Enhanced COB WebSocket: {e}")
# Fallback to old WebSocket method
await self._start_fallback_websocket_streaming()
logger.info("Starting real-time streaming via COBIntegration")
# COBIntegration is started in the constructor
async def stop_real_time_streaming(self):
"""Stop real-time data streaming"""
@ -1430,6 +1437,14 @@ class DataProvider:
logger.info("Stopping Enhanced COB WebSocket streaming")
self.is_streaming = False
# Stop COB Integration
if self.cob_integration:
try:
await self.cob_integration.stop()
logger.info("COB Integration stopped")
except Exception as e:
logger.error(f"Error stopping COB Integration: {e}")
# Stop Enhanced COB WebSocket
if self.enhanced_cob_websocket:
try:
@ -1453,6 +1468,7 @@ class DataProvider:
async def _on_enhanced_cob_data(self, symbol: str, cob_data: Dict):
"""Handle COB data from Enhanced WebSocket"""
try:
# This method will now be called by COBIntegration
# Ensure cob_websocket_data is initialized
if not hasattr(self, 'cob_websocket_data'):
self.cob_websocket_data = {}
@ -1460,6 +1476,9 @@ class DataProvider:
# Store the latest COB data
self.cob_websocket_data[symbol] = cob_data
# Trigger bucketing
self._update_price_buckets(symbol, cob_data)
# Ensure cob_data_cache is initialized
if not hasattr(self, 'cob_data_cache'):
self.cob_data_cache = {}
@ -4159,3 +4178,61 @@ class DataProvider:
}
return summary
def _update_price_buckets(self, symbol: str, cob_data: Dict):
"""Update price-level buckets based on new COB data."""
try:
bids = cob_data.get('bids', [])
asks = cob_data.get('asks', [])
for size in self.bucket_sizes:
bid_buckets = self._calculate_buckets(bids, size)
ask_buckets = self._calculate_buckets(asks, size)
bucketed_data = {
'symbol': symbol,
'timestamp': datetime.now(),
'bucket_size': size,
'bids': bid_buckets,
'asks': ask_buckets
}
if symbol not in self.bucketed_cob_data:
self.bucketed_cob_data[symbol] = {}
self.bucketed_cob_data[symbol][size] = bucketed_data
# Distribute to subscribers
self._distribute_bucketed_data(symbol, size, bucketed_data)
except Exception as e:
logger.error(f"Error updating price buckets for {symbol}: {e}")
def _calculate_buckets(self, levels: List[Dict], bucket_size: int) -> Dict[float, float]:
"""Calculates aggregated volume for price buckets."""
buckets = {}
for level in levels:
price = level.get('price', 0)
volume = level.get('volume', 0)
if price > 0 and volume > 0:
bucket = math.floor(price / bucket_size) * bucket_size
if bucket not in buckets:
buckets[bucket] = 0
buckets[bucket] += volume
return buckets
def subscribe_to_bucketed_cob(self, bucket_size: int, callback: Callable):
"""Subscribe to bucketed COB data."""
if bucket_size in self.bucketed_cob_callbacks:
self.bucketed_cob_callbacks[bucket_size].append(callback)
logger.info(f"New subscriber for ${bucket_size} bucketed COB data.")
else:
logger.warning(f"Bucket size {bucket_size} not supported.")
def _distribute_bucketed_data(self, symbol: str, bucket_size: int, data: Dict):
"""Distribute bucketed data to subscribers."""
if bucket_size in self.bucketed_cob_callbacks:
for callback in self.bucketed_cob_callbacks[bucket_size]:
try:
callback(symbol, data)
except Exception as e:
logger.error(f"Error in bucketed COB callback: {e}")

View File

@ -27,6 +27,8 @@ import shutil
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from pathlib import Path
from .config import get_config
from .data_provider import DataProvider
@ -202,9 +204,18 @@ class TradingOrchestrator:
# Training tracking
self.last_trained_symbols: Dict[str, datetime] = {}
# INFERENCE DATA STORAGE - Store model inputs and outputs for training
self.inference_history: Dict[str, deque] = {} # {symbol: deque of inference records}
self.max_inference_history = 1000 # Keep last 1000 inference records per symbol
# Initialize inference history for each symbol
for symbol in self.symbols:
self.inference_history[symbol] = deque(maxlen=self.max_inference_history)
# ENHANCED: Real-time Training System Integration
self.enhanced_training_system = None # Will be set to EnhancedRealtimeTrainingSystem if available
self.training_enabled: bool = enhanced_rl_training and ENHANCED_TRAINING_AVAILABLE
# Enable training by default - don't depend on external training system
self.training_enabled: bool = enhanced_rl_training
logger.info("Enhanced TradingOrchestrator initialized with full ML capabilities")
logger.info(f"Enhanced RL training: {enhanced_rl_training}")
@ -1023,34 +1034,489 @@ class TradingOrchestrator:
return None
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
"""Get predictions from all registered models"""
"""Get predictions from all registered models with input data storage"""
predictions = []
current_time = datetime.now()
# Collect input data for all models
input_data = await self._collect_model_input_data(symbol)
for model_name, model in self.model_registry.models.items():
try:
prediction = None
model_input = None
if isinstance(model, CNNModelInterface):
# Get CNN predictions for each timeframe
cnn_predictions = await self._get_cnn_predictions(model, symbol)
predictions.extend(cnn_predictions)
# Store input data for CNN
model_input = input_data.get('cnn_input')
elif isinstance(model, RLAgentInterface):
# Get RL prediction
rl_prediction = await self._get_rl_prediction(model, symbol)
if rl_prediction:
predictions.append(rl_prediction)
prediction = rl_prediction
# Store input data for RL
model_input = input_data.get('rl_input')
else:
# Generic model interface
generic_prediction = await self._get_generic_prediction(model, symbol)
if generic_prediction:
predictions.append(generic_prediction)
prediction = generic_prediction
# Store input data for generic model
model_input = input_data.get('generic_input')
# Store inference data for training
if prediction and model_input is not None:
self._store_inference_data(symbol, model_name, model_input, prediction, current_time)
except Exception as e:
logger.error(f"Error getting prediction from {model_name}: {e}")
continue
# Trigger training based on previous inference data
await self._trigger_model_training(symbol)
return predictions
async def _collect_model_input_data(self, symbol: str) -> Dict[str, Any]:
"""Collect comprehensive input data for all models"""
try:
input_data = {}
# Get current market data from data provider
current_price = self.data_provider.get_current_price(symbol)
# Collect OHLCV data for multiple timeframes
ohlcv_data = {}
timeframes = ['1s', '1m', '1h', '1d']
for tf in timeframes:
df = self.data_provider.get_historical_data(symbol, tf, limit=300)
if df is not None and not df.empty:
ohlcv_data[tf] = df
# Collect COB data if available
cob_data = self.get_cob_snapshot(symbol)
# Collect technical indicators
technical_indicators = {}
if '1h' in ohlcv_data:
df = ohlcv_data['1h']
if len(df) > 20:
technical_indicators['sma_20'] = df['close'].rolling(20).mean().iloc[-1]
technical_indicators['rsi'] = self._calculate_rsi(df['close'])
# Prepare CNN input
cnn_input = self._prepare_cnn_input_data(ohlcv_data, cob_data, technical_indicators)
# Prepare RL input
rl_input = self._prepare_rl_input_data(ohlcv_data, cob_data, technical_indicators)
# Prepare generic input
generic_input = {
'symbol': symbol,
'current_price': current_price,
'ohlcv_data': ohlcv_data,
'cob_data': cob_data,
'technical_indicators': technical_indicators
}
input_data = {
'cnn_input': cnn_input,
'rl_input': rl_input,
'generic_input': generic_input,
'timestamp': datetime.now(),
'symbol': symbol
}
return input_data
except Exception as e:
logger.error(f"Error collecting model input data for {symbol}: {e}")
return {}
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
"""Prepare standardized input data for CNN models"""
try:
# Create feature matrix from OHLCV data
features = []
# Add OHLCV features for each timeframe
for tf in ['1s', '1m', '1h', '1d']:
if tf in ohlcv_data and not ohlcv_data[tf].empty:
df = ohlcv_data[tf].tail(50) # Last 50 bars
features.extend([
df['close'].pct_change().fillna(0).values,
df['volume'].values / df['volume'].max() if df['volume'].max() > 0 else np.zeros(len(df))
])
# Add technical indicators
for key, value in technical_indicators.items():
if not np.isnan(value):
features.append([value])
# Flatten and pad/truncate to standard size
if features:
feature_array = np.concatenate([np.array(f).flatten() for f in features])
# Pad or truncate to 300 features
if len(feature_array) < 300:
feature_array = np.pad(feature_array, (0, 300 - len(feature_array)), 'constant')
else:
feature_array = feature_array[:300]
return feature_array.reshape(1, -1)
else:
return np.zeros((1, 300))
except Exception as e:
logger.error(f"Error preparing CNN input data: {e}")
return np.zeros((1, 300))
def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
"""Prepare standardized input data for RL models"""
try:
# Create state representation
state_features = []
# Add price and volume features
if '1m' in ohlcv_data and not ohlcv_data['1m'].empty:
df = ohlcv_data['1m'].tail(20)
state_features.extend([
df['close'].pct_change().fillna(0).values,
df['volume'].pct_change().fillna(0).values,
(df['high'] - df['low']) / df['close'] # Volatility proxy
])
# Add technical indicators
for key, value in technical_indicators.items():
if not np.isnan(value):
state_features.append(value)
# Flatten and standardize size
if state_features:
state_array = np.concatenate([np.array(f).flatten() for f in state_features])
# Pad or truncate to expected RL state size
expected_size = 100 # Adjust based on your RL model
if len(state_array) < expected_size:
state_array = np.pad(state_array, (0, expected_size - len(state_array)), 'constant')
else:
state_array = state_array[:expected_size]
return state_array
else:
return np.zeros(100)
except Exception as e:
logger.error(f"Error preparing RL input data: {e}")
return np.zeros(100)
def _store_inference_data(self, symbol: str, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime):
"""Store comprehensive inference data for future training with persistent storage"""
try:
# Get current market context for complete replay capability
current_price = self.data_provider.get_current_price(symbol)
# Create comprehensive inference record with ALL data needed for model replay
inference_record = {
'timestamp': timestamp,
'symbol': symbol,
'model_name': model_name,
'current_price': current_price,
# Complete model input data
'model_input': {
'raw_input': model_input,
'input_shape': model_input.shape if hasattr(model_input, 'shape') else None,
'input_type': str(type(model_input))
},
# Complete prediction data
'prediction': {
'action': prediction.action,
'confidence': prediction.confidence,
'probabilities': prediction.probabilities,
'timeframe': prediction.timeframe
},
# Market context at prediction time
'market_context': {
'price': current_price,
'timestamp': timestamp.isoformat(),
'symbol': symbol
},
# Model metadata
'metadata': {
'model_metadata': prediction.metadata or {},
'orchestrator_state': {
'confidence_threshold': self.confidence_threshold,
'training_enabled': self.training_enabled
}
},
# Training outcome (will be filled later)
'training_outcome': None,
'outcome_evaluated': False
}
# Store in memory (inference history)
if symbol in self.inference_history:
self.inference_history[symbol].append(inference_record)
logger.debug(f"Stored inference data for {model_name} on {symbol}")
# Persistent storage to disk (for long-term training data)
self._save_inference_to_disk(inference_record)
except Exception as e:
logger.error(f"Error storing inference data: {e}")
def _save_inference_to_disk(self, inference_record: Dict):
"""Save inference record to persistent storage"""
try:
# Create inference data directory
inference_dir = Path("training_data/inference_history")
inference_dir.mkdir(parents=True, exist_ok=True)
# Create filename with timestamp and model name
timestamp_str = inference_record['timestamp'].strftime('%Y%m%d_%H%M%S')
filename = f"{inference_record['symbol']}_{inference_record['model_name']}_{timestamp_str}.json"
filepath = inference_dir / filename
# Convert numpy arrays to lists for JSON serialization
serializable_record = self._make_json_serializable(inference_record)
# Save to JSON file
with open(filepath, 'w') as f:
json.dump(serializable_record, f, indent=2)
logger.debug(f"Saved inference record to disk: {filepath}")
except Exception as e:
logger.error(f"Error saving inference to disk: {e}")
def _make_json_serializable(self, obj):
"""Convert object to JSON-serializable format"""
if isinstance(obj, dict):
return {k: self._make_json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._make_json_serializable(item) for item in obj]
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, (np.integer, np.floating)):
return obj.item()
elif isinstance(obj, datetime):
return obj.isoformat()
else:
return obj
def load_inference_history_from_disk(self, symbol: str, days_back: int = 7) -> List[Dict]:
"""Load inference history from disk for training replay"""
try:
inference_dir = Path("training_data/inference_history")
if not inference_dir.exists():
return []
# Get files for the symbol from the last N days
cutoff_date = datetime.now() - timedelta(days=days_back)
inference_records = []
for filepath in inference_dir.glob(f"{symbol}_*.json"):
try:
# Extract timestamp from filename
filename_parts = filepath.stem.split('_')
if len(filename_parts) >= 3:
timestamp_str = f"{filename_parts[-2]}_{filename_parts[-1]}"
file_timestamp = datetime.strptime(timestamp_str, '%Y%m%d_%H%M%S')
if file_timestamp >= cutoff_date:
with open(filepath, 'r') as f:
record = json.load(f)
inference_records.append(record)
except Exception as e:
logger.warning(f"Error loading inference file {filepath}: {e}")
continue
# Sort by timestamp
inference_records.sort(key=lambda x: x['timestamp'])
logger.info(f"Loaded {len(inference_records)} inference records for {symbol} from disk")
return inference_records
except Exception as e:
logger.error(f"Error loading inference history from disk: {e}")
return []
def get_model_training_data(self, model_name: str, symbol: str = None) -> List[Dict]:
"""Get training data for a specific model"""
try:
training_data = []
# Get from memory first
if symbol:
symbols_to_check = [symbol]
else:
symbols_to_check = self.symbols
for sym in symbols_to_check:
if sym in self.inference_history:
for record in self.inference_history[sym]:
if record['model_name'] == model_name:
training_data.append(record)
# Also load from disk for more comprehensive training data
for sym in symbols_to_check:
disk_records = self.load_inference_history_from_disk(sym)
for record in disk_records:
if record['model_name'] == model_name:
training_data.append(record)
# Remove duplicates and sort by timestamp
seen_timestamps = set()
unique_data = []
for record in training_data:
timestamp_key = f"{record['timestamp']}_{record['symbol']}"
if timestamp_key not in seen_timestamps:
seen_timestamps.add(timestamp_key)
unique_data.append(record)
unique_data.sort(key=lambda x: x['timestamp'])
logger.info(f"Retrieved {len(unique_data)} training records for {model_name}")
return unique_data
except Exception as e:
logger.error(f"Error getting model training data: {e}")
return []
async def _trigger_model_training(self, symbol: str):
"""Trigger training for models based on previous inference data"""
try:
if not self.training_enabled or symbol not in self.inference_history:
return
# Get recent inference records
recent_records = list(self.inference_history[symbol])
if len(recent_records) < 2:
return # Need at least 2 records to compare
# Get current price for outcome evaluation
current_price = self.data_provider.get_current_price(symbol)
if current_price is None:
return
# Process records that are old enough to evaluate outcomes
cutoff_time = datetime.now() - timedelta(minutes=5) # 5 minutes ago
for record in recent_records:
if record['timestamp'] < cutoff_time:
await self._evaluate_and_train_on_record(record, current_price)
except Exception as e:
logger.error(f"Error triggering model training for {symbol}: {e}")
async def _evaluate_and_train_on_record(self, record: Dict, current_price: float):
"""Evaluate prediction outcome and train model"""
try:
model_name = record['model_name']
prediction = record['prediction']
timestamp = record['timestamp']
# Calculate price change since prediction
# This is a simplified outcome evaluation - you might want to make it more sophisticated
time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes
# Get historical price at prediction time (simplified)
symbol = record['symbol']
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
if historical_data is None or historical_data.empty:
return
# Find closest price to prediction timestamp
prediction_price = historical_data['close'].iloc[-1] # Simplified
price_change_pct = (current_price - prediction_price) / prediction_price * 100
# Determine if prediction was correct
predicted_action = prediction['action']
was_correct = False
if predicted_action == 'BUY' and price_change_pct > 0.1: # Price went up
was_correct = True
elif predicted_action == 'SELL' and price_change_pct < -0.1: # Price went down
was_correct = True
elif predicted_action == 'HOLD' and abs(price_change_pct) < 0.1: # Price stayed stable
was_correct = True
# Update model performance tracking
if model_name not in self.model_performance:
self.model_performance[model_name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
self.model_performance[model_name]['total'] += 1
if was_correct:
self.model_performance[model_name]['correct'] += 1
self.model_performance[model_name]['accuracy'] = (
self.model_performance[model_name]['correct'] /
self.model_performance[model_name]['total']
)
# Train the specific model based on outcome
await self._train_model_on_outcome(record, was_correct, price_change_pct)
logger.debug(f"Evaluated {model_name} prediction: {'' if was_correct else ''} "
f"({prediction['action']}, {price_change_pct:.2f}% change)")
except Exception as e:
logger.error(f"Error evaluating and training on record: {e}")
async def _train_model_on_outcome(self, record: Dict, was_correct: bool, price_change_pct: float):
"""Train specific model based on prediction outcome"""
try:
model_name = record['model_name']
model_input = record['model_input']
prediction = record['prediction']
# Create training signal based on outcome
reward = 1.0 if was_correct else -0.5
# Train RL models
if 'dqn' in model_name.lower() and self.rl_agent:
if hasattr(self.rl_agent, 'add_experience'):
action_idx = ['SELL', 'HOLD', 'BUY'].index(prediction['action'])
self.rl_agent.add_experience(
state=model_input,
action=action_idx,
reward=reward,
next_state=model_input, # Simplified
done=True
)
logger.debug(f"Added RL training experience: reward={reward}")
# Train CNN models
elif 'cnn' in model_name.lower() and self.cnn_model:
if hasattr(self.cnn_model, 'train_on_outcome'):
target = 1 if was_correct else 0
self.cnn_model.train_on_outcome(model_input, target)
logger.debug(f"Trained CNN on outcome: target={target}")
except Exception as e:
logger.error(f"Error training model on outcome: {e}")
def _calculate_rsi(self, prices: pd.Series, period: int = 14) -> float:
"""Calculate RSI indicator"""
try:
delta = prices.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
rsi = 100 - (100 / (1 + rs))
return rsi.iloc[-1] if not rsi.empty else 50.0
except:
return 50.0
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
"""Get predictions from CNN model for all timeframes with enhanced COB features"""
predictions = []

View File

@ -5859,7 +5859,7 @@ class CleanTradingDashboard:
from core.data_models import OHLCVBar
# Get data from data provider
df = self.data_provider.get_candles(symbol, timeframe)
df = self.data_provider.get_historical_data(symbol, timeframe)
if df is None or len(df) == 0:
return []
@ -6106,7 +6106,7 @@ class CleanTradingDashboard:
def _get_recent_price_history(self, symbol: str, count: int) -> List[float]:
"""Get recent price history for reward calculation"""
try:
df = self.data_provider.get_candles(symbol, '1s')
df = self.data_provider.get_historical_data(symbol, '1s')
if df is None or len(df) == 0:
return []