gogo2/main_clean.py
2025-05-25 00:28:52 +03:00

381 lines
15 KiB
Python

#!/usr/bin/env python3
"""
Clean Trading System - Main Entry Point
Unified entry point for the clean trading architecture with these modes:
- test: Test data provider and orchestrator
- cnn: Train CNN models only
- rl: Train RL agents only
- train: Train both CNN and RL models
- trade: Live trading mode
- web: Web dashboard with real-time charts
Usage:
python main_clean.py --mode [test|cnn|rl|train|trade|web] --symbol ETH/USDT
"""
import asyncio
import argparse
import logging
import sys
from pathlib import Path
from threading import Thread
import time
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import get_config, setup_logging, Config
from core.data_provider import DataProvider
from core.orchestrator import TradingOrchestrator
logger = logging.getLogger(__name__)
def run_data_test():
"""Test the enhanced data provider functionality"""
try:
config = get_config()
logger.info("Testing Enhanced Data Provider...")
# Test data provider with multiple timeframes
data_provider = DataProvider(
symbols=['ETH/USDT'],
timeframes=['1s', '1m', '1h', '4h'] # Include 1s for scalping
)
# Test historical data
logger.info("Testing historical data fetching...")
df = data_provider.get_historical_data('ETH/USDT', '1h', limit=100)
if df is not None:
logger.info(f"[SUCCESS] Historical data: {len(df)} candles loaded")
logger.info(f" Columns: {len(df.columns)} total")
logger.info(f" Date range: {df['timestamp'].min()} to {df['timestamp'].max()}")
# Show indicator breakdown
basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
indicators = [col for col in df.columns if col not in basic_cols]
logger.info(f" Technical indicators: {len(indicators)}")
else:
logger.error("[FAILED] Failed to load historical data")
# Test multi-timeframe feature matrix
logger.info("Testing multi-timeframe feature matrix...")
feature_matrix = data_provider.get_feature_matrix('ETH/USDT', ['1h', '4h'], window_size=20)
if feature_matrix is not None:
logger.info(f"[SUCCESS] Feature matrix shape: {feature_matrix.shape}")
logger.info(f" Timeframes: {feature_matrix.shape[0]}")
logger.info(f" Window size: {feature_matrix.shape[1]}")
logger.info(f" Features: {feature_matrix.shape[2]}")
else:
logger.error("[FAILED] Failed to create feature matrix")
# Test health check
health = data_provider.health_check()
logger.info(f"[SUCCESS] Data provider health check completed")
logger.info("Enhanced data provider test completed successfully!")
except Exception as e:
logger.error(f"Error in data test: {e}")
import traceback
logger.error(traceback.format_exc())
raise
def run_cnn_training(config: Config, symbol: str):
"""Run CNN training mode with TensorBoard monitoring"""
logger.info("Starting CNN Training Mode...")
# Import CNNTrainer
from training.cnn_trainer import CNNTrainer
# Initialize data provider and trainer
data_provider = DataProvider(config)
trainer = CNNTrainer(config)
# Use configured symbols or provided symbol
symbols = config.symbols if symbol == "ETH/USDT" else [symbol] + config.symbols
save_path = f"models/cnn/scalping_cnn_trained.pt"
logger.info(f"Training CNN for symbols: {symbols}")
logger.info(f"Will save to: {save_path}")
logger.info(f"🔗 Monitor training: tensorboard --logdir=runs")
try:
# Train model with TensorBoard logging
results = trainer.train(symbols, save_path=save_path)
logger.info("CNN Training Results:")
logger.info(f" Best validation accuracy: {results['best_val_accuracy']:.4f}")
logger.info(f" Best validation loss: {results['best_val_loss']:.4f}")
logger.info(f" Total epochs: {results['total_epochs']}")
logger.info(f" Training time: {results['training_time']:.2f} seconds")
logger.info(f" TensorBoard logs: {results['tensorboard_dir']}")
logger.info(f"📊 View training progress: tensorboard --logdir=runs")
logger.info("Evaluating CNN on test data...")
# Quick evaluation on same symbols
test_results = trainer.evaluate(symbols[:1]) # Use first symbol for quick test
logger.info("CNN Evaluation Results:")
logger.info(f" Test accuracy: {test_results['test_accuracy']:.4f}")
logger.info(f" Test loss: {test_results['test_loss']:.4f}")
logger.info(f" Average confidence: {test_results['avg_confidence']:.4f}")
logger.info("CNN training completed successfully!")
except Exception as e:
logger.error(f"CNN training failed: {e}")
raise
finally:
trainer.close_tensorboard()
def run_rl_training():
"""Train RL agents only with comprehensive pipeline"""
try:
logger.info("Starting RL Training Mode...")
# Initialize components for RL
data_provider = DataProvider(
symbols=['ETH/USDT'],
timeframes=['1s', '1m', '5m', '1h'] # Focus on scalping timeframes
)
# Import and create RL trainer
from training.rl_trainer import RLTrainer
trainer = RLTrainer(data_provider)
# Configure training
trainer.num_episodes = 1000
trainer.max_steps_per_episode = 1000
trainer.evaluation_frequency = 50
trainer.save_frequency = 100
# Train the agent
save_path = 'models/rl/scalping_agent_trained.pt'
logger.info(f"Training RL agent for scalping")
logger.info(f"Will save to: {save_path}")
results = trainer.train(save_path)
# Log results
logger.info("RL Training Results:")
logger.info(f" Best reward: {results['best_reward']:.4f}")
logger.info(f" Best balance: ${results['best_balance']:.2f}")
logger.info(f" Total episodes: {results['total_episodes']}")
logger.info(f" Training time: {results['total_time']:.2f} seconds")
logger.info(f" Final epsilon: {results['agent_config']['epsilon_final']:.4f}")
# Final evaluation results
final_eval = results['final_evaluation']
logger.info("Final Evaluation:")
logger.info(f" Win rate: {final_eval['win_rate']:.2%}")
logger.info(f" Average PnL: {final_eval['avg_pnl_percentage']:.2f}%")
logger.info(f" Average trades: {final_eval['avg_trades']:.1f}")
# Plot training progress
try:
plot_path = 'models/rl/training_progress.png'
trainer.plot_training_progress(plot_path)
logger.info(f"Training plots saved to: {plot_path}")
except Exception as e:
logger.warning(f"Could not save training plots: {e}")
# Backtest the trained agent
try:
logger.info("Backtesting trained agent...")
backtest_results = trainer.backtest_agent(save_path, test_episodes=50)
analysis = backtest_results['analysis']
logger.info("Backtest Results:")
logger.info(f" Win rate: {analysis['win_rate']:.2%}")
logger.info(f" Average PnL: {analysis['avg_pnl']:.2f}%")
logger.info(f" Sharpe ratio: {analysis['sharpe_ratio']:.4f}")
logger.info(f" Max drawdown: {analysis['max_drawdown']:.2f}%")
except Exception as e:
logger.warning(f"Could not run backtest: {e}")
logger.info("RL training completed successfully!")
except Exception as e:
logger.error(f"Error in RL training: {e}")
import traceback
logger.error(traceback.format_exc())
raise
def run_combined_training():
"""Train both CNN and RL models with hybrid approach"""
try:
logger.info("Starting Hybrid CNN + RL Training Mode...")
# Initialize data provider
data_provider = DataProvider(
symbols=['ETH/USDT', 'BTC/USDT'],
timeframes=['1s', '1m', '5m', '1h', '4h']
)
# Import and create hybrid trainer
from training.rl_trainer import HybridTrainer
trainer = HybridTrainer(data_provider)
# Define save paths
cnn_save_path = 'models/cnn/hybrid_cnn_trained.pt'
rl_save_path = 'models/rl/hybrid_rl_trained.pt'
# Train hybrid system
symbols = ['ETH/USDT', 'BTC/USDT']
logger.info(f"Training hybrid system for symbols: {symbols}")
results = trainer.train_hybrid(symbols, cnn_save_path, rl_save_path)
# Log results
cnn_results = results['cnn_results']
rl_results = results['rl_results']
logger.info("Hybrid Training Results:")
logger.info("CNN Phase:")
logger.info(f" Best accuracy: {cnn_results['best_val_accuracy']:.4f}")
logger.info(f" Training time: {cnn_results['total_time']:.2f}s")
logger.info("RL Phase:")
logger.info(f" Best reward: {rl_results['best_reward']:.4f}")
logger.info(f" Final balance: ${rl_results['best_balance']:.2f}")
logger.info(f" Training time: {rl_results['total_time']:.2f}s")
logger.info(f"Total training time: {results['total_time']:.2f}s")
logger.info("Hybrid training completed successfully!")
except Exception as e:
logger.error(f"Error in hybrid training: {e}")
import traceback
logger.error(traceback.format_exc())
raise
def run_live_trading():
"""Run live trading mode"""
try:
logger.info("Starting Live Trading Mode...")
# Initialize for live trading with 1s scalping focus
data_provider = DataProvider(
symbols=['ETH/USDT'],
timeframes=['1s', '1m', '5m', '15m']
)
orchestrator = TradingOrchestrator(data_provider)
# Start real-time data streaming
logger.info("Starting real-time data streaming...")
# This would integrate with your live trading logic
logger.info("Live trading mode ready!")
logger.info("Note: Integrate this with your actual trading execution")
except Exception as e:
logger.error(f"Error in live trading: {e}")
raise
def run_web_dashboard():
"""Run web dashboard with enhanced real-time data - NO SYNTHETIC DATA"""
try:
logger.info("Starting Web Dashboard Mode with REAL LIVE DATA...")
# Initialize with real data provider
data_provider = DataProvider()
# Verify we have real data connection
logger.info("🔍 Verifying REAL data connection...")
test_data = data_provider.get_historical_data('ETH/USDT', '1m', limit=10, refresh=True)
if test_data is None or test_data.empty:
logger.warning("⚠️ No fresh data available - trying cached data...")
test_data = data_provider.get_historical_data('ETH/USDT', '1m', limit=10, refresh=False)
if test_data is None or test_data.empty:
logger.warning("⚠️ No data available - starting dashboard with demo mode...")
else:
logger.info("✅ Data connection verified")
logger.info(f"✅ Fetched {len(test_data)} candles for validation")
# Initialize orchestrator with real data only
orchestrator = TradingOrchestrator(data_provider)
# Start dashboard - use the correct import
from web.dashboard import TradingDashboard
dashboard = TradingDashboard(data_provider, orchestrator)
logger.info("🎯 LAUNCHING DASHBOARD")
logger.info(f"🌐 Access at: http://127.0.0.1:8050")
# Run the dashboard
dashboard.run(host='127.0.0.1', port=8050, debug=False)
except Exception as e:
logger.error(f"Error in web dashboard: {e}")
logger.error("Dashboard stopped - trying fallback mode")
# Try a simpler fallback
try:
from web.dashboard import TradingDashboard
data_provider = DataProvider()
orchestrator = TradingOrchestrator(data_provider)
dashboard = TradingDashboard(data_provider, orchestrator)
dashboard.run(host='127.0.0.1', port=8050, debug=False)
except Exception as fallback_error:
logger.error(f"Fallback dashboard also failed: {fallback_error}")
raise
async def main():
"""Main entry point with clean mode selection"""
parser = argparse.ArgumentParser(description='Clean Trading System - Unified Entry Point')
parser.add_argument('--mode',
choices=['test', 'cnn', 'rl', 'train', 'trade', 'web'],
default='test',
help='Operation mode')
parser.add_argument('--symbol', type=str, default='ETH/USDT',
help='Trading symbol (default: ETH/USDT)')
parser.add_argument('--port', type=int, default=8050,
help='Web dashboard port (default: 8050)')
parser.add_argument('--demo', action='store_true',
help='Run web dashboard in demo mode')
args = parser.parse_args()
# Setup logging
setup_logging()
try:
logger.info("=" * 60)
logger.info("CLEAN TRADING SYSTEM - UNIFIED LAUNCH")
logger.info(f"Mode: {args.mode.upper()}")
logger.info(f"Symbol: {args.symbol}")
logger.info("=" * 60)
# Route to appropriate mode
if args.mode == 'test':
run_data_test()
elif args.mode == 'cnn':
run_cnn_training(get_config(), args.symbol)
elif args.mode == 'rl':
run_rl_training()
elif args.mode == 'train':
run_combined_training()
elif args.mode == 'trade':
run_live_trading()
elif args.mode == 'web':
run_web_dashboard()
logger.info("Operation completed successfully!")
except KeyboardInterrupt:
logger.info("System shutdown requested by user")
except Exception as e:
logger.error(f"Fatal error: {e}")
import traceback
logger.error(traceback.format_exc())
return 1
return 0
if __name__ == "__main__":
sys.exit(asyncio.run(main()))