cob data providers tests
This commit is contained in:
@ -48,6 +48,16 @@ class EnhancedCNNAdapter:
|
||||
self.learning_rate = 0.0001
|
||||
self.model_name = "enhanced_cnn_v1"
|
||||
|
||||
# Enhanced metrics tracking
|
||||
self.last_inference_time = None
|
||||
self.last_inference_duration = 0.0
|
||||
self.last_prediction_output = None
|
||||
self.last_training_time = None
|
||||
self.last_training_duration = 0.0
|
||||
self.last_training_loss = 0.0
|
||||
self.inference_count = 0
|
||||
self.training_count = 0
|
||||
|
||||
# Create checkpoint directory if it doesn't exist
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
@ -181,6 +191,10 @@ class EnhancedCNNAdapter:
|
||||
ModelOutput: Standardized model output
|
||||
"""
|
||||
try:
|
||||
# Track inference timing
|
||||
start_time = datetime.now()
|
||||
inference_start = start_time.timestamp()
|
||||
|
||||
# Convert BaseDataInput to features
|
||||
features = self._convert_base_data_to_features(base_data)
|
||||
|
||||
@ -204,6 +218,18 @@ class EnhancedCNNAdapter:
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
action = actions[action_idx]
|
||||
|
||||
# Extract pivot price prediction (simplified - take first value from price_pred)
|
||||
pivot_price = None
|
||||
if price_pred is not None and len(price_pred.squeeze()) > 0:
|
||||
# Get current price from base_data for context
|
||||
current_price = 0.0
|
||||
if base_data.ohlcv_1s and len(base_data.ohlcv_1s) > 0:
|
||||
current_price = base_data.ohlcv_1s[-1].close
|
||||
|
||||
# Calculate pivot price as current price + predicted change
|
||||
price_change_pct = float(price_pred.squeeze()[0].item()) # First prediction value
|
||||
pivot_price = current_price * (1 + price_change_pct * 0.01) # Convert percentage to price
|
||||
|
||||
# Create predictions dictionary
|
||||
predictions = {
|
||||
'action': action,
|
||||
@ -211,7 +237,8 @@ class EnhancedCNNAdapter:
|
||||
'sell_probability': float(action_probs[0, 1].item()),
|
||||
'hold_probability': float(action_probs[0, 2].item()),
|
||||
'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(),
|
||||
'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist()
|
||||
'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist(),
|
||||
'pivot_price': pivot_price
|
||||
}
|
||||
|
||||
# Create hidden states dictionary
|
||||
@ -219,11 +246,31 @@ class EnhancedCNNAdapter:
|
||||
'features': features_refined.squeeze(0).cpu().numpy().tolist()
|
||||
}
|
||||
|
||||
# Calculate inference duration
|
||||
end_time = datetime.now()
|
||||
inference_duration = (end_time.timestamp() - inference_start) * 1000 # Convert to milliseconds
|
||||
|
||||
# Update metrics
|
||||
self.last_inference_time = start_time
|
||||
self.last_inference_duration = inference_duration
|
||||
self.inference_count += 1
|
||||
|
||||
# Store last prediction output for dashboard
|
||||
self.last_prediction_output = {
|
||||
'action': action,
|
||||
'confidence': confidence,
|
||||
'pivot_price': pivot_price,
|
||||
'timestamp': start_time,
|
||||
'symbol': base_data.symbol
|
||||
}
|
||||
|
||||
# Create metadata dictionary
|
||||
metadata = {
|
||||
'model_version': '1.0',
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'input_shape': features.shape
|
||||
'timestamp': start_time.isoformat(),
|
||||
'input_shape': features.shape,
|
||||
'inference_duration_ms': inference_duration,
|
||||
'inference_count': self.inference_count
|
||||
}
|
||||
|
||||
# Create ModelOutput
|
||||
@ -231,7 +278,7 @@ class EnhancedCNNAdapter:
|
||||
model_type='cnn',
|
||||
model_name=self.model_name,
|
||||
symbol=base_data.symbol,
|
||||
timestamp=datetime.now(),
|
||||
timestamp=start_time,
|
||||
confidence=confidence,
|
||||
predictions=predictions,
|
||||
hidden_states=hidden_states,
|
||||
@ -294,6 +341,10 @@ class EnhancedCNNAdapter:
|
||||
Dict[str, float]: Training metrics
|
||||
"""
|
||||
try:
|
||||
# Track training timing
|
||||
training_start_time = datetime.now()
|
||||
training_start = training_start_time.timestamp()
|
||||
|
||||
with self.training_lock:
|
||||
# Check if we have enough data
|
||||
if len(self.training_data) < self.batch_size:
|
||||
@ -378,15 +429,27 @@ class EnhancedCNNAdapter:
|
||||
avg_loss = total_loss / (len(self.training_data) / self.batch_size)
|
||||
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||||
|
||||
# Calculate training duration
|
||||
training_end_time = datetime.now()
|
||||
training_duration = (training_end_time.timestamp() - training_start) * 1000 # Convert to milliseconds
|
||||
|
||||
# Update training metrics
|
||||
self.last_training_time = training_start_time
|
||||
self.last_training_duration = training_duration
|
||||
self.last_training_loss = avg_loss
|
||||
self.training_count += 1
|
||||
|
||||
# Save checkpoint
|
||||
self._save_checkpoint(avg_loss, accuracy)
|
||||
|
||||
logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}")
|
||||
logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}, duration={training_duration:.1f}ms")
|
||||
|
||||
return {
|
||||
'loss': avg_loss,
|
||||
'accuracy': accuracy,
|
||||
'samples': len(self.training_data)
|
||||
'samples': len(self.training_data),
|
||||
'duration_ms': training_duration,
|
||||
'training_count': self.training_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
Reference in New Issue
Block a user