rl cob agent
This commit is contained in:
15
.vscode/launch.json
vendored
15
.vscode/launch.json
vendored
@ -78,6 +78,21 @@
|
|||||||
"COB_ETH_BUCKET_SIZE": "1"
|
"COB_ETH_BUCKET_SIZE": "1"
|
||||||
},
|
},
|
||||||
"preLaunchTask": "Kill Stale Processes"
|
"preLaunchTask": "Kill Stale Processes"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "🔥 Real-time RL COB Trader (1B Parameters)",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run_realtime_rl_cob_trader.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false,
|
||||||
|
"env": {
|
||||||
|
"PYTHONUNBUFFERED": "1",
|
||||||
|
"CUDA_VISIBLE_DEVICES": "0",
|
||||||
|
"PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512",
|
||||||
|
"ENABLE_REALTIME_RL": "1"
|
||||||
|
},
|
||||||
|
"preLaunchTask": "Kill Stale Processes"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"compounds": [
|
"compounds": [
|
||||||
|
44
config.yaml
44
config.yaml
@ -196,6 +196,50 @@ memory:
|
|||||||
model_limit_gb: 4.0 # Per-model memory limit
|
model_limit_gb: 4.0 # Per-model memory limit
|
||||||
cleanup_interval: 1800 # Memory cleanup every 30 minutes
|
cleanup_interval: 1800 # Memory cleanup every 30 minutes
|
||||||
|
|
||||||
|
# Real-time RL COB Trader Configuration
|
||||||
|
realtime_rl:
|
||||||
|
# Model parameters for 1B parameter network
|
||||||
|
model:
|
||||||
|
input_size: 2000 # COB feature dimensions
|
||||||
|
hidden_size: 4096 # Massive hidden layer size
|
||||||
|
num_layers: 12 # Deep transformer layers
|
||||||
|
learning_rate: 0.00001 # Very low for stability
|
||||||
|
weight_decay: 0.000001 # L2 regularization
|
||||||
|
|
||||||
|
# Inference configuration
|
||||||
|
inference_interval_ms: 200 # Inference every 200ms
|
||||||
|
min_confidence_threshold: 0.7 # Minimum confidence for signal accumulation
|
||||||
|
required_confident_predictions: 3 # Need 3 confident predictions for trade
|
||||||
|
|
||||||
|
# Training configuration
|
||||||
|
training_interval_s: 1.0 # Train every second
|
||||||
|
batch_size: 32 # Training batch size
|
||||||
|
replay_buffer_size: 1000 # Store last 1000 predictions for training
|
||||||
|
|
||||||
|
# Signal accumulation
|
||||||
|
signal_buffer_size: 10 # Buffer size for signal accumulation
|
||||||
|
consensus_threshold: 3 # Need 3 signals in same direction
|
||||||
|
|
||||||
|
# Model checkpointing
|
||||||
|
model_checkpoint_dir: "models/realtime_rl_cob"
|
||||||
|
save_interval_s: 300 # Save models every 5 minutes
|
||||||
|
|
||||||
|
# COB integration
|
||||||
|
symbols: ["BTC/USDT", "ETH/USDT"] # Symbols to trade
|
||||||
|
cob_feature_normalization: "robust" # Feature normalization method
|
||||||
|
|
||||||
|
# Reward engineering for RL
|
||||||
|
reward_structure:
|
||||||
|
correct_direction_base: 1.0 # Base reward for correct prediction
|
||||||
|
confidence_scaling: true # Scale reward by confidence
|
||||||
|
magnitude_bonus: 0.5 # Bonus for predicting magnitude accurately
|
||||||
|
overconfidence_penalty: 1.5 # Penalty multiplier for wrong high-confidence predictions
|
||||||
|
trade_execution_multiplier: 10.0 # Higher weight for actual trade outcomes
|
||||||
|
|
||||||
|
# Performance monitoring
|
||||||
|
statistics_interval_s: 60 # Print stats every minute
|
||||||
|
detailed_logging: true # Enable detailed performance logging
|
||||||
|
|
||||||
# Web Dashboard
|
# Web Dashboard
|
||||||
web:
|
web:
|
||||||
host: "127.0.0.1"
|
host: "127.0.0.1"
|
||||||
|
1047
core/realtime_rl_cob_trader.py
Normal file
1047
core/realtime_rl_cob_trader.py
Normal file
File diff suppressed because it is too large
Load Diff
321
run_realtime_rl_cob_trader.py
Normal file
321
run_realtime_rl_cob_trader.py
Normal file
@ -0,0 +1,321 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Real-time RL COB Trader Launcher
|
||||||
|
|
||||||
|
Launch script for the real-time reinforcement learning trader that:
|
||||||
|
1. Uses COB data for training a 1B parameter model
|
||||||
|
2. Performs inference every 200ms
|
||||||
|
3. Accumulates confident signals for trade execution
|
||||||
|
4. Trains continuously in real-time based on outcomes
|
||||||
|
|
||||||
|
This script provides a complete trading system integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
# Local imports
|
||||||
|
from core.realtime_rl_cob_trader import RealtimeRLCOBTrader
|
||||||
|
from core.trading_executor import TradingExecutor
|
||||||
|
from core.config import load_config
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler('logs/realtime_rl_cob_trader.log'),
|
||||||
|
logging.StreamHandler(sys.stdout)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class RealtimeRLCOBTraderLauncher:
|
||||||
|
"""
|
||||||
|
Launcher for Real-time RL COB Trader system
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config_path: str = "config.yaml"):
|
||||||
|
"""Initialize launcher with configuration"""
|
||||||
|
self.config = load_config(config_path)
|
||||||
|
self.trader = None
|
||||||
|
self.trading_executor = None
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
# Setup signal handlers for graceful shutdown
|
||||||
|
signal.signal(signal.SIGINT, self._signal_handler)
|
||||||
|
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||||
|
|
||||||
|
logger.info("RealtimeRLCOBTraderLauncher initialized")
|
||||||
|
|
||||||
|
def _signal_handler(self, signum, frame):
|
||||||
|
"""Handle shutdown signals"""
|
||||||
|
logger.info(f"Received signal {signum}, initiating graceful shutdown...")
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start the real-time RL COB trading system"""
|
||||||
|
try:
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("REAL-TIME RL COB TRADER SYSTEM STARTING")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# Initialize trading executor
|
||||||
|
await self._initialize_trading_executor()
|
||||||
|
|
||||||
|
# Initialize RL trader
|
||||||
|
await self._initialize_rl_trader()
|
||||||
|
|
||||||
|
# Start the trading system
|
||||||
|
await self._start_trading_system()
|
||||||
|
|
||||||
|
# Run main loop
|
||||||
|
await self._run_main_loop()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Critical error in trader launcher: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await self.stop()
|
||||||
|
|
||||||
|
async def _initialize_trading_executor(self):
|
||||||
|
"""Initialize the trading executor"""
|
||||||
|
logger.info("Initializing Trading Executor...")
|
||||||
|
|
||||||
|
# Get trading configuration
|
||||||
|
trading_config = self.config.get('trading', {})
|
||||||
|
mexc_config = self.config.get('mexc', {})
|
||||||
|
|
||||||
|
# Determine if we should run in simulation mode
|
||||||
|
simulation_mode = mexc_config.get('simulation_mode', True)
|
||||||
|
|
||||||
|
if simulation_mode:
|
||||||
|
logger.info("Running in SIMULATION mode - no real trades will be executed")
|
||||||
|
else:
|
||||||
|
logger.warning("Running in LIVE TRADING mode - real money at risk!")
|
||||||
|
|
||||||
|
# Add safety confirmation for live trading
|
||||||
|
confirmation = input("Type 'CONFIRM_LIVE_TRADING' to proceed with live trading: ")
|
||||||
|
if confirmation != 'CONFIRM_LIVE_TRADING':
|
||||||
|
logger.info("Live trading not confirmed, switching to simulation mode")
|
||||||
|
simulation_mode = True
|
||||||
|
|
||||||
|
# Initialize trading executor
|
||||||
|
self.trading_executor = TradingExecutor(
|
||||||
|
simulation_mode=simulation_mode,
|
||||||
|
mexc_config=mexc_config
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Trading Executor initialized in {'SIMULATION' if simulation_mode else 'LIVE'} mode")
|
||||||
|
|
||||||
|
async def _initialize_rl_trader(self):
|
||||||
|
"""Initialize the RL trader"""
|
||||||
|
logger.info("Initializing Real-time RL COB Trader...")
|
||||||
|
|
||||||
|
# Get RL configuration
|
||||||
|
rl_config = self.config.get('realtime_rl', {})
|
||||||
|
|
||||||
|
# Trading symbols
|
||||||
|
symbols = rl_config.get('symbols', ['BTC/USDT', 'ETH/USDT'])
|
||||||
|
|
||||||
|
# RL parameters
|
||||||
|
inference_interval_ms = rl_config.get('inference_interval_ms', 200)
|
||||||
|
min_confidence_threshold = rl_config.get('min_confidence_threshold', 0.7)
|
||||||
|
required_confident_predictions = rl_config.get('required_confident_predictions', 3)
|
||||||
|
model_checkpoint_dir = rl_config.get('model_checkpoint_dir', 'models/realtime_rl_cob')
|
||||||
|
|
||||||
|
# Initialize RL trader
|
||||||
|
self.trader = RealtimeRLCOBTrader(
|
||||||
|
symbols=symbols,
|
||||||
|
trading_executor=self.trading_executor,
|
||||||
|
model_checkpoint_dir=model_checkpoint_dir,
|
||||||
|
inference_interval_ms=inference_interval_ms,
|
||||||
|
min_confidence_threshold=min_confidence_threshold,
|
||||||
|
required_confident_predictions=required_confident_predictions
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"RL Trader initialized for symbols: {symbols}")
|
||||||
|
logger.info(f"Inference interval: {inference_interval_ms}ms")
|
||||||
|
logger.info(f"Confidence threshold: {min_confidence_threshold}")
|
||||||
|
logger.info(f"Required predictions: {required_confident_predictions}")
|
||||||
|
|
||||||
|
async def _start_trading_system(self):
|
||||||
|
"""Start the complete trading system"""
|
||||||
|
logger.info("Starting Real-time RL COB Trading System...")
|
||||||
|
|
||||||
|
# Start RL trader (this will start COB integration internally)
|
||||||
|
await self.trader.start()
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
logger.info("✅ Real-time RL COB Trading System started successfully!")
|
||||||
|
logger.info("🔥 1B parameter model training and inference active")
|
||||||
|
logger.info("📊 COB data streaming and processing")
|
||||||
|
logger.info("🎯 Signal accumulation and trade execution ready")
|
||||||
|
logger.info("⚡ Real-time training on prediction outcomes")
|
||||||
|
|
||||||
|
async def _run_main_loop(self):
|
||||||
|
"""Main monitoring and statistics loop"""
|
||||||
|
logger.info("Starting main monitoring loop...")
|
||||||
|
|
||||||
|
last_stats_time = datetime.now()
|
||||||
|
stats_interval = 60 # Print stats every 60 seconds
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
# Sleep for a bit
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
# Print periodic statistics
|
||||||
|
current_time = datetime.now()
|
||||||
|
if (current_time - last_stats_time).total_seconds() >= stats_interval:
|
||||||
|
await self._print_performance_stats()
|
||||||
|
last_stats_time = current_time
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in main loop: {e}")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
logger.info("Main monitoring loop stopped")
|
||||||
|
|
||||||
|
async def _print_performance_stats(self):
|
||||||
|
"""Print comprehensive performance statistics"""
|
||||||
|
try:
|
||||||
|
if not self.trader:
|
||||||
|
return
|
||||||
|
|
||||||
|
stats = self.trader.get_performance_stats()
|
||||||
|
|
||||||
|
logger.info("=" * 80)
|
||||||
|
logger.info("🔥 REAL-TIME RL COB TRADER PERFORMANCE STATISTICS")
|
||||||
|
logger.info("=" * 80)
|
||||||
|
|
||||||
|
# Model information
|
||||||
|
logger.info("📊 Model Information:")
|
||||||
|
for symbol, model_info in stats.get('model_info', {}).items():
|
||||||
|
total_params = model_info.get('total_parameters', 0)
|
||||||
|
logger.info(f" {symbol}: {total_params:,} parameters ({total_params/1e9:.2f}B)")
|
||||||
|
|
||||||
|
# Training statistics
|
||||||
|
logger.info("\n🧠 Training Statistics:")
|
||||||
|
for symbol, training_stats in stats.get('training_stats', {}).items():
|
||||||
|
total_preds = training_stats.get('total_predictions', 0)
|
||||||
|
successful_preds = training_stats.get('successful_predictions', 0)
|
||||||
|
success_rate = (successful_preds / max(1, total_preds)) * 100
|
||||||
|
avg_loss = training_stats.get('average_loss', 0.0)
|
||||||
|
training_steps = training_stats.get('total_training_steps', 0)
|
||||||
|
last_training = training_stats.get('last_training_time')
|
||||||
|
|
||||||
|
logger.info(f" {symbol}:")
|
||||||
|
logger.info(f" Predictions: {total_preds} (Success: {success_rate:.1f}%)")
|
||||||
|
logger.info(f" Training Steps: {training_steps}")
|
||||||
|
logger.info(f" Average Loss: {avg_loss:.6f}")
|
||||||
|
if last_training:
|
||||||
|
logger.info(f" Last Training: {last_training}")
|
||||||
|
|
||||||
|
# Inference statistics
|
||||||
|
logger.info("\n⚡ Inference Statistics:")
|
||||||
|
for symbol, inference_stats in stats.get('inference_stats', {}).items():
|
||||||
|
total_inferences = inference_stats.get('total_inferences', 0)
|
||||||
|
avg_time = inference_stats.get('average_inference_time_ms', 0.0)
|
||||||
|
last_inference = inference_stats.get('last_inference_time')
|
||||||
|
|
||||||
|
logger.info(f" {symbol}:")
|
||||||
|
logger.info(f" Total Inferences: {total_inferences}")
|
||||||
|
logger.info(f" Average Time: {avg_time:.1f}ms")
|
||||||
|
if last_inference:
|
||||||
|
logger.info(f" Last Inference: {last_inference}")
|
||||||
|
|
||||||
|
# Signal statistics
|
||||||
|
logger.info("\n🎯 Signal Accumulation:")
|
||||||
|
for symbol, signal_stats in stats.get('signal_stats', {}).items():
|
||||||
|
current_signals = signal_stats.get('current_signals', 0)
|
||||||
|
confidence_sum = signal_stats.get('confidence_sum', 0.0)
|
||||||
|
success_rate = signal_stats.get('success_rate', 0.0) * 100
|
||||||
|
|
||||||
|
logger.info(f" {symbol}:")
|
||||||
|
logger.info(f" Current Signals: {current_signals}")
|
||||||
|
logger.info(f" Confidence Sum: {confidence_sum:.2f}")
|
||||||
|
logger.info(f" Historical Success Rate: {success_rate:.1f}%")
|
||||||
|
|
||||||
|
# Trading executor statistics
|
||||||
|
if self.trading_executor:
|
||||||
|
positions = self.trading_executor.get_positions()
|
||||||
|
trade_history = self.trading_executor.get_trade_history()
|
||||||
|
|
||||||
|
logger.info("\n💰 Trading Statistics:")
|
||||||
|
logger.info(f" Active Positions: {len(positions)}")
|
||||||
|
logger.info(f" Total Trades: {len(trade_history)}")
|
||||||
|
|
||||||
|
if trade_history:
|
||||||
|
# Calculate P&L statistics
|
||||||
|
total_pnl = sum(trade.pnl for trade in trade_history)
|
||||||
|
profitable_trades = sum(1 for trade in trade_history if trade.pnl > 0)
|
||||||
|
win_rate = (profitable_trades / len(trade_history)) * 100
|
||||||
|
|
||||||
|
logger.info(f" Total P&L: ${total_pnl:.2f}")
|
||||||
|
logger.info(f" Win Rate: {win_rate:.1f}%")
|
||||||
|
|
||||||
|
# Show active positions
|
||||||
|
if positions:
|
||||||
|
logger.info("\n📍 Active Positions:")
|
||||||
|
for symbol, position in positions.items():
|
||||||
|
logger.info(f" {symbol}: {position.side} {position.quantity:.6f} @ ${position.entry_price:.2f}")
|
||||||
|
|
||||||
|
logger.info("=" * 80)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error printing performance stats: {e}")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop the trading system gracefully"""
|
||||||
|
if not self.running:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Stopping Real-time RL COB Trading System...")
|
||||||
|
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
# Stop RL trader
|
||||||
|
if self.trader:
|
||||||
|
await self.trader.stop()
|
||||||
|
logger.info("✅ RL Trader stopped")
|
||||||
|
|
||||||
|
# Print final statistics
|
||||||
|
if self.trader:
|
||||||
|
logger.info("\n📊 Final Performance Summary:")
|
||||||
|
await self._print_performance_stats()
|
||||||
|
|
||||||
|
logger.info("Real-time RL COB Trading System stopped successfully")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main entry point"""
|
||||||
|
try:
|
||||||
|
# Create logs directory if it doesn't exist
|
||||||
|
os.makedirs('logs', exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize and start launcher
|
||||||
|
launcher = RealtimeRLCOBTraderLauncher()
|
||||||
|
await launcher.start()
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Received keyboard interrupt, shutting down...")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Critical error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Set event loop policy for Windows compatibility
|
||||||
|
if hasattr(asyncio, 'WindowsProactorEventLoopPolicy'):
|
||||||
|
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||||
|
|
||||||
|
asyncio.run(main())
|
547
test_realtime_rl_cob_trader.py
Normal file
547
test_realtime_rl_cob_trader.py
Normal file
@ -0,0 +1,547 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test Script for Real-time RL COB Trader
|
||||||
|
|
||||||
|
This script tests the real-time reinforcement learning system to ensure:
|
||||||
|
1. Proper model initialization and parameter count (~1B parameters)
|
||||||
|
2. COB data integration and feature extraction
|
||||||
|
3. Real-time inference pipeline
|
||||||
|
4. Signal accumulation and consensus
|
||||||
|
5. Training loop functionality
|
||||||
|
6. Trade execution integration
|
||||||
|
|
||||||
|
Run this before deploying the live system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
# Local imports
|
||||||
|
from core.realtime_rl_cob_trader import RealtimeRLCOBTrader, MassiveRLNetwork, PredictionResult
|
||||||
|
from core.trading_executor import TradingExecutor
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class RealtimeRLTester:
|
||||||
|
"""
|
||||||
|
Comprehensive tester for Real-time RL COB Trader
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.test_results = {}
|
||||||
|
self.trader = None
|
||||||
|
self.trading_executor = None
|
||||||
|
|
||||||
|
async def run_all_tests(self):
|
||||||
|
"""Run all tests and generate report"""
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("REAL-TIME RL COB TRADER TESTING SUITE")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
self.test_model_initialization,
|
||||||
|
self.test_model_parameter_count,
|
||||||
|
self.test_feature_extraction,
|
||||||
|
self.test_inference_performance,
|
||||||
|
self.test_signal_accumulation,
|
||||||
|
self.test_training_pipeline,
|
||||||
|
self.test_trading_integration,
|
||||||
|
self.test_performance_monitoring
|
||||||
|
]
|
||||||
|
|
||||||
|
for test in tests:
|
||||||
|
try:
|
||||||
|
await test()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Test {test.__name__} failed: {e}")
|
||||||
|
self.test_results[test.__name__] = {'status': 'FAILED', 'error': str(e)}
|
||||||
|
|
||||||
|
await self.generate_test_report()
|
||||||
|
|
||||||
|
async def test_model_initialization(self):
|
||||||
|
"""Test model initialization and architecture"""
|
||||||
|
logger.info("🧠 Testing Model Initialization...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test model creation
|
||||||
|
model = MassiveRLNetwork(input_size=2000, hidden_size=4096, num_layers=12)
|
||||||
|
|
||||||
|
# Check if CUDA is available
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
# Test forward pass
|
||||||
|
batch_size = 4
|
||||||
|
test_input = torch.randn(batch_size, 2000).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(test_input)
|
||||||
|
|
||||||
|
# Verify outputs
|
||||||
|
assert 'price_logits' in outputs
|
||||||
|
assert 'value' in outputs
|
||||||
|
assert 'confidence' in outputs
|
||||||
|
assert 'features' in outputs
|
||||||
|
|
||||||
|
assert outputs['price_logits'].shape == (batch_size, 3) # DOWN, SIDEWAYS, UP
|
||||||
|
assert outputs['value'].shape == (batch_size, 1)
|
||||||
|
assert outputs['confidence'].shape == (batch_size, 1)
|
||||||
|
|
||||||
|
self.test_results['test_model_initialization'] = {
|
||||||
|
'status': 'PASSED',
|
||||||
|
'device': str(device),
|
||||||
|
'output_shapes': {k: list(v.shape) for k, v in outputs.items()}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("✅ Model initialization test PASSED")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.test_results['test_model_initialization'] = {'status': 'FAILED', 'error': str(e)}
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def test_model_parameter_count(self):
|
||||||
|
"""Test that model has approximately 1B parameters"""
|
||||||
|
logger.info("🔢 Testing Model Parameter Count...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = MassiveRLNetwork(input_size=2000, hidden_size=4096, num_layers=12)
|
||||||
|
|
||||||
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
logger.info(f"Total parameters: {total_params:,}")
|
||||||
|
logger.info(f"Trainable parameters: {trainable_params:,}")
|
||||||
|
|
||||||
|
self.test_results['test_model_parameter_count'] = {
|
||||||
|
'status': 'PASSED',
|
||||||
|
'total_parameters': total_params,
|
||||||
|
'trainable_parameters': trainable_params,
|
||||||
|
'parameter_size_gb': (total_params * 4) / (1024**3), # 4 bytes per float32
|
||||||
|
'is_massive': total_params > 100_000_000 # At least 100M parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"✅ Model has {total_params:,} parameters ({total_params/1e9:.2f}B)")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.test_results['test_model_parameter_count'] = {'status': 'FAILED', 'error': str(e)}
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def test_feature_extraction(self):
|
||||||
|
"""Test feature extraction from COB data"""
|
||||||
|
logger.info("🔍 Testing Feature Extraction...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize trader
|
||||||
|
self.trading_executor = TradingExecutor(simulation_mode=True)
|
||||||
|
self.trader = RealtimeRLCOBTrader(
|
||||||
|
symbols=['BTC/USDT'],
|
||||||
|
trading_executor=self.trading_executor,
|
||||||
|
inference_interval_ms=1000 # Slower for testing
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mock COB data
|
||||||
|
mock_cob_data = {
|
||||||
|
'state': np.random.randn(1500), # Mock state features
|
||||||
|
'timestamp': datetime.now(),
|
||||||
|
'type': 'cob_state'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test feature extraction
|
||||||
|
features = self.trader._extract_features('BTC/USDT', mock_cob_data)
|
||||||
|
|
||||||
|
assert features is not None
|
||||||
|
assert len(features) == 2000 # Target feature size
|
||||||
|
assert features.dtype == np.float32
|
||||||
|
assert not np.any(np.isnan(features))
|
||||||
|
assert not np.any(np.isinf(features))
|
||||||
|
|
||||||
|
# Test normalization
|
||||||
|
assert np.abs(np.mean(features)) < 1.0 # Roughly normalized
|
||||||
|
assert np.std(features) < 10.0 # Not too spread out
|
||||||
|
|
||||||
|
self.test_results['test_feature_extraction'] = {
|
||||||
|
'status': 'PASSED',
|
||||||
|
'feature_size': len(features),
|
||||||
|
'feature_range': [float(np.min(features)), float(np.max(features))],
|
||||||
|
'feature_stats': {
|
||||||
|
'mean': float(np.mean(features)),
|
||||||
|
'std': float(np.std(features)),
|
||||||
|
'median': float(np.median(features))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("✅ Feature extraction test PASSED")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.test_results['test_feature_extraction'] = {'status': 'FAILED', 'error': str(e)}
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def test_inference_performance(self):
|
||||||
|
"""Test inference speed and quality"""
|
||||||
|
logger.info("⚡ Testing Inference Performance...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.trader:
|
||||||
|
self.trading_executor = TradingExecutor(simulation_mode=True)
|
||||||
|
self.trader = RealtimeRLCOBTrader(
|
||||||
|
symbols=['BTC/USDT'],
|
||||||
|
trading_executor=self.trading_executor
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test multiple inferences
|
||||||
|
num_tests = 10
|
||||||
|
inference_times = []
|
||||||
|
|
||||||
|
for i in range(num_tests):
|
||||||
|
# Create test features
|
||||||
|
test_features = np.random.randn(2000).astype(np.float32)
|
||||||
|
test_features = self.trader._normalize_features(test_features)
|
||||||
|
|
||||||
|
# Time inference
|
||||||
|
start_time = time.time()
|
||||||
|
prediction = self.trader._predict('BTC/USDT', test_features)
|
||||||
|
inference_time = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
inference_times.append(inference_time)
|
||||||
|
|
||||||
|
# Verify prediction structure
|
||||||
|
assert 'direction' in prediction
|
||||||
|
assert 'confidence' in prediction
|
||||||
|
assert 'change' in prediction
|
||||||
|
assert 'value' in prediction
|
||||||
|
|
||||||
|
assert 0 <= prediction['direction'] <= 2
|
||||||
|
assert 0.0 <= prediction['confidence'] <= 1.0
|
||||||
|
assert isinstance(prediction['change'], float)
|
||||||
|
assert isinstance(prediction['value'], float)
|
||||||
|
|
||||||
|
avg_inference_time = np.mean(inference_times)
|
||||||
|
max_inference_time = np.max(inference_times)
|
||||||
|
|
||||||
|
# Check if inference is fast enough (target: <50ms per inference)
|
||||||
|
inference_target_ms = 50.0
|
||||||
|
|
||||||
|
self.test_results['test_inference_performance'] = {
|
||||||
|
'status': 'PASSED' if avg_inference_time < inference_target_ms else 'WARNING',
|
||||||
|
'average_inference_time_ms': float(avg_inference_time),
|
||||||
|
'max_inference_time_ms': float(max_inference_time),
|
||||||
|
'target_time_ms': inference_target_ms,
|
||||||
|
'meets_target': avg_inference_time < inference_target_ms,
|
||||||
|
'inferences_per_second': 1000.0 / avg_inference_time
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"✅ Average inference time: {avg_inference_time:.2f}ms")
|
||||||
|
logger.info(f"✅ Max inference time: {max_inference_time:.2f}ms")
|
||||||
|
logger.info(f"✅ Inferences per second: {1000.0/avg_inference_time:.1f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.test_results['test_inference_performance'] = {'status': 'FAILED', 'error': str(e)}
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def test_signal_accumulation(self):
|
||||||
|
"""Test signal accumulation and consensus logic"""
|
||||||
|
logger.info("🎯 Testing Signal Accumulation...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.trader:
|
||||||
|
self.trading_executor = TradingExecutor(simulation_mode=True)
|
||||||
|
self.trader = RealtimeRLCOBTrader(
|
||||||
|
symbols=['BTC/USDT'],
|
||||||
|
trading_executor=self.trading_executor,
|
||||||
|
required_confident_predictions=3
|
||||||
|
)
|
||||||
|
|
||||||
|
symbol = 'BTC/USDT'
|
||||||
|
accumulator = self.trader.signal_accumulators[symbol]
|
||||||
|
|
||||||
|
# Test adding signals
|
||||||
|
test_predictions = []
|
||||||
|
for i in range(5):
|
||||||
|
prediction = PredictionResult(
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
symbol=symbol,
|
||||||
|
predicted_direction=2, # UP
|
||||||
|
confidence=0.8,
|
||||||
|
predicted_change=0.001,
|
||||||
|
features=np.random.randn(2000).astype(np.float32)
|
||||||
|
)
|
||||||
|
test_predictions.append(prediction)
|
||||||
|
self.trader._add_signal(symbol, prediction)
|
||||||
|
|
||||||
|
# Check accumulator state
|
||||||
|
assert len(accumulator.signals) == 5
|
||||||
|
assert accumulator.confidence_sum == 5 * 0.8
|
||||||
|
assert accumulator.total_predictions == 5
|
||||||
|
|
||||||
|
# Test consensus logic (simulate processing)
|
||||||
|
recent_signals = list(accumulator.signals)[-3:]
|
||||||
|
directions = [signal.predicted_direction for signal in recent_signals]
|
||||||
|
|
||||||
|
# All should be direction 2 (UP)
|
||||||
|
direction_counts = {0: 0, 1: 0, 2: 0}
|
||||||
|
for direction in directions:
|
||||||
|
direction_counts[direction] += 1
|
||||||
|
|
||||||
|
dominant_direction = max(direction_counts, key=direction_counts.get)
|
||||||
|
consensus_count = direction_counts[dominant_direction]
|
||||||
|
|
||||||
|
assert dominant_direction == 2
|
||||||
|
assert consensus_count == 3
|
||||||
|
|
||||||
|
self.test_results['test_signal_accumulation'] = {
|
||||||
|
'status': 'PASSED',
|
||||||
|
'signals_added': len(accumulator.signals),
|
||||||
|
'confidence_sum': accumulator.confidence_sum,
|
||||||
|
'consensus_direction': dominant_direction,
|
||||||
|
'consensus_count': consensus_count
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("✅ Signal accumulation test PASSED")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.test_results['test_signal_accumulation'] = {'status': 'FAILED', 'error': str(e)}
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def test_training_pipeline(self):
|
||||||
|
"""Test training pipeline functionality"""
|
||||||
|
logger.info("🧠 Testing Training Pipeline...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.trader:
|
||||||
|
self.trading_executor = TradingExecutor(simulation_mode=True)
|
||||||
|
self.trader = RealtimeRLCOBTrader(
|
||||||
|
symbols=['BTC/USDT'],
|
||||||
|
trading_executor=self.trading_executor
|
||||||
|
)
|
||||||
|
|
||||||
|
symbol = 'BTC/USDT'
|
||||||
|
|
||||||
|
# Create mock training data
|
||||||
|
test_predictions = []
|
||||||
|
for i in range(10):
|
||||||
|
prediction = PredictionResult(
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
symbol=symbol,
|
||||||
|
predicted_direction=np.random.randint(0, 3),
|
||||||
|
confidence=np.random.uniform(0.5, 1.0),
|
||||||
|
predicted_change=np.random.uniform(-0.001, 0.001),
|
||||||
|
features=np.random.randn(2000).astype(np.float32),
|
||||||
|
actual_direction=np.random.randint(0, 3),
|
||||||
|
actual_change=np.random.uniform(-0.001, 0.001),
|
||||||
|
reward=np.random.uniform(-1.0, 1.0)
|
||||||
|
)
|
||||||
|
test_predictions.append(prediction)
|
||||||
|
|
||||||
|
# Test training batch
|
||||||
|
loss = await self.trader._train_batch(symbol, test_predictions)
|
||||||
|
|
||||||
|
assert isinstance(loss, float)
|
||||||
|
assert not np.isnan(loss)
|
||||||
|
assert not np.isinf(loss)
|
||||||
|
assert loss >= 0.0 # Loss should be non-negative
|
||||||
|
|
||||||
|
self.test_results['test_training_pipeline'] = {
|
||||||
|
'status': 'PASSED',
|
||||||
|
'training_loss': float(loss),
|
||||||
|
'batch_size': len(test_predictions),
|
||||||
|
'training_successful': True
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"✅ Training pipeline test PASSED (loss: {loss:.6f})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.test_results['test_training_pipeline'] = {'status': 'FAILED', 'error': str(e)}
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def test_trading_integration(self):
|
||||||
|
"""Test integration with trading executor"""
|
||||||
|
logger.info("💰 Testing Trading Integration...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize with simulation mode
|
||||||
|
trading_executor = TradingExecutor(simulation_mode=True)
|
||||||
|
|
||||||
|
# Test signal execution
|
||||||
|
success = trading_executor.execute_signal(
|
||||||
|
symbol='BTC/USDT',
|
||||||
|
action='BUY',
|
||||||
|
confidence=0.8,
|
||||||
|
current_price=50000.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# In simulation mode, this should always succeed
|
||||||
|
assert success == True
|
||||||
|
|
||||||
|
# Check positions
|
||||||
|
positions = trading_executor.get_positions()
|
||||||
|
assert 'BTC/USDT' in positions
|
||||||
|
|
||||||
|
# Test sell signal
|
||||||
|
success = trading_executor.execute_signal(
|
||||||
|
symbol='BTC/USDT',
|
||||||
|
action='SELL',
|
||||||
|
confidence=0.8,
|
||||||
|
current_price=50100.0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert success == True
|
||||||
|
|
||||||
|
# Check trade history
|
||||||
|
trade_history = trading_executor.get_trade_history()
|
||||||
|
assert len(trade_history) > 0
|
||||||
|
|
||||||
|
last_trade = trade_history[-1]
|
||||||
|
assert last_trade.symbol == 'BTC/USDT'
|
||||||
|
assert last_trade.pnl != 0 # Should have some P&L
|
||||||
|
|
||||||
|
self.test_results['test_trading_integration'] = {
|
||||||
|
'status': 'PASSED',
|
||||||
|
'simulation_mode': True,
|
||||||
|
'trades_executed': len(trade_history),
|
||||||
|
'last_trade_pnl': float(last_trade.pnl)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("✅ Trading integration test PASSED")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.test_results['test_trading_integration'] = {'status': 'FAILED', 'error': str(e)}
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def test_performance_monitoring(self):
|
||||||
|
"""Test performance monitoring and statistics"""
|
||||||
|
logger.info("📊 Testing Performance Monitoring...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.trader:
|
||||||
|
self.trading_executor = TradingExecutor(simulation_mode=True)
|
||||||
|
self.trader = RealtimeRLCOBTrader(
|
||||||
|
symbols=['BTC/USDT', 'ETH/USDT'],
|
||||||
|
trading_executor=self.trading_executor
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get performance stats
|
||||||
|
stats = self.trader.get_performance_stats()
|
||||||
|
|
||||||
|
# Verify structure
|
||||||
|
assert 'symbols' in stats
|
||||||
|
assert 'training_stats' in stats
|
||||||
|
assert 'inference_stats' in stats
|
||||||
|
assert 'signal_stats' in stats
|
||||||
|
assert 'model_info' in stats
|
||||||
|
|
||||||
|
# Check symbols
|
||||||
|
assert 'BTC/USDT' in stats['symbols']
|
||||||
|
assert 'ETH/USDT' in stats['symbols']
|
||||||
|
|
||||||
|
# Check model info
|
||||||
|
for symbol in stats['symbols']:
|
||||||
|
assert symbol in stats['model_info']
|
||||||
|
model_info = stats['model_info'][symbol]
|
||||||
|
assert 'total_parameters' in model_info
|
||||||
|
assert 'trainable_parameters' in model_info
|
||||||
|
assert model_info['total_parameters'] > 0
|
||||||
|
|
||||||
|
self.test_results['test_performance_monitoring'] = {
|
||||||
|
'status': 'PASSED',
|
||||||
|
'stats_structure_valid': True,
|
||||||
|
'symbols_tracked': len(stats['symbols']),
|
||||||
|
'model_info_available': len(stats['model_info'])
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("✅ Performance monitoring test PASSED")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.test_results['test_performance_monitoring'] = {'status': 'FAILED', 'error': str(e)}
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def generate_test_report(self):
|
||||||
|
"""Generate comprehensive test report"""
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("REAL-TIME RL COB TRADER TEST REPORT")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
total_tests = len(self.test_results)
|
||||||
|
passed_tests = sum(1 for result in self.test_results.values() if result['status'] == 'PASSED')
|
||||||
|
failed_tests = sum(1 for result in self.test_results.values() if result['status'] == 'FAILED')
|
||||||
|
warning_tests = sum(1 for result in self.test_results.values() if result['status'] == 'WARNING')
|
||||||
|
|
||||||
|
logger.info(f"📊 Test Summary:")
|
||||||
|
logger.info(f" Total Tests: {total_tests}")
|
||||||
|
logger.info(f" ✅ Passed: {passed_tests}")
|
||||||
|
logger.info(f" ⚠️ Warnings: {warning_tests}")
|
||||||
|
logger.info(f" ❌ Failed: {failed_tests}")
|
||||||
|
|
||||||
|
success_rate = (passed_tests / total_tests) * 100 if total_tests > 0 else 0
|
||||||
|
logger.info(f" Success Rate: {success_rate:.1f}%")
|
||||||
|
|
||||||
|
logger.info("\n📋 Detailed Results:")
|
||||||
|
for test_name, result in self.test_results.items():
|
||||||
|
status_icon = "✅" if result['status'] == 'PASSED' else "⚠️" if result['status'] == 'WARNING' else "❌"
|
||||||
|
logger.info(f" {status_icon} {test_name}: {result['status']}")
|
||||||
|
|
||||||
|
if result['status'] == 'FAILED':
|
||||||
|
logger.error(f" Error: {result.get('error', 'Unknown error')}")
|
||||||
|
|
||||||
|
# System readiness assessment
|
||||||
|
logger.info("\n🎯 System Readiness Assessment:")
|
||||||
|
if failed_tests == 0:
|
||||||
|
if warning_tests == 0:
|
||||||
|
logger.info(" 🟢 SYSTEM READY FOR DEPLOYMENT")
|
||||||
|
logger.info(" All tests passed. The real-time RL COB trader is ready for live operation.")
|
||||||
|
else:
|
||||||
|
logger.info(" 🟡 SYSTEM READY WITH WARNINGS")
|
||||||
|
logger.info(" System is functional but some performance warnings exist.")
|
||||||
|
else:
|
||||||
|
logger.info(" 🔴 SYSTEM NOT READY")
|
||||||
|
logger.info(" Critical issues found. Fix errors before deployment.")
|
||||||
|
|
||||||
|
# Save detailed report
|
||||||
|
report_data = {
|
||||||
|
'timestamp': datetime.now().isoformat(),
|
||||||
|
'test_summary': {
|
||||||
|
'total_tests': total_tests,
|
||||||
|
'passed_tests': passed_tests,
|
||||||
|
'warning_tests': warning_tests,
|
||||||
|
'failed_tests': failed_tests,
|
||||||
|
'success_rate': success_rate
|
||||||
|
},
|
||||||
|
'test_results': self.test_results,
|
||||||
|
'system_readiness': 'READY' if failed_tests == 0 else 'NOT_READY'
|
||||||
|
}
|
||||||
|
|
||||||
|
report_file = f"test_reports/realtime_rl_test_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.makedirs('test_reports', exist_ok=True)
|
||||||
|
|
||||||
|
with open(report_file, 'w') as f:
|
||||||
|
json.dump(report_data, f, indent=2, default=str)
|
||||||
|
|
||||||
|
logger.info(f"\n📄 Detailed report saved to: {report_file}")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main test entry point"""
|
||||||
|
logger.info("Starting Real-time RL COB Trader Test Suite...")
|
||||||
|
|
||||||
|
tester = RealtimeRLTester()
|
||||||
|
await tester.run_all_tests()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Set event loop policy for Windows compatibility
|
||||||
|
if hasattr(asyncio, 'WindowsProactorEventLoopPolicy'):
|
||||||
|
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||||
|
|
||||||
|
asyncio.run(main())
|
Reference in New Issue
Block a user