new overhaul
This commit is contained in:
566
training/enhanced_cnn_trainer.py
Normal file
566
training/enhanced_cnn_trainer.py
Normal file
@ -0,0 +1,566 @@
|
||||
"""
|
||||
Enhanced CNN Trainer with Perfect Move Learning
|
||||
|
||||
This trainer implements:
|
||||
1. Training on marked perfect moves with known outcomes
|
||||
2. Multi-timeframe CNN model training with confidence scoring
|
||||
3. Backpropagation on optimal moves when future outcomes are known
|
||||
4. Progressive learning from real trading experience
|
||||
5. Symbol-specific and timeframe-specific model fine-tuning
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from pathlib import Path
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import PerfectMove, EnhancedTradingOrchestrator
|
||||
from models import CNNModelInterface
|
||||
import models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PerfectMoveDataset(Dataset):
|
||||
"""Dataset for training on perfect moves with known outcomes"""
|
||||
|
||||
def __init__(self, perfect_moves: List[PerfectMove], data_provider: DataProvider):
|
||||
"""
|
||||
Initialize dataset from perfect moves
|
||||
|
||||
Args:
|
||||
perfect_moves: List of perfect moves with known outcomes
|
||||
data_provider: Data provider to fetch additional context
|
||||
"""
|
||||
self.perfect_moves = perfect_moves
|
||||
self.data_provider = data_provider
|
||||
self.samples = []
|
||||
self._prepare_samples()
|
||||
|
||||
def _prepare_samples(self):
|
||||
"""Prepare training samples from perfect moves"""
|
||||
logger.info(f"Preparing {len(self.perfect_moves)} perfect move samples")
|
||||
|
||||
for move in self.perfect_moves:
|
||||
try:
|
||||
# Get feature matrix at the time of the decision
|
||||
feature_matrix = self.data_provider.get_feature_matrix(
|
||||
symbol=move.symbol,
|
||||
timeframes=[move.timeframe],
|
||||
window_size=20,
|
||||
end_time=move.timestamp
|
||||
)
|
||||
|
||||
if feature_matrix is not None:
|
||||
# Convert optimal action to label
|
||||
action_to_label = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
|
||||
label = action_to_label.get(move.optimal_action, 1)
|
||||
|
||||
# Create confidence target (what confidence should have been)
|
||||
confidence_target = move.confidence_should_have_been
|
||||
|
||||
sample = {
|
||||
'features': feature_matrix,
|
||||
'action_label': label,
|
||||
'confidence_target': confidence_target,
|
||||
'symbol': move.symbol,
|
||||
'timeframe': move.timeframe,
|
||||
'outcome': move.actual_outcome,
|
||||
'timestamp': move.timestamp
|
||||
}
|
||||
self.samples.append(sample)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error preparing sample for perfect move: {e}")
|
||||
|
||||
logger.info(f"Prepared {len(self.samples)} valid training samples")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.samples[idx]
|
||||
|
||||
# Convert to tensors
|
||||
features = torch.FloatTensor(sample['features'])
|
||||
action_label = torch.LongTensor([sample['action_label']])
|
||||
confidence_target = torch.FloatTensor([sample['confidence_target']])
|
||||
|
||||
return {
|
||||
'features': features,
|
||||
'action_label': action_label,
|
||||
'confidence_target': confidence_target,
|
||||
'metadata': {
|
||||
'symbol': sample['symbol'],
|
||||
'timeframe': sample['timeframe'],
|
||||
'outcome': sample['outcome'],
|
||||
'timestamp': sample['timestamp']
|
||||
}
|
||||
}
|
||||
|
||||
class EnhancedCNNModel(nn.Module, CNNModelInterface):
|
||||
"""Enhanced CNN model with timeframe-specific predictions and confidence scoring"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
nn.Module.__init__(self)
|
||||
CNNModelInterface.__init__(self, config)
|
||||
|
||||
self.timeframes = config.get('timeframes', ['1h', '4h', '1d'])
|
||||
self.n_features = len(config.get('features', ['open', 'high', 'low', 'close', 'volume']))
|
||||
self.window_size = config.get('window_size', 20)
|
||||
|
||||
# Build the neural network
|
||||
self._build_network()
|
||||
|
||||
# Initialize device
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.to(self.device)
|
||||
|
||||
# Training components
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=config.get('learning_rate', 0.001))
|
||||
self.action_criterion = nn.CrossEntropyLoss()
|
||||
self.confidence_criterion = nn.MSELoss()
|
||||
|
||||
logger.info(f"Enhanced CNN model initialized for {len(self.timeframes)} timeframes")
|
||||
|
||||
def _build_network(self):
|
||||
"""Build the CNN architecture"""
|
||||
# Convolutional feature extraction
|
||||
self.conv_layers = nn.Sequential(
|
||||
# First conv block
|
||||
nn.Conv1d(self.n_features, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Second conv block
|
||||
nn.Conv1d(64, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Third conv block
|
||||
nn.Conv1d(128, 256, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Global average pooling
|
||||
nn.AdaptiveAvgPool1d(1)
|
||||
)
|
||||
|
||||
# Timeframe-specific heads
|
||||
self.timeframe_heads = nn.ModuleDict()
|
||||
for timeframe in self.timeframes:
|
||||
self.timeframe_heads[timeframe] = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3)
|
||||
)
|
||||
|
||||
# Action prediction heads (one per timeframe)
|
||||
self.action_heads = nn.ModuleDict()
|
||||
for timeframe in self.timeframes:
|
||||
self.action_heads[timeframe] = nn.Linear(64, 3) # BUY, HOLD, SELL
|
||||
|
||||
# Confidence prediction heads (one per timeframe)
|
||||
self.confidence_heads = nn.ModuleDict()
|
||||
for timeframe in self.timeframes:
|
||||
self.confidence_heads[timeframe] = nn.Sequential(
|
||||
nn.Linear(64, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, 1),
|
||||
nn.Sigmoid() # Output between 0 and 1
|
||||
)
|
||||
|
||||
def forward(self, x, timeframe: str = None):
|
||||
"""
|
||||
Forward pass through the network
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch_size, window_size, features]
|
||||
timeframe: Specific timeframe to predict for
|
||||
|
||||
Returns:
|
||||
action_probs: Action probabilities
|
||||
confidence: Confidence score
|
||||
"""
|
||||
# Reshape for conv1d: [batch, features, sequence]
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
# Extract features
|
||||
features = self.conv_layers(x) # [batch, 256, 1]
|
||||
features = features.squeeze(-1) # [batch, 256]
|
||||
|
||||
if timeframe and timeframe in self.timeframe_heads:
|
||||
# Timeframe-specific prediction
|
||||
tf_features = self.timeframe_heads[timeframe](features)
|
||||
action_logits = self.action_heads[timeframe](tf_features)
|
||||
confidence = self.confidence_heads[timeframe](tf_features)
|
||||
|
||||
action_probs = torch.softmax(action_logits, dim=1)
|
||||
return action_probs, confidence.squeeze(-1)
|
||||
else:
|
||||
# Multi-timeframe prediction (average across timeframes)
|
||||
all_action_probs = []
|
||||
all_confidences = []
|
||||
|
||||
for tf in self.timeframes:
|
||||
tf_features = self.timeframe_heads[tf](features)
|
||||
action_logits = self.action_heads[tf](tf_features)
|
||||
confidence = self.confidence_heads[tf](tf_features)
|
||||
|
||||
action_probs = torch.softmax(action_logits, dim=1)
|
||||
all_action_probs.append(action_probs)
|
||||
all_confidences.append(confidence.squeeze(-1))
|
||||
|
||||
# Average predictions across timeframes
|
||||
avg_action_probs = torch.stack(all_action_probs).mean(dim=0)
|
||||
avg_confidence = torch.stack(all_confidences).mean(dim=0)
|
||||
|
||||
return avg_action_probs, avg_confidence
|
||||
|
||||
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||
"""Predict action probabilities and confidence"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
x = torch.FloatTensor(features).to(self.device)
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(0) # Add batch dimension
|
||||
|
||||
action_probs, confidence = self.forward(x)
|
||||
|
||||
return action_probs[0].cpu().numpy(), confidence[0].cpu().item()
|
||||
|
||||
def predict_timeframe(self, features: np.ndarray, timeframe: str) -> Tuple[np.ndarray, float]:
|
||||
"""Predict for specific timeframe"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
x = torch.FloatTensor(features).to(self.device)
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(0) # Add batch dimension
|
||||
|
||||
action_probs, confidence = self.forward(x, timeframe)
|
||||
|
||||
return action_probs[0].cpu().numpy(), confidence[0].cpu().item()
|
||||
|
||||
def get_memory_usage(self) -> int:
|
||||
"""Get memory usage in MB"""
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.memory_allocated(self.device) // (1024 * 1024)
|
||||
else:
|
||||
# Rough estimate for CPU
|
||||
param_count = sum(p.numel() for p in self.parameters())
|
||||
return (param_count * 4) // (1024 * 1024) # 4 bytes per float32
|
||||
|
||||
def train(self, training_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Train the model (placeholder for interface compatibility)"""
|
||||
return {}
|
||||
|
||||
class EnhancedCNNTrainer:
|
||||
"""Enhanced CNN trainer using perfect moves and real market outcomes"""
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
|
||||
"""Initialize the enhanced trainer"""
|
||||
self.config = config or get_config()
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = DataProvider(self.config)
|
||||
|
||||
# Training parameters
|
||||
self.learning_rate = self.config.training.get('learning_rate', 0.001)
|
||||
self.batch_size = self.config.training.get('batch_size', 32)
|
||||
self.epochs = self.config.training.get('epochs', 100)
|
||||
self.patience = self.config.training.get('early_stopping_patience', 10)
|
||||
|
||||
# Model
|
||||
self.model = EnhancedCNNModel(self.config.cnn)
|
||||
|
||||
# Training history
|
||||
self.training_history = {
|
||||
'train_loss': [],
|
||||
'val_loss': [],
|
||||
'train_accuracy': [],
|
||||
'val_accuracy': [],
|
||||
'confidence_accuracy': []
|
||||
} # Create save directory models_path = self.config.cnn.get('model_dir', "models/enhanced_cnn") self.save_dir = Path(models_path) self.save_dir.mkdir(parents=True, exist_ok=True) logger.info("Enhanced CNN trainer initialized")
|
||||
|
||||
def train_on_perfect_moves(self, min_samples: int = 100) -> Dict[str, Any]:
|
||||
"""Train the model on perfect moves from the orchestrator"""
|
||||
if not self.orchestrator:
|
||||
raise ValueError("Orchestrator required for perfect move training")
|
||||
|
||||
# Get perfect moves from orchestrator
|
||||
perfect_moves = []
|
||||
for symbol in self.config.symbols:
|
||||
symbol_moves = self.orchestrator.get_perfect_moves_for_training(symbol=symbol)
|
||||
perfect_moves.extend(symbol_moves)
|
||||
|
||||
if len(perfect_moves) < min_samples:
|
||||
logger.warning(f"Not enough perfect moves for training: {len(perfect_moves)} < {min_samples}")
|
||||
return {'error': 'insufficient_data', 'samples': len(perfect_moves)}
|
||||
|
||||
logger.info(f"Training on {len(perfect_moves)} perfect moves")
|
||||
|
||||
# Create dataset
|
||||
dataset = PerfectMoveDataset(perfect_moves, self.data_provider)
|
||||
|
||||
# Split into train/validation
|
||||
train_size = int(0.8 * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||
|
||||
# Create data loaders
|
||||
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
|
||||
|
||||
# Training loop
|
||||
best_val_loss = float('inf')
|
||||
patience_counter = 0
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
# Training phase
|
||||
train_loss, train_acc = self._train_epoch(train_loader)
|
||||
|
||||
# Validation phase
|
||||
val_loss, val_acc, conf_acc = self._validate_epoch(val_loader)
|
||||
|
||||
# Update history
|
||||
self.training_history['train_loss'].append(train_loss)
|
||||
self.training_history['val_loss'].append(val_loss)
|
||||
self.training_history['train_accuracy'].append(train_acc)
|
||||
self.training_history['val_accuracy'].append(val_acc)
|
||||
self.training_history['confidence_accuracy'].append(conf_acc)
|
||||
|
||||
# Log progress
|
||||
logger.info(f"Epoch {epoch+1}/{self.epochs}: "
|
||||
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
|
||||
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
|
||||
f"Conf Acc: {conf_acc:.4f}")
|
||||
|
||||
# Early stopping
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
patience_counter = 0
|
||||
self._save_model('best_model.pt')
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= self.patience:
|
||||
logger.info(f"Early stopping at epoch {epoch+1}")
|
||||
break
|
||||
|
||||
# Save final model
|
||||
self._save_model('final_model.pt')
|
||||
|
||||
# Generate training report
|
||||
return self._generate_training_report()
|
||||
|
||||
def _train_epoch(self, train_loader: DataLoader) -> Tuple[float, float]:
|
||||
"""Train for one epoch"""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
correct_predictions = 0
|
||||
total_predictions = 0
|
||||
|
||||
for batch in train_loader:
|
||||
features = batch['features'].to(self.model.device)
|
||||
action_labels = batch['action_label'].to(self.model.device).squeeze(-1)
|
||||
confidence_targets = batch['confidence_target'].to(self.model.device).squeeze(-1)
|
||||
|
||||
# Zero gradients
|
||||
self.model.optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
action_probs, confidence_pred = self.model(features)
|
||||
|
||||
# Calculate losses
|
||||
action_loss = self.model.action_criterion(action_probs, action_labels)
|
||||
confidence_loss = self.model.confidence_criterion(confidence_pred, confidence_targets)
|
||||
|
||||
# Combined loss
|
||||
total_loss_batch = action_loss + 0.5 * confidence_loss
|
||||
|
||||
# Backward pass
|
||||
total_loss_batch.backward()
|
||||
self.model.optimizer.step()
|
||||
|
||||
# Track metrics
|
||||
total_loss += total_loss_batch.item()
|
||||
predicted_actions = torch.argmax(action_probs, dim=1)
|
||||
correct_predictions += (predicted_actions == action_labels).sum().item()
|
||||
total_predictions += action_labels.size(0)
|
||||
|
||||
avg_loss = total_loss / len(train_loader)
|
||||
accuracy = correct_predictions / total_predictions
|
||||
|
||||
return avg_loss, accuracy
|
||||
|
||||
def _validate_epoch(self, val_loader: DataLoader) -> Tuple[float, float, float]:
|
||||
"""Validate for one epoch"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
correct_predictions = 0
|
||||
total_predictions = 0
|
||||
confidence_errors = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
features = batch['features'].to(self.model.device)
|
||||
action_labels = batch['action_label'].to(self.model.device).squeeze(-1)
|
||||
confidence_targets = batch['confidence_target'].to(self.model.device).squeeze(-1)
|
||||
|
||||
# Forward pass
|
||||
action_probs, confidence_pred = self.model(features)
|
||||
|
||||
# Calculate losses
|
||||
action_loss = self.model.action_criterion(action_probs, action_labels)
|
||||
confidence_loss = self.model.confidence_criterion(confidence_pred, confidence_targets)
|
||||
total_loss_batch = action_loss + 0.5 * confidence_loss
|
||||
|
||||
# Track metrics
|
||||
total_loss += total_loss_batch.item()
|
||||
predicted_actions = torch.argmax(action_probs, dim=1)
|
||||
correct_predictions += (predicted_actions == action_labels).sum().item()
|
||||
total_predictions += action_labels.size(0)
|
||||
|
||||
# Track confidence accuracy
|
||||
conf_errors = torch.abs(confidence_pred - confidence_targets)
|
||||
confidence_errors.extend(conf_errors.cpu().numpy())
|
||||
|
||||
avg_loss = total_loss / len(val_loader)
|
||||
accuracy = correct_predictions / total_predictions
|
||||
confidence_accuracy = 1.0 - np.mean(confidence_errors) # 1 - mean absolute error
|
||||
|
||||
return avg_loss, accuracy, confidence_accuracy
|
||||
|
||||
def _save_model(self, filename: str):
|
||||
"""Save the model"""
|
||||
save_path = self.save_dir / filename
|
||||
torch.save({
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.model.optimizer.state_dict(),
|
||||
'config': self.config.cnn,
|
||||
'training_history': self.training_history
|
||||
}, save_path)
|
||||
logger.info(f"Model saved to {save_path}")
|
||||
|
||||
def load_model(self, filename: str) -> bool:
|
||||
"""Load a saved model"""
|
||||
load_path = self.save_dir / filename
|
||||
if not load_path.exists():
|
||||
logger.error(f"Model file not found: {load_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(load_path, map_location=self.model.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.training_history = checkpoint.get('training_history', {})
|
||||
logger.info(f"Model loaded from {load_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {e}")
|
||||
return False
|
||||
|
||||
def _generate_training_report(self) -> Dict[str, Any]:
|
||||
"""Generate comprehensive training report"""
|
||||
if not self.training_history['train_loss']:
|
||||
return {'error': 'no_training_data'}
|
||||
|
||||
# Calculate final metrics
|
||||
final_train_loss = self.training_history['train_loss'][-1]
|
||||
final_val_loss = self.training_history['val_loss'][-1]
|
||||
final_train_acc = self.training_history['train_accuracy'][-1]
|
||||
final_val_acc = self.training_history['val_accuracy'][-1]
|
||||
final_conf_acc = self.training_history['confidence_accuracy'][-1]
|
||||
|
||||
# Best metrics
|
||||
best_val_loss = min(self.training_history['val_loss'])
|
||||
best_val_acc = max(self.training_history['val_accuracy'])
|
||||
best_conf_acc = max(self.training_history['confidence_accuracy'])
|
||||
|
||||
report = {
|
||||
'training_completed': True,
|
||||
'epochs_trained': len(self.training_history['train_loss']),
|
||||
'final_metrics': {
|
||||
'train_loss': final_train_loss,
|
||||
'val_loss': final_val_loss,
|
||||
'train_accuracy': final_train_acc,
|
||||
'val_accuracy': final_val_acc,
|
||||
'confidence_accuracy': final_conf_acc
|
||||
},
|
||||
'best_metrics': {
|
||||
'val_loss': best_val_loss,
|
||||
'val_accuracy': best_val_acc,
|
||||
'confidence_accuracy': best_conf_acc
|
||||
},
|
||||
'model_info': {
|
||||
'timeframes': self.model.timeframes,
|
||||
'memory_usage_mb': self.model.get_memory_usage(),
|
||||
'device': str(self.model.device)
|
||||
}
|
||||
}
|
||||
|
||||
# Generate plots
|
||||
self._plot_training_history()
|
||||
|
||||
logger.info("Training completed successfully")
|
||||
logger.info(f"Final validation accuracy: {final_val_acc:.4f}")
|
||||
logger.info(f"Final confidence accuracy: {final_conf_acc:.4f}")
|
||||
|
||||
return report
|
||||
|
||||
def _plot_training_history(self):
|
||||
"""Plot training history"""
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||
fig.suptitle('Enhanced CNN Training History')
|
||||
|
||||
# Loss plot
|
||||
axes[0, 0].plot(self.training_history['train_loss'], label='Train Loss')
|
||||
axes[0, 0].plot(self.training_history['val_loss'], label='Val Loss')
|
||||
axes[0, 0].set_title('Loss')
|
||||
axes[0, 0].set_xlabel('Epoch')
|
||||
axes[0, 0].set_ylabel('Loss')
|
||||
axes[0, 0].legend()
|
||||
|
||||
# Accuracy plot
|
||||
axes[0, 1].plot(self.training_history['train_accuracy'], label='Train Accuracy')
|
||||
axes[0, 1].plot(self.training_history['val_accuracy'], label='Val Accuracy')
|
||||
axes[0, 1].set_title('Action Accuracy')
|
||||
axes[0, 1].set_xlabel('Epoch')
|
||||
axes[0, 1].set_ylabel('Accuracy')
|
||||
axes[0, 1].legend()
|
||||
|
||||
# Confidence accuracy plot
|
||||
axes[1, 0].plot(self.training_history['confidence_accuracy'], label='Confidence Accuracy')
|
||||
axes[1, 0].set_title('Confidence Prediction Accuracy')
|
||||
axes[1, 0].set_xlabel('Epoch')
|
||||
axes[1, 0].set_ylabel('Accuracy')
|
||||
axes[1, 0].legend()
|
||||
|
||||
# Learning curves comparison
|
||||
axes[1, 1].plot(self.training_history['val_loss'], label='Validation Loss')
|
||||
axes[1, 1].plot(self.training_history['confidence_accuracy'], label='Confidence Accuracy')
|
||||
axes[1, 1].set_title('Model Performance Overview')
|
||||
axes[1, 1].set_xlabel('Epoch')
|
||||
axes[1, 1].legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(self.save_dir / 'training_history.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
logger.info(f"Training plots saved to {self.save_dir / 'training_history.png'}")
|
||||
|
||||
def get_model(self) -> EnhancedCNNModel:
|
||||
"""Get the trained model"""
|
||||
return self.model
|
Reference in New Issue
Block a user