cash works again!

This commit is contained in:
Dobromir Popov
2025-05-25 00:28:52 +03:00
parent d418f6ce59
commit cf825239cd
18 changed files with 1970 additions and 1331 deletions

View File

@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Tuple, Any
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
from core.config import get_config
from core.data_provider import DataProvider
@ -293,7 +294,14 @@ class EnhancedCNNTrainer:
'train_accuracy': [],
'val_accuracy': [],
'confidence_accuracy': []
} # Create save directory models_path = self.config.cnn.get('model_dir', "models/enhanced_cnn") self.save_dir = Path(models_path) self.save_dir.mkdir(parents=True, exist_ok=True) logger.info("Enhanced CNN trainer initialized")
}
# Create save directory
models_path = self.config.cnn.get('model_dir', "models/enhanced_cnn")
self.save_dir = Path(models_path)
self.save_dir.mkdir(parents=True, exist_ok=True)
logger.info("Enhanced CNN trainer initialized")
def train_on_perfect_moves(self, min_samples: int = 100) -> Dict[str, Any]:
"""Train the model on perfect moves from the orchestrator"""
@ -563,4 +571,233 @@ class EnhancedCNNTrainer:
def get_model(self) -> EnhancedCNNModel:
"""Get the trained model"""
return self.model
return self.model
def __del__(self):
"""Cleanup"""
self.close_tensorboard()
def main():
"""Main function for standalone CNN live training with backtesting and analysis"""
import argparse
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
parser = argparse.ArgumentParser(description='Enhanced CNN Live Training with Backtesting and Analysis')
parser.add_argument('--symbols', type=str, nargs='+', default=['ETH/USDT', 'BTC/USDT'],
help='Trading symbols to train on')
parser.add_argument('--timeframes', type=str, nargs='+', default=['1m', '5m', '15m', '1h'],
help='Timeframes to use for training')
parser.add_argument('--epochs', type=int, default=100,
help='Number of training epochs')
parser.add_argument('--batch-size', type=int, default=32,
help='Training batch size')
parser.add_argument('--learning-rate', type=float, default=0.001,
help='Learning rate')
parser.add_argument('--save-path', type=str, default='models/enhanced_cnn/live_trained_model.pt',
help='Path to save the trained model')
parser.add_argument('--enable-backtesting', action='store_true', default=True,
help='Enable backtesting after training')
parser.add_argument('--enable-analysis', action='store_true', default=True,
help='Enable detailed analysis and reporting')
parser.add_argument('--enable-live-validation', action='store_true', default=True,
help='Enable live validation during training')
args = parser.parse_args()
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger.info("="*80)
logger.info("🧠 ENHANCED CNN LIVE TRAINING WITH BACKTESTING & ANALYSIS")
logger.info("="*80)
logger.info(f"Symbols: {args.symbols}")
logger.info(f"Timeframes: {args.timeframes}")
logger.info(f"Epochs: {args.epochs}")
logger.info(f"Batch Size: {args.batch_size}")
logger.info(f"Learning Rate: {args.learning_rate}")
logger.info(f"Save Path: {args.save_path}")
logger.info(f"Backtesting: {'Enabled' if args.enable_backtesting else 'Disabled'}")
logger.info(f"Analysis: {'Enabled' if args.enable_analysis else 'Disabled'}")
logger.info(f"Live Validation: {'Enabled' if args.enable_live_validation else 'Disabled'}")
logger.info("="*80)
try:
# Update config with command line arguments
config = get_config()
config.update('symbols', args.symbols)
config.update('timeframes', args.timeframes)
config.update('training', {
**config.training,
'epochs': args.epochs,
'batch_size': args.batch_size,
'learning_rate': args.learning_rate
})
# Initialize enhanced trainer
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
from core.data_provider import DataProvider
data_provider = DataProvider(config)
orchestrator = EnhancedTradingOrchestrator(data_provider)
trainer = EnhancedCNNTrainer(config, orchestrator)
# Phase 1: Data Collection and Preparation
logger.info("📊 Phase 1: Collecting and preparing training data...")
training_data = trainer.collect_training_data(args.symbols, lookback_days=30)
logger.info(f" Collected {len(training_data)} training samples")
# Phase 2: Model Training
logger.info("🧠 Phase 2: Training Enhanced CNN Model...")
training_results = trainer.train_on_perfect_moves(min_samples=1000)
logger.info("Training Results:")
logger.info(f" Best Validation Accuracy: {training_results['best_val_accuracy']:.4f}")
logger.info(f" Best Validation Loss: {training_results['best_val_loss']:.4f}")
logger.info(f" Total Epochs: {training_results['epochs_completed']}")
logger.info(f" Training Time: {training_results['total_time']:.2f}s")
# Phase 3: Model Evaluation
logger.info("📈 Phase 3: Model Evaluation...")
evaluation_results = trainer.evaluate_model(args.symbols[:1]) # Use first symbol for evaluation
logger.info("Evaluation Results:")
logger.info(f" Test Accuracy: {evaluation_results['test_accuracy']:.4f}")
logger.info(f" Test Loss: {evaluation_results['test_loss']:.4f}")
logger.info(f" Confidence Score: {evaluation_results['avg_confidence']:.4f}")
# Phase 4: Backtesting (if enabled)
if args.enable_backtesting:
logger.info("📊 Phase 4: Backtesting...")
# Create backtest environment
from trading.backtest_environment import BacktestEnvironment
backtest_env = BacktestEnvironment(
symbols=args.symbols,
timeframes=args.timeframes,
initial_balance=10000.0,
data_provider=data_provider
)
# Run backtest
backtest_results = backtest_env.run_backtest_with_model(
model=trainer.model,
lookback_days=7, # Test on last 7 days
max_trades_per_day=50
)
logger.info("Backtesting Results:")
logger.info(f" Total Returns: {backtest_results['total_return']:.2f}%")
logger.info(f" Win Rate: {backtest_results['win_rate']:.2f}%")
logger.info(f" Sharpe Ratio: {backtest_results['sharpe_ratio']:.4f}")
logger.info(f" Max Drawdown: {backtest_results['max_drawdown']:.2f}%")
logger.info(f" Total Trades: {backtest_results['total_trades']}")
logger.info(f" Profit Factor: {backtest_results['profit_factor']:.4f}")
# Phase 5: Analysis and Reporting (if enabled)
if args.enable_analysis:
logger.info("📋 Phase 5: Analysis and Reporting...")
# Generate comprehensive analysis report
analysis_report = trainer.generate_analysis_report(
training_results=training_results,
evaluation_results=evaluation_results,
backtest_results=backtest_results if args.enable_backtesting else None
)
# Save analysis report
report_path = Path(args.save_path).parent / "analysis_report.json"
report_path.parent.mkdir(parents=True, exist_ok=True)
with open(report_path, 'w') as f:
json.dump(analysis_report, f, indent=2, default=str)
logger.info(f" Analysis report saved: {report_path}")
# Generate performance plots
plots_dir = Path(args.save_path).parent / "plots"
plots_dir.mkdir(parents=True, exist_ok=True)
trainer.generate_performance_plots(
training_results=training_results,
evaluation_results=evaluation_results,
save_dir=plots_dir
)
logger.info(f" Performance plots saved: {plots_dir}")
# Phase 6: Model Saving
logger.info("💾 Phase 6: Saving trained model...")
model_path = Path(args.save_path)
model_path.parent.mkdir(parents=True, exist_ok=True)
trainer.model.save(str(model_path))
logger.info(f" Model saved: {model_path}")
# Save training metadata
metadata = {
'training_config': {
'symbols': args.symbols,
'timeframes': args.timeframes,
'epochs': args.epochs,
'batch_size': args.batch_size,
'learning_rate': args.learning_rate
},
'training_results': training_results,
'evaluation_results': evaluation_results
}
if args.enable_backtesting:
metadata['backtest_results'] = backtest_results
metadata_path = model_path.with_suffix('.json')
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2, default=str)
logger.info(f" Training metadata saved: {metadata_path}")
# Phase 7: Live Validation (if enabled)
if args.enable_live_validation:
logger.info("🔄 Phase 7: Live Validation...")
# Test model on recent live data
live_validation_results = trainer.run_live_validation(
symbols=args.symbols[:1], # Use first symbol
validation_hours=2 # Validate on last 2 hours
)
logger.info("Live Validation Results:")
logger.info(f" Prediction Accuracy: {live_validation_results['accuracy']:.2f}%")
logger.info(f" Average Confidence: {live_validation_results['avg_confidence']:.4f}")
logger.info(f" Predictions Made: {live_validation_results['total_predictions']}")
logger.info("="*80)
logger.info("🎉 ENHANCED CNN LIVE TRAINING COMPLETED SUCCESSFULLY!")
logger.info("="*80)
logger.info(f"📊 Model Path: {model_path}")
logger.info(f"📋 Metadata: {metadata_path}")
if args.enable_analysis:
logger.info(f"📈 Analysis Report: {report_path}")
logger.info(f"📊 Performance Plots: {plots_dir}")
logger.info("="*80)
except KeyboardInterrupt:
logger.info("Training interrupted by user")
return 1
except Exception as e:
logger.error(f"Training failed: {e}")
import traceback
logger.error(traceback.format_exc())
return 1
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -8,7 +8,27 @@ Comprehensive training pipeline for scalping RL agents:
- Memory-efficient training loops
"""
import torchimport numpy as npimport pandas as pdimport loggingfrom typing import Dict, List, Tuple, Optional, Anyimport timefrom pathlib import Pathimport matplotlib.pyplot as pltfrom collections import dequeimport randomfrom torch.utils.tensorboard import SummaryWriter# Add project importsimport sysimport ossys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))from core.config import get_configfrom core.data_provider import DataProviderfrom models.rl.scalping_agent import ScalpingEnvironment, ScalpingRLAgentfrom utils.model_utils import robust_save, robust_load
import torch
import numpy as np
import pandas as pd
import logging
from typing import Dict, List, Tuple, Optional, Any
import time
from pathlib import Path
import matplotlib.pyplot as plt
from collections import deque
import random
from torch.utils.tensorboard import SummaryWriter
# Add project imports
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.config import get_config
from core.data_provider import DataProvider
from models.rl.scalping_agent import ScalpingEnvironment, ScalpingRLAgent
from utils.model_utils import robust_save, robust_load
logger = logging.getLogger(__name__)