cash works again!
This commit is contained in:
@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Tuple, Any
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
@ -293,7 +294,14 @@ class EnhancedCNNTrainer:
|
||||
'train_accuracy': [],
|
||||
'val_accuracy': [],
|
||||
'confidence_accuracy': []
|
||||
} # Create save directory models_path = self.config.cnn.get('model_dir', "models/enhanced_cnn") self.save_dir = Path(models_path) self.save_dir.mkdir(parents=True, exist_ok=True) logger.info("Enhanced CNN trainer initialized")
|
||||
}
|
||||
|
||||
# Create save directory
|
||||
models_path = self.config.cnn.get('model_dir', "models/enhanced_cnn")
|
||||
self.save_dir = Path(models_path)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("Enhanced CNN trainer initialized")
|
||||
|
||||
def train_on_perfect_moves(self, min_samples: int = 100) -> Dict[str, Any]:
|
||||
"""Train the model on perfect moves from the orchestrator"""
|
||||
@ -563,4 +571,233 @@ class EnhancedCNNTrainer:
|
||||
|
||||
def get_model(self) -> EnhancedCNNModel:
|
||||
"""Get the trained model"""
|
||||
return self.model
|
||||
return self.model
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup"""
|
||||
self.close_tensorboard()
|
||||
|
||||
def main():
|
||||
"""Main function for standalone CNN live training with backtesting and analysis"""
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
parser = argparse.ArgumentParser(description='Enhanced CNN Live Training with Backtesting and Analysis')
|
||||
parser.add_argument('--symbols', type=str, nargs='+', default=['ETH/USDT', 'BTC/USDT'],
|
||||
help='Trading symbols to train on')
|
||||
parser.add_argument('--timeframes', type=str, nargs='+', default=['1m', '5m', '15m', '1h'],
|
||||
help='Timeframes to use for training')
|
||||
parser.add_argument('--epochs', type=int, default=100,
|
||||
help='Number of training epochs')
|
||||
parser.add_argument('--batch-size', type=int, default=32,
|
||||
help='Training batch size')
|
||||
parser.add_argument('--learning-rate', type=float, default=0.001,
|
||||
help='Learning rate')
|
||||
parser.add_argument('--save-path', type=str, default='models/enhanced_cnn/live_trained_model.pt',
|
||||
help='Path to save the trained model')
|
||||
parser.add_argument('--enable-backtesting', action='store_true', default=True,
|
||||
help='Enable backtesting after training')
|
||||
parser.add_argument('--enable-analysis', action='store_true', default=True,
|
||||
help='Enable detailed analysis and reporting')
|
||||
parser.add_argument('--enable-live-validation', action='store_true', default=True,
|
||||
help='Enable live validation during training')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
logger.info("="*80)
|
||||
logger.info("🧠 ENHANCED CNN LIVE TRAINING WITH BACKTESTING & ANALYSIS")
|
||||
logger.info("="*80)
|
||||
logger.info(f"Symbols: {args.symbols}")
|
||||
logger.info(f"Timeframes: {args.timeframes}")
|
||||
logger.info(f"Epochs: {args.epochs}")
|
||||
logger.info(f"Batch Size: {args.batch_size}")
|
||||
logger.info(f"Learning Rate: {args.learning_rate}")
|
||||
logger.info(f"Save Path: {args.save_path}")
|
||||
logger.info(f"Backtesting: {'Enabled' if args.enable_backtesting else 'Disabled'}")
|
||||
logger.info(f"Analysis: {'Enabled' if args.enable_analysis else 'Disabled'}")
|
||||
logger.info(f"Live Validation: {'Enabled' if args.enable_live_validation else 'Disabled'}")
|
||||
logger.info("="*80)
|
||||
|
||||
try:
|
||||
# Update config with command line arguments
|
||||
config = get_config()
|
||||
config.update('symbols', args.symbols)
|
||||
config.update('timeframes', args.timeframes)
|
||||
config.update('training', {
|
||||
**config.training,
|
||||
'epochs': args.epochs,
|
||||
'batch_size': args.batch_size,
|
||||
'learning_rate': args.learning_rate
|
||||
})
|
||||
|
||||
# Initialize enhanced trainer
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
data_provider = DataProvider(config)
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
trainer = EnhancedCNNTrainer(config, orchestrator)
|
||||
|
||||
# Phase 1: Data Collection and Preparation
|
||||
logger.info("📊 Phase 1: Collecting and preparing training data...")
|
||||
training_data = trainer.collect_training_data(args.symbols, lookback_days=30)
|
||||
logger.info(f" Collected {len(training_data)} training samples")
|
||||
|
||||
# Phase 2: Model Training
|
||||
logger.info("🧠 Phase 2: Training Enhanced CNN Model...")
|
||||
training_results = trainer.train_on_perfect_moves(min_samples=1000)
|
||||
|
||||
logger.info("Training Results:")
|
||||
logger.info(f" Best Validation Accuracy: {training_results['best_val_accuracy']:.4f}")
|
||||
logger.info(f" Best Validation Loss: {training_results['best_val_loss']:.4f}")
|
||||
logger.info(f" Total Epochs: {training_results['epochs_completed']}")
|
||||
logger.info(f" Training Time: {training_results['total_time']:.2f}s")
|
||||
|
||||
# Phase 3: Model Evaluation
|
||||
logger.info("📈 Phase 3: Model Evaluation...")
|
||||
evaluation_results = trainer.evaluate_model(args.symbols[:1]) # Use first symbol for evaluation
|
||||
|
||||
logger.info("Evaluation Results:")
|
||||
logger.info(f" Test Accuracy: {evaluation_results['test_accuracy']:.4f}")
|
||||
logger.info(f" Test Loss: {evaluation_results['test_loss']:.4f}")
|
||||
logger.info(f" Confidence Score: {evaluation_results['avg_confidence']:.4f}")
|
||||
|
||||
# Phase 4: Backtesting (if enabled)
|
||||
if args.enable_backtesting:
|
||||
logger.info("📊 Phase 4: Backtesting...")
|
||||
|
||||
# Create backtest environment
|
||||
from trading.backtest_environment import BacktestEnvironment
|
||||
backtest_env = BacktestEnvironment(
|
||||
symbols=args.symbols,
|
||||
timeframes=args.timeframes,
|
||||
initial_balance=10000.0,
|
||||
data_provider=data_provider
|
||||
)
|
||||
|
||||
# Run backtest
|
||||
backtest_results = backtest_env.run_backtest_with_model(
|
||||
model=trainer.model,
|
||||
lookback_days=7, # Test on last 7 days
|
||||
max_trades_per_day=50
|
||||
)
|
||||
|
||||
logger.info("Backtesting Results:")
|
||||
logger.info(f" Total Returns: {backtest_results['total_return']:.2f}%")
|
||||
logger.info(f" Win Rate: {backtest_results['win_rate']:.2f}%")
|
||||
logger.info(f" Sharpe Ratio: {backtest_results['sharpe_ratio']:.4f}")
|
||||
logger.info(f" Max Drawdown: {backtest_results['max_drawdown']:.2f}%")
|
||||
logger.info(f" Total Trades: {backtest_results['total_trades']}")
|
||||
logger.info(f" Profit Factor: {backtest_results['profit_factor']:.4f}")
|
||||
|
||||
# Phase 5: Analysis and Reporting (if enabled)
|
||||
if args.enable_analysis:
|
||||
logger.info("📋 Phase 5: Analysis and Reporting...")
|
||||
|
||||
# Generate comprehensive analysis report
|
||||
analysis_report = trainer.generate_analysis_report(
|
||||
training_results=training_results,
|
||||
evaluation_results=evaluation_results,
|
||||
backtest_results=backtest_results if args.enable_backtesting else None
|
||||
)
|
||||
|
||||
# Save analysis report
|
||||
report_path = Path(args.save_path).parent / "analysis_report.json"
|
||||
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(report_path, 'w') as f:
|
||||
json.dump(analysis_report, f, indent=2, default=str)
|
||||
|
||||
logger.info(f" Analysis report saved: {report_path}")
|
||||
|
||||
# Generate performance plots
|
||||
plots_dir = Path(args.save_path).parent / "plots"
|
||||
plots_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
trainer.generate_performance_plots(
|
||||
training_results=training_results,
|
||||
evaluation_results=evaluation_results,
|
||||
save_dir=plots_dir
|
||||
)
|
||||
|
||||
logger.info(f" Performance plots saved: {plots_dir}")
|
||||
|
||||
# Phase 6: Model Saving
|
||||
logger.info("💾 Phase 6: Saving trained model...")
|
||||
model_path = Path(args.save_path)
|
||||
model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
trainer.model.save(str(model_path))
|
||||
logger.info(f" Model saved: {model_path}")
|
||||
|
||||
# Save training metadata
|
||||
metadata = {
|
||||
'training_config': {
|
||||
'symbols': args.symbols,
|
||||
'timeframes': args.timeframes,
|
||||
'epochs': args.epochs,
|
||||
'batch_size': args.batch_size,
|
||||
'learning_rate': args.learning_rate
|
||||
},
|
||||
'training_results': training_results,
|
||||
'evaluation_results': evaluation_results
|
||||
}
|
||||
|
||||
if args.enable_backtesting:
|
||||
metadata['backtest_results'] = backtest_results
|
||||
|
||||
metadata_path = model_path.with_suffix('.json')
|
||||
with open(metadata_path, 'w') as f:
|
||||
json.dump(metadata, f, indent=2, default=str)
|
||||
|
||||
logger.info(f" Training metadata saved: {metadata_path}")
|
||||
|
||||
# Phase 7: Live Validation (if enabled)
|
||||
if args.enable_live_validation:
|
||||
logger.info("🔄 Phase 7: Live Validation...")
|
||||
|
||||
# Test model on recent live data
|
||||
live_validation_results = trainer.run_live_validation(
|
||||
symbols=args.symbols[:1], # Use first symbol
|
||||
validation_hours=2 # Validate on last 2 hours
|
||||
)
|
||||
|
||||
logger.info("Live Validation Results:")
|
||||
logger.info(f" Prediction Accuracy: {live_validation_results['accuracy']:.2f}%")
|
||||
logger.info(f" Average Confidence: {live_validation_results['avg_confidence']:.4f}")
|
||||
logger.info(f" Predictions Made: {live_validation_results['total_predictions']}")
|
||||
|
||||
logger.info("="*80)
|
||||
logger.info("🎉 ENHANCED CNN LIVE TRAINING COMPLETED SUCCESSFULLY!")
|
||||
logger.info("="*80)
|
||||
logger.info(f"📊 Model Path: {model_path}")
|
||||
logger.info(f"📋 Metadata: {metadata_path}")
|
||||
if args.enable_analysis:
|
||||
logger.info(f"📈 Analysis Report: {report_path}")
|
||||
logger.info(f"📊 Performance Plots: {plots_dir}")
|
||||
logger.info("="*80)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
return 1
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
@ -8,7 +8,27 @@ Comprehensive training pipeline for scalping RL agents:
|
||||
- Memory-efficient training loops
|
||||
"""
|
||||
|
||||
import torchimport numpy as npimport pandas as pdimport loggingfrom typing import Dict, List, Tuple, Optional, Anyimport timefrom pathlib import Pathimport matplotlib.pyplot as pltfrom collections import dequeimport randomfrom torch.utils.tensorboard import SummaryWriter# Add project importsimport sysimport ossys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))from core.config import get_configfrom core.data_provider import DataProviderfrom models.rl.scalping_agent import ScalpingEnvironment, ScalpingRLAgentfrom utils.model_utils import robust_save, robust_load
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Optional, Any
|
||||
import time
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import deque
|
||||
import random
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
# Add project imports
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from models.rl.scalping_agent import ScalpingEnvironment, ScalpingRLAgent
|
||||
from utils.model_utils import robust_save, robust_load
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
Reference in New Issue
Block a user