430 lines
16 KiB
Python
430 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Clean Trading System - Main Entry Point
|
|
|
|
Unified entry point for the clean trading architecture with these modes:
|
|
- test: Test data provider and orchestrator
|
|
- cnn: Train CNN models only
|
|
- rl: Train RL agents only
|
|
- train: Train both CNN and RL models
|
|
- trade: Live trading mode
|
|
- web: Web dashboard with real-time charts
|
|
|
|
Usage:
|
|
python main_clean.py --mode [test|cnn|rl|train|trade|web] --symbol ETH/USDT
|
|
"""
|
|
|
|
import asyncio
|
|
import argparse
|
|
import logging
|
|
import sys
|
|
from pathlib import Path
|
|
from threading import Thread
|
|
import time
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from core.config import get_config, setup_logging
|
|
from core.data_provider import DataProvider
|
|
from core.orchestrator import TradingOrchestrator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def run_data_test():
|
|
"""Test the enhanced data provider functionality"""
|
|
try:
|
|
config = get_config()
|
|
logger.info("Testing Enhanced Data Provider...")
|
|
|
|
# Test data provider with multiple timeframes
|
|
data_provider = DataProvider(
|
|
symbols=['ETH/USDT'],
|
|
timeframes=['1s', '1m', '1h', '4h'] # Include 1s for scalping
|
|
)
|
|
|
|
# Test historical data
|
|
logger.info("Testing historical data fetching...")
|
|
df = data_provider.get_historical_data('ETH/USDT', '1h', limit=100)
|
|
if df is not None:
|
|
logger.info(f"[SUCCESS] Historical data: {len(df)} candles loaded")
|
|
logger.info(f" Columns: {len(df.columns)} total")
|
|
logger.info(f" Date range: {df['timestamp'].min()} to {df['timestamp'].max()}")
|
|
|
|
# Show indicator breakdown
|
|
basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
|
indicators = [col for col in df.columns if col not in basic_cols]
|
|
logger.info(f" Technical indicators: {len(indicators)}")
|
|
else:
|
|
logger.error("[FAILED] Failed to load historical data")
|
|
|
|
# Test multi-timeframe feature matrix
|
|
logger.info("Testing multi-timeframe feature matrix...")
|
|
feature_matrix = data_provider.get_feature_matrix('ETH/USDT', ['1h', '4h'], window_size=20)
|
|
if feature_matrix is not None:
|
|
logger.info(f"[SUCCESS] Feature matrix shape: {feature_matrix.shape}")
|
|
logger.info(f" Timeframes: {feature_matrix.shape[0]}")
|
|
logger.info(f" Window size: {feature_matrix.shape[1]}")
|
|
logger.info(f" Features: {feature_matrix.shape[2]}")
|
|
else:
|
|
logger.error("[FAILED] Failed to create feature matrix")
|
|
|
|
# Test health check
|
|
health = data_provider.health_check()
|
|
logger.info(f"[SUCCESS] Data provider health check completed")
|
|
|
|
logger.info("Enhanced data provider test completed successfully!")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in data test: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
def run_cnn_training():
|
|
"""Train CNN models only with comprehensive pipeline"""
|
|
try:
|
|
logger.info("Starting CNN Training Mode...")
|
|
|
|
# Initialize components
|
|
data_provider = DataProvider(
|
|
symbols=['ETH/USDT', 'BTC/USDT'],
|
|
timeframes=['1s', '1m', '5m', '1h', '4h']
|
|
)
|
|
|
|
# Import and create CNN trainer
|
|
from training.cnn_trainer import CNNTrainer
|
|
trainer = CNNTrainer(data_provider)
|
|
|
|
# Configure training
|
|
trainer.num_samples = 20000 # Training samples
|
|
trainer.batch_size = 64
|
|
trainer.num_epochs = 100
|
|
trainer.patience = 15
|
|
|
|
# Train the model
|
|
symbols = ['ETH/USDT', 'BTC/USDT']
|
|
save_path = 'models/cnn/scalping_cnn_trained.pt'
|
|
|
|
logger.info(f"Training CNN for symbols: {symbols}")
|
|
logger.info(f"Will save to: {save_path}")
|
|
|
|
results = trainer.train(symbols, save_path)
|
|
|
|
# Log results
|
|
logger.info("CNN Training Results:")
|
|
logger.info(f" Best validation accuracy: {results['best_val_accuracy']:.4f}")
|
|
logger.info(f" Best validation loss: {results['best_val_loss']:.4f}")
|
|
logger.info(f" Total epochs: {results['total_epochs']}")
|
|
logger.info(f" Training time: {results['total_time']:.2f} seconds")
|
|
|
|
# Plot training history
|
|
try:
|
|
plot_path = 'models/cnn/training_history.png'
|
|
trainer.plot_training_history(plot_path)
|
|
logger.info(f"Training plots saved to: {plot_path}")
|
|
except Exception as e:
|
|
logger.warning(f"Could not save training plots: {e}")
|
|
|
|
# Evaluate on test data
|
|
try:
|
|
logger.info("Evaluating CNN on test data...")
|
|
test_symbols = ['ETH/USDT'] # Use subset for testing
|
|
eval_results = trainer.evaluate_model(test_symbols)
|
|
|
|
logger.info("CNN Evaluation Results:")
|
|
logger.info(f" Test accuracy: {eval_results['test_accuracy']:.4f}")
|
|
logger.info(f" Test loss: {eval_results['test_loss']:.4f}")
|
|
logger.info(f" Average confidence: {eval_results['avg_confidence']:.4f}")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Could not run evaluation: {e}")
|
|
|
|
logger.info("CNN training completed successfully!")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in CNN training: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
def run_rl_training():
|
|
"""Train RL agents only with comprehensive pipeline"""
|
|
try:
|
|
logger.info("Starting RL Training Mode...")
|
|
|
|
# Initialize components for RL
|
|
data_provider = DataProvider(
|
|
symbols=['ETH/USDT'],
|
|
timeframes=['1s', '1m', '5m', '1h'] # Focus on scalping timeframes
|
|
)
|
|
|
|
# Import and create RL trainer
|
|
from training.rl_trainer import RLTrainer
|
|
trainer = RLTrainer(data_provider)
|
|
|
|
# Configure training
|
|
trainer.num_episodes = 1000
|
|
trainer.max_steps_per_episode = 1000
|
|
trainer.evaluation_frequency = 50
|
|
trainer.save_frequency = 100
|
|
|
|
# Train the agent
|
|
save_path = 'models/rl/scalping_agent_trained.pt'
|
|
|
|
logger.info(f"Training RL agent for scalping")
|
|
logger.info(f"Will save to: {save_path}")
|
|
|
|
results = trainer.train(save_path)
|
|
|
|
# Log results
|
|
logger.info("RL Training Results:")
|
|
logger.info(f" Best reward: {results['best_reward']:.4f}")
|
|
logger.info(f" Best balance: ${results['best_balance']:.2f}")
|
|
logger.info(f" Total episodes: {results['total_episodes']}")
|
|
logger.info(f" Training time: {results['total_time']:.2f} seconds")
|
|
logger.info(f" Final epsilon: {results['agent_config']['epsilon_final']:.4f}")
|
|
|
|
# Final evaluation results
|
|
final_eval = results['final_evaluation']
|
|
logger.info("Final Evaluation:")
|
|
logger.info(f" Win rate: {final_eval['win_rate']:.2%}")
|
|
logger.info(f" Average PnL: {final_eval['avg_pnl_percentage']:.2f}%")
|
|
logger.info(f" Average trades: {final_eval['avg_trades']:.1f}")
|
|
|
|
# Plot training progress
|
|
try:
|
|
plot_path = 'models/rl/training_progress.png'
|
|
trainer.plot_training_progress(plot_path)
|
|
logger.info(f"Training plots saved to: {plot_path}")
|
|
except Exception as e:
|
|
logger.warning(f"Could not save training plots: {e}")
|
|
|
|
# Backtest the trained agent
|
|
try:
|
|
logger.info("Backtesting trained agent...")
|
|
backtest_results = trainer.backtest_agent(save_path, test_episodes=50)
|
|
|
|
analysis = backtest_results['analysis']
|
|
logger.info("Backtest Results:")
|
|
logger.info(f" Win rate: {analysis['win_rate']:.2%}")
|
|
logger.info(f" Average PnL: {analysis['avg_pnl']:.2f}%")
|
|
logger.info(f" Sharpe ratio: {analysis['sharpe_ratio']:.4f}")
|
|
logger.info(f" Max drawdown: {analysis['max_drawdown']:.2f}%")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Could not run backtest: {e}")
|
|
|
|
logger.info("RL training completed successfully!")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in RL training: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
def run_combined_training():
|
|
"""Train both CNN and RL models with hybrid approach"""
|
|
try:
|
|
logger.info("Starting Hybrid CNN + RL Training Mode...")
|
|
|
|
# Initialize data provider
|
|
data_provider = DataProvider(
|
|
symbols=['ETH/USDT', 'BTC/USDT'],
|
|
timeframes=['1s', '1m', '5m', '1h', '4h']
|
|
)
|
|
|
|
# Import and create hybrid trainer
|
|
from training.rl_trainer import HybridTrainer
|
|
trainer = HybridTrainer(data_provider)
|
|
|
|
# Define save paths
|
|
cnn_save_path = 'models/cnn/hybrid_cnn_trained.pt'
|
|
rl_save_path = 'models/rl/hybrid_rl_trained.pt'
|
|
|
|
# Train hybrid system
|
|
symbols = ['ETH/USDT', 'BTC/USDT']
|
|
logger.info(f"Training hybrid system for symbols: {symbols}")
|
|
|
|
results = trainer.train_hybrid(symbols, cnn_save_path, rl_save_path)
|
|
|
|
# Log results
|
|
cnn_results = results['cnn_results']
|
|
rl_results = results['rl_results']
|
|
|
|
logger.info("Hybrid Training Results:")
|
|
logger.info("CNN Phase:")
|
|
logger.info(f" Best accuracy: {cnn_results['best_val_accuracy']:.4f}")
|
|
logger.info(f" Training time: {cnn_results['total_time']:.2f}s")
|
|
|
|
logger.info("RL Phase:")
|
|
logger.info(f" Best reward: {rl_results['best_reward']:.4f}")
|
|
logger.info(f" Final balance: ${rl_results['best_balance']:.2f}")
|
|
logger.info(f" Training time: {rl_results['total_time']:.2f}s")
|
|
|
|
logger.info(f"Total training time: {results['total_time']:.2f}s")
|
|
|
|
logger.info("Hybrid training completed successfully!")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in hybrid training: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
def run_live_trading():
|
|
"""Run live trading mode"""
|
|
try:
|
|
logger.info("Starting Live Trading Mode...")
|
|
|
|
# Initialize for live trading with 1s scalping focus
|
|
data_provider = DataProvider(
|
|
symbols=['ETH/USDT'],
|
|
timeframes=['1s', '1m', '5m', '15m']
|
|
)
|
|
orchestrator = TradingOrchestrator(data_provider)
|
|
|
|
# Start real-time data streaming
|
|
logger.info("Starting real-time data streaming...")
|
|
|
|
# This would integrate with your live trading logic
|
|
logger.info("Live trading mode ready!")
|
|
logger.info("Note: Integrate this with your actual trading execution")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in live trading: {e}")
|
|
raise
|
|
|
|
def run_web_dashboard(port: int = 8050, demo_mode: bool = True):
|
|
"""Run the enhanced web dashboard"""
|
|
try:
|
|
from web.dashboard import TradingDashboard
|
|
|
|
logger.info("Starting Enhanced Web Dashboard...")
|
|
|
|
# Initialize components with 1s scalping focus
|
|
data_provider = DataProvider(
|
|
symbols=['ETH/USDT'],
|
|
timeframes=['1s', '1m', '5m', '1h', '4h']
|
|
)
|
|
orchestrator = TradingOrchestrator(data_provider)
|
|
|
|
# Create dashboard
|
|
dashboard = TradingDashboard(data_provider, orchestrator)
|
|
|
|
if demo_mode:
|
|
# Start demo mode with realistic scalping decisions
|
|
logger.info("Starting scalping demo mode...")
|
|
|
|
def scalping_demo_thread():
|
|
"""Generate realistic scalping decisions"""
|
|
import random
|
|
import time
|
|
from datetime import datetime
|
|
from core.orchestrator import TradingDecision
|
|
|
|
actions = ['BUY', 'SELL', 'HOLD']
|
|
action_weights = [0.3, 0.3, 0.4] # More holds in scalping
|
|
base_price = 3000.0
|
|
|
|
while True:
|
|
try:
|
|
# Simulate small price movements for scalping
|
|
price_change = random.uniform(-5, 5) # Smaller movements
|
|
current_price = max(base_price + price_change, 1000)
|
|
|
|
# Create scalping decision
|
|
action = random.choices(actions, weights=action_weights)[0]
|
|
confidence = random.uniform(0.7, 0.95) # Higher confidence for scalping
|
|
|
|
decision = TradingDecision(
|
|
action=action,
|
|
confidence=confidence,
|
|
symbol='ETH/USDT',
|
|
price=current_price,
|
|
timestamp=datetime.now(),
|
|
reasoning={'scalping_demo': True, 'timeframe': '1s'},
|
|
memory_usage={'demo': 0}
|
|
)
|
|
|
|
dashboard.add_trading_decision(decision)
|
|
logger.info(f"Scalping: {action} ETH/USDT @${current_price:.2f} (conf: {confidence:.2f})")
|
|
|
|
# Update base price occasionally
|
|
if random.random() < 0.2:
|
|
base_price = current_price
|
|
|
|
time.sleep(3) # Faster decisions for scalping
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in scalping demo: {e}")
|
|
time.sleep(5)
|
|
|
|
# Start scalping demo thread
|
|
demo_thread_instance = Thread(target=scalping_demo_thread, daemon=True)
|
|
demo_thread_instance.start()
|
|
|
|
# Run dashboard
|
|
dashboard.run(port=port, debug=False)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error running web dashboard: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
async def main():
|
|
"""Main entry point with clean mode selection"""
|
|
parser = argparse.ArgumentParser(description='Clean Trading System - Unified Entry Point')
|
|
parser.add_argument('--mode',
|
|
choices=['test', 'cnn', 'rl', 'train', 'trade', 'web'],
|
|
default='test',
|
|
help='Operation mode')
|
|
parser.add_argument('--symbol', type=str, default='ETH/USDT',
|
|
help='Trading symbol (default: ETH/USDT)')
|
|
parser.add_argument('--port', type=int, default=8050,
|
|
help='Web dashboard port (default: 8050)')
|
|
parser.add_argument('--demo', action='store_true',
|
|
help='Run web dashboard in demo mode')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Setup logging
|
|
setup_logging()
|
|
|
|
try:
|
|
logger.info("=" * 60)
|
|
logger.info("CLEAN TRADING SYSTEM - UNIFIED LAUNCH")
|
|
logger.info(f"Mode: {args.mode.upper()}")
|
|
logger.info(f"Symbol: {args.symbol}")
|
|
logger.info("=" * 60)
|
|
|
|
# Route to appropriate mode
|
|
if args.mode == 'test':
|
|
run_data_test()
|
|
elif args.mode == 'cnn':
|
|
run_cnn_training()
|
|
elif args.mode == 'rl':
|
|
run_rl_training()
|
|
elif args.mode == 'train':
|
|
run_combined_training()
|
|
elif args.mode == 'trade':
|
|
run_live_trading()
|
|
elif args.mode == 'web':
|
|
run_web_dashboard(port=args.port, demo_mode=args.demo)
|
|
|
|
logger.info("Operation completed successfully!")
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("System shutdown requested by user")
|
|
except Exception as e:
|
|
logger.error(f"Fatal error: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return 1
|
|
|
|
return 0
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(asyncio.run(main())) |