cob data providers tests

This commit is contained in:
Dobromir Popov
2025-07-23 22:49:54 +03:00
parent c30267bf0b
commit 4765b1b1e1
5 changed files with 298 additions and 294 deletions

View File

@ -48,6 +48,16 @@ class EnhancedCNNAdapter:
self.learning_rate = 0.0001
self.model_name = "enhanced_cnn_v1"
# Enhanced metrics tracking
self.last_inference_time = None
self.last_inference_duration = 0.0
self.last_prediction_output = None
self.last_training_time = None
self.last_training_duration = 0.0
self.last_training_loss = 0.0
self.inference_count = 0
self.training_count = 0
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
@ -181,6 +191,10 @@ class EnhancedCNNAdapter:
ModelOutput: Standardized model output
"""
try:
# Track inference timing
start_time = datetime.now()
inference_start = start_time.timestamp()
# Convert BaseDataInput to features
features = self._convert_base_data_to_features(base_data)
@ -204,6 +218,18 @@ class EnhancedCNNAdapter:
actions = ['BUY', 'SELL', 'HOLD']
action = actions[action_idx]
# Extract pivot price prediction (simplified - take first value from price_pred)
pivot_price = None
if price_pred is not None and len(price_pred.squeeze()) > 0:
# Get current price from base_data for context
current_price = 0.0
if base_data.ohlcv_1s and len(base_data.ohlcv_1s) > 0:
current_price = base_data.ohlcv_1s[-1].close
# Calculate pivot price as current price + predicted change
price_change_pct = float(price_pred.squeeze()[0].item()) # First prediction value
pivot_price = current_price * (1 + price_change_pct * 0.01) # Convert percentage to price
# Create predictions dictionary
predictions = {
'action': action,
@ -211,7 +237,8 @@ class EnhancedCNNAdapter:
'sell_probability': float(action_probs[0, 1].item()),
'hold_probability': float(action_probs[0, 2].item()),
'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(),
'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist()
'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist(),
'pivot_price': pivot_price
}
# Create hidden states dictionary
@ -219,11 +246,31 @@ class EnhancedCNNAdapter:
'features': features_refined.squeeze(0).cpu().numpy().tolist()
}
# Calculate inference duration
end_time = datetime.now()
inference_duration = (end_time.timestamp() - inference_start) * 1000 # Convert to milliseconds
# Update metrics
self.last_inference_time = start_time
self.last_inference_duration = inference_duration
self.inference_count += 1
# Store last prediction output for dashboard
self.last_prediction_output = {
'action': action,
'confidence': confidence,
'pivot_price': pivot_price,
'timestamp': start_time,
'symbol': base_data.symbol
}
# Create metadata dictionary
metadata = {
'model_version': '1.0',
'timestamp': datetime.now().isoformat(),
'input_shape': features.shape
'timestamp': start_time.isoformat(),
'input_shape': features.shape,
'inference_duration_ms': inference_duration,
'inference_count': self.inference_count
}
# Create ModelOutput
@ -231,7 +278,7 @@ class EnhancedCNNAdapter:
model_type='cnn',
model_name=self.model_name,
symbol=base_data.symbol,
timestamp=datetime.now(),
timestamp=start_time,
confidence=confidence,
predictions=predictions,
hidden_states=hidden_states,
@ -294,6 +341,10 @@ class EnhancedCNNAdapter:
Dict[str, float]: Training metrics
"""
try:
# Track training timing
training_start_time = datetime.now()
training_start = training_start_time.timestamp()
with self.training_lock:
# Check if we have enough data
if len(self.training_data) < self.batch_size:
@ -378,15 +429,27 @@ class EnhancedCNNAdapter:
avg_loss = total_loss / (len(self.training_data) / self.batch_size)
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
# Calculate training duration
training_end_time = datetime.now()
training_duration = (training_end_time.timestamp() - training_start) * 1000 # Convert to milliseconds
# Update training metrics
self.last_training_time = training_start_time
self.last_training_duration = training_duration
self.last_training_loss = avg_loss
self.training_count += 1
# Save checkpoint
self._save_checkpoint(avg_loss, accuracy)
logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}")
logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}, duration={training_duration:.1f}ms")
return {
'loss': avg_loss,
'accuracy': accuracy,
'samples': len(self.training_data)
'samples': len(self.training_data),
'duration_ms': training_duration,
'training_count': self.training_count
}
except Exception as e: