wip training
This commit is contained in:
@ -140,7 +140,7 @@ Training:
|
|||||||
|
|
||||||
### 4. Orchestrator
|
### 4. Orchestrator
|
||||||
|
|
||||||
The Orchestrator serves as the central coordination hub of the multi-modal trading system, responsible for data subscription management, model inference coordination, output storage, and training pipeline orchestration.
|
The Orchestrator serves as the central coordination hub of the multi-modal trading system, responsible for data subscription management, model inference coordination, output storage, training pipeline orchestration, and inference-training feedback loop management.
|
||||||
|
|
||||||
#### Key Classes and Interfaces
|
#### Key Classes and Interfaces
|
||||||
|
|
||||||
@ -245,6 +245,47 @@ The Orchestrator coordinates training for all models by managing the prediction-
|
|||||||
- checkpoint manager has capability to ensure only top 5 to 10 best checkpoints are stored for each model deleting the least performant ones. it stores metadata along the CPs to decide the performance
|
- checkpoint manager has capability to ensure only top 5 to 10 best checkpoints are stored for each model deleting the least performant ones. it stores metadata along the CPs to decide the performance
|
||||||
- we automatically load the best CP at startup if we have stored ones
|
- we automatically load the best CP at startup if we have stored ones
|
||||||
|
|
||||||
|
##### 5. Inference Data Validation and Storage
|
||||||
|
|
||||||
|
The Orchestrator implements comprehensive inference data validation and persistent storage:
|
||||||
|
|
||||||
|
**Input Data Validation**:
|
||||||
|
- Validates complete OHLCV dataframes for all required timeframes before inference
|
||||||
|
- Checks input data dimensions against model requirements
|
||||||
|
- Logs missing components and prevents prediction on incomplete data
|
||||||
|
- Raises validation errors with specific details about expected vs actual dimensions
|
||||||
|
|
||||||
|
**Inference History Storage**:
|
||||||
|
- Stores complete input data packages with each prediction in persistent storage
|
||||||
|
- Includes timestamp, symbol, input features, prediction outputs, confidence scores, and model internal states
|
||||||
|
- Maintains compressed storage to minimize footprint while preserving accessibility
|
||||||
|
- Implements efficient query mechanisms by symbol, timeframe, and date range
|
||||||
|
|
||||||
|
**Storage Management**:
|
||||||
|
- Applies configurable retention policies to manage storage limits
|
||||||
|
- Archives or removes oldest entries when limits are reached
|
||||||
|
- Prioritizes keeping most recent and valuable training examples during storage pressure
|
||||||
|
- Provides data completeness metrics and validation results in logs
|
||||||
|
|
||||||
|
##### 6. Inference-Training Feedback Loop
|
||||||
|
|
||||||
|
The Orchestrator manages the continuous learning cycle through inference-training feedback:
|
||||||
|
|
||||||
|
**Prediction Outcome Evaluation**:
|
||||||
|
- Evaluates prediction accuracy against actual price movements after sufficient time has passed
|
||||||
|
- Creates training examples using stored inference data paired with actual market outcomes
|
||||||
|
- Feeds prediction-result pairs back to respective models for learning
|
||||||
|
|
||||||
|
**Adaptive Learning Signals**:
|
||||||
|
- Provides positive reinforcement signals for accurate predictions
|
||||||
|
- Delivers corrective training signals for inaccurate predictions to help models learn from mistakes
|
||||||
|
- Retrieves last inference data for each model to compare predictions against actual outcomes
|
||||||
|
|
||||||
|
**Continuous Improvement Tracking**:
|
||||||
|
- Tracks and reports accuracy improvements or degradations over time
|
||||||
|
- Monitors model learning progress through the feedback loop
|
||||||
|
- Alerts administrators when data flow issues are detected with specific error details and remediation suggestions
|
||||||
|
|
||||||
##### 5. Decision Making and Trading Actions
|
##### 5. Decision Making and Trading Actions
|
||||||
|
|
||||||
Beyond coordination, the Orchestrator makes final trading decisions:
|
Beyond coordination, the Orchestrator makes final trading decisions:
|
||||||
|
@ -131,3 +131,45 @@ The Multi-Modal Trading System is an advanced algorithmic trading platform that
|
|||||||
6. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all trading executors.
|
6. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all trading executors.
|
||||||
7. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all risk management components.
|
7. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all risk management components.
|
||||||
8. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all dashboard components.
|
8. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all dashboard components.
|
||||||
|
|
||||||
|
### Requirement 9: Model Inference Data Validation and Storage
|
||||||
|
|
||||||
|
**User Story:** As a trading system developer, I want to ensure that all model predictions include complete input data validation and persistent storage, so that I can verify models receive correct inputs and track their performance over time.
|
||||||
|
|
||||||
|
#### Acceptance Criteria
|
||||||
|
|
||||||
|
1. WHEN a model makes a prediction THEN the system SHALL validate that the input data contains complete OHLCV dataframes for all required timeframes
|
||||||
|
2. WHEN input data is incomplete THEN the system SHALL log the missing components and SHALL NOT proceed with prediction
|
||||||
|
3. WHEN input validation passes THEN the system SHALL store the complete input data package with the prediction in persistent storage
|
||||||
|
4. IF input data dimensions are incorrect THEN the system SHALL raise a validation error with specific details about expected vs actual dimensions
|
||||||
|
5. WHEN a model completes inference THEN the system SHALL store the complete input data, model outputs, confidence scores, and metadata in a persistent inference history
|
||||||
|
6. WHEN storing inference data THEN the system SHALL include timestamp, symbol, input features, prediction outputs, and model internal states
|
||||||
|
7. IF inference history storage fails THEN the system SHALL log the error and continue operation without breaking the prediction flow
|
||||||
|
|
||||||
|
### Requirement 10: Inference-Training Feedback Loop
|
||||||
|
|
||||||
|
**User Story:** As a machine learning engineer, I want the system to automatically train models using their previous inference data compared to actual market outcomes, so that models continuously improve their accuracy through real-world feedback.
|
||||||
|
|
||||||
|
#### Acceptance Criteria
|
||||||
|
|
||||||
|
1. WHEN sufficient time has passed after a prediction THEN the system SHALL evaluate the prediction accuracy against actual price movements
|
||||||
|
2. WHEN a prediction outcome is determined THEN the system SHALL create a training example using the stored inference data and actual outcome
|
||||||
|
3. WHEN training examples are created THEN the system SHALL feed them back to the respective models for learning
|
||||||
|
4. IF the prediction was accurate THEN the system SHALL reinforce the model's decision pathway through positive training signals
|
||||||
|
5. IF the prediction was inaccurate THEN the system SHALL provide corrective training signals to help the model learn from mistakes
|
||||||
|
6. WHEN the system needs training data THEN it SHALL retrieve the last inference data for each model to compare predictions against actual market outcomes
|
||||||
|
7. WHEN models are trained on inference feedback THEN the system SHALL track and report accuracy improvements or degradations over time
|
||||||
|
|
||||||
|
### Requirement 11: Inference History Management and Monitoring
|
||||||
|
|
||||||
|
**User Story:** As a system administrator, I want comprehensive logging and monitoring of the inference-training feedback loop with configurable retention policies, so that I can track model learning progress and manage storage efficiently.
|
||||||
|
|
||||||
|
#### Acceptance Criteria
|
||||||
|
|
||||||
|
1. WHEN inference data is stored THEN the system SHALL log the storage operation with data completeness metrics and validation results
|
||||||
|
2. WHEN training occurs based on previous inference THEN the system SHALL log the training outcome and model performance changes
|
||||||
|
3. WHEN the system detects data flow issues THEN it SHALL alert administrators with specific error details and suggested remediation
|
||||||
|
4. WHEN inference history reaches configured limits THEN the system SHALL archive or remove oldest entries based on retention policy
|
||||||
|
5. WHEN storing inference data THEN the system SHALL compress data to minimize storage footprint while maintaining accessibility
|
||||||
|
6. WHEN retrieving historical inference data THEN the system SHALL provide efficient query mechanisms by symbol, timeframe, and date range
|
||||||
|
7. IF storage space is critically low THEN the system SHALL prioritize keeping the most recent and most valuable training examples
|
@ -135,6 +135,9 @@
|
|||||||
- Add thread-safe access to multi-rate data streams
|
- Add thread-safe access to multi-rate data streams
|
||||||
- _Requirements: 4.1, 1.6, 8.5_
|
- _Requirements: 4.1, 1.6, 8.5_
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
- [ ] 4.2. Implement model inference coordination
|
- [ ] 4.2. Implement model inference coordination
|
||||||
- Create ModelInferenceCoordinator class
|
- Create ModelInferenceCoordinator class
|
||||||
- Trigger model inference based on data availability and requirements
|
- Trigger model inference based on data availability and requirements
|
||||||
@ -176,6 +179,84 @@
|
|||||||
- Provide model performance monitoring and alerting
|
- Provide model performance monitoring and alerting
|
||||||
- _Requirements: 4.6, 8.2, 8.3_
|
- _Requirements: 4.6, 8.2, 8.3_
|
||||||
|
|
||||||
|
## Model Inference Data Validation and Storage
|
||||||
|
|
||||||
|
- [ ] 5. Implement comprehensive inference data validation system
|
||||||
|
- Create InferenceDataValidator class for input validation
|
||||||
|
- Validate complete OHLCV dataframes for all required timeframes
|
||||||
|
- Check input data dimensions against model requirements
|
||||||
|
- Log missing components and prevent prediction on incomplete data
|
||||||
|
- _Requirements: 9.1, 9.2, 9.3, 9.4_
|
||||||
|
|
||||||
|
- [ ] 5.1. Implement input data validation for all models
|
||||||
|
- Create validation methods for CNN, RL, and future model inputs
|
||||||
|
- Validate OHLCV data completeness (300 frames for 1s, 1m, 1h, 1d)
|
||||||
|
- Validate COB data structure (±20 buckets, MA calculations)
|
||||||
|
- Raise specific validation errors with expected vs actual dimensions
|
||||||
|
- Ensure validation occurs before any model inference
|
||||||
|
- _Requirements: 9.1, 9.4_
|
||||||
|
|
||||||
|
- [ ] 5.2. Implement persistent inference history storage
|
||||||
|
- Create InferenceHistoryStore class for persistent storage
|
||||||
|
- Store complete input data packages with each prediction
|
||||||
|
- Include timestamp, symbol, input features, prediction outputs, confidence scores
|
||||||
|
- Store model internal states for cross-model feeding
|
||||||
|
- Implement compressed storage to minimize footprint
|
||||||
|
- _Requirements: 9.5, 9.6_
|
||||||
|
|
||||||
|
- [ ] 5.3. Implement inference history query and retrieval system
|
||||||
|
- Create efficient query mechanisms by symbol, timeframe, and date range
|
||||||
|
- Implement data retrieval for training pipeline consumption
|
||||||
|
- Add data completeness metrics and validation results in storage
|
||||||
|
- Handle storage failures gracefully without breaking prediction flow
|
||||||
|
- _Requirements: 9.7, 11.6_
|
||||||
|
|
||||||
|
## Inference-Training Feedback Loop Implementation
|
||||||
|
|
||||||
|
- [ ] 6. Implement prediction outcome evaluation system
|
||||||
|
- Create PredictionOutcomeEvaluator class
|
||||||
|
- Evaluate prediction accuracy against actual price movements
|
||||||
|
- Create training examples using stored inference data and actual outcomes
|
||||||
|
- Feed prediction-result pairs back to respective models
|
||||||
|
- _Requirements: 10.1, 10.2, 10.3_
|
||||||
|
|
||||||
|
- [ ] 6.1. Implement adaptive learning signal generation
|
||||||
|
- Create positive reinforcement signals for accurate predictions
|
||||||
|
- Generate corrective training signals for inaccurate predictions
|
||||||
|
- Retrieve last inference data for each model for outcome comparison
|
||||||
|
- Implement model-specific learning signal formats
|
||||||
|
- _Requirements: 10.4, 10.5, 10.6_
|
||||||
|
|
||||||
|
- [ ] 6.2. Implement continuous improvement tracking
|
||||||
|
- Track and report accuracy improvements/degradations over time
|
||||||
|
- Monitor model learning progress through feedback loop
|
||||||
|
- Create performance metrics for inference-training effectiveness
|
||||||
|
- Generate alerts for learning regression or stagnation
|
||||||
|
- _Requirements: 10.7_
|
||||||
|
|
||||||
|
## Inference History Management and Monitoring
|
||||||
|
|
||||||
|
- [ ] 7. Implement comprehensive inference logging and monitoring
|
||||||
|
- Create InferenceMonitor class for logging and alerting
|
||||||
|
- Log inference data storage operations with completeness metrics
|
||||||
|
- Log training outcomes and model performance changes
|
||||||
|
- Alert administrators on data flow issues with specific error details
|
||||||
|
- _Requirements: 11.1, 11.2, 11.3_
|
||||||
|
|
||||||
|
- [ ] 7.1. Implement configurable retention policies
|
||||||
|
- Create RetentionPolicyManager class
|
||||||
|
- Archive or remove oldest entries when limits are reached
|
||||||
|
- Prioritize keeping most recent and valuable training examples
|
||||||
|
- Implement storage space monitoring and alerts
|
||||||
|
- _Requirements: 11.4, 11.7_
|
||||||
|
|
||||||
|
- [ ] 7.2. Implement efficient historical data management
|
||||||
|
- Compress inference data to minimize storage footprint
|
||||||
|
- Maintain accessibility for training and analysis
|
||||||
|
- Implement efficient query mechanisms for historical analysis
|
||||||
|
- Add data archival and restoration capabilities
|
||||||
|
- _Requirements: 11.5, 11.6_
|
||||||
|
|
||||||
## Trading Executor Implementation
|
## Trading Executor Implementation
|
||||||
|
|
||||||
- [ ] 5. Design and implement the trading executor
|
- [ ] 5. Design and implement the trading executor
|
||||||
|
@ -27,6 +27,7 @@ import shutil
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
from .config import get_config
|
from .config import get_config
|
||||||
from .data_provider import DataProvider
|
from .data_provider import DataProvider
|
||||||
@ -202,9 +203,18 @@ class TradingOrchestrator:
|
|||||||
# Training tracking
|
# Training tracking
|
||||||
self.last_trained_symbols: Dict[str, datetime] = {}
|
self.last_trained_symbols: Dict[str, datetime] = {}
|
||||||
|
|
||||||
|
# INFERENCE DATA STORAGE - Store model inputs and outputs for training
|
||||||
|
self.inference_history: Dict[str, deque] = {} # {symbol: deque of inference records}
|
||||||
|
self.max_inference_history = 1000 # Keep last 1000 inference records per symbol
|
||||||
|
|
||||||
|
# Initialize inference history for each symbol
|
||||||
|
for symbol in self.symbols:
|
||||||
|
self.inference_history[symbol] = deque(maxlen=self.max_inference_history)
|
||||||
|
|
||||||
# ENHANCED: Real-time Training System Integration
|
# ENHANCED: Real-time Training System Integration
|
||||||
self.enhanced_training_system = None # Will be set to EnhancedRealtimeTrainingSystem if available
|
self.enhanced_training_system = None # Will be set to EnhancedRealtimeTrainingSystem if available
|
||||||
self.training_enabled: bool = enhanced_rl_training and ENHANCED_TRAINING_AVAILABLE
|
# Enable training by default - don't depend on external training system
|
||||||
|
self.training_enabled: bool = enhanced_rl_training
|
||||||
|
|
||||||
logger.info("Enhanced TradingOrchestrator initialized with full ML capabilities")
|
logger.info("Enhanced TradingOrchestrator initialized with full ML capabilities")
|
||||||
logger.info(f"Enhanced RL training: {enhanced_rl_training}")
|
logger.info(f"Enhanced RL training: {enhanced_rl_training}")
|
||||||
@ -1023,34 +1033,409 @@ class TradingOrchestrator:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
||||||
"""Get predictions from all registered models"""
|
"""Get predictions from all registered models with input data storage"""
|
||||||
predictions = []
|
predictions = []
|
||||||
|
current_time = datetime.now()
|
||||||
|
|
||||||
|
# Collect input data for all models
|
||||||
|
input_data = await self._collect_model_input_data(symbol)
|
||||||
|
|
||||||
for model_name, model in self.model_registry.models.items():
|
for model_name, model in self.model_registry.models.items():
|
||||||
try:
|
try:
|
||||||
|
prediction = None
|
||||||
|
model_input = None
|
||||||
|
|
||||||
if isinstance(model, CNNModelInterface):
|
if isinstance(model, CNNModelInterface):
|
||||||
# Get CNN predictions for each timeframe
|
# Get CNN predictions for each timeframe
|
||||||
cnn_predictions = await self._get_cnn_predictions(model, symbol)
|
cnn_predictions = await self._get_cnn_predictions(model, symbol)
|
||||||
predictions.extend(cnn_predictions)
|
predictions.extend(cnn_predictions)
|
||||||
|
# Store input data for CNN
|
||||||
|
model_input = input_data.get('cnn_input')
|
||||||
|
|
||||||
elif isinstance(model, RLAgentInterface):
|
elif isinstance(model, RLAgentInterface):
|
||||||
# Get RL prediction
|
# Get RL prediction
|
||||||
rl_prediction = await self._get_rl_prediction(model, symbol)
|
rl_prediction = await self._get_rl_prediction(model, symbol)
|
||||||
if rl_prediction:
|
if rl_prediction:
|
||||||
predictions.append(rl_prediction)
|
predictions.append(rl_prediction)
|
||||||
|
prediction = rl_prediction
|
||||||
|
# Store input data for RL
|
||||||
|
model_input = input_data.get('rl_input')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Generic model interface
|
# Generic model interface
|
||||||
generic_prediction = await self._get_generic_prediction(model, symbol)
|
generic_prediction = await self._get_generic_prediction(model, symbol)
|
||||||
if generic_prediction:
|
if generic_prediction:
|
||||||
predictions.append(generic_prediction)
|
predictions.append(generic_prediction)
|
||||||
|
prediction = generic_prediction
|
||||||
|
# Store input data for generic model
|
||||||
|
model_input = input_data.get('generic_input')
|
||||||
|
|
||||||
|
# Store inference data for training
|
||||||
|
if prediction and model_input is not None:
|
||||||
|
self._store_inference_data(symbol, model_name, model_input, prediction, current_time)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting prediction from {model_name}: {e}")
|
logger.error(f"Error getting prediction from {model_name}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Trigger training based on previous inference data
|
||||||
|
await self._trigger_model_training(symbol)
|
||||||
|
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
async def _collect_model_input_data(self, symbol: str) -> Dict[str, Any]:
|
||||||
|
"""Collect comprehensive input data for all models"""
|
||||||
|
try:
|
||||||
|
input_data = {}
|
||||||
|
|
||||||
|
# Get current market data from data provider
|
||||||
|
current_price = self.data_provider.get_current_price(symbol)
|
||||||
|
|
||||||
|
# Collect OHLCV data for multiple timeframes
|
||||||
|
ohlcv_data = {}
|
||||||
|
timeframes = ['1s', '1m', '1h', '1d']
|
||||||
|
for tf in timeframes:
|
||||||
|
df = self.data_provider.get_historical_data(symbol, tf, limit=300)
|
||||||
|
if df is not None and not df.empty:
|
||||||
|
ohlcv_data[tf] = df
|
||||||
|
|
||||||
|
# Collect COB data if available
|
||||||
|
cob_data = self.get_cob_snapshot(symbol)
|
||||||
|
|
||||||
|
# Collect technical indicators
|
||||||
|
technical_indicators = {}
|
||||||
|
if '1h' in ohlcv_data:
|
||||||
|
df = ohlcv_data['1h']
|
||||||
|
if len(df) > 20:
|
||||||
|
technical_indicators['sma_20'] = df['close'].rolling(20).mean().iloc[-1]
|
||||||
|
technical_indicators['rsi'] = self._calculate_rsi(df['close'])
|
||||||
|
|
||||||
|
# Prepare CNN input
|
||||||
|
cnn_input = self._prepare_cnn_input_data(ohlcv_data, cob_data, technical_indicators)
|
||||||
|
|
||||||
|
# Prepare RL input
|
||||||
|
rl_input = self._prepare_rl_input_data(ohlcv_data, cob_data, technical_indicators)
|
||||||
|
|
||||||
|
# Prepare generic input
|
||||||
|
generic_input = {
|
||||||
|
'symbol': symbol,
|
||||||
|
'current_price': current_price,
|
||||||
|
'ohlcv_data': ohlcv_data,
|
||||||
|
'cob_data': cob_data,
|
||||||
|
'technical_indicators': technical_indicators
|
||||||
|
}
|
||||||
|
|
||||||
|
input_data = {
|
||||||
|
'cnn_input': cnn_input,
|
||||||
|
'rl_input': rl_input,
|
||||||
|
'generic_input': generic_input,
|
||||||
|
'timestamp': datetime.now(),
|
||||||
|
'symbol': symbol
|
||||||
|
}
|
||||||
|
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error collecting model input data for {symbol}: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
|
||||||
|
"""Prepare standardized input data for CNN models"""
|
||||||
|
try:
|
||||||
|
# Create feature matrix from OHLCV data
|
||||||
|
features = []
|
||||||
|
|
||||||
|
# Add OHLCV features for each timeframe
|
||||||
|
for tf in ['1s', '1m', '1h', '1d']:
|
||||||
|
if tf in ohlcv_data and not ohlcv_data[tf].empty:
|
||||||
|
df = ohlcv_data[tf].tail(50) # Last 50 bars
|
||||||
|
features.extend([
|
||||||
|
df['close'].pct_change().fillna(0).values,
|
||||||
|
df['volume'].values / df['volume'].max() if df['volume'].max() > 0 else np.zeros(len(df))
|
||||||
|
])
|
||||||
|
|
||||||
|
# Add technical indicators
|
||||||
|
for key, value in technical_indicators.items():
|
||||||
|
if not np.isnan(value):
|
||||||
|
features.append([value])
|
||||||
|
|
||||||
|
# Flatten and pad/truncate to standard size
|
||||||
|
if features:
|
||||||
|
feature_array = np.concatenate([np.array(f).flatten() for f in features])
|
||||||
|
# Pad or truncate to 300 features
|
||||||
|
if len(feature_array) < 300:
|
||||||
|
feature_array = np.pad(feature_array, (0, 300 - len(feature_array)), 'constant')
|
||||||
|
else:
|
||||||
|
feature_array = feature_array[:300]
|
||||||
|
return feature_array.reshape(1, -1)
|
||||||
|
else:
|
||||||
|
return np.zeros((1, 300))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error preparing CNN input data: {e}")
|
||||||
|
return np.zeros((1, 300))
|
||||||
|
|
||||||
|
def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
|
||||||
|
"""Prepare standardized input data for RL models"""
|
||||||
|
try:
|
||||||
|
# Create state representation
|
||||||
|
state_features = []
|
||||||
|
|
||||||
|
# Add price and volume features
|
||||||
|
if '1m' in ohlcv_data and not ohlcv_data['1m'].empty:
|
||||||
|
df = ohlcv_data['1m'].tail(20)
|
||||||
|
state_features.extend([
|
||||||
|
df['close'].pct_change().fillna(0).values,
|
||||||
|
df['volume'].pct_change().fillna(0).values,
|
||||||
|
(df['high'] - df['low']) / df['close'] # Volatility proxy
|
||||||
|
])
|
||||||
|
|
||||||
|
# Add technical indicators
|
||||||
|
for key, value in technical_indicators.items():
|
||||||
|
if not np.isnan(value):
|
||||||
|
state_features.append(value)
|
||||||
|
|
||||||
|
# Flatten and standardize size
|
||||||
|
if state_features:
|
||||||
|
state_array = np.concatenate([np.array(f).flatten() for f in state_features])
|
||||||
|
# Pad or truncate to expected RL state size
|
||||||
|
expected_size = 100 # Adjust based on your RL model
|
||||||
|
if len(state_array) < expected_size:
|
||||||
|
state_array = np.pad(state_array, (0, expected_size - len(state_array)), 'constant')
|
||||||
|
else:
|
||||||
|
state_array = state_array[:expected_size]
|
||||||
|
return state_array
|
||||||
|
else:
|
||||||
|
return np.zeros(100)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error preparing RL input data: {e}")
|
||||||
|
return np.zeros(100)
|
||||||
|
|
||||||
|
def _store_inference_data(self, symbol: str, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime):
|
||||||
|
"""Store comprehensive inference data for future training with persistent storage"""
|
||||||
|
try:
|
||||||
|
# Get current market context for complete replay capability
|
||||||
|
current_price = self.data_provider.get_current_price(symbol)
|
||||||
|
|
||||||
|
# Create comprehensive inference record with ALL data needed for model replay
|
||||||
|
inference_record = {
|
||||||
|
'timestamp': timestamp,
|
||||||
|
'symbol': symbol,
|
||||||
|
'model_name': model_name,
|
||||||
|
'current_price': current_price,
|
||||||
|
|
||||||
|
# Complete model input data
|
||||||
|
'model_input': {
|
||||||
|
'raw_input': model_input,
|
||||||
|
'input_shape': model_input.shape if hasattr(model_input, 'shape') else None,
|
||||||
|
'input_type': str(type(model_input))
|
||||||
|
},
|
||||||
|
|
||||||
|
# Complete prediction data
|
||||||
|
'prediction': {
|
||||||
|
'action': prediction.action,
|
||||||
|
'confidence': prediction.confidence,
|
||||||
|
'probabilities': prediction.probabilities,
|
||||||
|
'timeframe': prediction.timeframe
|
||||||
|
},
|
||||||
|
|
||||||
|
# Market context at prediction time
|
||||||
|
'market_context': {
|
||||||
|
'price': current_price,
|
||||||
|
'timestamp': timestamp.isoformat(),
|
||||||
|
'symbol': symbol
|
||||||
|
},
|
||||||
|
|
||||||
|
# Model metadata
|
||||||
|
'metadata': {
|
||||||
|
'model_metadata': prediction.metadata or {},
|
||||||
|
'orchestrator_state': {
|
||||||
|
'confidence_threshold': self.confidence_threshold,
|
||||||
|
'training_enabled': self.training_enabled
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
# Training outcome (will be filled later)
|
||||||
|
'training_outcome': None,
|
||||||
|
'outcome_evaluated': False
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store in memory (inference history)
|
||||||
|
if symbol in self.inference_history:
|
||||||
|
self.inference_history[symbol].append(inference_record)
|
||||||
|
logger.debug(f"Stored inference data for {model_name} on {symbol}")
|
||||||
|
|
||||||
|
# Persistent storage to disk (for long-term training data)
|
||||||
|
self._save_inference_to_disk(inference_record)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error storing inference data: {e}")
|
||||||
|
|
||||||
|
def _save_inference_to_disk(self, inference_record: Dict):
|
||||||
|
"""Save inference record to persistent storage"""
|
||||||
|
try:
|
||||||
|
# Create inference data directory
|
||||||
|
inference_dir = Path("training_data/inference_history")
|
||||||
|
inference_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create filename with timestamp and model name
|
||||||
|
timestamp_str = inference_record['timestamp'].strftime('%Y%m%d_%H%M%S')
|
||||||
|
filename = f"{inference_record['symbol']}_{inference_record['model_name']}_{timestamp_str}.json"
|
||||||
|
filepath = inference_dir / filename
|
||||||
|
|
||||||
|
# Convert numpy arrays to lists for JSON serialization
|
||||||
|
serializable_record = self._make_json_serializable(inference_record)
|
||||||
|
|
||||||
|
# Save to JSON file
|
||||||
|
with open(filepath, 'w') as f:
|
||||||
|
json.dump(serializable_record, f, indent=2)
|
||||||
|
|
||||||
|
logger.debug(f"Saved inference record to disk: {filepath}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving inference to disk: {e}")
|
||||||
|
|
||||||
|
def _make_json_serializable(self, obj):
|
||||||
|
"""Convert object to JSON-serializable format"""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: self._make_json_serializable(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [self._make_json_serializable(item) for item in obj]
|
||||||
|
elif isinstance(obj, np.ndarray):
|
||||||
|
return obj.tolist()
|
||||||
|
elif isinstance(obj, (np.integer, np.floating)):
|
||||||
|
return obj.item()
|
||||||
|
elif isinstance(obj, datetime):
|
||||||
|
return obj.isoformat()
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def _trigger_model_training(self, symbol: str):
|
||||||
|
"""Trigger training for models based on previous inference data"""
|
||||||
|
try:
|
||||||
|
if not self.training_enabled or symbol not in self.inference_history:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get recent inference records
|
||||||
|
recent_records = list(self.inference_history[symbol])
|
||||||
|
if len(recent_records) < 2:
|
||||||
|
return # Need at least 2 records to compare
|
||||||
|
|
||||||
|
# Get current price for outcome evaluation
|
||||||
|
current_price = self.data_provider.get_current_price(symbol)
|
||||||
|
if current_price is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Process records that are old enough to evaluate outcomes
|
||||||
|
cutoff_time = datetime.now() - timedelta(minutes=5) # 5 minutes ago
|
||||||
|
|
||||||
|
for record in recent_records:
|
||||||
|
if record['timestamp'] < cutoff_time:
|
||||||
|
await self._evaluate_and_train_on_record(record, current_price)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error triggering model training for {symbol}: {e}")
|
||||||
|
|
||||||
|
async def _evaluate_and_train_on_record(self, record: Dict, current_price: float):
|
||||||
|
"""Evaluate prediction outcome and train model"""
|
||||||
|
try:
|
||||||
|
model_name = record['model_name']
|
||||||
|
prediction = record['prediction']
|
||||||
|
timestamp = record['timestamp']
|
||||||
|
|
||||||
|
# Calculate price change since prediction
|
||||||
|
# This is a simplified outcome evaluation - you might want to make it more sophisticated
|
||||||
|
time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes
|
||||||
|
|
||||||
|
# Get historical price at prediction time (simplified)
|
||||||
|
symbol = record['symbol']
|
||||||
|
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||||
|
if historical_data is None or historical_data.empty:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find closest price to prediction timestamp
|
||||||
|
prediction_price = historical_data['close'].iloc[-1] # Simplified
|
||||||
|
price_change_pct = (current_price - prediction_price) / prediction_price * 100
|
||||||
|
|
||||||
|
# Determine if prediction was correct
|
||||||
|
predicted_action = prediction['action']
|
||||||
|
was_correct = False
|
||||||
|
|
||||||
|
if predicted_action == 'BUY' and price_change_pct > 0.1: # Price went up
|
||||||
|
was_correct = True
|
||||||
|
elif predicted_action == 'SELL' and price_change_pct < -0.1: # Price went down
|
||||||
|
was_correct = True
|
||||||
|
elif predicted_action == 'HOLD' and abs(price_change_pct) < 0.1: # Price stayed stable
|
||||||
|
was_correct = True
|
||||||
|
|
||||||
|
# Update model performance tracking
|
||||||
|
if model_name not in self.model_performance:
|
||||||
|
self.model_performance[model_name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
|
||||||
|
|
||||||
|
self.model_performance[model_name]['total'] += 1
|
||||||
|
if was_correct:
|
||||||
|
self.model_performance[model_name]['correct'] += 1
|
||||||
|
|
||||||
|
self.model_performance[model_name]['accuracy'] = (
|
||||||
|
self.model_performance[model_name]['correct'] /
|
||||||
|
self.model_performance[model_name]['total']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train the specific model based on outcome
|
||||||
|
await self._train_model_on_outcome(record, was_correct, price_change_pct)
|
||||||
|
|
||||||
|
logger.debug(f"Evaluated {model_name} prediction: {'✓' if was_correct else '✗'} "
|
||||||
|
f"({prediction['action']}, {price_change_pct:.2f}% change)")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error evaluating and training on record: {e}")
|
||||||
|
|
||||||
|
async def _train_model_on_outcome(self, record: Dict, was_correct: bool, price_change_pct: float):
|
||||||
|
"""Train specific model based on prediction outcome"""
|
||||||
|
try:
|
||||||
|
model_name = record['model_name']
|
||||||
|
model_input = record['model_input']
|
||||||
|
prediction = record['prediction']
|
||||||
|
|
||||||
|
# Create training signal based on outcome
|
||||||
|
reward = 1.0 if was_correct else -0.5
|
||||||
|
|
||||||
|
# Train RL models
|
||||||
|
if 'dqn' in model_name.lower() and self.rl_agent:
|
||||||
|
if hasattr(self.rl_agent, 'add_experience'):
|
||||||
|
action_idx = ['SELL', 'HOLD', 'BUY'].index(prediction['action'])
|
||||||
|
self.rl_agent.add_experience(
|
||||||
|
state=model_input,
|
||||||
|
action=action_idx,
|
||||||
|
reward=reward,
|
||||||
|
next_state=model_input, # Simplified
|
||||||
|
done=True
|
||||||
|
)
|
||||||
|
logger.debug(f"Added RL training experience: reward={reward}")
|
||||||
|
|
||||||
|
# Train CNN models
|
||||||
|
elif 'cnn' in model_name.lower() and self.cnn_model:
|
||||||
|
if hasattr(self.cnn_model, 'train_on_outcome'):
|
||||||
|
target = 1 if was_correct else 0
|
||||||
|
self.cnn_model.train_on_outcome(model_input, target)
|
||||||
|
logger.debug(f"Trained CNN on outcome: target={target}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error training model on outcome: {e}")
|
||||||
|
|
||||||
|
def _calculate_rsi(self, prices: pd.Series, period: int = 14) -> float:
|
||||||
|
"""Calculate RSI indicator"""
|
||||||
|
try:
|
||||||
|
delta = prices.diff()
|
||||||
|
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
||||||
|
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
||||||
|
rs = gain / loss
|
||||||
|
rsi = 100 - (100 / (1 + rs))
|
||||||
|
return rsi.iloc[-1] if not rsi.empty else 50.0
|
||||||
|
except:
|
||||||
|
return 50.0
|
||||||
|
|
||||||
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
|
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
|
||||||
"""Get predictions from CNN model for all timeframes with enhanced COB features"""
|
"""Get predictions from CNN model for all timeframes with enhanced COB features"""
|
||||||
predictions = []
|
predictions = []
|
||||||
|
Reference in New Issue
Block a user