inrefence predictions fix
This commit is contained in:
@ -15,6 +15,7 @@ from threading import Lock
|
||||
|
||||
from .data_models import BaseDataInput, ModelOutput, create_model_output
|
||||
from NN.models.enhanced_cnn import EnhancedCNN
|
||||
from utils.inference_logger import log_model_inference
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -339,6 +340,42 @@ class EnhancedCNNAdapter:
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Log inference with full input data for training feedback
|
||||
log_model_inference(
|
||||
model_name=self.model_name,
|
||||
symbol=base_data.symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
probabilities={
|
||||
'BUY': predictions['buy_probability'],
|
||||
'SELL': predictions['sell_probability'],
|
||||
'HOLD': predictions['hold_probability']
|
||||
},
|
||||
input_features=features.cpu().numpy(), # Store full feature vector
|
||||
processing_time_ms=inference_duration,
|
||||
checkpoint_id=None, # Could be enhanced to track checkpoint
|
||||
metadata={
|
||||
'base_data_input': {
|
||||
'symbol': base_data.symbol,
|
||||
'timestamp': base_data.timestamp.isoformat(),
|
||||
'ohlcv_1s_count': len(base_data.ohlcv_1s),
|
||||
'ohlcv_1m_count': len(base_data.ohlcv_1m),
|
||||
'ohlcv_1h_count': len(base_data.ohlcv_1h),
|
||||
'ohlcv_1d_count': len(base_data.ohlcv_1d),
|
||||
'btc_ohlcv_1s_count': len(base_data.btc_ohlcv_1s),
|
||||
'has_cob_data': base_data.cob_data is not None,
|
||||
'technical_indicators_count': len(base_data.technical_indicators),
|
||||
'pivot_points_count': len(base_data.pivot_points),
|
||||
'last_predictions_count': len(base_data.last_predictions)
|
||||
},
|
||||
'model_predictions': {
|
||||
'pivot_price': pivot_price,
|
||||
'extrema_prediction': predictions['extrema'],
|
||||
'price_prediction': predictions['price_prediction']
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return model_output
|
||||
|
||||
except Exception as e:
|
||||
@ -401,7 +438,7 @@ class EnhancedCNNAdapter:
|
||||
|
||||
def train(self, epochs: int = 1) -> Dict[str, float]:
|
||||
"""
|
||||
Train the model with collected data
|
||||
Train the model with collected data and inference history
|
||||
|
||||
Args:
|
||||
epochs: Number of epochs to train for
|
||||
@ -415,6 +452,9 @@ class EnhancedCNNAdapter:
|
||||
training_start = training_start_time.timestamp()
|
||||
|
||||
with self.training_lock:
|
||||
# Get additional training data from inference history
|
||||
self._load_training_data_from_inference_history()
|
||||
|
||||
# Check if we have enough data
|
||||
if len(self.training_data) < self.batch_size:
|
||||
logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
|
||||
@ -583,3 +623,100 @@ class EnhancedCNNAdapter:
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint: {e}")
|
||||
|
||||
def _load_training_data_from_inference_history(self):
|
||||
"""Load training data from inference history for continuous learning"""
|
||||
try:
|
||||
from utils.database_manager import get_database_manager
|
||||
|
||||
db_manager = get_database_manager()
|
||||
|
||||
# Get recent inference records with input features
|
||||
inference_records = db_manager.get_inference_records_for_training(
|
||||
model_name=self.model_name,
|
||||
hours_back=24, # Last 24 hours
|
||||
limit=1000
|
||||
)
|
||||
|
||||
if not inference_records:
|
||||
logger.debug("No inference records found for training")
|
||||
return
|
||||
|
||||
# Convert inference records to training samples
|
||||
# For now, use a simple approach: treat high-confidence predictions as ground truth
|
||||
for record in inference_records:
|
||||
if record.input_features is not None and record.confidence > 0.7:
|
||||
# Convert action to index
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
if record.action in actions:
|
||||
action_idx = actions.index(record.action)
|
||||
|
||||
# Use confidence as a proxy for reward (high confidence = good prediction)
|
||||
reward = record.confidence * 2 - 1 # Scale to [-1, 1]
|
||||
|
||||
# Convert features to tensor
|
||||
features_tensor = torch.tensor(record.input_features, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Add to training data if not already present (avoid duplicates)
|
||||
sample_exists = any(
|
||||
torch.equal(features_tensor, existing[0])
|
||||
for existing in self.training_data
|
||||
)
|
||||
|
||||
if not sample_exists:
|
||||
self.training_data.append((features_tensor, action_idx, reward))
|
||||
|
||||
logger.info(f"Loaded {len(inference_records)} inference records for training, total training samples: {len(self.training_data)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading training data from inference history: {e}")
|
||||
|
||||
def evaluate_predictions_against_outcomes(self, hours_back: int = 1) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate past predictions against actual market outcomes
|
||||
|
||||
Args:
|
||||
hours_back: How many hours back to evaluate
|
||||
|
||||
Returns:
|
||||
Dict with evaluation metrics
|
||||
"""
|
||||
try:
|
||||
from utils.database_manager import get_database_manager
|
||||
|
||||
db_manager = get_database_manager()
|
||||
|
||||
# Get inference records from the specified time period
|
||||
inference_records = db_manager.get_inference_records_for_training(
|
||||
model_name=self.model_name,
|
||||
hours_back=hours_back,
|
||||
limit=100
|
||||
)
|
||||
|
||||
if not inference_records:
|
||||
return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}
|
||||
|
||||
# For now, use a simple evaluation based on confidence
|
||||
# In a real implementation, this would compare against actual price movements
|
||||
correct_predictions = 0
|
||||
total_predictions = len(inference_records)
|
||||
|
||||
# Simple heuristic: high confidence predictions are more likely to be correct
|
||||
for record in inference_records:
|
||||
if record.confidence > 0.8: # High confidence threshold
|
||||
correct_predictions += 1
|
||||
elif record.confidence > 0.6: # Medium confidence
|
||||
correct_predictions += 0.5
|
||||
|
||||
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||||
|
||||
logger.info(f"Prediction evaluation: {correct_predictions:.1f}/{total_predictions} = {accuracy:.3f} accuracy")
|
||||
|
||||
return {
|
||||
'accuracy': accuracy,
|
||||
'total_predictions': total_predictions,
|
||||
'correct_predictions': correct_predictions
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating predictions: {e}")
|
||||
return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}
|
||||
|
Reference in New Issue
Block a user