cash works again!

This commit is contained in:
Dobromir Popov
2025-05-25 00:28:52 +03:00
parent d418f6ce59
commit cf825239cd
18 changed files with 1970 additions and 1331 deletions

View File

@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Tuple, Any
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
from core.config import get_config
from core.data_provider import DataProvider
@ -293,7 +294,14 @@ class EnhancedCNNTrainer:
'train_accuracy': [],
'val_accuracy': [],
'confidence_accuracy': []
} # Create save directory models_path = self.config.cnn.get('model_dir', "models/enhanced_cnn") self.save_dir = Path(models_path) self.save_dir.mkdir(parents=True, exist_ok=True) logger.info("Enhanced CNN trainer initialized")
}
# Create save directory
models_path = self.config.cnn.get('model_dir', "models/enhanced_cnn")
self.save_dir = Path(models_path)
self.save_dir.mkdir(parents=True, exist_ok=True)
logger.info("Enhanced CNN trainer initialized")
def train_on_perfect_moves(self, min_samples: int = 100) -> Dict[str, Any]:
"""Train the model on perfect moves from the orchestrator"""
@ -563,4 +571,233 @@ class EnhancedCNNTrainer:
def get_model(self) -> EnhancedCNNModel:
"""Get the trained model"""
return self.model
return self.model
def __del__(self):
"""Cleanup"""
self.close_tensorboard()
def main():
"""Main function for standalone CNN live training with backtesting and analysis"""
import argparse
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
parser = argparse.ArgumentParser(description='Enhanced CNN Live Training with Backtesting and Analysis')
parser.add_argument('--symbols', type=str, nargs='+', default=['ETH/USDT', 'BTC/USDT'],
help='Trading symbols to train on')
parser.add_argument('--timeframes', type=str, nargs='+', default=['1m', '5m', '15m', '1h'],
help='Timeframes to use for training')
parser.add_argument('--epochs', type=int, default=100,
help='Number of training epochs')
parser.add_argument('--batch-size', type=int, default=32,
help='Training batch size')
parser.add_argument('--learning-rate', type=float, default=0.001,
help='Learning rate')
parser.add_argument('--save-path', type=str, default='models/enhanced_cnn/live_trained_model.pt',
help='Path to save the trained model')
parser.add_argument('--enable-backtesting', action='store_true', default=True,
help='Enable backtesting after training')
parser.add_argument('--enable-analysis', action='store_true', default=True,
help='Enable detailed analysis and reporting')
parser.add_argument('--enable-live-validation', action='store_true', default=True,
help='Enable live validation during training')
args = parser.parse_args()
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger.info("="*80)
logger.info("🧠 ENHANCED CNN LIVE TRAINING WITH BACKTESTING & ANALYSIS")
logger.info("="*80)
logger.info(f"Symbols: {args.symbols}")
logger.info(f"Timeframes: {args.timeframes}")
logger.info(f"Epochs: {args.epochs}")
logger.info(f"Batch Size: {args.batch_size}")
logger.info(f"Learning Rate: {args.learning_rate}")
logger.info(f"Save Path: {args.save_path}")
logger.info(f"Backtesting: {'Enabled' if args.enable_backtesting else 'Disabled'}")
logger.info(f"Analysis: {'Enabled' if args.enable_analysis else 'Disabled'}")
logger.info(f"Live Validation: {'Enabled' if args.enable_live_validation else 'Disabled'}")
logger.info("="*80)
try:
# Update config with command line arguments
config = get_config()
config.update('symbols', args.symbols)
config.update('timeframes', args.timeframes)
config.update('training', {
**config.training,
'epochs': args.epochs,
'batch_size': args.batch_size,
'learning_rate': args.learning_rate
})
# Initialize enhanced trainer
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
from core.data_provider import DataProvider
data_provider = DataProvider(config)
orchestrator = EnhancedTradingOrchestrator(data_provider)
trainer = EnhancedCNNTrainer(config, orchestrator)
# Phase 1: Data Collection and Preparation
logger.info("📊 Phase 1: Collecting and preparing training data...")
training_data = trainer.collect_training_data(args.symbols, lookback_days=30)
logger.info(f" Collected {len(training_data)} training samples")
# Phase 2: Model Training
logger.info("🧠 Phase 2: Training Enhanced CNN Model...")
training_results = trainer.train_on_perfect_moves(min_samples=1000)
logger.info("Training Results:")
logger.info(f" Best Validation Accuracy: {training_results['best_val_accuracy']:.4f}")
logger.info(f" Best Validation Loss: {training_results['best_val_loss']:.4f}")
logger.info(f" Total Epochs: {training_results['epochs_completed']}")
logger.info(f" Training Time: {training_results['total_time']:.2f}s")
# Phase 3: Model Evaluation
logger.info("📈 Phase 3: Model Evaluation...")
evaluation_results = trainer.evaluate_model(args.symbols[:1]) # Use first symbol for evaluation
logger.info("Evaluation Results:")
logger.info(f" Test Accuracy: {evaluation_results['test_accuracy']:.4f}")
logger.info(f" Test Loss: {evaluation_results['test_loss']:.4f}")
logger.info(f" Confidence Score: {evaluation_results['avg_confidence']:.4f}")
# Phase 4: Backtesting (if enabled)
if args.enable_backtesting:
logger.info("📊 Phase 4: Backtesting...")
# Create backtest environment
from trading.backtest_environment import BacktestEnvironment
backtest_env = BacktestEnvironment(
symbols=args.symbols,
timeframes=args.timeframes,
initial_balance=10000.0,
data_provider=data_provider
)
# Run backtest
backtest_results = backtest_env.run_backtest_with_model(
model=trainer.model,
lookback_days=7, # Test on last 7 days
max_trades_per_day=50
)
logger.info("Backtesting Results:")
logger.info(f" Total Returns: {backtest_results['total_return']:.2f}%")
logger.info(f" Win Rate: {backtest_results['win_rate']:.2f}%")
logger.info(f" Sharpe Ratio: {backtest_results['sharpe_ratio']:.4f}")
logger.info(f" Max Drawdown: {backtest_results['max_drawdown']:.2f}%")
logger.info(f" Total Trades: {backtest_results['total_trades']}")
logger.info(f" Profit Factor: {backtest_results['profit_factor']:.4f}")
# Phase 5: Analysis and Reporting (if enabled)
if args.enable_analysis:
logger.info("📋 Phase 5: Analysis and Reporting...")
# Generate comprehensive analysis report
analysis_report = trainer.generate_analysis_report(
training_results=training_results,
evaluation_results=evaluation_results,
backtest_results=backtest_results if args.enable_backtesting else None
)
# Save analysis report
report_path = Path(args.save_path).parent / "analysis_report.json"
report_path.parent.mkdir(parents=True, exist_ok=True)
with open(report_path, 'w') as f:
json.dump(analysis_report, f, indent=2, default=str)
logger.info(f" Analysis report saved: {report_path}")
# Generate performance plots
plots_dir = Path(args.save_path).parent / "plots"
plots_dir.mkdir(parents=True, exist_ok=True)
trainer.generate_performance_plots(
training_results=training_results,
evaluation_results=evaluation_results,
save_dir=plots_dir
)
logger.info(f" Performance plots saved: {plots_dir}")
# Phase 6: Model Saving
logger.info("💾 Phase 6: Saving trained model...")
model_path = Path(args.save_path)
model_path.parent.mkdir(parents=True, exist_ok=True)
trainer.model.save(str(model_path))
logger.info(f" Model saved: {model_path}")
# Save training metadata
metadata = {
'training_config': {
'symbols': args.symbols,
'timeframes': args.timeframes,
'epochs': args.epochs,
'batch_size': args.batch_size,
'learning_rate': args.learning_rate
},
'training_results': training_results,
'evaluation_results': evaluation_results
}
if args.enable_backtesting:
metadata['backtest_results'] = backtest_results
metadata_path = model_path.with_suffix('.json')
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2, default=str)
logger.info(f" Training metadata saved: {metadata_path}")
# Phase 7: Live Validation (if enabled)
if args.enable_live_validation:
logger.info("🔄 Phase 7: Live Validation...")
# Test model on recent live data
live_validation_results = trainer.run_live_validation(
symbols=args.symbols[:1], # Use first symbol
validation_hours=2 # Validate on last 2 hours
)
logger.info("Live Validation Results:")
logger.info(f" Prediction Accuracy: {live_validation_results['accuracy']:.2f}%")
logger.info(f" Average Confidence: {live_validation_results['avg_confidence']:.4f}")
logger.info(f" Predictions Made: {live_validation_results['total_predictions']}")
logger.info("="*80)
logger.info("🎉 ENHANCED CNN LIVE TRAINING COMPLETED SUCCESSFULLY!")
logger.info("="*80)
logger.info(f"📊 Model Path: {model_path}")
logger.info(f"📋 Metadata: {metadata_path}")
if args.enable_analysis:
logger.info(f"📈 Analysis Report: {report_path}")
logger.info(f"📊 Performance Plots: {plots_dir}")
logger.info("="*80)
except KeyboardInterrupt:
logger.info("Training interrupted by user")
return 1
except Exception as e:
logger.error(f"Training failed: {e}")
import traceback
logger.error(traceback.format_exc())
return 1
return 0
if __name__ == "__main__":
sys.exit(main())