wip training
This commit is contained in:
@ -140,7 +140,7 @@ Training:
|
||||
|
||||
### 4. Orchestrator
|
||||
|
||||
The Orchestrator serves as the central coordination hub of the multi-modal trading system, responsible for data subscription management, model inference coordination, output storage, and training pipeline orchestration.
|
||||
The Orchestrator serves as the central coordination hub of the multi-modal trading system, responsible for data subscription management, model inference coordination, output storage, training pipeline orchestration, and inference-training feedback loop management.
|
||||
|
||||
#### Key Classes and Interfaces
|
||||
|
||||
@ -245,6 +245,47 @@ The Orchestrator coordinates training for all models by managing the prediction-
|
||||
- checkpoint manager has capability to ensure only top 5 to 10 best checkpoints are stored for each model deleting the least performant ones. it stores metadata along the CPs to decide the performance
|
||||
- we automatically load the best CP at startup if we have stored ones
|
||||
|
||||
##### 5. Inference Data Validation and Storage
|
||||
|
||||
The Orchestrator implements comprehensive inference data validation and persistent storage:
|
||||
|
||||
**Input Data Validation**:
|
||||
- Validates complete OHLCV dataframes for all required timeframes before inference
|
||||
- Checks input data dimensions against model requirements
|
||||
- Logs missing components and prevents prediction on incomplete data
|
||||
- Raises validation errors with specific details about expected vs actual dimensions
|
||||
|
||||
**Inference History Storage**:
|
||||
- Stores complete input data packages with each prediction in persistent storage
|
||||
- Includes timestamp, symbol, input features, prediction outputs, confidence scores, and model internal states
|
||||
- Maintains compressed storage to minimize footprint while preserving accessibility
|
||||
- Implements efficient query mechanisms by symbol, timeframe, and date range
|
||||
|
||||
**Storage Management**:
|
||||
- Applies configurable retention policies to manage storage limits
|
||||
- Archives or removes oldest entries when limits are reached
|
||||
- Prioritizes keeping most recent and valuable training examples during storage pressure
|
||||
- Provides data completeness metrics and validation results in logs
|
||||
|
||||
##### 6. Inference-Training Feedback Loop
|
||||
|
||||
The Orchestrator manages the continuous learning cycle through inference-training feedback:
|
||||
|
||||
**Prediction Outcome Evaluation**:
|
||||
- Evaluates prediction accuracy against actual price movements after sufficient time has passed
|
||||
- Creates training examples using stored inference data paired with actual market outcomes
|
||||
- Feeds prediction-result pairs back to respective models for learning
|
||||
|
||||
**Adaptive Learning Signals**:
|
||||
- Provides positive reinforcement signals for accurate predictions
|
||||
- Delivers corrective training signals for inaccurate predictions to help models learn from mistakes
|
||||
- Retrieves last inference data for each model to compare predictions against actual outcomes
|
||||
|
||||
**Continuous Improvement Tracking**:
|
||||
- Tracks and reports accuracy improvements or degradations over time
|
||||
- Monitors model learning progress through the feedback loop
|
||||
- Alerts administrators when data flow issues are detected with specific error details and remediation suggestions
|
||||
|
||||
##### 5. Decision Making and Trading Actions
|
||||
|
||||
Beyond coordination, the Orchestrator makes final trading decisions:
|
||||
|
@ -130,4 +130,46 @@ The Multi-Modal Trading System is an advanced algorithmic trading platform that
|
||||
5. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all data providers.
|
||||
6. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all trading executors.
|
||||
7. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all risk management components.
|
||||
8. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all dashboard components.
|
||||
8. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all dashboard components.
|
||||
|
||||
### Requirement 9: Model Inference Data Validation and Storage
|
||||
|
||||
**User Story:** As a trading system developer, I want to ensure that all model predictions include complete input data validation and persistent storage, so that I can verify models receive correct inputs and track their performance over time.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN a model makes a prediction THEN the system SHALL validate that the input data contains complete OHLCV dataframes for all required timeframes
|
||||
2. WHEN input data is incomplete THEN the system SHALL log the missing components and SHALL NOT proceed with prediction
|
||||
3. WHEN input validation passes THEN the system SHALL store the complete input data package with the prediction in persistent storage
|
||||
4. IF input data dimensions are incorrect THEN the system SHALL raise a validation error with specific details about expected vs actual dimensions
|
||||
5. WHEN a model completes inference THEN the system SHALL store the complete input data, model outputs, confidence scores, and metadata in a persistent inference history
|
||||
6. WHEN storing inference data THEN the system SHALL include timestamp, symbol, input features, prediction outputs, and model internal states
|
||||
7. IF inference history storage fails THEN the system SHALL log the error and continue operation without breaking the prediction flow
|
||||
|
||||
### Requirement 10: Inference-Training Feedback Loop
|
||||
|
||||
**User Story:** As a machine learning engineer, I want the system to automatically train models using their previous inference data compared to actual market outcomes, so that models continuously improve their accuracy through real-world feedback.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN sufficient time has passed after a prediction THEN the system SHALL evaluate the prediction accuracy against actual price movements
|
||||
2. WHEN a prediction outcome is determined THEN the system SHALL create a training example using the stored inference data and actual outcome
|
||||
3. WHEN training examples are created THEN the system SHALL feed them back to the respective models for learning
|
||||
4. IF the prediction was accurate THEN the system SHALL reinforce the model's decision pathway through positive training signals
|
||||
5. IF the prediction was inaccurate THEN the system SHALL provide corrective training signals to help the model learn from mistakes
|
||||
6. WHEN the system needs training data THEN it SHALL retrieve the last inference data for each model to compare predictions against actual market outcomes
|
||||
7. WHEN models are trained on inference feedback THEN the system SHALL track and report accuracy improvements or degradations over time
|
||||
|
||||
### Requirement 11: Inference History Management and Monitoring
|
||||
|
||||
**User Story:** As a system administrator, I want comprehensive logging and monitoring of the inference-training feedback loop with configurable retention policies, so that I can track model learning progress and manage storage efficiently.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN inference data is stored THEN the system SHALL log the storage operation with data completeness metrics and validation results
|
||||
2. WHEN training occurs based on previous inference THEN the system SHALL log the training outcome and model performance changes
|
||||
3. WHEN the system detects data flow issues THEN it SHALL alert administrators with specific error details and suggested remediation
|
||||
4. WHEN inference history reaches configured limits THEN the system SHALL archive or remove oldest entries based on retention policy
|
||||
5. WHEN storing inference data THEN the system SHALL compress data to minimize storage footprint while maintaining accessibility
|
||||
6. WHEN retrieving historical inference data THEN the system SHALL provide efficient query mechanisms by symbol, timeframe, and date range
|
||||
7. IF storage space is critically low THEN the system SHALL prioritize keeping the most recent and most valuable training examples
|
@ -135,6 +135,9 @@
|
||||
- Add thread-safe access to multi-rate data streams
|
||||
- _Requirements: 4.1, 1.6, 8.5_
|
||||
|
||||
|
||||
|
||||
|
||||
- [ ] 4.2. Implement model inference coordination
|
||||
- Create ModelInferenceCoordinator class
|
||||
- Trigger model inference based on data availability and requirements
|
||||
@ -176,6 +179,84 @@
|
||||
- Provide model performance monitoring and alerting
|
||||
- _Requirements: 4.6, 8.2, 8.3_
|
||||
|
||||
## Model Inference Data Validation and Storage
|
||||
|
||||
- [ ] 5. Implement comprehensive inference data validation system
|
||||
- Create InferenceDataValidator class for input validation
|
||||
- Validate complete OHLCV dataframes for all required timeframes
|
||||
- Check input data dimensions against model requirements
|
||||
- Log missing components and prevent prediction on incomplete data
|
||||
- _Requirements: 9.1, 9.2, 9.3, 9.4_
|
||||
|
||||
- [ ] 5.1. Implement input data validation for all models
|
||||
- Create validation methods for CNN, RL, and future model inputs
|
||||
- Validate OHLCV data completeness (300 frames for 1s, 1m, 1h, 1d)
|
||||
- Validate COB data structure (±20 buckets, MA calculations)
|
||||
- Raise specific validation errors with expected vs actual dimensions
|
||||
- Ensure validation occurs before any model inference
|
||||
- _Requirements: 9.1, 9.4_
|
||||
|
||||
- [ ] 5.2. Implement persistent inference history storage
|
||||
- Create InferenceHistoryStore class for persistent storage
|
||||
- Store complete input data packages with each prediction
|
||||
- Include timestamp, symbol, input features, prediction outputs, confidence scores
|
||||
- Store model internal states for cross-model feeding
|
||||
- Implement compressed storage to minimize footprint
|
||||
- _Requirements: 9.5, 9.6_
|
||||
|
||||
- [ ] 5.3. Implement inference history query and retrieval system
|
||||
- Create efficient query mechanisms by symbol, timeframe, and date range
|
||||
- Implement data retrieval for training pipeline consumption
|
||||
- Add data completeness metrics and validation results in storage
|
||||
- Handle storage failures gracefully without breaking prediction flow
|
||||
- _Requirements: 9.7, 11.6_
|
||||
|
||||
## Inference-Training Feedback Loop Implementation
|
||||
|
||||
- [ ] 6. Implement prediction outcome evaluation system
|
||||
- Create PredictionOutcomeEvaluator class
|
||||
- Evaluate prediction accuracy against actual price movements
|
||||
- Create training examples using stored inference data and actual outcomes
|
||||
- Feed prediction-result pairs back to respective models
|
||||
- _Requirements: 10.1, 10.2, 10.3_
|
||||
|
||||
- [ ] 6.1. Implement adaptive learning signal generation
|
||||
- Create positive reinforcement signals for accurate predictions
|
||||
- Generate corrective training signals for inaccurate predictions
|
||||
- Retrieve last inference data for each model for outcome comparison
|
||||
- Implement model-specific learning signal formats
|
||||
- _Requirements: 10.4, 10.5, 10.6_
|
||||
|
||||
- [ ] 6.2. Implement continuous improvement tracking
|
||||
- Track and report accuracy improvements/degradations over time
|
||||
- Monitor model learning progress through feedback loop
|
||||
- Create performance metrics for inference-training effectiveness
|
||||
- Generate alerts for learning regression or stagnation
|
||||
- _Requirements: 10.7_
|
||||
|
||||
## Inference History Management and Monitoring
|
||||
|
||||
- [ ] 7. Implement comprehensive inference logging and monitoring
|
||||
- Create InferenceMonitor class for logging and alerting
|
||||
- Log inference data storage operations with completeness metrics
|
||||
- Log training outcomes and model performance changes
|
||||
- Alert administrators on data flow issues with specific error details
|
||||
- _Requirements: 11.1, 11.2, 11.3_
|
||||
|
||||
- [ ] 7.1. Implement configurable retention policies
|
||||
- Create RetentionPolicyManager class
|
||||
- Archive or remove oldest entries when limits are reached
|
||||
- Prioritize keeping most recent and valuable training examples
|
||||
- Implement storage space monitoring and alerts
|
||||
- _Requirements: 11.4, 11.7_
|
||||
|
||||
- [ ] 7.2. Implement efficient historical data management
|
||||
- Compress inference data to minimize storage footprint
|
||||
- Maintain accessibility for training and analysis
|
||||
- Implement efficient query mechanisms for historical analysis
|
||||
- Add data archival and restoration capabilities
|
||||
- _Requirements: 11.5, 11.6_
|
||||
|
||||
## Trading Executor Implementation
|
||||
|
||||
- [ ] 5. Design and implement the trading executor
|
||||
|
@ -27,6 +27,7 @@ import shutil
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import pandas as pd
|
||||
|
||||
from .config import get_config
|
||||
from .data_provider import DataProvider
|
||||
@ -202,9 +203,18 @@ class TradingOrchestrator:
|
||||
# Training tracking
|
||||
self.last_trained_symbols: Dict[str, datetime] = {}
|
||||
|
||||
# INFERENCE DATA STORAGE - Store model inputs and outputs for training
|
||||
self.inference_history: Dict[str, deque] = {} # {symbol: deque of inference records}
|
||||
self.max_inference_history = 1000 # Keep last 1000 inference records per symbol
|
||||
|
||||
# Initialize inference history for each symbol
|
||||
for symbol in self.symbols:
|
||||
self.inference_history[symbol] = deque(maxlen=self.max_inference_history)
|
||||
|
||||
# ENHANCED: Real-time Training System Integration
|
||||
self.enhanced_training_system = None # Will be set to EnhancedRealtimeTrainingSystem if available
|
||||
self.training_enabled: bool = enhanced_rl_training and ENHANCED_TRAINING_AVAILABLE
|
||||
# Enable training by default - don't depend on external training system
|
||||
self.training_enabled: bool = enhanced_rl_training
|
||||
|
||||
logger.info("Enhanced TradingOrchestrator initialized with full ML capabilities")
|
||||
logger.info(f"Enhanced RL training: {enhanced_rl_training}")
|
||||
@ -1023,34 +1033,409 @@ class TradingOrchestrator:
|
||||
return None
|
||||
|
||||
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
||||
"""Get predictions from all registered models"""
|
||||
"""Get predictions from all registered models with input data storage"""
|
||||
predictions = []
|
||||
current_time = datetime.now()
|
||||
|
||||
# Collect input data for all models
|
||||
input_data = await self._collect_model_input_data(symbol)
|
||||
|
||||
for model_name, model in self.model_registry.models.items():
|
||||
try:
|
||||
prediction = None
|
||||
model_input = None
|
||||
|
||||
if isinstance(model, CNNModelInterface):
|
||||
# Get CNN predictions for each timeframe
|
||||
cnn_predictions = await self._get_cnn_predictions(model, symbol)
|
||||
predictions.extend(cnn_predictions)
|
||||
# Store input data for CNN
|
||||
model_input = input_data.get('cnn_input')
|
||||
|
||||
elif isinstance(model, RLAgentInterface):
|
||||
# Get RL prediction
|
||||
rl_prediction = await self._get_rl_prediction(model, symbol)
|
||||
if rl_prediction:
|
||||
predictions.append(rl_prediction)
|
||||
prediction = rl_prediction
|
||||
# Store input data for RL
|
||||
model_input = input_data.get('rl_input')
|
||||
|
||||
else:
|
||||
# Generic model interface
|
||||
generic_prediction = await self._get_generic_prediction(model, symbol)
|
||||
if generic_prediction:
|
||||
predictions.append(generic_prediction)
|
||||
|
||||
prediction = generic_prediction
|
||||
# Store input data for generic model
|
||||
model_input = input_data.get('generic_input')
|
||||
|
||||
# Store inference data for training
|
||||
if prediction and model_input is not None:
|
||||
self._store_inference_data(symbol, model_name, model_input, prediction, current_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting prediction from {model_name}: {e}")
|
||||
continue
|
||||
|
||||
# Trigger training based on previous inference data
|
||||
await self._trigger_model_training(symbol)
|
||||
|
||||
return predictions
|
||||
|
||||
async def _collect_model_input_data(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Collect comprehensive input data for all models"""
|
||||
try:
|
||||
input_data = {}
|
||||
|
||||
# Get current market data from data provider
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
|
||||
# Collect OHLCV data for multiple timeframes
|
||||
ohlcv_data = {}
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
for tf in timeframes:
|
||||
df = self.data_provider.get_historical_data(symbol, tf, limit=300)
|
||||
if df is not None and not df.empty:
|
||||
ohlcv_data[tf] = df
|
||||
|
||||
# Collect COB data if available
|
||||
cob_data = self.get_cob_snapshot(symbol)
|
||||
|
||||
# Collect technical indicators
|
||||
technical_indicators = {}
|
||||
if '1h' in ohlcv_data:
|
||||
df = ohlcv_data['1h']
|
||||
if len(df) > 20:
|
||||
technical_indicators['sma_20'] = df['close'].rolling(20).mean().iloc[-1]
|
||||
technical_indicators['rsi'] = self._calculate_rsi(df['close'])
|
||||
|
||||
# Prepare CNN input
|
||||
cnn_input = self._prepare_cnn_input_data(ohlcv_data, cob_data, technical_indicators)
|
||||
|
||||
# Prepare RL input
|
||||
rl_input = self._prepare_rl_input_data(ohlcv_data, cob_data, technical_indicators)
|
||||
|
||||
# Prepare generic input
|
||||
generic_input = {
|
||||
'symbol': symbol,
|
||||
'current_price': current_price,
|
||||
'ohlcv_data': ohlcv_data,
|
||||
'cob_data': cob_data,
|
||||
'technical_indicators': technical_indicators
|
||||
}
|
||||
|
||||
input_data = {
|
||||
'cnn_input': cnn_input,
|
||||
'rl_input': rl_input,
|
||||
'generic_input': generic_input,
|
||||
'timestamp': datetime.now(),
|
||||
'symbol': symbol
|
||||
}
|
||||
|
||||
return input_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting model input data for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
|
||||
"""Prepare standardized input data for CNN models"""
|
||||
try:
|
||||
# Create feature matrix from OHLCV data
|
||||
features = []
|
||||
|
||||
# Add OHLCV features for each timeframe
|
||||
for tf in ['1s', '1m', '1h', '1d']:
|
||||
if tf in ohlcv_data and not ohlcv_data[tf].empty:
|
||||
df = ohlcv_data[tf].tail(50) # Last 50 bars
|
||||
features.extend([
|
||||
df['close'].pct_change().fillna(0).values,
|
||||
df['volume'].values / df['volume'].max() if df['volume'].max() > 0 else np.zeros(len(df))
|
||||
])
|
||||
|
||||
# Add technical indicators
|
||||
for key, value in technical_indicators.items():
|
||||
if not np.isnan(value):
|
||||
features.append([value])
|
||||
|
||||
# Flatten and pad/truncate to standard size
|
||||
if features:
|
||||
feature_array = np.concatenate([np.array(f).flatten() for f in features])
|
||||
# Pad or truncate to 300 features
|
||||
if len(feature_array) < 300:
|
||||
feature_array = np.pad(feature_array, (0, 300 - len(feature_array)), 'constant')
|
||||
else:
|
||||
feature_array = feature_array[:300]
|
||||
return feature_array.reshape(1, -1)
|
||||
else:
|
||||
return np.zeros((1, 300))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing CNN input data: {e}")
|
||||
return np.zeros((1, 300))
|
||||
|
||||
def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
|
||||
"""Prepare standardized input data for RL models"""
|
||||
try:
|
||||
# Create state representation
|
||||
state_features = []
|
||||
|
||||
# Add price and volume features
|
||||
if '1m' in ohlcv_data and not ohlcv_data['1m'].empty:
|
||||
df = ohlcv_data['1m'].tail(20)
|
||||
state_features.extend([
|
||||
df['close'].pct_change().fillna(0).values,
|
||||
df['volume'].pct_change().fillna(0).values,
|
||||
(df['high'] - df['low']) / df['close'] # Volatility proxy
|
||||
])
|
||||
|
||||
# Add technical indicators
|
||||
for key, value in technical_indicators.items():
|
||||
if not np.isnan(value):
|
||||
state_features.append(value)
|
||||
|
||||
# Flatten and standardize size
|
||||
if state_features:
|
||||
state_array = np.concatenate([np.array(f).flatten() for f in state_features])
|
||||
# Pad or truncate to expected RL state size
|
||||
expected_size = 100 # Adjust based on your RL model
|
||||
if len(state_array) < expected_size:
|
||||
state_array = np.pad(state_array, (0, expected_size - len(state_array)), 'constant')
|
||||
else:
|
||||
state_array = state_array[:expected_size]
|
||||
return state_array
|
||||
else:
|
||||
return np.zeros(100)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing RL input data: {e}")
|
||||
return np.zeros(100)
|
||||
|
||||
def _store_inference_data(self, symbol: str, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime):
|
||||
"""Store comprehensive inference data for future training with persistent storage"""
|
||||
try:
|
||||
# Get current market context for complete replay capability
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
|
||||
# Create comprehensive inference record with ALL data needed for model replay
|
||||
inference_record = {
|
||||
'timestamp': timestamp,
|
||||
'symbol': symbol,
|
||||
'model_name': model_name,
|
||||
'current_price': current_price,
|
||||
|
||||
# Complete model input data
|
||||
'model_input': {
|
||||
'raw_input': model_input,
|
||||
'input_shape': model_input.shape if hasattr(model_input, 'shape') else None,
|
||||
'input_type': str(type(model_input))
|
||||
},
|
||||
|
||||
# Complete prediction data
|
||||
'prediction': {
|
||||
'action': prediction.action,
|
||||
'confidence': prediction.confidence,
|
||||
'probabilities': prediction.probabilities,
|
||||
'timeframe': prediction.timeframe
|
||||
},
|
||||
|
||||
# Market context at prediction time
|
||||
'market_context': {
|
||||
'price': current_price,
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'symbol': symbol
|
||||
},
|
||||
|
||||
# Model metadata
|
||||
'metadata': {
|
||||
'model_metadata': prediction.metadata or {},
|
||||
'orchestrator_state': {
|
||||
'confidence_threshold': self.confidence_threshold,
|
||||
'training_enabled': self.training_enabled
|
||||
}
|
||||
},
|
||||
|
||||
# Training outcome (will be filled later)
|
||||
'training_outcome': None,
|
||||
'outcome_evaluated': False
|
||||
}
|
||||
|
||||
# Store in memory (inference history)
|
||||
if symbol in self.inference_history:
|
||||
self.inference_history[symbol].append(inference_record)
|
||||
logger.debug(f"Stored inference data for {model_name} on {symbol}")
|
||||
|
||||
# Persistent storage to disk (for long-term training data)
|
||||
self._save_inference_to_disk(inference_record)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing inference data: {e}")
|
||||
|
||||
def _save_inference_to_disk(self, inference_record: Dict):
|
||||
"""Save inference record to persistent storage"""
|
||||
try:
|
||||
# Create inference data directory
|
||||
inference_dir = Path("training_data/inference_history")
|
||||
inference_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create filename with timestamp and model name
|
||||
timestamp_str = inference_record['timestamp'].strftime('%Y%m%d_%H%M%S')
|
||||
filename = f"{inference_record['symbol']}_{inference_record['model_name']}_{timestamp_str}.json"
|
||||
filepath = inference_dir / filename
|
||||
|
||||
# Convert numpy arrays to lists for JSON serialization
|
||||
serializable_record = self._make_json_serializable(inference_record)
|
||||
|
||||
# Save to JSON file
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(serializable_record, f, indent=2)
|
||||
|
||||
logger.debug(f"Saved inference record to disk: {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving inference to disk: {e}")
|
||||
|
||||
def _make_json_serializable(self, obj):
|
||||
"""Convert object to JSON-serializable format"""
|
||||
if isinstance(obj, dict):
|
||||
return {k: self._make_json_serializable(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [self._make_json_serializable(item) for item in obj]
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif isinstance(obj, (np.integer, np.floating)):
|
||||
return obj.item()
|
||||
elif isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
else:
|
||||
return obj
|
||||
|
||||
async def _trigger_model_training(self, symbol: str):
|
||||
"""Trigger training for models based on previous inference data"""
|
||||
try:
|
||||
if not self.training_enabled or symbol not in self.inference_history:
|
||||
return
|
||||
|
||||
# Get recent inference records
|
||||
recent_records = list(self.inference_history[symbol])
|
||||
if len(recent_records) < 2:
|
||||
return # Need at least 2 records to compare
|
||||
|
||||
# Get current price for outcome evaluation
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
if current_price is None:
|
||||
return
|
||||
|
||||
# Process records that are old enough to evaluate outcomes
|
||||
cutoff_time = datetime.now() - timedelta(minutes=5) # 5 minutes ago
|
||||
|
||||
for record in recent_records:
|
||||
if record['timestamp'] < cutoff_time:
|
||||
await self._evaluate_and_train_on_record(record, current_price)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering model training for {symbol}: {e}")
|
||||
|
||||
async def _evaluate_and_train_on_record(self, record: Dict, current_price: float):
|
||||
"""Evaluate prediction outcome and train model"""
|
||||
try:
|
||||
model_name = record['model_name']
|
||||
prediction = record['prediction']
|
||||
timestamp = record['timestamp']
|
||||
|
||||
# Calculate price change since prediction
|
||||
# This is a simplified outcome evaluation - you might want to make it more sophisticated
|
||||
time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes
|
||||
|
||||
# Get historical price at prediction time (simplified)
|
||||
symbol = record['symbol']
|
||||
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||
if historical_data is None or historical_data.empty:
|
||||
return
|
||||
|
||||
# Find closest price to prediction timestamp
|
||||
prediction_price = historical_data['close'].iloc[-1] # Simplified
|
||||
price_change_pct = (current_price - prediction_price) / prediction_price * 100
|
||||
|
||||
# Determine if prediction was correct
|
||||
predicted_action = prediction['action']
|
||||
was_correct = False
|
||||
|
||||
if predicted_action == 'BUY' and price_change_pct > 0.1: # Price went up
|
||||
was_correct = True
|
||||
elif predicted_action == 'SELL' and price_change_pct < -0.1: # Price went down
|
||||
was_correct = True
|
||||
elif predicted_action == 'HOLD' and abs(price_change_pct) < 0.1: # Price stayed stable
|
||||
was_correct = True
|
||||
|
||||
# Update model performance tracking
|
||||
if model_name not in self.model_performance:
|
||||
self.model_performance[model_name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
|
||||
|
||||
self.model_performance[model_name]['total'] += 1
|
||||
if was_correct:
|
||||
self.model_performance[model_name]['correct'] += 1
|
||||
|
||||
self.model_performance[model_name]['accuracy'] = (
|
||||
self.model_performance[model_name]['correct'] /
|
||||
self.model_performance[model_name]['total']
|
||||
)
|
||||
|
||||
# Train the specific model based on outcome
|
||||
await self._train_model_on_outcome(record, was_correct, price_change_pct)
|
||||
|
||||
logger.debug(f"Evaluated {model_name} prediction: {'✓' if was_correct else '✗'} "
|
||||
f"({prediction['action']}, {price_change_pct:.2f}% change)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating and training on record: {e}")
|
||||
|
||||
async def _train_model_on_outcome(self, record: Dict, was_correct: bool, price_change_pct: float):
|
||||
"""Train specific model based on prediction outcome"""
|
||||
try:
|
||||
model_name = record['model_name']
|
||||
model_input = record['model_input']
|
||||
prediction = record['prediction']
|
||||
|
||||
# Create training signal based on outcome
|
||||
reward = 1.0 if was_correct else -0.5
|
||||
|
||||
# Train RL models
|
||||
if 'dqn' in model_name.lower() and self.rl_agent:
|
||||
if hasattr(self.rl_agent, 'add_experience'):
|
||||
action_idx = ['SELL', 'HOLD', 'BUY'].index(prediction['action'])
|
||||
self.rl_agent.add_experience(
|
||||
state=model_input,
|
||||
action=action_idx,
|
||||
reward=reward,
|
||||
next_state=model_input, # Simplified
|
||||
done=True
|
||||
)
|
||||
logger.debug(f"Added RL training experience: reward={reward}")
|
||||
|
||||
# Train CNN models
|
||||
elif 'cnn' in model_name.lower() and self.cnn_model:
|
||||
if hasattr(self.cnn_model, 'train_on_outcome'):
|
||||
target = 1 if was_correct else 0
|
||||
self.cnn_model.train_on_outcome(model_input, target)
|
||||
logger.debug(f"Trained CNN on outcome: target={target}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training model on outcome: {e}")
|
||||
|
||||
def _calculate_rsi(self, prices: pd.Series, period: int = 14) -> float:
|
||||
"""Calculate RSI indicator"""
|
||||
try:
|
||||
delta = prices.diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
||||
rs = gain / loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return rsi.iloc[-1] if not rsi.empty else 50.0
|
||||
except:
|
||||
return 50.0
|
||||
|
||||
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
|
||||
"""Get predictions from CNN model for all timeframes with enhanced COB features"""
|
||||
predictions = []
|
||||
|
Reference in New Issue
Block a user