a bit of cleanup
This commit is contained in:
352
main_clean.py
352
main_clean.py
@ -1,17 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Clean Trading System - Main Entry Point
|
||||
Clean Trading System - Streamlined Entry Point
|
||||
|
||||
Unified entry point for the clean trading architecture with these modes:
|
||||
- test: Test data provider and orchestrator
|
||||
- cnn: Train CNN models only
|
||||
- rl: Train RL agents only
|
||||
- train: Train both CNN and RL models
|
||||
- trade: Live trading mode
|
||||
- web: Web dashboard with real-time charts
|
||||
Simplified entry point with only essential modes:
|
||||
- test: Test data provider and core components
|
||||
- web: Live trading dashboard with integrated training pipeline
|
||||
|
||||
Streamlined Flow: Data -> Indicators/Pivots -> CNN -> RL -> Orchestrator -> Execution
|
||||
|
||||
Usage:
|
||||
python main_clean.py --mode [test|cnn|rl|train|trade|web] --symbol ETH/USDT
|
||||
python main_clean.py --mode [test|web] --symbol ETH/USDT
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@ -28,20 +26,19 @@ sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import get_config, setup_logging, Config
|
||||
from core.data_provider import DataProvider
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def run_data_test():
|
||||
"""Test the enhanced data provider functionality"""
|
||||
"""Test the enhanced data provider and core components"""
|
||||
try:
|
||||
config = get_config()
|
||||
logger.info("Testing Enhanced Data Provider...")
|
||||
logger.info("Testing Enhanced Data Provider and Core Components...")
|
||||
|
||||
# Test data provider with multiple timeframes
|
||||
data_provider = DataProvider(
|
||||
symbols=['ETH/USDT'],
|
||||
timeframes=['1s', '1m', '1h', '4h'] # Include 1s for scalping
|
||||
timeframes=['1s', '1m', '1h', '4h']
|
||||
)
|
||||
|
||||
# Test historical data
|
||||
@ -70,321 +67,149 @@ def run_data_test():
|
||||
else:
|
||||
logger.error("[FAILED] Failed to create feature matrix")
|
||||
|
||||
# Test CNN model availability
|
||||
try:
|
||||
from NN.models.cnn_model import CNNModel
|
||||
cnn = CNNModel(n_actions=2) # 2-action system
|
||||
logger.info("[SUCCESS] CNN model initialized with 2 actions (BUY/SELL)")
|
||||
except Exception as e:
|
||||
logger.warning(f"[WARNING] CNN model not available: {e}")
|
||||
|
||||
# Test RL agent availability
|
||||
try:
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
agent = DQNAgent(state_shape=(50,), n_actions=2) # 2-action system
|
||||
logger.info("[SUCCESS] RL Agent initialized with 2 actions (BUY/SELL)")
|
||||
except Exception as e:
|
||||
logger.warning(f"[WARNING] RL Agent not available: {e}")
|
||||
|
||||
# Test orchestrator
|
||||
try:
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
logger.info("[SUCCESS] Enhanced Trading Orchestrator initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"[WARNING] Enhanced Orchestrator not available: {e}")
|
||||
|
||||
# Test health check
|
||||
health = data_provider.health_check()
|
||||
logger.info(f"[SUCCESS] Data provider health check completed")
|
||||
|
||||
logger.info("Enhanced data provider test completed successfully!")
|
||||
logger.info("[SUCCESS] Core system test completed successfully!")
|
||||
logger.info("2-Action System: BUY/SELL only (no HOLD)")
|
||||
logger.info("Streamlined Flow: Data -> Indicators -> CNN -> RL -> Orchestrator -> Execution")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in data test: {e}")
|
||||
logger.error(f"Error in system test: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
def run_cnn_training(config: Config, symbol: str):
|
||||
"""Run CNN training mode with TensorBoard monitoring"""
|
||||
logger.info("Starting CNN Training Mode...")
|
||||
|
||||
# Import CNNTrainer
|
||||
from training.cnn_trainer import CNNTrainer
|
||||
|
||||
# Initialize data provider and trainer
|
||||
data_provider = DataProvider(config)
|
||||
trainer = CNNTrainer(config)
|
||||
|
||||
# Use configured symbols or provided symbol
|
||||
symbols = config.symbols if symbol == "ETH/USDT" else [symbol] + config.symbols
|
||||
save_path = f"models/cnn/scalping_cnn_trained.pt"
|
||||
|
||||
logger.info(f"Training CNN for symbols: {symbols}")
|
||||
logger.info(f"Will save to: {save_path}")
|
||||
logger.info(f"🔗 Monitor training: tensorboard --logdir=runs")
|
||||
|
||||
try:
|
||||
# Train model with TensorBoard logging
|
||||
results = trainer.train(symbols, save_path=save_path)
|
||||
|
||||
logger.info("CNN Training Results:")
|
||||
logger.info(f" Best validation accuracy: {results['best_val_accuracy']:.4f}")
|
||||
logger.info(f" Best validation loss: {results['best_val_loss']:.4f}")
|
||||
logger.info(f" Total epochs: {results['total_epochs']}")
|
||||
logger.info(f" Training time: {results['training_time']:.2f} seconds")
|
||||
logger.info(f" TensorBoard logs: {results['tensorboard_dir']}")
|
||||
|
||||
logger.info(f"📊 View training progress: tensorboard --logdir=runs")
|
||||
logger.info("Evaluating CNN on test data...")
|
||||
|
||||
# Quick evaluation on same symbols
|
||||
test_results = trainer.evaluate(symbols[:1]) # Use first symbol for quick test
|
||||
logger.info("CNN Evaluation Results:")
|
||||
logger.info(f" Test accuracy: {test_results['test_accuracy']:.4f}")
|
||||
logger.info(f" Test loss: {test_results['test_loss']:.4f}")
|
||||
logger.info(f" Average confidence: {test_results['avg_confidence']:.4f}")
|
||||
|
||||
logger.info("CNN training completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CNN training failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
trainer.close_tensorboard()
|
||||
|
||||
def run_rl_training():
|
||||
"""Train RL agents only with comprehensive pipeline"""
|
||||
try:
|
||||
logger.info("Starting RL Training Mode...")
|
||||
|
||||
# Initialize components for RL
|
||||
data_provider = DataProvider(
|
||||
symbols=['ETH/USDT'],
|
||||
timeframes=['1s', '1m', '5m', '1h'] # Focus on scalping timeframes
|
||||
)
|
||||
|
||||
# Import and create RL trainer
|
||||
from training.rl_trainer import RLTrainer
|
||||
trainer = RLTrainer(data_provider)
|
||||
|
||||
# Configure training
|
||||
trainer.num_episodes = 1000
|
||||
trainer.max_steps_per_episode = 1000
|
||||
trainer.evaluation_frequency = 50
|
||||
trainer.save_frequency = 100
|
||||
|
||||
# Train the agent
|
||||
save_path = 'models/rl/scalping_agent_trained.pt'
|
||||
|
||||
logger.info(f"Training RL agent for scalping")
|
||||
logger.info(f"Will save to: {save_path}")
|
||||
|
||||
results = trainer.train(save_path)
|
||||
|
||||
# Log results
|
||||
logger.info("RL Training Results:")
|
||||
logger.info(f" Best reward: {results['best_reward']:.4f}")
|
||||
logger.info(f" Best balance: ${results['best_balance']:.2f}")
|
||||
logger.info(f" Total episodes: {results['total_episodes']}")
|
||||
logger.info(f" Training time: {results['total_time']:.2f} seconds")
|
||||
logger.info(f" Final epsilon: {results['agent_config']['epsilon_final']:.4f}")
|
||||
|
||||
# Final evaluation results
|
||||
final_eval = results['final_evaluation']
|
||||
logger.info("Final Evaluation:")
|
||||
logger.info(f" Win rate: {final_eval['win_rate']:.2%}")
|
||||
logger.info(f" Average PnL: {final_eval['avg_pnl_percentage']:.2f}%")
|
||||
logger.info(f" Average trades: {final_eval['avg_trades']:.1f}")
|
||||
|
||||
# Plot training progress
|
||||
try:
|
||||
plot_path = 'models/rl/training_progress.png'
|
||||
trainer.plot_training_progress(plot_path)
|
||||
logger.info(f"Training plots saved to: {plot_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not save training plots: {e}")
|
||||
|
||||
# Backtest the trained agent
|
||||
try:
|
||||
logger.info("Backtesting trained agent...")
|
||||
backtest_results = trainer.backtest_agent(save_path, test_episodes=50)
|
||||
|
||||
analysis = backtest_results['analysis']
|
||||
logger.info("Backtest Results:")
|
||||
logger.info(f" Win rate: {analysis['win_rate']:.2%}")
|
||||
logger.info(f" Average PnL: {analysis['avg_pnl']:.2f}%")
|
||||
logger.info(f" Sharpe ratio: {analysis['sharpe_ratio']:.4f}")
|
||||
logger.info(f" Max drawdown: {analysis['max_drawdown']:.2f}%")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not run backtest: {e}")
|
||||
|
||||
logger.info("RL training completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in RL training: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
def run_combined_training():
|
||||
"""Train both CNN and RL models with hybrid approach"""
|
||||
try:
|
||||
logger.info("Starting Hybrid CNN + RL Training Mode...")
|
||||
|
||||
# Initialize data provider
|
||||
data_provider = DataProvider(
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
timeframes=['1s', '1m', '5m', '1h', '4h']
|
||||
)
|
||||
|
||||
# Import and create hybrid trainer
|
||||
from training.rl_trainer import HybridTrainer
|
||||
trainer = HybridTrainer(data_provider)
|
||||
|
||||
# Define save paths
|
||||
cnn_save_path = 'models/cnn/hybrid_cnn_trained.pt'
|
||||
rl_save_path = 'models/rl/hybrid_rl_trained.pt'
|
||||
|
||||
# Train hybrid system
|
||||
symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
logger.info(f"Training hybrid system for symbols: {symbols}")
|
||||
|
||||
results = trainer.train_hybrid(symbols, cnn_save_path, rl_save_path)
|
||||
|
||||
# Log results
|
||||
cnn_results = results['cnn_results']
|
||||
rl_results = results['rl_results']
|
||||
|
||||
logger.info("Hybrid Training Results:")
|
||||
logger.info("CNN Phase:")
|
||||
logger.info(f" Best accuracy: {cnn_results['best_val_accuracy']:.4f}")
|
||||
logger.info(f" Training time: {cnn_results['total_time']:.2f}s")
|
||||
|
||||
logger.info("RL Phase:")
|
||||
logger.info(f" Best reward: {rl_results['best_reward']:.4f}")
|
||||
logger.info(f" Final balance: ${rl_results['best_balance']:.2f}")
|
||||
logger.info(f" Training time: {rl_results['total_time']:.2f}s")
|
||||
|
||||
logger.info(f"Total training time: {results['total_time']:.2f}s")
|
||||
|
||||
logger.info("Hybrid training completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in hybrid training: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
def run_live_trading():
|
||||
"""Run live trading mode"""
|
||||
try:
|
||||
logger.info("Starting Live Trading Mode...")
|
||||
|
||||
# Initialize for live trading with 1s scalping focus
|
||||
data_provider = DataProvider(
|
||||
symbols=['ETH/USDT'],
|
||||
timeframes=['1s', '1m', '5m', '15m']
|
||||
)
|
||||
orchestrator = TradingOrchestrator(data_provider)
|
||||
|
||||
# Start real-time data streaming
|
||||
logger.info("Starting real-time data streaming...")
|
||||
|
||||
# This would integrate with your live trading logic
|
||||
logger.info("Live trading mode ready!")
|
||||
logger.info("Note: Integrate this with your actual trading execution")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in live trading: {e}")
|
||||
raise
|
||||
|
||||
def run_web_dashboard():
|
||||
"""Run the web dashboard with real live data"""
|
||||
"""Run the streamlined web dashboard with integrated training pipeline"""
|
||||
try:
|
||||
logger.info("Starting Web Dashboard Mode with REAL LIVE DATA...")
|
||||
logger.info("Starting Streamlined Trading Dashboard...")
|
||||
logger.info("2-Action System: BUY/SELL with intelligent position management")
|
||||
logger.info("Integrated Training Pipeline: Live data -> Models -> Trading")
|
||||
|
||||
# Get configuration
|
||||
config = get_config()
|
||||
|
||||
# Initialize core components with enhanced RL support
|
||||
from core.tick_aggregator import RealTimeTickAggregator
|
||||
# Initialize core components for streamlined pipeline
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator # Use enhanced version
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.trading_executor import TradingExecutor
|
||||
|
||||
# Create tick aggregator for real-time data - fix parameter name
|
||||
tick_aggregator = RealTimeTickAggregator(
|
||||
symbols=['ETHUSDC', 'BTCUSDT', 'MXUSDT'],
|
||||
tick_buffer_size=10000 # Changed from buffer_size to tick_buffer_size
|
||||
)
|
||||
|
||||
# Create data provider
|
||||
data_provider = DataProvider()
|
||||
|
||||
# Verify data connection with real data
|
||||
logger.info("[DATA] Verifying REAL data connection...")
|
||||
# Verify data connection
|
||||
logger.info("[DATA] Verifying live data connection...")
|
||||
symbol = config.get('symbols', ['ETH/USDT'])[0]
|
||||
test_df = data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||
if test_df is not None and len(test_df) > 0:
|
||||
logger.info("[SUCCESS] Data connection verified")
|
||||
logger.info(f"[SUCCESS] Fetched {len(test_df)} candles for validation")
|
||||
else:
|
||||
logger.error("[ERROR] Data connection failed - no real data available")
|
||||
logger.error("[ERROR] Data connection failed - no live data available")
|
||||
return
|
||||
|
||||
# Load model registry - create simple fallback
|
||||
# Load model registry for integrated pipeline
|
||||
try:
|
||||
from core.model_registry import get_model_registry
|
||||
model_registry = get_model_registry()
|
||||
logger.info("[MODELS] Model registry loaded for integrated training")
|
||||
except ImportError:
|
||||
model_registry = {} # Fallback empty registry
|
||||
model_registry = {}
|
||||
logger.warning("Model registry not available, using empty registry")
|
||||
|
||||
# Create ENHANCED trading orchestrator for RL training
|
||||
# Create streamlined orchestrator with 2-action system
|
||||
orchestrator = EnhancedTradingOrchestrator(
|
||||
data_provider=data_provider,
|
||||
symbols=config.get('symbols', ['ETH/USDT']),
|
||||
enhanced_rl_training=True, # Enable enhanced RL
|
||||
enhanced_rl_training=True,
|
||||
model_registry=model_registry
|
||||
)
|
||||
logger.info("Enhanced RL Trading Orchestrator initialized")
|
||||
logger.info("Enhanced Trading Orchestrator with 2-Action System initialized")
|
||||
|
||||
# Create trading executor (handles MEXC integration)
|
||||
# Create trading executor for live execution
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Import and create enhanced dashboard
|
||||
# Import and create streamlined dashboard
|
||||
from web.dashboard import TradingDashboard
|
||||
dashboard = TradingDashboard(
|
||||
data_provider=data_provider,
|
||||
orchestrator=orchestrator, # Enhanced orchestrator
|
||||
orchestrator=orchestrator,
|
||||
trading_executor=trading_executor
|
||||
)
|
||||
|
||||
# Start the dashboard
|
||||
# Start the integrated dashboard
|
||||
port = config.get('web', {}).get('port', 8050)
|
||||
host = config.get('web', {}).get('host', '127.0.0.1')
|
||||
|
||||
logger.info(f"TRADING: Starting Live Scalping Dashboard at http://{host}:{port}")
|
||||
logger.info("Enhanced RL Training: ENABLED")
|
||||
logger.info("Real Market Data: ENABLED")
|
||||
logger.info("MEXC Integration: ENABLED")
|
||||
logger.info("CNN Training: ENABLED at Williams pivot points")
|
||||
logger.info(f"Starting Streamlined Dashboard at http://{host}:{port}")
|
||||
logger.info("Live Data Processing: ENABLED")
|
||||
logger.info("Integrated CNN Training: ENABLED")
|
||||
logger.info("Integrated RL Training: ENABLED")
|
||||
logger.info("Real-time Indicators & Pivots: ENABLED")
|
||||
logger.info("Live Trading Execution: ENABLED")
|
||||
logger.info("2-Action System: BUY/SELL with position intelligence")
|
||||
logger.info("Pipeline: Data -> Indicators -> CNN -> RL -> Orchestrator -> Execution")
|
||||
|
||||
dashboard.run(host=host, port=port, debug=False)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in web dashboard: {e}")
|
||||
logger.error("Dashboard stopped - trying fallback mode")
|
||||
logger.error(f"Error in streamlined dashboard: {e}")
|
||||
logger.error("Dashboard stopped - trying minimal fallback")
|
||||
|
||||
try:
|
||||
# Fallback to basic dashboard function - use working import
|
||||
# Minimal fallback dashboard
|
||||
from web.dashboard import TradingDashboard
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Create minimal dashboard
|
||||
data_provider = DataProvider()
|
||||
dashboard = TradingDashboard(data_provider)
|
||||
logger.info("Using fallback dashboard")
|
||||
logger.info("Using minimal fallback dashboard")
|
||||
dashboard.run(host='127.0.0.1', port=8050, debug=False)
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Fallback dashboard also failed: {fallback_error}")
|
||||
logger.error(f"Fallback dashboard failed: {fallback_error}")
|
||||
logger.error(f"Fatal error: {e}")
|
||||
import traceback
|
||||
logger.error("Traceback (most recent call last):")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def main():
|
||||
"""Main entry point with clean mode selection"""
|
||||
parser = argparse.ArgumentParser(description='Clean Trading System - Unified Entry Point')
|
||||
"""Main entry point with streamlined mode selection"""
|
||||
parser = argparse.ArgumentParser(description='Streamlined Trading System - Integrated Pipeline')
|
||||
parser.add_argument('--mode',
|
||||
choices=['test', 'cnn', 'rl', 'train', 'trade', 'web'],
|
||||
default='test',
|
||||
help='Operation mode')
|
||||
choices=['test', 'web'],
|
||||
default='web',
|
||||
help='Operation mode: test (system check) or web (live trading)')
|
||||
parser.add_argument('--symbol', type=str, default='ETH/USDT',
|
||||
help='Trading symbol (default: ETH/USDT)')
|
||||
help='Primary trading symbol (default: ETH/USDT)')
|
||||
parser.add_argument('--port', type=int, default=8050,
|
||||
help='Web dashboard port (default: 8050)')
|
||||
parser.add_argument('--demo', action='store_true',
|
||||
help='Run web dashboard in demo mode')
|
||||
parser.add_argument('--debug', action='store_true',
|
||||
help='Enable debug mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -392,27 +217,22 @@ async def main():
|
||||
setup_logging()
|
||||
|
||||
try:
|
||||
logger.info("=" * 60)
|
||||
logger.info("CLEAN TRADING SYSTEM - UNIFIED LAUNCH")
|
||||
logger.info("=" * 70)
|
||||
logger.info("STREAMLINED TRADING SYSTEM - INTEGRATED PIPELINE")
|
||||
logger.info(f"Mode: {args.mode.upper()}")
|
||||
logger.info(f"Symbol: {args.symbol}")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Primary Symbol: {args.symbol}")
|
||||
if args.mode == 'web':
|
||||
logger.info("Integrated Flow: Data -> Indicators -> CNN -> RL -> Execution")
|
||||
logger.info("2-Action System: BUY/SELL with intelligent position management")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Route to appropriate mode
|
||||
if args.mode == 'test':
|
||||
run_data_test()
|
||||
elif args.mode == 'cnn':
|
||||
run_cnn_training(get_config(), args.symbol)
|
||||
elif args.mode == 'rl':
|
||||
run_rl_training()
|
||||
elif args.mode == 'train':
|
||||
run_combined_training()
|
||||
elif args.mode == 'trade':
|
||||
run_live_trading()
|
||||
elif args.mode == 'web':
|
||||
run_web_dashboard()
|
||||
|
||||
logger.info("Operation completed successfully!")
|
||||
logger.info("[SUCCESS] Operation completed successfully!")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("System shutdown requested by user")
|
||||
|
Reference in New Issue
Block a user