318 lines
12 KiB
Python
318 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Clean Trading System - Main Entry Point
|
|
|
|
This is the new clean entry point that demonstrates the consolidated architecture:
|
|
- Single configuration system
|
|
- Clean data provider
|
|
- Modular CNN and RL components
|
|
- Centralized orchestrator
|
|
- Simple web dashboard
|
|
|
|
Usage:
|
|
python main_clean.py --mode [train|trade|web] --symbol ETH/USDT
|
|
"""
|
|
|
|
import asyncio
|
|
import argparse
|
|
import logging
|
|
import sys
|
|
from pathlib import Path
|
|
from threading import Thread
|
|
import time
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from core.config import get_config, setup_logging
|
|
from core.data_provider import DataProvider
|
|
from core.orchestrator import TradingOrchestrator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def run_data_test():
|
|
"""Test the data provider functionality"""
|
|
try:
|
|
config = get_config()
|
|
logger.info("Testing Data Provider...")
|
|
|
|
# Test data provider
|
|
data_provider = DataProvider(
|
|
symbols=['ETH/USDT'],
|
|
timeframes=['1h', '4h']
|
|
)
|
|
|
|
# Test historical data
|
|
logger.info("Testing historical data fetching...")
|
|
df = data_provider.get_historical_data('ETH/USDT', '1h', limit=100)
|
|
if df is not None:
|
|
logger.info(f"[SUCCESS] Historical data: {len(df)} candles loaded")
|
|
logger.info(f" Columns: {list(df.columns)}")
|
|
logger.info(f" Date range: {df['timestamp'].min()} to {df['timestamp'].max()}")
|
|
else:
|
|
logger.error("[FAILED] Failed to load historical data")
|
|
|
|
# Test feature matrix
|
|
logger.info("Testing feature matrix...")
|
|
feature_matrix = data_provider.get_feature_matrix('ETH/USDT', ['1h'], window_size=20)
|
|
if feature_matrix is not None:
|
|
logger.info(f"[SUCCESS] Feature matrix shape: {feature_matrix.shape}")
|
|
else:
|
|
logger.error("[FAILED] Failed to create feature matrix")
|
|
|
|
# Test health check
|
|
health = data_provider.health_check()
|
|
logger.info(f"[SUCCESS] Data provider health: {health}")
|
|
|
|
logger.info("Data provider test completed successfully!")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in data test: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
def run_orchestrator_test():
|
|
"""Test the modular orchestrator system"""
|
|
try:
|
|
from models import get_model_registry, ModelInterface
|
|
import numpy as np
|
|
import torch
|
|
|
|
logger.info("Testing Modular Orchestrator System...")
|
|
|
|
# Test model registry
|
|
registry = get_model_registry()
|
|
logger.info(f"[SUCCESS] Model registry initialized with {registry.total_memory_limit_mb}MB limit")
|
|
|
|
# Create a mock model for testing
|
|
class MockCNNModel(ModelInterface):
|
|
def __init__(self):
|
|
config = {'max_memory_mb': 500} # 500MB limit
|
|
super().__init__('MockCNN', config)
|
|
self.model_params = torch.randn(1000, 100) # Small mock model
|
|
|
|
def predict(self, features):
|
|
# Mock prediction: random but consistent
|
|
np.random.seed(42)
|
|
action_probs = np.random.dirichlet([1, 1, 1]) # Random probabilities that sum to 1
|
|
confidence = np.random.uniform(0.5, 0.9)
|
|
return action_probs, confidence
|
|
|
|
def get_memory_usage(self):
|
|
# Estimate memory usage
|
|
if hasattr(self, 'model_params'):
|
|
return int(self.model_params.numel() * 4 / (1024*1024)) # 4 bytes per float, convert to MB
|
|
return 0
|
|
|
|
class MockRLAgent(ModelInterface):
|
|
def __init__(self):
|
|
config = {'max_memory_mb': 300} # 300MB limit
|
|
super().__init__('MockRL', config)
|
|
self.q_network = torch.randn(500, 50) # Smaller mock RL model
|
|
|
|
def predict(self, features):
|
|
# Mock RL prediction
|
|
np.random.seed(123)
|
|
action_probs = np.random.dirichlet([2, 1, 2]) # Favor BUY/SELL over HOLD
|
|
confidence = np.random.uniform(0.6, 0.8)
|
|
return action_probs, confidence
|
|
|
|
def act_with_confidence(self, state):
|
|
action_probs, confidence = self.predict(state)
|
|
action = np.argmax(action_probs)
|
|
return action, confidence
|
|
|
|
def get_memory_usage(self):
|
|
if hasattr(self, 'q_network'):
|
|
return int(self.q_network.numel() * 4 / (1024*1024))
|
|
return 0
|
|
|
|
def act(self, state):
|
|
return self.act_with_confidence(state)[0]
|
|
|
|
def remember(self, state, action, reward, next_state, done):
|
|
pass # Mock implementation
|
|
|
|
def replay(self):
|
|
return 0.0 # Mock implementation
|
|
|
|
# Test model registration
|
|
logger.info("Testing model registration...")
|
|
mock_cnn = MockCNNModel()
|
|
mock_rl = MockRLAgent()
|
|
|
|
success1 = registry.register_model(mock_cnn)
|
|
success2 = registry.register_model(mock_rl)
|
|
|
|
if success1 and success2:
|
|
logger.info("[SUCCESS] Both models registered successfully")
|
|
else:
|
|
logger.error(f"[FAILED] Model registration failed: CNN={success1}, RL={success2}")
|
|
|
|
# Test memory stats
|
|
memory_stats = registry.get_memory_stats()
|
|
logger.info(f"[SUCCESS] Memory stats: {memory_stats}")
|
|
|
|
# Test orchestrator
|
|
logger.info("Testing orchestrator integration...")
|
|
data_provider = DataProvider(symbols=['ETH/USDT'], timeframes=['1h'])
|
|
orchestrator = TradingOrchestrator(data_provider)
|
|
|
|
# Register models with orchestrator
|
|
success1 = orchestrator.register_model(mock_cnn, weight=0.7)
|
|
success2 = orchestrator.register_model(mock_rl, weight=0.3)
|
|
|
|
if success1 and success2:
|
|
logger.info("[SUCCESS] Models registered with orchestrator")
|
|
else:
|
|
logger.error(f"[FAILED] Orchestrator registration failed")
|
|
|
|
# Test orchestrator metrics
|
|
metrics = orchestrator.get_performance_metrics()
|
|
logger.info(f"[SUCCESS] Orchestrator metrics: {metrics}")
|
|
|
|
logger.info("Modular orchestrator test completed successfully!")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in orchestrator test: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
def run_web_dashboard(port: int = 8050, demo_mode: bool = True):
|
|
"""Run the web dashboard"""
|
|
try:
|
|
from web.dashboard import TradingDashboard
|
|
|
|
logger.info("Starting Web Dashboard...")
|
|
|
|
# Initialize components
|
|
data_provider = DataProvider(symbols=['ETH/USDT'], timeframes=['1h', '4h'])
|
|
orchestrator = TradingOrchestrator(data_provider)
|
|
|
|
# Create dashboard
|
|
dashboard = TradingDashboard(data_provider, orchestrator)
|
|
|
|
# Add orchestrator callback to send decisions to dashboard
|
|
async def decision_callback(decision):
|
|
dashboard.add_trading_decision(decision)
|
|
|
|
orchestrator.add_decision_callback(decision_callback)
|
|
|
|
if demo_mode:
|
|
# Start demo mode with mock decisions
|
|
logger.info("Starting demo mode with simulated trading decisions...")
|
|
|
|
def demo_thread():
|
|
"""Generate demo trading decisions"""
|
|
import random
|
|
import time
|
|
from datetime import datetime
|
|
from core.orchestrator import TradingDecision
|
|
|
|
actions = ['BUY', 'SELL', 'HOLD']
|
|
base_price = 3000.0
|
|
|
|
while True:
|
|
try:
|
|
# Simulate price movement
|
|
price_change = random.uniform(-50, 50)
|
|
current_price = max(base_price + price_change, 1000)
|
|
|
|
# Create mock decision
|
|
action = random.choice(actions)
|
|
confidence = random.uniform(0.6, 0.95)
|
|
|
|
decision = TradingDecision(
|
|
action=action,
|
|
confidence=confidence,
|
|
symbol='ETH/USDT',
|
|
price=current_price,
|
|
timestamp=datetime.now(),
|
|
reasoning={'demo_mode': True, 'random_decision': True},
|
|
memory_usage={'demo': 0}
|
|
)
|
|
|
|
dashboard.add_trading_decision(decision)
|
|
logger.info(f"Demo decision: {action} ETH/USDT @${current_price:.2f} (confidence: {confidence:.2f})")
|
|
|
|
# Update base price occasionally
|
|
if random.random() < 0.1:
|
|
base_price = current_price
|
|
|
|
time.sleep(5) # New decision every 5 seconds
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in demo thread: {e}")
|
|
time.sleep(10)
|
|
|
|
# Start demo thread
|
|
demo_thread_instance = Thread(target=demo_thread, daemon=True)
|
|
demo_thread_instance.start()
|
|
|
|
# Start data streaming if available
|
|
try:
|
|
logger.info("Starting real-time data streaming...")
|
|
# Don't use asyncio.run here as we're already in an event loop context
|
|
# Just log that streaming would be started in a real deployment
|
|
logger.info("Real-time streaming would be started in production deployment")
|
|
except Exception as e:
|
|
logger.warning(f"Could not start real-time streaming: {e}")
|
|
|
|
# Run dashboard
|
|
dashboard.run(port=port, debug=False)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error running web dashboard: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
async def main():
|
|
"""Main entry point"""
|
|
parser = argparse.ArgumentParser(description='Clean Trading System')
|
|
parser.add_argument('--mode', choices=['trade', 'train', 'web', 'test', 'orchestrator'],
|
|
default='test', help='Mode to run the system in')
|
|
parser.add_argument('--symbol', type=str, help='Override default symbol')
|
|
parser.add_argument('--config', type=str, default='config.yaml',
|
|
help='Configuration file path')
|
|
parser.add_argument('--port', type=int, default=8050,
|
|
help='Port for web dashboard')
|
|
parser.add_argument('--demo', action='store_true',
|
|
help='Run web dashboard in demo mode with simulated data')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Setup logging
|
|
setup_logging()
|
|
|
|
try:
|
|
logger.info("=" * 60)
|
|
logger.info("CLEAN TRADING SYSTEM STARTING")
|
|
logger.info("=" * 60)
|
|
|
|
# Run appropriate mode
|
|
if args.mode == 'test':
|
|
run_data_test()
|
|
elif args.mode == 'orchestrator':
|
|
run_orchestrator_test()
|
|
elif args.mode == 'web':
|
|
run_web_dashboard(port=args.port, demo_mode=args.demo)
|
|
else:
|
|
logger.info(f"Mode '{args.mode}' not yet implemented in clean architecture")
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("System shutdown requested by user")
|
|
except Exception as e:
|
|
logger.error(f"Fatal error: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return 1
|
|
|
|
logger.info("Clean Trading System finished")
|
|
return 0
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(asyncio.run(main())) |