BIG CLEANUP
This commit is contained in:
@@ -1,402 +0,0 @@
|
||||
"""
|
||||
API Rate Limiter and Error Handler
|
||||
|
||||
This module provides robust rate limiting and error handling for API requests,
|
||||
specifically designed to handle Binance's aggressive rate limiting (HTTP 418 errors)
|
||||
and other exchange API limitations.
|
||||
|
||||
Features:
|
||||
- Exponential backoff for rate limiting
|
||||
- IP rotation and proxy support
|
||||
- Request queuing and throttling
|
||||
- Error recovery strategies
|
||||
- Thread-safe operations
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Callable, Any
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class RateLimitConfig:
|
||||
"""Configuration for rate limiting"""
|
||||
requests_per_second: float = 0.5 # Very conservative for Binance
|
||||
requests_per_minute: int = 20
|
||||
requests_per_hour: int = 1000
|
||||
|
||||
# Backoff configuration
|
||||
initial_backoff: float = 1.0
|
||||
max_backoff: float = 300.0 # 5 minutes max
|
||||
backoff_multiplier: float = 2.0
|
||||
|
||||
# Error handling
|
||||
max_retries: int = 3
|
||||
retry_delay: float = 5.0
|
||||
|
||||
# IP blocking detection
|
||||
block_detection_threshold: int = 3 # 3 consecutive 418s = blocked
|
||||
block_recovery_time: int = 3600 # 1 hour recovery time
|
||||
|
||||
@dataclass
|
||||
class APIEndpoint:
|
||||
"""API endpoint configuration"""
|
||||
name: str
|
||||
base_url: str
|
||||
rate_limit: RateLimitConfig
|
||||
last_request_time: float = 0.0
|
||||
request_count_minute: int = 0
|
||||
request_count_hour: int = 0
|
||||
consecutive_errors: int = 0
|
||||
blocked_until: Optional[datetime] = None
|
||||
|
||||
# Request history for rate limiting
|
||||
request_history: deque = field(default_factory=lambda: deque(maxlen=3600)) # 1 hour history
|
||||
|
||||
class APIRateLimiter:
|
||||
"""Thread-safe API rate limiter with error handling"""
|
||||
|
||||
def __init__(self, config: RateLimitConfig = None):
|
||||
self.config = config or RateLimitConfig()
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Endpoint tracking
|
||||
self.endpoints: Dict[str, APIEndpoint] = {}
|
||||
|
||||
# Global rate limiting
|
||||
self.global_request_history = deque(maxlen=3600)
|
||||
self.global_blocked_until: Optional[datetime] = None
|
||||
|
||||
# Request session with retry strategy
|
||||
self.session = self._create_session()
|
||||
|
||||
# Background cleanup thread
|
||||
self.cleanup_thread = None
|
||||
self.is_running = False
|
||||
|
||||
logger.info("API Rate Limiter initialized")
|
||||
logger.info(f"Rate limits: {self.config.requests_per_second}/s, {self.config.requests_per_minute}/m")
|
||||
|
||||
def _create_session(self) -> requests.Session:
|
||||
"""Create requests session with retry strategy"""
|
||||
session = requests.Session()
|
||||
|
||||
# Retry strategy
|
||||
retry_strategy = Retry(
|
||||
total=self.config.max_retries,
|
||||
backoff_factor=1,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
allowed_methods=["HEAD", "GET", "OPTIONS"]
|
||||
)
|
||||
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
# Headers to appear more legitimate
|
||||
session.headers.update({
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||
'Accept': 'application/json',
|
||||
'Accept-Language': 'en-US,en;q=0.9',
|
||||
'Accept-Encoding': 'gzip, deflate, br',
|
||||
'Connection': 'keep-alive',
|
||||
'Upgrade-Insecure-Requests': '1',
|
||||
})
|
||||
|
||||
return session
|
||||
|
||||
def register_endpoint(self, name: str, base_url: str, rate_limit: RateLimitConfig = None):
|
||||
"""Register an API endpoint for rate limiting"""
|
||||
with self.lock:
|
||||
self.endpoints[name] = APIEndpoint(
|
||||
name=name,
|
||||
base_url=base_url,
|
||||
rate_limit=rate_limit or self.config
|
||||
)
|
||||
logger.info(f"Registered endpoint: {name} -> {base_url}")
|
||||
|
||||
def start_background_cleanup(self):
|
||||
"""Start background cleanup thread"""
|
||||
if self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.cleanup_thread = threading.Thread(target=self._cleanup_worker, daemon=True)
|
||||
self.cleanup_thread.start()
|
||||
logger.info("Started background cleanup thread")
|
||||
|
||||
def stop_background_cleanup(self):
|
||||
"""Stop background cleanup thread"""
|
||||
self.is_running = False
|
||||
if self.cleanup_thread:
|
||||
self.cleanup_thread.join(timeout=5)
|
||||
logger.info("Stopped background cleanup thread")
|
||||
|
||||
def _cleanup_worker(self):
|
||||
"""Background worker to clean up old request history"""
|
||||
while self.is_running:
|
||||
try:
|
||||
current_time = time.time()
|
||||
cutoff_time = current_time - 3600 # 1 hour ago
|
||||
|
||||
with self.lock:
|
||||
# Clean global history
|
||||
while (self.global_request_history and
|
||||
self.global_request_history[0] < cutoff_time):
|
||||
self.global_request_history.popleft()
|
||||
|
||||
# Clean endpoint histories
|
||||
for endpoint in self.endpoints.values():
|
||||
while (endpoint.request_history and
|
||||
endpoint.request_history[0] < cutoff_time):
|
||||
endpoint.request_history.popleft()
|
||||
|
||||
# Reset counters
|
||||
endpoint.request_count_minute = len([
|
||||
t for t in endpoint.request_history
|
||||
if t > current_time - 60
|
||||
])
|
||||
endpoint.request_count_hour = len(endpoint.request_history)
|
||||
|
||||
time.sleep(60) # Clean every minute
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup worker: {e}")
|
||||
time.sleep(30)
|
||||
|
||||
def can_make_request(self, endpoint_name: str) -> tuple[bool, float]:
|
||||
"""
|
||||
Check if we can make a request to the endpoint
|
||||
|
||||
Returns:
|
||||
(can_make_request, wait_time_seconds)
|
||||
"""
|
||||
with self.lock:
|
||||
current_time = time.time()
|
||||
|
||||
# Check global blocking
|
||||
if self.global_blocked_until and datetime.now() < self.global_blocked_until:
|
||||
wait_time = (self.global_blocked_until - datetime.now()).total_seconds()
|
||||
return False, wait_time
|
||||
|
||||
# Get endpoint
|
||||
endpoint = self.endpoints.get(endpoint_name)
|
||||
if not endpoint:
|
||||
logger.warning(f"Unknown endpoint: {endpoint_name}")
|
||||
return False, 60.0
|
||||
|
||||
# Check endpoint blocking
|
||||
if endpoint.blocked_until and datetime.now() < endpoint.blocked_until:
|
||||
wait_time = (endpoint.blocked_until - datetime.now()).total_seconds()
|
||||
return False, wait_time
|
||||
|
||||
# Check rate limits
|
||||
config = endpoint.rate_limit
|
||||
|
||||
# Per-second rate limit
|
||||
time_since_last = current_time - endpoint.last_request_time
|
||||
if time_since_last < (1.0 / config.requests_per_second):
|
||||
wait_time = (1.0 / config.requests_per_second) - time_since_last
|
||||
return False, wait_time
|
||||
|
||||
# Per-minute rate limit
|
||||
minute_requests = len([
|
||||
t for t in endpoint.request_history
|
||||
if t > current_time - 60
|
||||
])
|
||||
if minute_requests >= config.requests_per_minute:
|
||||
return False, 60.0
|
||||
|
||||
# Per-hour rate limit
|
||||
if len(endpoint.request_history) >= config.requests_per_hour:
|
||||
return False, 3600.0
|
||||
|
||||
return True, 0.0
|
||||
|
||||
def make_request(self, endpoint_name: str, url: str, method: str = 'GET',
|
||||
**kwargs) -> Optional[requests.Response]:
|
||||
"""
|
||||
Make a rate-limited request with error handling
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the registered endpoint
|
||||
url: Full URL to request
|
||||
method: HTTP method
|
||||
**kwargs: Additional arguments for requests
|
||||
|
||||
Returns:
|
||||
Response object or None if failed
|
||||
"""
|
||||
with self.lock:
|
||||
endpoint = self.endpoints.get(endpoint_name)
|
||||
if not endpoint:
|
||||
logger.error(f"Unknown endpoint: {endpoint_name}")
|
||||
return None
|
||||
|
||||
# Check if we can make the request
|
||||
can_request, wait_time = self.can_make_request(endpoint_name)
|
||||
if not can_request:
|
||||
logger.debug(f"Rate limited for {endpoint_name}, waiting {wait_time:.2f}s")
|
||||
time.sleep(min(wait_time, 30)) # Cap wait time
|
||||
return None
|
||||
|
||||
# Record request attempt
|
||||
current_time = time.time()
|
||||
endpoint.last_request_time = current_time
|
||||
endpoint.request_history.append(current_time)
|
||||
self.global_request_history.append(current_time)
|
||||
|
||||
# Add jitter to avoid thundering herd
|
||||
jitter = random.uniform(0.1, 0.5)
|
||||
time.sleep(jitter)
|
||||
|
||||
# Make the request (outside of lock to avoid blocking other threads)
|
||||
try:
|
||||
# Set timeout
|
||||
kwargs.setdefault('timeout', 10)
|
||||
|
||||
# Make request
|
||||
response = self.session.request(method, url, **kwargs)
|
||||
|
||||
# Handle response
|
||||
with self.lock:
|
||||
if response.status_code == 200:
|
||||
# Success - reset error counter
|
||||
endpoint.consecutive_errors = 0
|
||||
return response
|
||||
|
||||
elif response.status_code == 418:
|
||||
# Binance "I'm a teapot" - rate limited/blocked
|
||||
endpoint.consecutive_errors += 1
|
||||
logger.warning(f"HTTP 418 (rate limited) for {endpoint_name}, consecutive errors: {endpoint.consecutive_errors}")
|
||||
|
||||
if endpoint.consecutive_errors >= endpoint.rate_limit.block_detection_threshold:
|
||||
# We're likely IP blocked
|
||||
block_time = datetime.now() + timedelta(seconds=endpoint.rate_limit.block_recovery_time)
|
||||
endpoint.blocked_until = block_time
|
||||
logger.error(f"Endpoint {endpoint_name} blocked until {block_time}")
|
||||
|
||||
return None
|
||||
|
||||
elif response.status_code == 429:
|
||||
# Too many requests
|
||||
endpoint.consecutive_errors += 1
|
||||
logger.warning(f"HTTP 429 (too many requests) for {endpoint_name}")
|
||||
|
||||
# Implement exponential backoff
|
||||
backoff_time = min(
|
||||
endpoint.rate_limit.initial_backoff * (endpoint.rate_limit.backoff_multiplier ** endpoint.consecutive_errors),
|
||||
endpoint.rate_limit.max_backoff
|
||||
)
|
||||
|
||||
block_time = datetime.now() + timedelta(seconds=backoff_time)
|
||||
endpoint.blocked_until = block_time
|
||||
logger.warning(f"Backing off {endpoint_name} for {backoff_time:.2f}s")
|
||||
|
||||
return None
|
||||
|
||||
else:
|
||||
# Other error
|
||||
endpoint.consecutive_errors += 1
|
||||
logger.warning(f"HTTP {response.status_code} for {endpoint_name}: {response.text[:200]}")
|
||||
return None
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
with self.lock:
|
||||
endpoint.consecutive_errors += 1
|
||||
logger.error(f"Request exception for {endpoint_name}: {e}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
with self.lock:
|
||||
endpoint.consecutive_errors += 1
|
||||
logger.error(f"Unexpected error for {endpoint_name}: {e}")
|
||||
return None
|
||||
|
||||
def get_endpoint_status(self, endpoint_name: str) -> Dict[str, Any]:
|
||||
"""Get status information for an endpoint"""
|
||||
with self.lock:
|
||||
endpoint = self.endpoints.get(endpoint_name)
|
||||
if not endpoint:
|
||||
return {'error': 'Unknown endpoint'}
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
return {
|
||||
'name': endpoint.name,
|
||||
'base_url': endpoint.base_url,
|
||||
'consecutive_errors': endpoint.consecutive_errors,
|
||||
'blocked_until': endpoint.blocked_until.isoformat() if endpoint.blocked_until else None,
|
||||
'requests_last_minute': len([t for t in endpoint.request_history if t > current_time - 60]),
|
||||
'requests_last_hour': len(endpoint.request_history),
|
||||
'last_request_time': endpoint.last_request_time,
|
||||
'can_make_request': self.can_make_request(endpoint_name)[0]
|
||||
}
|
||||
|
||||
def get_all_endpoint_status(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get status for all endpoints"""
|
||||
return {name: self.get_endpoint_status(name) for name in self.endpoints.keys()}
|
||||
|
||||
def reset_endpoint(self, endpoint_name: str):
|
||||
"""Reset an endpoint's error state"""
|
||||
with self.lock:
|
||||
endpoint = self.endpoints.get(endpoint_name)
|
||||
if endpoint:
|
||||
endpoint.consecutive_errors = 0
|
||||
endpoint.blocked_until = None
|
||||
logger.info(f"Reset endpoint: {endpoint_name}")
|
||||
|
||||
def reset_all_endpoints(self):
|
||||
"""Reset all endpoints' error states"""
|
||||
with self.lock:
|
||||
for endpoint in self.endpoints.values():
|
||||
endpoint.consecutive_errors = 0
|
||||
endpoint.blocked_until = None
|
||||
self.global_blocked_until = None
|
||||
logger.info("Reset all endpoints")
|
||||
|
||||
# Global rate limiter instance
|
||||
_global_rate_limiter = None
|
||||
|
||||
def get_rate_limiter() -> APIRateLimiter:
|
||||
"""Get global rate limiter instance"""
|
||||
global _global_rate_limiter
|
||||
if _global_rate_limiter is None:
|
||||
_global_rate_limiter = APIRateLimiter()
|
||||
_global_rate_limiter.start_background_cleanup()
|
||||
|
||||
# Register common endpoints
|
||||
_global_rate_limiter.register_endpoint(
|
||||
'binance_api',
|
||||
'https://api.binance.com',
|
||||
RateLimitConfig(
|
||||
requests_per_second=0.2, # Very conservative
|
||||
requests_per_minute=10,
|
||||
requests_per_hour=500
|
||||
)
|
||||
)
|
||||
|
||||
_global_rate_limiter.register_endpoint(
|
||||
'mexc_api',
|
||||
'https://api.mexc.com',
|
||||
RateLimitConfig(
|
||||
requests_per_second=0.5,
|
||||
requests_per_minute=20,
|
||||
requests_per_hour=1000
|
||||
)
|
||||
)
|
||||
|
||||
return _global_rate_limiter
|
||||
@@ -1,442 +0,0 @@
|
||||
"""
|
||||
Async Handler for UI Stability Fix
|
||||
|
||||
Properly handles all async operations in the dashboard with single event loop management,
|
||||
proper exception handling, and timeout support to prevent async/await errors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Callable, Coroutine, Dict, Optional, Union
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import functools
|
||||
import weakref
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncOperationError(Exception):
|
||||
"""Exception raised for async operation errors"""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncHandler:
|
||||
"""
|
||||
Centralized async operation handler with single event loop management
|
||||
and proper exception handling for async operations.
|
||||
"""
|
||||
|
||||
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None):
|
||||
"""
|
||||
Initialize the async handler
|
||||
|
||||
Args:
|
||||
loop: Optional event loop to use. If None, creates a new one.
|
||||
"""
|
||||
self._loop = loop
|
||||
self._thread = None
|
||||
self._executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="AsyncHandler")
|
||||
self._running = False
|
||||
self._callbacks = weakref.WeakSet()
|
||||
self._timeout_default = 30.0 # Default timeout for operations
|
||||
|
||||
# Start the event loop in a separate thread if not provided
|
||||
if self._loop is None:
|
||||
self._start_event_loop_thread()
|
||||
|
||||
logger.info("AsyncHandler initialized with event loop management")
|
||||
|
||||
def _start_event_loop_thread(self):
|
||||
"""Start the event loop in a separate thread"""
|
||||
def run_event_loop():
|
||||
"""Run the event loop in a separate thread"""
|
||||
try:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._running = True
|
||||
logger.debug("Event loop started in separate thread")
|
||||
self._loop.run_forever()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in event loop thread: {e}")
|
||||
finally:
|
||||
self._running = False
|
||||
logger.debug("Event loop thread stopped")
|
||||
|
||||
self._thread = threading.Thread(target=run_event_loop, daemon=True, name="AsyncHandler-EventLoop")
|
||||
self._thread.start()
|
||||
|
||||
# Wait for the loop to be ready
|
||||
timeout = 5.0
|
||||
start_time = time.time()
|
||||
while not self._running and (time.time() - start_time) < timeout:
|
||||
time.sleep(0.1)
|
||||
|
||||
if not self._running:
|
||||
raise AsyncOperationError("Failed to start event loop within timeout")
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the async handler is running"""
|
||||
return self._running and self._loop is not None and not self._loop.is_closed()
|
||||
|
||||
def run_async_safely(self, coro: Coroutine, timeout: Optional[float] = None) -> Any:
|
||||
"""
|
||||
Run an async coroutine safely with proper error handling and timeout
|
||||
|
||||
Args:
|
||||
coro: The coroutine to run
|
||||
timeout: Timeout in seconds (uses default if None)
|
||||
|
||||
Returns:
|
||||
The result of the coroutine
|
||||
|
||||
Raises:
|
||||
AsyncOperationError: If the operation fails or times out
|
||||
"""
|
||||
if not self.is_running():
|
||||
raise AsyncOperationError("AsyncHandler is not running")
|
||||
|
||||
timeout = timeout or self._timeout_default
|
||||
|
||||
try:
|
||||
# Schedule the coroutine on the event loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
asyncio.wait_for(coro, timeout=timeout),
|
||||
self._loop
|
||||
)
|
||||
|
||||
# Wait for the result with timeout
|
||||
result = future.result(timeout=timeout + 1.0) # Add buffer to future timeout
|
||||
logger.debug("Async operation completed successfully")
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Async operation timed out after {timeout} seconds")
|
||||
raise AsyncOperationError(f"Operation timed out after {timeout} seconds")
|
||||
except Exception as e:
|
||||
logger.error(f"Async operation failed: {e}")
|
||||
raise AsyncOperationError(f"Async operation failed: {e}")
|
||||
|
||||
def schedule_coroutine(self, coro: Coroutine, callback: Optional[Callable] = None) -> None:
|
||||
"""
|
||||
Schedule a coroutine to run asynchronously without waiting for result
|
||||
|
||||
Args:
|
||||
coro: The coroutine to schedule
|
||||
callback: Optional callback to call with the result
|
||||
"""
|
||||
if not self.is_running():
|
||||
logger.warning("Cannot schedule coroutine: AsyncHandler is not running")
|
||||
return
|
||||
|
||||
async def wrapped_coro():
|
||||
"""Wrapper to handle exceptions and callbacks"""
|
||||
try:
|
||||
result = await coro
|
||||
if callback:
|
||||
try:
|
||||
callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in coroutine callback: {e}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scheduled coroutine: {e}")
|
||||
if callback:
|
||||
try:
|
||||
callback(None) # Call callback with None on error
|
||||
except Exception as cb_e:
|
||||
logger.error(f"Error in error callback: {cb_e}")
|
||||
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(wrapped_coro(), self._loop)
|
||||
logger.debug("Coroutine scheduled successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to schedule coroutine: {e}")
|
||||
|
||||
def create_task_safely(self, coro: Coroutine, name: Optional[str] = None) -> Optional[asyncio.Task]:
|
||||
"""
|
||||
Create an asyncio task safely with proper error handling
|
||||
|
||||
Args:
|
||||
coro: The coroutine to create a task for
|
||||
name: Optional name for the task
|
||||
|
||||
Returns:
|
||||
The created task or None if failed
|
||||
"""
|
||||
if not self.is_running():
|
||||
logger.warning("Cannot create task: AsyncHandler is not running")
|
||||
return None
|
||||
|
||||
async def create_task():
|
||||
"""Create the task in the event loop"""
|
||||
try:
|
||||
task = asyncio.create_task(coro, name=name)
|
||||
logger.debug(f"Task created: {name or 'unnamed'}")
|
||||
return task
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create task {name}: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(create_task(), self._loop)
|
||||
return future.result(timeout=5.0)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create task {name}: {e}")
|
||||
return None
|
||||
|
||||
async def handle_orchestrator_connection(self, orchestrator) -> bool:
|
||||
"""
|
||||
Handle orchestrator connection with proper async patterns
|
||||
|
||||
Args:
|
||||
orchestrator: The orchestrator instance to connect to
|
||||
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
logger.info("Connecting to orchestrator...")
|
||||
|
||||
# Add decision callback if orchestrator supports it
|
||||
if hasattr(orchestrator, 'add_decision_callback'):
|
||||
await orchestrator.add_decision_callback(self._handle_trading_decision)
|
||||
logger.info("Decision callback added to orchestrator")
|
||||
|
||||
# Start COB integration if available
|
||||
if hasattr(orchestrator, 'start_cob_integration'):
|
||||
await orchestrator.start_cob_integration()
|
||||
logger.info("COB integration started")
|
||||
|
||||
# Start continuous trading if available
|
||||
if hasattr(orchestrator, 'start_continuous_trading'):
|
||||
await orchestrator.start_continuous_trading()
|
||||
logger.info("Continuous trading started")
|
||||
|
||||
logger.info("Successfully connected to orchestrator")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to orchestrator: {e}")
|
||||
return False
|
||||
|
||||
async def handle_cob_integration(self, cob_integration) -> bool:
|
||||
"""
|
||||
Handle COB integration startup with proper async patterns
|
||||
|
||||
Args:
|
||||
cob_integration: The COB integration instance
|
||||
|
||||
Returns:
|
||||
True if startup successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
logger.info("Starting COB integration...")
|
||||
|
||||
if hasattr(cob_integration, 'start'):
|
||||
await cob_integration.start()
|
||||
logger.info("COB integration started successfully")
|
||||
return True
|
||||
else:
|
||||
logger.warning("COB integration does not have start method")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start COB integration: {e}")
|
||||
return False
|
||||
|
||||
async def _handle_trading_decision(self, decision: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Handle trading decision with proper async patterns
|
||||
|
||||
Args:
|
||||
decision: The trading decision dictionary
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Handling trading decision: {decision.get('action', 'UNKNOWN')}")
|
||||
|
||||
# Process the decision (this would be customized based on needs)
|
||||
# For now, just log it
|
||||
symbol = decision.get('symbol', 'UNKNOWN')
|
||||
action = decision.get('action', 'HOLD')
|
||||
confidence = decision.get('confidence', 0.0)
|
||||
|
||||
logger.info(f"Trading decision processed: {action} {symbol} (confidence: {confidence:.2f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling trading decision: {e}")
|
||||
|
||||
def run_in_executor(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Run a blocking function in the thread pool executor
|
||||
|
||||
Args:
|
||||
func: The function to run
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
The result of the function
|
||||
"""
|
||||
if not self.is_running():
|
||||
raise AsyncOperationError("AsyncHandler is not running")
|
||||
|
||||
try:
|
||||
# Create a partial function with the arguments
|
||||
partial_func = functools.partial(func, *args, **kwargs)
|
||||
|
||||
# Create a coroutine that runs the function in executor
|
||||
async def run_in_executor_coro():
|
||||
return await self._loop.run_in_executor(self._executor, partial_func)
|
||||
|
||||
# Run the coroutine
|
||||
future = asyncio.run_coroutine_threadsafe(run_in_executor_coro(), self._loop)
|
||||
|
||||
result = future.result(timeout=self._timeout_default)
|
||||
logger.debug("Executor function completed successfully")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running function in executor: {e}")
|
||||
raise AsyncOperationError(f"Executor function failed: {e}")
|
||||
|
||||
def add_periodic_task(self, coro_func: Callable[[], Coroutine], interval: float, name: Optional[str] = None) -> Optional[asyncio.Task]:
|
||||
"""
|
||||
Add a periodic task that runs at specified intervals
|
||||
|
||||
Args:
|
||||
coro_func: Function that returns a coroutine to run periodically
|
||||
interval: Interval in seconds between runs
|
||||
name: Optional name for the task
|
||||
|
||||
Returns:
|
||||
The created task or None if failed
|
||||
"""
|
||||
async def periodic_runner():
|
||||
"""Run the coroutine periodically"""
|
||||
task_name = name or "periodic_task"
|
||||
logger.info(f"Starting periodic task: {task_name} (interval: {interval}s)")
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
coro = coro_func()
|
||||
await coro
|
||||
logger.debug(f"Periodic task {task_name} completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in periodic task {task_name}: {e}")
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Periodic task {task_name} cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in periodic task {task_name}: {e}")
|
||||
|
||||
return self.create_task_safely(periodic_runner(), name=f"periodic_{name}")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the async handler and clean up resources"""
|
||||
try:
|
||||
logger.info("Stopping AsyncHandler...")
|
||||
|
||||
if self._loop and not self._loop.is_closed():
|
||||
# Cancel all tasks
|
||||
if self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(self._cancel_all_tasks(), self._loop)
|
||||
|
||||
# Stop the event loop
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
|
||||
# Shutdown executor
|
||||
if self._executor:
|
||||
self._executor.shutdown(wait=True)
|
||||
|
||||
# Wait for thread to finish
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=5.0)
|
||||
|
||||
self._running = False
|
||||
logger.info("AsyncHandler stopped successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping AsyncHandler: {e}")
|
||||
|
||||
async def _cancel_all_tasks(self) -> None:
|
||||
"""Cancel all running tasks"""
|
||||
try:
|
||||
tasks = [task for task in asyncio.all_tasks(self._loop) if not task.done()]
|
||||
if tasks:
|
||||
logger.info(f"Cancelling {len(tasks)} running tasks")
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for tasks to be cancelled
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.debug("All tasks cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling tasks: {e}")
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit"""
|
||||
self.stop()
|
||||
|
||||
|
||||
class AsyncContextManager:
|
||||
"""
|
||||
Context manager for async operations that ensures proper cleanup
|
||||
"""
|
||||
|
||||
def __init__(self, async_handler: AsyncHandler):
|
||||
self.async_handler = async_handler
|
||||
self.active_tasks = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Cancel any active tasks
|
||||
for task in self.active_tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
def create_task(self, coro: Coroutine, name: Optional[str] = None) -> Optional[asyncio.Task]:
|
||||
"""Create a task and track it for cleanup"""
|
||||
task = self.async_handler.create_task_safely(coro, name)
|
||||
if task:
|
||||
self.active_tasks.append(task)
|
||||
return task
|
||||
|
||||
|
||||
def create_async_handler(loop: Optional[asyncio.AbstractEventLoop] = None) -> AsyncHandler:
|
||||
"""
|
||||
Factory function to create an AsyncHandler instance
|
||||
|
||||
Args:
|
||||
loop: Optional event loop to use
|
||||
|
||||
Returns:
|
||||
AsyncHandler instance
|
||||
"""
|
||||
return AsyncHandler(loop=loop)
|
||||
|
||||
|
||||
def run_async_safely(coro: Coroutine, timeout: Optional[float] = None) -> Any:
|
||||
"""
|
||||
Convenience function to run a coroutine safely with a temporary AsyncHandler
|
||||
|
||||
Args:
|
||||
coro: The coroutine to run
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
The result of the coroutine
|
||||
"""
|
||||
with AsyncHandler() as handler:
|
||||
return handler.run_async_safely(coro, timeout=timeout)
|
||||
@@ -1,952 +0,0 @@
|
||||
"""
|
||||
Bookmap Order Book Data Provider
|
||||
|
||||
This module integrates with Bookmap to gather:
|
||||
- Current Order Book (COB) data
|
||||
- Session Volume Profile (SVP) data
|
||||
- Order book sweeps and momentum trades detection
|
||||
- Real-time order size heatmap matrix (last 10 minutes)
|
||||
- Level 2 market depth analysis
|
||||
|
||||
The data is processed and fed to CNN and DQN networks for enhanced trading decisions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import websockets
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||
from collections import deque, defaultdict
|
||||
from dataclasses import dataclass
|
||||
from threading import Thread, Lock
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class OrderBookLevel:
|
||||
"""Represents a single order book level"""
|
||||
price: float
|
||||
size: float
|
||||
orders: int
|
||||
side: str # 'bid' or 'ask'
|
||||
timestamp: datetime
|
||||
|
||||
@dataclass
|
||||
class OrderBookSnapshot:
|
||||
"""Complete order book snapshot"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
bids: List[OrderBookLevel]
|
||||
asks: List[OrderBookLevel]
|
||||
spread: float
|
||||
mid_price: float
|
||||
|
||||
@dataclass
|
||||
class VolumeProfileLevel:
|
||||
"""Volume profile level data"""
|
||||
price: float
|
||||
volume: float
|
||||
buy_volume: float
|
||||
sell_volume: float
|
||||
trades_count: int
|
||||
vwap: float
|
||||
|
||||
@dataclass
|
||||
class OrderFlowSignal:
|
||||
"""Order flow signal detection"""
|
||||
timestamp: datetime
|
||||
signal_type: str # 'sweep', 'absorption', 'iceberg', 'momentum'
|
||||
price: float
|
||||
volume: float
|
||||
confidence: float
|
||||
description: str
|
||||
|
||||
class BookmapDataProvider:
|
||||
"""
|
||||
Real-time order book data provider using Bookmap-style analysis
|
||||
|
||||
Features:
|
||||
- Level 2 order book monitoring
|
||||
- Order flow detection (sweeps, absorptions)
|
||||
- Volume profile analysis
|
||||
- Order size heatmap generation
|
||||
- Market microstructure analysis
|
||||
"""
|
||||
|
||||
def __init__(self, symbols: List[str] = None, depth_levels: int = 20):
|
||||
"""
|
||||
Initialize Bookmap data provider
|
||||
|
||||
Args:
|
||||
symbols: List of symbols to monitor
|
||||
depth_levels: Number of order book levels to track
|
||||
"""
|
||||
self.symbols = symbols or ['ETHUSDT', 'BTCUSDT']
|
||||
self.depth_levels = depth_levels
|
||||
self.is_streaming = False
|
||||
|
||||
# Order book data storage
|
||||
self.order_books: Dict[str, OrderBookSnapshot] = {}
|
||||
self.order_book_history: Dict[str, deque] = {}
|
||||
self.volume_profiles: Dict[str, List[VolumeProfileLevel]] = {}
|
||||
|
||||
# Heatmap data (10-minute rolling window)
|
||||
self.heatmap_window = timedelta(minutes=10)
|
||||
self.order_heatmaps: Dict[str, deque] = {}
|
||||
self.price_levels: Dict[str, List[float]] = {}
|
||||
|
||||
# Order flow detection
|
||||
self.flow_signals: Dict[str, deque] = {}
|
||||
self.sweep_threshold = 0.8 # Minimum confidence for sweep detection
|
||||
self.absorption_threshold = 0.7 # Minimum confidence for absorption
|
||||
|
||||
# Market microstructure metrics
|
||||
self.bid_ask_spreads: Dict[str, deque] = {}
|
||||
self.order_book_imbalances: Dict[str, deque] = {}
|
||||
self.liquidity_metrics: Dict[str, Dict] = {}
|
||||
|
||||
# WebSocket connections
|
||||
self.websocket_tasks: Dict[str, asyncio.Task] = {}
|
||||
self.data_lock = Lock()
|
||||
|
||||
# Callbacks for CNN/DQN integration
|
||||
self.cnn_callbacks: List[Callable] = []
|
||||
self.dqn_callbacks: List[Callable] = []
|
||||
|
||||
# Performance tracking
|
||||
self.update_counts = defaultdict(int)
|
||||
self.last_update_times = {}
|
||||
|
||||
# Initialize data structures
|
||||
for symbol in self.symbols:
|
||||
self.order_book_history[symbol] = deque(maxlen=1000)
|
||||
self.order_heatmaps[symbol] = deque(maxlen=600) # 10 min at 1s intervals
|
||||
self.flow_signals[symbol] = deque(maxlen=500)
|
||||
self.bid_ask_spreads[symbol] = deque(maxlen=1000)
|
||||
self.order_book_imbalances[symbol] = deque(maxlen=1000)
|
||||
self.liquidity_metrics[symbol] = {
|
||||
'total_bid_size': 0.0,
|
||||
'total_ask_size': 0.0,
|
||||
'weighted_mid': 0.0,
|
||||
'liquidity_ratio': 1.0
|
||||
}
|
||||
|
||||
logger.info(f"BookmapDataProvider initialized for {len(self.symbols)} symbols")
|
||||
logger.info(f"Tracking {depth_levels} order book levels per side")
|
||||
|
||||
def add_cnn_callback(self, callback: Callable[[str, Dict], None]):
|
||||
"""Add callback for CNN model updates"""
|
||||
self.cnn_callbacks.append(callback)
|
||||
logger.info(f"Added CNN callback: {len(self.cnn_callbacks)} total")
|
||||
|
||||
def add_dqn_callback(self, callback: Callable[[str, Dict], None]):
|
||||
"""Add callback for DQN model updates"""
|
||||
self.dqn_callbacks.append(callback)
|
||||
logger.info(f"Added DQN callback: {len(self.dqn_callbacks)} total")
|
||||
|
||||
async def start_streaming(self):
|
||||
"""Start real-time order book streaming"""
|
||||
if self.is_streaming:
|
||||
logger.warning("Bookmap streaming already active")
|
||||
return
|
||||
|
||||
self.is_streaming = True
|
||||
logger.info("Starting Bookmap order book streaming")
|
||||
|
||||
# Start order book streams for each symbol
|
||||
for symbol in self.symbols:
|
||||
# Order book depth stream
|
||||
depth_task = asyncio.create_task(self._stream_order_book_depth(symbol))
|
||||
self.websocket_tasks[f"{symbol}_depth"] = depth_task
|
||||
|
||||
# Trade stream for order flow analysis
|
||||
trade_task = asyncio.create_task(self._stream_trades(symbol))
|
||||
self.websocket_tasks[f"{symbol}_trades"] = trade_task
|
||||
|
||||
# Start analysis threads
|
||||
analysis_task = asyncio.create_task(self._continuous_analysis())
|
||||
self.websocket_tasks["analysis"] = analysis_task
|
||||
|
||||
logger.info(f"Started streaming for {len(self.symbols)} symbols")
|
||||
|
||||
async def stop_streaming(self):
|
||||
"""Stop order book streaming"""
|
||||
if not self.is_streaming:
|
||||
return
|
||||
|
||||
logger.info("Stopping Bookmap streaming")
|
||||
self.is_streaming = False
|
||||
|
||||
# Cancel all tasks
|
||||
for name, task in self.websocket_tasks.items():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self.websocket_tasks.clear()
|
||||
logger.info("Bookmap streaming stopped")
|
||||
|
||||
async def _stream_order_book_depth(self, symbol: str):
|
||||
"""Stream order book depth data"""
|
||||
binance_symbol = symbol.lower()
|
||||
url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@depth20@100ms"
|
||||
|
||||
while self.is_streaming:
|
||||
try:
|
||||
async with websockets.connect(url) as websocket:
|
||||
logger.info(f"Order book depth WebSocket connected for {symbol}")
|
||||
|
||||
async for message in websocket:
|
||||
if not self.is_streaming:
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._process_depth_update(symbol, data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing depth for {symbol}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Depth WebSocket error for {symbol}: {e}")
|
||||
if self.is_streaming:
|
||||
await asyncio.sleep(2)
|
||||
|
||||
async def _stream_trades(self, symbol: str):
|
||||
"""Stream trade data for order flow analysis"""
|
||||
binance_symbol = symbol.lower()
|
||||
url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@trade"
|
||||
|
||||
while self.is_streaming:
|
||||
try:
|
||||
async with websockets.connect(url) as websocket:
|
||||
logger.info(f"Trade WebSocket connected for {symbol}")
|
||||
|
||||
async for message in websocket:
|
||||
if not self.is_streaming:
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._process_trade_update(symbol, data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing trade for {symbol}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Trade WebSocket error for {symbol}: {e}")
|
||||
if self.is_streaming:
|
||||
await asyncio.sleep(2)
|
||||
|
||||
async def _process_depth_update(self, symbol: str, data: Dict):
|
||||
"""Process order book depth update"""
|
||||
try:
|
||||
timestamp = datetime.now()
|
||||
|
||||
# Parse bids and asks
|
||||
bids = []
|
||||
asks = []
|
||||
|
||||
for bid_data in data.get('bids', []):
|
||||
price = float(bid_data[0])
|
||||
size = float(bid_data[1])
|
||||
bids.append(OrderBookLevel(
|
||||
price=price,
|
||||
size=size,
|
||||
orders=1, # Binance doesn't provide order count
|
||||
side='bid',
|
||||
timestamp=timestamp
|
||||
))
|
||||
|
||||
for ask_data in data.get('asks', []):
|
||||
price = float(ask_data[0])
|
||||
size = float(ask_data[1])
|
||||
asks.append(OrderBookLevel(
|
||||
price=price,
|
||||
size=size,
|
||||
orders=1,
|
||||
side='ask',
|
||||
timestamp=timestamp
|
||||
))
|
||||
|
||||
# Sort order book levels
|
||||
bids.sort(key=lambda x: x.price, reverse=True)
|
||||
asks.sort(key=lambda x: x.price)
|
||||
|
||||
# Calculate spread and mid price
|
||||
if bids and asks:
|
||||
best_bid = bids[0].price
|
||||
best_ask = asks[0].price
|
||||
spread = best_ask - best_bid
|
||||
mid_price = (best_bid + best_ask) / 2
|
||||
else:
|
||||
spread = 0.0
|
||||
mid_price = 0.0
|
||||
|
||||
# Create order book snapshot
|
||||
snapshot = OrderBookSnapshot(
|
||||
symbol=symbol,
|
||||
timestamp=timestamp,
|
||||
bids=bids,
|
||||
asks=asks,
|
||||
spread=spread,
|
||||
mid_price=mid_price
|
||||
)
|
||||
|
||||
with self.data_lock:
|
||||
self.order_books[symbol] = snapshot
|
||||
self.order_book_history[symbol].append(snapshot)
|
||||
|
||||
# Update liquidity metrics
|
||||
self._update_liquidity_metrics(symbol, snapshot)
|
||||
|
||||
# Update order book imbalance
|
||||
self._calculate_order_book_imbalance(symbol, snapshot)
|
||||
|
||||
# Update heatmap data
|
||||
self._update_order_heatmap(symbol, snapshot)
|
||||
|
||||
# Update counters
|
||||
self.update_counts[f"{symbol}_depth"] += 1
|
||||
self.last_update_times[f"{symbol}_depth"] = timestamp
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing depth update for {symbol}: {e}")
|
||||
|
||||
async def _process_trade_update(self, symbol: str, data: Dict):
|
||||
"""Process trade data for order flow analysis"""
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(int(data['T']) / 1000)
|
||||
price = float(data['p'])
|
||||
quantity = float(data['q'])
|
||||
is_buyer_maker = data['m']
|
||||
|
||||
# Analyze for order flow signals
|
||||
await self._analyze_order_flow(symbol, timestamp, price, quantity, is_buyer_maker)
|
||||
|
||||
# Update volume profile
|
||||
self._update_volume_profile(symbol, price, quantity, is_buyer_maker)
|
||||
|
||||
self.update_counts[f"{symbol}_trades"] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing trade for {symbol}: {e}")
|
||||
|
||||
def _update_liquidity_metrics(self, symbol: str, snapshot: OrderBookSnapshot):
|
||||
"""Update liquidity metrics from order book snapshot"""
|
||||
try:
|
||||
total_bid_size = sum(level.size for level in snapshot.bids)
|
||||
total_ask_size = sum(level.size for level in snapshot.asks)
|
||||
|
||||
# Calculate weighted mid price
|
||||
if snapshot.bids and snapshot.asks:
|
||||
bid_weight = total_bid_size / (total_bid_size + total_ask_size)
|
||||
ask_weight = total_ask_size / (total_bid_size + total_ask_size)
|
||||
weighted_mid = (snapshot.bids[0].price * ask_weight +
|
||||
snapshot.asks[0].price * bid_weight)
|
||||
else:
|
||||
weighted_mid = snapshot.mid_price
|
||||
|
||||
# Liquidity ratio (bid/ask balance)
|
||||
if total_ask_size > 0:
|
||||
liquidity_ratio = total_bid_size / total_ask_size
|
||||
else:
|
||||
liquidity_ratio = 1.0
|
||||
|
||||
self.liquidity_metrics[symbol] = {
|
||||
'total_bid_size': total_bid_size,
|
||||
'total_ask_size': total_ask_size,
|
||||
'weighted_mid': weighted_mid,
|
||||
'liquidity_ratio': liquidity_ratio,
|
||||
'spread_bps': (snapshot.spread / snapshot.mid_price) * 10000 if snapshot.mid_price > 0 else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating liquidity metrics for {symbol}: {e}")
|
||||
|
||||
def _calculate_order_book_imbalance(self, symbol: str, snapshot: OrderBookSnapshot):
|
||||
"""Calculate order book imbalance ratio"""
|
||||
try:
|
||||
if not snapshot.bids or not snapshot.asks:
|
||||
return
|
||||
|
||||
# Calculate imbalance for top N levels
|
||||
n_levels = min(5, len(snapshot.bids), len(snapshot.asks))
|
||||
|
||||
total_bid_size = sum(snapshot.bids[i].size for i in range(n_levels))
|
||||
total_ask_size = sum(snapshot.asks[i].size for i in range(n_levels))
|
||||
|
||||
if total_bid_size + total_ask_size > 0:
|
||||
imbalance = (total_bid_size - total_ask_size) / (total_bid_size + total_ask_size)
|
||||
else:
|
||||
imbalance = 0.0
|
||||
|
||||
self.order_book_imbalances[symbol].append({
|
||||
'timestamp': snapshot.timestamp,
|
||||
'imbalance': imbalance,
|
||||
'bid_size': total_bid_size,
|
||||
'ask_size': total_ask_size
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating imbalance for {symbol}: {e}")
|
||||
|
||||
def _update_order_heatmap(self, symbol: str, snapshot: OrderBookSnapshot):
|
||||
"""Update order size heatmap matrix"""
|
||||
try:
|
||||
# Create heatmap entry
|
||||
heatmap_entry = {
|
||||
'timestamp': snapshot.timestamp,
|
||||
'mid_price': snapshot.mid_price,
|
||||
'levels': {}
|
||||
}
|
||||
|
||||
# Add bid levels
|
||||
for level in snapshot.bids:
|
||||
price_offset = level.price - snapshot.mid_price
|
||||
heatmap_entry['levels'][price_offset] = {
|
||||
'side': 'bid',
|
||||
'size': level.size,
|
||||
'price': level.price
|
||||
}
|
||||
|
||||
# Add ask levels
|
||||
for level in snapshot.asks:
|
||||
price_offset = level.price - snapshot.mid_price
|
||||
heatmap_entry['levels'][price_offset] = {
|
||||
'side': 'ask',
|
||||
'size': level.size,
|
||||
'price': level.price
|
||||
}
|
||||
|
||||
self.order_heatmaps[symbol].append(heatmap_entry)
|
||||
|
||||
# Clean old entries (keep 10 minutes)
|
||||
cutoff_time = snapshot.timestamp - self.heatmap_window
|
||||
while (self.order_heatmaps[symbol] and
|
||||
self.order_heatmaps[symbol][0]['timestamp'] < cutoff_time):
|
||||
self.order_heatmaps[symbol].popleft()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating heatmap for {symbol}: {e}")
|
||||
|
||||
def _update_volume_profile(self, symbol: str, price: float, quantity: float, is_buyer_maker: bool):
|
||||
"""Update volume profile with new trade"""
|
||||
try:
|
||||
# Initialize if not exists
|
||||
if symbol not in self.volume_profiles:
|
||||
self.volume_profiles[symbol] = []
|
||||
|
||||
# Find or create price level
|
||||
price_level = None
|
||||
for level in self.volume_profiles[symbol]:
|
||||
if abs(level.price - price) < 0.01: # Price tolerance
|
||||
price_level = level
|
||||
break
|
||||
|
||||
if not price_level:
|
||||
price_level = VolumeProfileLevel(
|
||||
price=price,
|
||||
volume=0.0,
|
||||
buy_volume=0.0,
|
||||
sell_volume=0.0,
|
||||
trades_count=0,
|
||||
vwap=price
|
||||
)
|
||||
self.volume_profiles[symbol].append(price_level)
|
||||
|
||||
# Update volume profile
|
||||
volume = price * quantity
|
||||
old_total = price_level.volume
|
||||
|
||||
price_level.volume += volume
|
||||
price_level.trades_count += 1
|
||||
|
||||
if is_buyer_maker:
|
||||
price_level.sell_volume += volume
|
||||
else:
|
||||
price_level.buy_volume += volume
|
||||
|
||||
# Update VWAP
|
||||
if price_level.volume > 0:
|
||||
price_level.vwap = ((price_level.vwap * old_total) + (price * volume)) / price_level.volume
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating volume profile for {symbol}: {e}")
|
||||
|
||||
async def _analyze_order_flow(self, symbol: str, timestamp: datetime, price: float,
|
||||
quantity: float, is_buyer_maker: bool):
|
||||
"""Analyze order flow for sweep and absorption patterns"""
|
||||
try:
|
||||
# Get recent order book data
|
||||
if symbol not in self.order_book_history or not self.order_book_history[symbol]:
|
||||
return
|
||||
|
||||
recent_snapshots = list(self.order_book_history[symbol])[-10:] # Last 10 snapshots
|
||||
|
||||
# Check for order book sweeps
|
||||
sweep_signal = self._detect_order_sweep(symbol, recent_snapshots, price, quantity, is_buyer_maker)
|
||||
if sweep_signal:
|
||||
self.flow_signals[symbol].append(sweep_signal)
|
||||
await self._notify_flow_signal(symbol, sweep_signal)
|
||||
|
||||
# Check for absorption patterns
|
||||
absorption_signal = self._detect_absorption(symbol, recent_snapshots, price, quantity)
|
||||
if absorption_signal:
|
||||
self.flow_signals[symbol].append(absorption_signal)
|
||||
await self._notify_flow_signal(symbol, absorption_signal)
|
||||
|
||||
# Check for momentum trades
|
||||
momentum_signal = self._detect_momentum_trade(symbol, price, quantity, is_buyer_maker)
|
||||
if momentum_signal:
|
||||
self.flow_signals[symbol].append(momentum_signal)
|
||||
await self._notify_flow_signal(symbol, momentum_signal)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing order flow for {symbol}: {e}")
|
||||
|
||||
def _detect_order_sweep(self, symbol: str, snapshots: List[OrderBookSnapshot],
|
||||
price: float, quantity: float, is_buyer_maker: bool) -> Optional[OrderFlowSignal]:
|
||||
"""Detect order book sweep patterns"""
|
||||
try:
|
||||
if len(snapshots) < 2:
|
||||
return None
|
||||
|
||||
before_snapshot = snapshots[-2]
|
||||
after_snapshot = snapshots[-1]
|
||||
|
||||
# Check if multiple levels were consumed
|
||||
if is_buyer_maker: # Sell order, check ask side
|
||||
levels_consumed = 0
|
||||
total_consumed_size = 0
|
||||
|
||||
for level in before_snapshot.asks[:5]: # Check top 5 levels
|
||||
if level.price <= price:
|
||||
levels_consumed += 1
|
||||
total_consumed_size += level.size
|
||||
|
||||
if levels_consumed >= 2 and total_consumed_size > quantity * 1.5:
|
||||
confidence = min(0.9, levels_consumed / 5.0 + 0.3)
|
||||
|
||||
return OrderFlowSignal(
|
||||
timestamp=datetime.now(),
|
||||
signal_type='sweep',
|
||||
price=price,
|
||||
volume=quantity * price,
|
||||
confidence=confidence,
|
||||
description=f"Sell sweep: {levels_consumed} levels, {total_consumed_size:.2f} size"
|
||||
)
|
||||
else: # Buy order, check bid side
|
||||
levels_consumed = 0
|
||||
total_consumed_size = 0
|
||||
|
||||
for level in before_snapshot.bids[:5]:
|
||||
if level.price >= price:
|
||||
levels_consumed += 1
|
||||
total_consumed_size += level.size
|
||||
|
||||
if levels_consumed >= 2 and total_consumed_size > quantity * 1.5:
|
||||
confidence = min(0.9, levels_consumed / 5.0 + 0.3)
|
||||
|
||||
return OrderFlowSignal(
|
||||
timestamp=datetime.now(),
|
||||
signal_type='sweep',
|
||||
price=price,
|
||||
volume=quantity * price,
|
||||
confidence=confidence,
|
||||
description=f"Buy sweep: {levels_consumed} levels, {total_consumed_size:.2f} size"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting sweep for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _detect_absorption(self, symbol: str, snapshots: List[OrderBookSnapshot],
|
||||
price: float, quantity: float) -> Optional[OrderFlowSignal]:
|
||||
"""Detect absorption patterns where large orders are absorbed without price movement"""
|
||||
try:
|
||||
if len(snapshots) < 3:
|
||||
return None
|
||||
|
||||
# Check if large order was absorbed with minimal price impact
|
||||
volume_threshold = 10000 # $10K minimum for absorption
|
||||
price_impact_threshold = 0.001 # 0.1% max price impact
|
||||
|
||||
trade_value = price * quantity
|
||||
if trade_value < volume_threshold:
|
||||
return None
|
||||
|
||||
# Calculate price impact
|
||||
price_before = snapshots[-3].mid_price
|
||||
price_after = snapshots[-1].mid_price
|
||||
price_impact = abs(price_after - price_before) / price_before
|
||||
|
||||
if price_impact < price_impact_threshold:
|
||||
confidence = min(0.8, (trade_value / 50000) * 0.5 + 0.3) # Scale with size
|
||||
|
||||
return OrderFlowSignal(
|
||||
timestamp=datetime.now(),
|
||||
signal_type='absorption',
|
||||
price=price,
|
||||
volume=trade_value,
|
||||
confidence=confidence,
|
||||
description=f"Absorption: ${trade_value:.0f} with {price_impact*100:.3f}% impact"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting absorption for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _detect_momentum_trade(self, symbol: str, price: float, quantity: float,
|
||||
is_buyer_maker: bool) -> Optional[OrderFlowSignal]:
|
||||
"""Detect momentum trades based on size and direction"""
|
||||
try:
|
||||
trade_value = price * quantity
|
||||
momentum_threshold = 25000 # $25K minimum for momentum classification
|
||||
|
||||
if trade_value < momentum_threshold:
|
||||
return None
|
||||
|
||||
# Calculate confidence based on trade size
|
||||
confidence = min(0.9, trade_value / 100000 * 0.6 + 0.3)
|
||||
|
||||
direction = "sell" if is_buyer_maker else "buy"
|
||||
|
||||
return OrderFlowSignal(
|
||||
timestamp=datetime.now(),
|
||||
signal_type='momentum',
|
||||
price=price,
|
||||
volume=trade_value,
|
||||
confidence=confidence,
|
||||
description=f"Large {direction}: ${trade_value:.0f}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting momentum for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
async def _notify_flow_signal(self, symbol: str, signal: OrderFlowSignal):
|
||||
"""Notify CNN and DQN models of order flow signals"""
|
||||
try:
|
||||
signal_data = {
|
||||
'signal_type': signal.signal_type,
|
||||
'price': signal.price,
|
||||
'volume': signal.volume,
|
||||
'confidence': signal.confidence,
|
||||
'timestamp': signal.timestamp,
|
||||
'description': signal.description
|
||||
}
|
||||
|
||||
# Notify CNN callbacks
|
||||
for callback in self.cnn_callbacks:
|
||||
try:
|
||||
callback(symbol, signal_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in CNN callback: {e}")
|
||||
|
||||
# Notify DQN callbacks
|
||||
for callback in self.dqn_callbacks:
|
||||
try:
|
||||
callback(symbol, signal_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in DQN callback: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error notifying flow signal: {e}")
|
||||
|
||||
async def _continuous_analysis(self):
|
||||
"""Continuous analysis of market microstructure"""
|
||||
while self.is_streaming:
|
||||
try:
|
||||
await asyncio.sleep(1) # Analyze every second
|
||||
|
||||
for symbol in self.symbols:
|
||||
# Generate CNN features
|
||||
cnn_features = self.get_cnn_features(symbol)
|
||||
if cnn_features is not None:
|
||||
for callback in self.cnn_callbacks:
|
||||
try:
|
||||
callback(symbol, {'features': cnn_features, 'type': 'orderbook'})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in CNN feature callback: {e}")
|
||||
|
||||
# Generate DQN state features
|
||||
dqn_features = self.get_dqn_state_features(symbol)
|
||||
if dqn_features is not None:
|
||||
for callback in self.dqn_callbacks:
|
||||
try:
|
||||
callback(symbol, {'state': dqn_features, 'type': 'orderbook'})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in DQN state callback: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in continuous analysis: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
def get_cnn_features(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Generate CNN input features from order book data"""
|
||||
try:
|
||||
if symbol not in self.order_books:
|
||||
return None
|
||||
|
||||
snapshot = self.order_books[symbol]
|
||||
features = []
|
||||
|
||||
# Order book features (40 features: 20 levels x 2 sides)
|
||||
for i in range(min(20, len(snapshot.bids))):
|
||||
bid = snapshot.bids[i]
|
||||
features.append(bid.size)
|
||||
features.append(bid.price - snapshot.mid_price) # Price offset
|
||||
|
||||
# Pad if not enough bid levels
|
||||
while len(features) < 40:
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
for i in range(min(20, len(snapshot.asks))):
|
||||
ask = snapshot.asks[i]
|
||||
features.append(ask.size)
|
||||
features.append(ask.price - snapshot.mid_price) # Price offset
|
||||
|
||||
# Pad if not enough ask levels
|
||||
while len(features) < 80:
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
# Liquidity metrics (10 features)
|
||||
metrics = self.liquidity_metrics.get(symbol, {})
|
||||
features.extend([
|
||||
metrics.get('total_bid_size', 0.0),
|
||||
metrics.get('total_ask_size', 0.0),
|
||||
metrics.get('liquidity_ratio', 1.0),
|
||||
metrics.get('spread_bps', 0.0),
|
||||
snapshot.spread,
|
||||
metrics.get('weighted_mid', snapshot.mid_price) - snapshot.mid_price,
|
||||
len(snapshot.bids),
|
||||
len(snapshot.asks),
|
||||
snapshot.mid_price,
|
||||
time.time() % 86400 # Time of day
|
||||
])
|
||||
|
||||
# Order book imbalance features (5 features)
|
||||
if self.order_book_imbalances[symbol]:
|
||||
latest_imbalance = self.order_book_imbalances[symbol][-1]
|
||||
features.extend([
|
||||
latest_imbalance['imbalance'],
|
||||
latest_imbalance['bid_size'],
|
||||
latest_imbalance['ask_size'],
|
||||
latest_imbalance['bid_size'] + latest_imbalance['ask_size'],
|
||||
abs(latest_imbalance['imbalance'])
|
||||
])
|
||||
else:
|
||||
features.extend([0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
# Flow signal features (5 features)
|
||||
recent_signals = [s for s in self.flow_signals[symbol]
|
||||
if (datetime.now() - s.timestamp).seconds < 60]
|
||||
|
||||
sweep_count = sum(1 for s in recent_signals if s.signal_type == 'sweep')
|
||||
absorption_count = sum(1 for s in recent_signals if s.signal_type == 'absorption')
|
||||
momentum_count = sum(1 for s in recent_signals if s.signal_type == 'momentum')
|
||||
|
||||
max_confidence = max([s.confidence for s in recent_signals], default=0.0)
|
||||
total_flow_volume = sum(s.volume for s in recent_signals)
|
||||
|
||||
features.extend([
|
||||
sweep_count,
|
||||
absorption_count,
|
||||
momentum_count,
|
||||
max_confidence,
|
||||
total_flow_volume
|
||||
])
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating CNN features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_dqn_state_features(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Generate DQN state features from order book data"""
|
||||
try:
|
||||
if symbol not in self.order_books:
|
||||
return None
|
||||
|
||||
snapshot = self.order_books[symbol]
|
||||
state_features = []
|
||||
|
||||
# Normalized order book state (20 features)
|
||||
total_bid_size = sum(level.size for level in snapshot.bids[:10])
|
||||
total_ask_size = sum(level.size for level in snapshot.asks[:10])
|
||||
total_size = total_bid_size + total_ask_size
|
||||
|
||||
if total_size > 0:
|
||||
for i in range(min(10, len(snapshot.bids))):
|
||||
state_features.append(snapshot.bids[i].size / total_size)
|
||||
|
||||
# Pad bids
|
||||
while len(state_features) < 10:
|
||||
state_features.append(0.0)
|
||||
|
||||
for i in range(min(10, len(snapshot.asks))):
|
||||
state_features.append(snapshot.asks[i].size / total_size)
|
||||
|
||||
# Pad asks
|
||||
while len(state_features) < 20:
|
||||
state_features.append(0.0)
|
||||
else:
|
||||
state_features.extend([0.0] * 20)
|
||||
|
||||
# Market state indicators (10 features)
|
||||
metrics = self.liquidity_metrics.get(symbol, {})
|
||||
|
||||
# Normalize spread as percentage
|
||||
spread_pct = (snapshot.spread / snapshot.mid_price) if snapshot.mid_price > 0 else 0
|
||||
|
||||
# Liquidity imbalance
|
||||
liquidity_ratio = metrics.get('liquidity_ratio', 1.0)
|
||||
liquidity_imbalance = (liquidity_ratio - 1) / (liquidity_ratio + 1)
|
||||
|
||||
# Recent flow signals strength
|
||||
recent_signals = [s for s in self.flow_signals[symbol]
|
||||
if (datetime.now() - s.timestamp).seconds < 30]
|
||||
flow_strength = sum(s.confidence for s in recent_signals) / max(len(recent_signals), 1)
|
||||
|
||||
# Price volatility (from recent snapshots)
|
||||
if len(self.order_book_history[symbol]) >= 10:
|
||||
recent_prices = [s.mid_price for s in list(self.order_book_history[symbol])[-10:]]
|
||||
price_volatility = np.std(recent_prices) / np.mean(recent_prices) if recent_prices else 0
|
||||
else:
|
||||
price_volatility = 0
|
||||
|
||||
state_features.extend([
|
||||
spread_pct * 10000, # Spread in basis points
|
||||
liquidity_imbalance,
|
||||
flow_strength,
|
||||
price_volatility * 100, # Volatility as percentage
|
||||
min(len(snapshot.bids), 20) / 20, # Book depth ratio
|
||||
min(len(snapshot.asks), 20) / 20,
|
||||
sweep_count / 10 if 'sweep_count' in locals() else 0, # From CNN features
|
||||
absorption_count / 5 if 'absorption_count' in locals() else 0,
|
||||
momentum_count / 5 if 'momentum_count' in locals() else 0,
|
||||
(datetime.now().hour * 60 + datetime.now().minute) / 1440 # Time of day normalized
|
||||
])
|
||||
|
||||
return np.array(state_features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating DQN features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_order_heatmap_matrix(self, symbol: str, levels: int = 40) -> Optional[np.ndarray]:
|
||||
"""Generate order size heatmap matrix for dashboard visualization"""
|
||||
try:
|
||||
if symbol not in self.order_heatmaps or not self.order_heatmaps[symbol]:
|
||||
return None
|
||||
|
||||
# Create price levels around current mid price
|
||||
current_snapshot = self.order_books.get(symbol)
|
||||
if not current_snapshot:
|
||||
return None
|
||||
|
||||
mid_price = current_snapshot.mid_price
|
||||
price_step = mid_price * 0.0001 # 1 basis point steps
|
||||
|
||||
# Create matrix: time x price levels
|
||||
time_window = min(600, len(self.order_heatmaps[symbol])) # 10 minutes max
|
||||
heatmap_matrix = np.zeros((time_window, levels))
|
||||
|
||||
# Fill matrix with order sizes
|
||||
for t, entry in enumerate(list(self.order_heatmaps[symbol])[-time_window:]):
|
||||
for price_offset, level_data in entry['levels'].items():
|
||||
# Convert price offset to matrix index
|
||||
level_idx = int((price_offset + (levels/2) * price_step) / price_step)
|
||||
|
||||
if 0 <= level_idx < levels:
|
||||
size_weight = 1.0 if level_data['side'] == 'bid' else -1.0
|
||||
heatmap_matrix[t, level_idx] = level_data['size'] * size_weight
|
||||
|
||||
return heatmap_matrix
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating heatmap matrix for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_volume_profile_data(self, symbol: str) -> Optional[List[Dict]]:
|
||||
"""Get session volume profile data"""
|
||||
try:
|
||||
if symbol not in self.volume_profiles:
|
||||
return None
|
||||
|
||||
profile_data = []
|
||||
for level in sorted(self.volume_profiles[symbol], key=lambda x: x.price):
|
||||
profile_data.append({
|
||||
'price': level.price,
|
||||
'volume': level.volume,
|
||||
'buy_volume': level.buy_volume,
|
||||
'sell_volume': level.sell_volume,
|
||||
'trades_count': level.trades_count,
|
||||
'vwap': level.vwap,
|
||||
'net_volume': level.buy_volume - level.sell_volume
|
||||
})
|
||||
|
||||
return profile_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting volume profile for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_current_order_book(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get current order book snapshot"""
|
||||
try:
|
||||
if symbol not in self.order_books:
|
||||
return None
|
||||
|
||||
snapshot = self.order_books[symbol]
|
||||
|
||||
return {
|
||||
'timestamp': snapshot.timestamp.isoformat(),
|
||||
'symbol': symbol,
|
||||
'mid_price': snapshot.mid_price,
|
||||
'spread': snapshot.spread,
|
||||
'bids': [{'price': l.price, 'size': l.size} for l in snapshot.bids[:20]],
|
||||
'asks': [{'price': l.price, 'size': l.size} for l in snapshot.asks[:20]],
|
||||
'liquidity_metrics': self.liquidity_metrics.get(symbol, {}),
|
||||
'recent_signals': [
|
||||
{
|
||||
'type': s.signal_type,
|
||||
'price': s.price,
|
||||
'volume': s.volume,
|
||||
'confidence': s.confidence,
|
||||
'timestamp': s.timestamp.isoformat()
|
||||
}
|
||||
for s in list(self.flow_signals[symbol])[-5:] # Last 5 signals
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting order book for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get provider statistics"""
|
||||
return {
|
||||
'symbols': self.symbols,
|
||||
'is_streaming': self.is_streaming,
|
||||
'update_counts': dict(self.update_counts),
|
||||
'last_update_times': {k: v.isoformat() if isinstance(v, datetime) else v
|
||||
for k, v in self.last_update_times.items()},
|
||||
'order_books_active': len(self.order_books),
|
||||
'flow_signals_total': sum(len(signals) for signals in self.flow_signals.values()),
|
||||
'cnn_callbacks': len(self.cnn_callbacks),
|
||||
'dqn_callbacks': len(self.dqn_callbacks),
|
||||
'websocket_tasks': len(self.websocket_tasks)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,785 +0,0 @@
|
||||
"""
|
||||
CNN Training Pipeline with Comprehensive Data Storage and Replay
|
||||
|
||||
This module implements a robust CNN training pipeline that:
|
||||
1. Integrates with the comprehensive training data collection system
|
||||
2. Stores all backpropagation data for gradient replay
|
||||
3. Enables retraining on most profitable setups
|
||||
4. Maintains training episode profitability tracking
|
||||
5. Supports both real-time and batch training modes
|
||||
|
||||
Key Features:
|
||||
- Integration with TrainingDataCollector for data validation
|
||||
- Gradient and loss storage for each training step
|
||||
- Profitable episode prioritization and replay
|
||||
- Comprehensive training metrics and validation
|
||||
- Real-time pivot point prediction with outcome tracking
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import pickle
|
||||
from collections import deque, defaultdict
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from .training_data_collector import (
|
||||
TrainingDataCollector,
|
||||
TrainingEpisode,
|
||||
ModelInputPackage,
|
||||
get_training_data_collector
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class CNNTrainingStep:
|
||||
"""Single CNN training step with complete backpropagation data"""
|
||||
step_id: str
|
||||
timestamp: datetime
|
||||
episode_id: str
|
||||
|
||||
# Input data
|
||||
input_features: torch.Tensor
|
||||
target_labels: torch.Tensor
|
||||
|
||||
# Forward pass results
|
||||
model_outputs: Dict[str, torch.Tensor]
|
||||
predictions: Dict[str, Any]
|
||||
confidence_scores: torch.Tensor
|
||||
|
||||
# Loss components
|
||||
total_loss: float
|
||||
pivot_prediction_loss: float
|
||||
confidence_loss: float
|
||||
regularization_loss: float
|
||||
|
||||
# Backpropagation data
|
||||
gradients: Dict[str, torch.Tensor] # Gradients for each parameter
|
||||
gradient_norms: Dict[str, float] # Gradient norms for monitoring
|
||||
|
||||
# Model state
|
||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None
|
||||
optimizer_state: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Training metadata
|
||||
learning_rate: float = 0.001
|
||||
batch_size: int = 32
|
||||
epoch: int = 0
|
||||
|
||||
# Profitability tracking
|
||||
actual_profitability: Optional[float] = None
|
||||
prediction_accuracy: Optional[float] = None
|
||||
training_value: float = 0.0 # Value of this training step for replay
|
||||
|
||||
@dataclass
|
||||
class CNNTrainingSession:
|
||||
"""Complete CNN training session with multiple steps"""
|
||||
session_id: str
|
||||
start_timestamp: datetime
|
||||
end_timestamp: Optional[datetime] = None
|
||||
|
||||
# Session configuration
|
||||
training_mode: str = 'real_time' # 'real_time', 'batch', 'replay'
|
||||
symbol: str = ''
|
||||
|
||||
# Training steps
|
||||
training_steps: List[CNNTrainingStep] = field(default_factory=list)
|
||||
|
||||
# Session metrics
|
||||
total_steps: int = 0
|
||||
average_loss: float = 0.0
|
||||
best_loss: float = float('inf')
|
||||
convergence_achieved: bool = False
|
||||
|
||||
# Profitability metrics
|
||||
profitable_predictions: int = 0
|
||||
total_predictions: int = 0
|
||||
profitability_rate: float = 0.0
|
||||
|
||||
# Session value for replay prioritization
|
||||
session_value: float = 0.0
|
||||
|
||||
class CNNPivotPredictor(nn.Module):
|
||||
"""CNN model for pivot point prediction with comprehensive output"""
|
||||
|
||||
def __init__(self,
|
||||
input_channels: int = 10, # Multiple timeframes
|
||||
sequence_length: int = 300, # 300 bars
|
||||
hidden_dim: int = 256,
|
||||
num_pivot_classes: int = 3, # high, low, none
|
||||
dropout_rate: float = 0.2):
|
||||
|
||||
super(CNNPivotPredictor, self).__init__()
|
||||
|
||||
self.input_channels = input_channels
|
||||
self.sequence_length = sequence_length
|
||||
self.hidden_dim = hidden_dim
|
||||
|
||||
# Convolutional layers for pattern extraction
|
||||
self.conv_layers = nn.Sequential(
|
||||
# First conv block
|
||||
nn.Conv1d(input_channels, 64, kernel_size=7, padding=3),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
|
||||
# Second conv block
|
||||
nn.Conv1d(64, 128, kernel_size=5, padding=2),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
|
||||
# Third conv block
|
||||
nn.Conv1d(128, 256, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
)
|
||||
|
||||
# LSTM for temporal dependencies
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=256,
|
||||
hidden_size=hidden_dim,
|
||||
num_layers=2,
|
||||
batch_first=True,
|
||||
dropout=dropout_rate,
|
||||
bidirectional=True
|
||||
)
|
||||
|
||||
# Attention mechanism
|
||||
self.attention = nn.MultiheadAttention(
|
||||
embed_dim=hidden_dim * 2, # Bidirectional LSTM
|
||||
num_heads=8,
|
||||
dropout=dropout_rate,
|
||||
batch_first=True
|
||||
)
|
||||
|
||||
# Output heads
|
||||
self.pivot_classifier = nn.Sequential(
|
||||
nn.Linear(hidden_dim * 2, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(hidden_dim, num_pivot_classes)
|
||||
)
|
||||
|
||||
self.pivot_price_regressor = nn.Sequential(
|
||||
nn.Linear(hidden_dim * 2, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(hidden_dim, 1)
|
||||
)
|
||||
|
||||
self.confidence_head = nn.Sequential(
|
||||
nn.Linear(hidden_dim * 2, hidden_dim // 2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim // 2, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize weights with proper scaling"""
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Conv1d):
|
||||
torch.nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through CNN pivot predictor
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch_size, input_channels, sequence_length]
|
||||
|
||||
Returns:
|
||||
Dict containing predictions and hidden states
|
||||
"""
|
||||
batch_size = x.size(0)
|
||||
|
||||
# Convolutional feature extraction
|
||||
conv_features = self.conv_layers(x) # [batch, 256, sequence_length]
|
||||
|
||||
# Prepare for LSTM (transpose to [batch, sequence, features])
|
||||
lstm_input = conv_features.transpose(1, 2) # [batch, sequence_length, 256]
|
||||
|
||||
# LSTM processing
|
||||
lstm_output, (hidden, cell) = self.lstm(lstm_input) # [batch, sequence_length, hidden_dim*2]
|
||||
|
||||
# Attention mechanism
|
||||
attended_output, attention_weights = self.attention(
|
||||
lstm_output, lstm_output, lstm_output
|
||||
)
|
||||
|
||||
# Use the last timestep for predictions
|
||||
final_features = attended_output[:, -1, :] # [batch, hidden_dim*2]
|
||||
|
||||
# Generate predictions
|
||||
pivot_logits = self.pivot_classifier(final_features)
|
||||
pivot_price = self.pivot_price_regressor(final_features)
|
||||
confidence = self.confidence_head(final_features)
|
||||
|
||||
return {
|
||||
'pivot_logits': pivot_logits,
|
||||
'pivot_price': pivot_price,
|
||||
'confidence': confidence,
|
||||
'hidden_states': final_features,
|
||||
'attention_weights': attention_weights,
|
||||
'conv_features': conv_features,
|
||||
'lstm_output': lstm_output
|
||||
}
|
||||
|
||||
class CNNTrainingDataset(Dataset):
|
||||
"""Dataset for CNN training with training episodes"""
|
||||
|
||||
def __init__(self, training_episodes: List[TrainingEpisode]):
|
||||
self.episodes = training_episodes
|
||||
self.valid_episodes = self._validate_episodes()
|
||||
|
||||
def _validate_episodes(self) -> List[TrainingEpisode]:
|
||||
"""Validate and filter episodes for training"""
|
||||
valid = []
|
||||
for episode in self.episodes:
|
||||
try:
|
||||
# Check if episode has required data
|
||||
if (episode.input_package.cnn_features is not None and
|
||||
episode.actual_outcome.outcome_validated):
|
||||
valid.append(episode)
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid episode {episode.episode_id}: {e}")
|
||||
|
||||
logger.info(f"Validated {len(valid)}/{len(self.episodes)} episodes for training")
|
||||
return valid
|
||||
|
||||
def __len__(self):
|
||||
return len(self.valid_episodes)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
episode = self.valid_episodes[idx]
|
||||
|
||||
# Extract features
|
||||
features = torch.from_numpy(episode.input_package.cnn_features).float()
|
||||
|
||||
# Create labels from actual outcomes
|
||||
pivot_class = self._determine_pivot_class(episode.actual_outcome)
|
||||
pivot_price = episode.actual_outcome.optimal_exit_price
|
||||
confidence_target = episode.actual_outcome.profitability_score
|
||||
|
||||
return {
|
||||
'features': features,
|
||||
'pivot_class': torch.tensor(pivot_class, dtype=torch.long),
|
||||
'pivot_price': torch.tensor(pivot_price, dtype=torch.float),
|
||||
'confidence_target': torch.tensor(confidence_target, dtype=torch.float),
|
||||
'episode_id': episode.episode_id,
|
||||
'profitability': episode.actual_outcome.profitability_score
|
||||
}
|
||||
|
||||
def _determine_pivot_class(self, outcome) -> int:
|
||||
"""Determine pivot class from outcome"""
|
||||
if outcome.price_change_15m > 0.5: # Significant upward movement
|
||||
return 0 # High pivot
|
||||
elif outcome.price_change_15m < -0.5: # Significant downward movement
|
||||
return 1 # Low pivot
|
||||
else:
|
||||
return 2 # No significant pivot
|
||||
|
||||
class CNNTrainer:
|
||||
"""CNN trainer with comprehensive data storage and replay capabilities"""
|
||||
|
||||
def __init__(self,
|
||||
model: CNNPivotPredictor,
|
||||
device: str = 'cuda',
|
||||
learning_rate: float = 0.001,
|
||||
storage_dir: str = "cnn_training_storage"):
|
||||
|
||||
self.model = model.to(device)
|
||||
self.device = device
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
# Storage
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Optimizer
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=learning_rate,
|
||||
weight_decay=1e-5
|
||||
)
|
||||
|
||||
# Learning rate scheduler
|
||||
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.optimizer, mode='min', patience=10, factor=0.5
|
||||
)
|
||||
|
||||
# Training data collector
|
||||
self.data_collector = get_training_data_collector()
|
||||
|
||||
# Training sessions storage
|
||||
self.training_sessions: List[CNNTrainingSession] = []
|
||||
self.current_session: Optional[CNNTrainingSession] = None
|
||||
|
||||
# Training statistics
|
||||
self.training_stats = {
|
||||
'total_sessions': 0,
|
||||
'total_steps': 0,
|
||||
'best_validation_loss': float('inf'),
|
||||
'profitable_predictions': 0,
|
||||
'total_predictions': 0,
|
||||
'replay_sessions': 0
|
||||
}
|
||||
|
||||
# Background training
|
||||
self.is_training = False
|
||||
self.training_thread = None
|
||||
|
||||
logger.info(f"CNN Trainer initialized")
|
||||
logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
||||
logger.info(f"Storage directory: {self.storage_dir}")
|
||||
|
||||
def start_real_time_training(self, symbol: str):
|
||||
"""Start real-time training for a symbol"""
|
||||
if self.is_training:
|
||||
logger.warning("CNN training already running")
|
||||
return
|
||||
|
||||
self.is_training = True
|
||||
self.training_thread = threading.Thread(
|
||||
target=self._real_time_training_worker,
|
||||
args=(symbol,),
|
||||
daemon=True
|
||||
)
|
||||
self.training_thread.start()
|
||||
|
||||
logger.info(f"Started real-time CNN training for {symbol}")
|
||||
|
||||
def stop_training(self):
|
||||
"""Stop training"""
|
||||
self.is_training = False
|
||||
if self.training_thread:
|
||||
self.training_thread.join(timeout=10)
|
||||
|
||||
if self.current_session:
|
||||
self._finalize_training_session()
|
||||
|
||||
logger.info("CNN training stopped")
|
||||
|
||||
def _real_time_training_worker(self, symbol: str):
|
||||
"""Real-time training worker"""
|
||||
logger.info(f"Real-time CNN training worker started for {symbol}")
|
||||
|
||||
while self.is_training:
|
||||
try:
|
||||
# Get high-priority episodes for training
|
||||
episodes = self.data_collector.get_high_priority_episodes(
|
||||
symbol=symbol,
|
||||
limit=100,
|
||||
min_priority=0.3
|
||||
)
|
||||
|
||||
if len(episodes) >= 32: # Minimum batch size
|
||||
self._train_on_episodes(episodes, training_mode='real_time')
|
||||
|
||||
# Wait before next training cycle
|
||||
threading.Event().wait(300) # Train every 5 minutes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in real-time training worker: {e}")
|
||||
threading.Event().wait(60) # Wait before retrying
|
||||
|
||||
logger.info(f"Real-time CNN training worker stopped for {symbol}")
|
||||
|
||||
def train_on_profitable_episodes(self,
|
||||
symbol: str,
|
||||
min_profitability: float = 0.7,
|
||||
max_episodes: int = 500) -> Dict[str, Any]:
|
||||
"""Train specifically on most profitable episodes"""
|
||||
try:
|
||||
# Get all episodes for symbol
|
||||
all_episodes = self.data_collector.training_episodes.get(symbol, [])
|
||||
|
||||
# Filter for profitable episodes
|
||||
profitable_episodes = [
|
||||
ep for ep in all_episodes
|
||||
if (ep.actual_outcome.is_profitable and
|
||||
ep.actual_outcome.profitability_score >= min_profitability)
|
||||
]
|
||||
|
||||
# Sort by profitability and limit
|
||||
profitable_episodes.sort(
|
||||
key=lambda x: x.actual_outcome.profitability_score,
|
||||
reverse=True
|
||||
)
|
||||
profitable_episodes = profitable_episodes[:max_episodes]
|
||||
|
||||
if len(profitable_episodes) < 10:
|
||||
logger.warning(f"Insufficient profitable episodes for {symbol}: {len(profitable_episodes)}")
|
||||
return {'status': 'insufficient_data', 'episodes_found': len(profitable_episodes)}
|
||||
|
||||
# Train on profitable episodes
|
||||
results = self._train_on_episodes(
|
||||
profitable_episodes,
|
||||
training_mode='profitable_replay'
|
||||
)
|
||||
|
||||
logger.info(f"Trained on {len(profitable_episodes)} profitable episodes for {symbol}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training on profitable episodes: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
def _train_on_episodes(self,
|
||||
episodes: List[TrainingEpisode],
|
||||
training_mode: str = 'batch') -> Dict[str, Any]:
|
||||
"""Train on a batch of episodes with comprehensive data storage"""
|
||||
try:
|
||||
# Start new training session
|
||||
session = CNNTrainingSession(
|
||||
session_id=f"{training_mode}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
start_timestamp=datetime.now(),
|
||||
training_mode=training_mode,
|
||||
symbol=episodes[0].input_package.symbol if episodes else 'unknown'
|
||||
)
|
||||
self.current_session = session
|
||||
|
||||
# Create dataset and dataloader
|
||||
dataset = CNNTrainingDataset(episodes)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
num_workers=2
|
||||
)
|
||||
|
||||
# Training loop
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
# Move to device
|
||||
features = batch['features'].to(self.device)
|
||||
pivot_class = batch['pivot_class'].to(self.device)
|
||||
pivot_price = batch['pivot_price'].to(self.device)
|
||||
confidence_target = batch['confidence_target'].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
self.optimizer.zero_grad()
|
||||
outputs = self.model(features)
|
||||
|
||||
# Calculate losses
|
||||
classification_loss = F.cross_entropy(outputs['pivot_logits'], pivot_class)
|
||||
regression_loss = F.mse_loss(outputs['pivot_price'].squeeze(), pivot_price)
|
||||
confidence_loss = F.binary_cross_entropy(
|
||||
outputs['confidence'].squeeze(),
|
||||
confidence_target
|
||||
)
|
||||
|
||||
# Combined loss
|
||||
total_batch_loss = classification_loss + 0.5 * regression_loss + 0.3 * confidence_loss
|
||||
|
||||
# Backward pass
|
||||
total_batch_loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
||||
|
||||
# Store gradients before optimizer step
|
||||
gradients = {}
|
||||
gradient_norms = {}
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.grad is not None:
|
||||
gradients[name] = param.grad.clone().detach()
|
||||
gradient_norms[name] = param.grad.norm().item()
|
||||
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
|
||||
# Create training step record
|
||||
step = CNNTrainingStep(
|
||||
step_id=f"{session.session_id}_step_{batch_idx}",
|
||||
timestamp=datetime.now(),
|
||||
episode_id=f"batch_{batch_idx}",
|
||||
input_features=features.detach().cpu(),
|
||||
target_labels=pivot_class.detach().cpu(),
|
||||
model_outputs={k: v.detach().cpu() for k, v in outputs.items()},
|
||||
predictions=self._extract_predictions(outputs),
|
||||
confidence_scores=outputs['confidence'].detach().cpu(),
|
||||
total_loss=total_batch_loss.item(),
|
||||
pivot_prediction_loss=classification_loss.item(),
|
||||
confidence_loss=confidence_loss.item(),
|
||||
regularization_loss=0.0,
|
||||
gradients=gradients,
|
||||
gradient_norms=gradient_norms,
|
||||
learning_rate=self.optimizer.param_groups[0]['lr'],
|
||||
batch_size=features.size(0)
|
||||
)
|
||||
|
||||
# Calculate training value for this step
|
||||
step.training_value = self._calculate_step_training_value(step, batch)
|
||||
|
||||
# Add to session
|
||||
session.training_steps.append(step)
|
||||
|
||||
total_loss += total_batch_loss.item()
|
||||
num_batches += 1
|
||||
|
||||
# Log progress
|
||||
if batch_idx % 10 == 0:
|
||||
logger.debug(f"Batch {batch_idx}: Loss = {total_batch_loss.item():.4f}")
|
||||
|
||||
# Finalize session
|
||||
session.end_timestamp = datetime.now()
|
||||
session.total_steps = num_batches
|
||||
session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||||
session.best_loss = min(step.total_loss for step in session.training_steps)
|
||||
|
||||
# Calculate session value
|
||||
session.session_value = self._calculate_session_value(session)
|
||||
|
||||
# Update scheduler
|
||||
self.scheduler.step(session.average_loss)
|
||||
|
||||
# Save session
|
||||
self._save_training_session(session)
|
||||
|
||||
# Update statistics
|
||||
self.training_stats['total_sessions'] += 1
|
||||
self.training_stats['total_steps'] += session.total_steps
|
||||
if training_mode == 'profitable_replay':
|
||||
self.training_stats['replay_sessions'] += 1
|
||||
|
||||
logger.info(f"Training session completed: {session.session_id}")
|
||||
logger.info(f"Average loss: {session.average_loss:.4f}")
|
||||
logger.info(f"Session value: {session.session_value:.3f}")
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'session_id': session.session_id,
|
||||
'average_loss': session.average_loss,
|
||||
'total_steps': session.total_steps,
|
||||
'session_value': session.session_value
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training session: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
finally:
|
||||
self.current_session = None
|
||||
|
||||
def _extract_predictions(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
|
||||
"""Extract human-readable predictions from model outputs"""
|
||||
try:
|
||||
pivot_probs = F.softmax(outputs['pivot_logits'], dim=1)
|
||||
predicted_class = torch.argmax(pivot_probs, dim=1)
|
||||
|
||||
return {
|
||||
'pivot_class': predicted_class.cpu().numpy().tolist(),
|
||||
'pivot_probabilities': pivot_probs.cpu().numpy().tolist(),
|
||||
'pivot_price': outputs['pivot_price'].cpu().numpy().tolist(),
|
||||
'confidence': outputs['confidence'].cpu().numpy().tolist()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting predictions: {e}")
|
||||
return {}
|
||||
|
||||
def _calculate_step_training_value(self,
|
||||
step: CNNTrainingStep,
|
||||
batch: Dict[str, Any]) -> float:
|
||||
"""Calculate the training value of a step for replay prioritization"""
|
||||
try:
|
||||
value = 0.0
|
||||
|
||||
# Base value from loss (lower loss = higher value)
|
||||
if step.total_loss > 0:
|
||||
value += 1.0 / (1.0 + step.total_loss)
|
||||
|
||||
# Bonus for high profitability episodes in batch
|
||||
avg_profitability = torch.mean(batch['profitability']).item()
|
||||
value += avg_profitability * 0.3
|
||||
|
||||
# Bonus for gradient magnitude (indicates learning)
|
||||
avg_grad_norm = np.mean(list(step.gradient_norms.values()))
|
||||
value += min(avg_grad_norm / 10.0, 0.2) # Cap at 0.2
|
||||
|
||||
return min(value, 1.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating step training value: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_session_value(self, session: CNNTrainingSession) -> float:
|
||||
"""Calculate overall session value for replay prioritization"""
|
||||
try:
|
||||
if not session.training_steps:
|
||||
return 0.0
|
||||
|
||||
# Average step values
|
||||
avg_step_value = np.mean([step.training_value for step in session.training_steps])
|
||||
|
||||
# Bonus for convergence
|
||||
convergence_bonus = 0.0
|
||||
if len(session.training_steps) > 10:
|
||||
early_loss = np.mean([s.total_loss for s in session.training_steps[:5]])
|
||||
late_loss = np.mean([s.total_loss for s in session.training_steps[-5:]])
|
||||
if early_loss > late_loss:
|
||||
convergence_bonus = min((early_loss - late_loss) / early_loss, 0.3)
|
||||
|
||||
# Bonus for profitable replay sessions
|
||||
mode_bonus = 0.2 if session.training_mode == 'profitable_replay' else 0.0
|
||||
|
||||
return min(avg_step_value + convergence_bonus + mode_bonus, 1.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating session value: {e}")
|
||||
return 0.0
|
||||
|
||||
def _save_training_session(self, session: CNNTrainingSession):
|
||||
"""Save training session to disk"""
|
||||
try:
|
||||
session_dir = self.storage_dir / session.symbol / 'sessions'
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save full session data
|
||||
session_file = session_dir / f"{session.session_id}.pkl"
|
||||
with open(session_file, 'wb') as f:
|
||||
pickle.dump(session, f)
|
||||
|
||||
# Save session metadata
|
||||
metadata = {
|
||||
'session_id': session.session_id,
|
||||
'start_timestamp': session.start_timestamp.isoformat(),
|
||||
'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None,
|
||||
'training_mode': session.training_mode,
|
||||
'symbol': session.symbol,
|
||||
'total_steps': session.total_steps,
|
||||
'average_loss': session.average_loss,
|
||||
'best_loss': session.best_loss,
|
||||
'session_value': session.session_value
|
||||
}
|
||||
|
||||
metadata_file = session_dir / f"{session.session_id}_metadata.json"
|
||||
with open(metadata_file, 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
logger.debug(f"Saved training session: {session.session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving training session: {e}")
|
||||
|
||||
def _finalize_training_session(self):
|
||||
"""Finalize current training session"""
|
||||
if self.current_session:
|
||||
self.current_session.end_timestamp = datetime.now()
|
||||
self._save_training_session(self.current_session)
|
||||
self.training_sessions.append(self.current_session)
|
||||
self.current_session = None
|
||||
|
||||
def get_training_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive training statistics"""
|
||||
stats = self.training_stats.copy()
|
||||
|
||||
# Add recent session information
|
||||
if self.training_sessions:
|
||||
recent_sessions = sorted(
|
||||
self.training_sessions,
|
||||
key=lambda x: x.start_timestamp,
|
||||
reverse=True
|
||||
)[:10]
|
||||
|
||||
stats['recent_sessions'] = [
|
||||
{
|
||||
'session_id': s.session_id,
|
||||
'timestamp': s.start_timestamp.isoformat(),
|
||||
'mode': s.training_mode,
|
||||
'average_loss': s.average_loss,
|
||||
'session_value': s.session_value
|
||||
}
|
||||
for s in recent_sessions
|
||||
]
|
||||
|
||||
# Calculate profitability rate
|
||||
if stats['total_predictions'] > 0:
|
||||
stats['profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions']
|
||||
else:
|
||||
stats['profitability_rate'] = 0.0
|
||||
|
||||
return stats
|
||||
|
||||
def replay_high_value_sessions(self,
|
||||
symbol: str,
|
||||
min_session_value: float = 0.7,
|
||||
max_sessions: int = 10) -> Dict[str, Any]:
|
||||
"""Replay high-value training sessions"""
|
||||
try:
|
||||
# Find high-value sessions
|
||||
high_value_sessions = [
|
||||
s for s in self.training_sessions
|
||||
if (s.symbol == symbol and
|
||||
s.session_value >= min_session_value)
|
||||
]
|
||||
|
||||
# Sort by value and limit
|
||||
high_value_sessions.sort(key=lambda x: x.session_value, reverse=True)
|
||||
high_value_sessions = high_value_sessions[:max_sessions]
|
||||
|
||||
if not high_value_sessions:
|
||||
return {'status': 'no_high_value_sessions', 'sessions_found': 0}
|
||||
|
||||
# Replay sessions
|
||||
total_replayed = 0
|
||||
for session in high_value_sessions:
|
||||
# Extract episodes from session steps
|
||||
episode_ids = list(set(step.episode_id for step in session.training_steps))
|
||||
|
||||
# Get corresponding episodes
|
||||
episodes = []
|
||||
for episode_id in episode_ids:
|
||||
# Find episode in data collector
|
||||
for ep in self.data_collector.training_episodes.get(symbol, []):
|
||||
if ep.episode_id == episode_id:
|
||||
episodes.append(ep)
|
||||
break
|
||||
|
||||
if episodes:
|
||||
self._train_on_episodes(episodes, training_mode='high_value_replay')
|
||||
total_replayed += 1
|
||||
|
||||
logger.info(f"Replayed {total_replayed} high-value sessions for {symbol}")
|
||||
return {
|
||||
'status': 'success',
|
||||
'sessions_replayed': total_replayed,
|
||||
'sessions_found': len(high_value_sessions)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error replaying high-value sessions: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
# Global instance
|
||||
cnn_trainer = None
|
||||
|
||||
def get_cnn_trainer(model: CNNPivotPredictor = None) -> CNNTrainer:
|
||||
"""Get global CNN trainer instance"""
|
||||
global cnn_trainer
|
||||
if cnn_trainer is None:
|
||||
if model is None:
|
||||
model = CNNPivotPredictor()
|
||||
cnn_trainer = CNNTrainer(model)
|
||||
return cnn_trainer
|
||||
@@ -1,864 +0,0 @@
|
||||
# """
|
||||
# Enhanced CNN Adapter for Standardized Input Format
|
||||
|
||||
# This module provides an adapter for the EnhancedCNN model to work with the standardized
|
||||
# BaseDataInput format, enabling seamless integration with the multi-modal trading system.
|
||||
# """
|
||||
|
||||
# import torch
|
||||
# import numpy as np
|
||||
# import logging
|
||||
# import os
|
||||
# import random
|
||||
# from datetime import datetime, timedelta
|
||||
# from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
# from threading import Lock
|
||||
|
||||
# from .data_models import BaseDataInput, ModelOutput, create_model_output
|
||||
# from NN.models.enhanced_cnn import EnhancedCNN
|
||||
# from utils.inference_logger import log_model_inference
|
||||
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
# class EnhancedCNNAdapter:
|
||||
# """
|
||||
# Adapter for EnhancedCNN model to work with standardized BaseDataInput format
|
||||
|
||||
# This adapter:
|
||||
# 1. Converts BaseDataInput to the format expected by EnhancedCNN
|
||||
# 2. Processes model outputs to create standardized ModelOutput
|
||||
# 3. Manages model training with collected data
|
||||
# 4. Handles checkpoint management
|
||||
# """
|
||||
|
||||
# def __init__(self, model_path: str = None, checkpoint_dir: str = "models/enhanced_cnn"):
|
||||
# """
|
||||
# Initialize the EnhancedCNN adapter
|
||||
|
||||
# Args:
|
||||
# model_path: Path to load model from, if None a new model is created
|
||||
# checkpoint_dir: Directory to save checkpoints to
|
||||
# """
|
||||
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
# self.model = None
|
||||
# self.model_path = model_path
|
||||
# self.checkpoint_dir = checkpoint_dir
|
||||
# self.training_lock = Lock()
|
||||
# self.training_data = []
|
||||
# self.max_training_samples = 10000
|
||||
# self.batch_size = 32
|
||||
# self.learning_rate = 0.0001
|
||||
# self.model_name = "enhanced_cnn"
|
||||
|
||||
# # Enhanced metrics tracking
|
||||
# self.last_inference_time = None
|
||||
# self.last_inference_duration = 0.0
|
||||
# self.last_prediction_output = None
|
||||
# self.last_training_time = None
|
||||
# self.last_training_duration = 0.0
|
||||
# self.last_training_loss = 0.0
|
||||
# self.inference_count = 0
|
||||
# self.training_count = 0
|
||||
|
||||
# # Create checkpoint directory if it doesn't exist
|
||||
# os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# # Initialize the model
|
||||
# self._initialize_model()
|
||||
|
||||
# # Load checkpoint if available
|
||||
# if model_path and os.path.exists(model_path):
|
||||
# self._load_checkpoint(model_path)
|
||||
# else:
|
||||
# self._load_best_checkpoint()
|
||||
|
||||
# # Final device check and move
|
||||
# self._ensure_model_on_device()
|
||||
|
||||
# logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
|
||||
|
||||
# def _create_realistic_synthetic_features(self, symbol: str) -> torch.Tensor:
|
||||
# """Create realistic synthetic features instead of random data"""
|
||||
# try:
|
||||
# # Create realistic market-like features
|
||||
# features = torch.zeros(7850, dtype=torch.float32, device=self.device)
|
||||
|
||||
# # OHLCV features (6000 features: 300 frames x 4 timeframes x 5 features)
|
||||
# ohlcv_start = 0
|
||||
# for timeframe_idx in range(4): # 1s, 1m, 1h, 1d
|
||||
# base_price = 3500.0 + timeframe_idx * 10 # Slight variation per timeframe
|
||||
# for frame_idx in range(300):
|
||||
# # Create realistic price movement
|
||||
# price_change = torch.sin(torch.tensor(frame_idx * 0.1)) * 0.01 # Cyclical movement
|
||||
# current_price = base_price * (1 + price_change)
|
||||
|
||||
# # Realistic OHLCV values
|
||||
# open_price = current_price
|
||||
# high_price = current_price * torch.uniform(1.0, 1.005)
|
||||
# low_price = current_price * torch.uniform(0.995, 1.0)
|
||||
# close_price = current_price * torch.uniform(0.998, 1.002)
|
||||
# volume = torch.uniform(500.0, 2000.0)
|
||||
|
||||
# # Set features
|
||||
# feature_idx = ohlcv_start + frame_idx * 5 + timeframe_idx * 1500
|
||||
# features[feature_idx:feature_idx+5] = torch.tensor([open_price, high_price, low_price, close_price, volume])
|
||||
|
||||
# # BTC OHLCV features (1500 features: 300 frames x 5 features)
|
||||
# btc_start = 6000
|
||||
# btc_base_price = 50000.0
|
||||
# for frame_idx in range(300):
|
||||
# price_change = torch.sin(torch.tensor(frame_idx * 0.05)) * 0.02
|
||||
# current_price = btc_base_price * (1 + price_change)
|
||||
|
||||
# open_price = current_price
|
||||
# high_price = current_price * torch.uniform(1.0, 1.01)
|
||||
# low_price = current_price * torch.uniform(0.99, 1.0)
|
||||
# close_price = current_price * torch.uniform(0.995, 1.005)
|
||||
# volume = torch.uniform(100.0, 500.0)
|
||||
|
||||
# feature_idx = btc_start + frame_idx * 5
|
||||
# features[feature_idx:feature_idx+5] = torch.tensor([open_price, high_price, low_price, close_price, volume])
|
||||
|
||||
# # COB features (200 features) - realistic order book data
|
||||
# cob_start = 7500
|
||||
# for i in range(200):
|
||||
# features[cob_start + i] = torch.uniform(0.0, 1000.0) # Realistic COB values
|
||||
|
||||
# # Technical indicators (100 features)
|
||||
# indicator_start = 7700
|
||||
# for i in range(100):
|
||||
# features[indicator_start + i] = torch.uniform(-1.0, 1.0) # Normalized indicators
|
||||
|
||||
# # Last predictions (50 features)
|
||||
# prediction_start = 7800
|
||||
# for i in range(50):
|
||||
# features[prediction_start + i] = torch.uniform(0.0, 1.0) # Probability values
|
||||
|
||||
# return features
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error creating realistic synthetic features: {e}")
|
||||
# # Fallback to small random variation
|
||||
# base_features = torch.ones(7850, dtype=torch.float32, device=self.device) * 0.5
|
||||
# noise = torch.randn(7850, dtype=torch.float32, device=self.device) * 0.1
|
||||
# return base_features + noise
|
||||
|
||||
# def _create_realistic_features(self, symbol: str) -> torch.Tensor:
|
||||
# """Create features from real market data if available"""
|
||||
# try:
|
||||
# # This would need to be implemented to use actual market data
|
||||
# # For now, fall back to synthetic features
|
||||
# return self._create_realistic_synthetic_features(symbol)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error creating realistic features: {e}")
|
||||
# return self._create_realistic_synthetic_features(symbol)
|
||||
|
||||
# def _initialize_model(self):
|
||||
# """Initialize the EnhancedCNN model"""
|
||||
# try:
|
||||
# # Calculate input shape based on BaseDataInput structure
|
||||
# # OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features
|
||||
# # BTC OHLCV: 300 frames x 5 features = 1500 features
|
||||
# # COB: ±20 buckets x 4 metrics = 160 features
|
||||
# # MA: 4 timeframes x 10 buckets = 40 features
|
||||
# # Technical indicators: 100 features
|
||||
# # Last predictions: 50 features
|
||||
# # Total: 7850 features
|
||||
# input_shape = 7850
|
||||
# n_actions = 3 # BUY, SELL, HOLD
|
||||
|
||||
# # Create model
|
||||
# self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
|
||||
# # Ensure model is moved to the correct device
|
||||
# self.model.to(self.device)
|
||||
|
||||
# logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions} on device {self.device}")
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error initializing EnhancedCNN model: {e}")
|
||||
# raise
|
||||
|
||||
# def _load_checkpoint(self, checkpoint_path: str) -> bool:
|
||||
# """Load model from checkpoint path"""
|
||||
# try:
|
||||
# if self.model and os.path.exists(checkpoint_path):
|
||||
# success = self.model.load(checkpoint_path)
|
||||
# if success:
|
||||
# # Ensure model is moved to the correct device after loading
|
||||
# self.model.to(self.device)
|
||||
# logger.info(f"Loaded model from {checkpoint_path} and moved to {self.device}")
|
||||
# return True
|
||||
# else:
|
||||
# logger.warning(f"Failed to load model from {checkpoint_path}")
|
||||
# return False
|
||||
# else:
|
||||
# logger.warning(f"Checkpoint path does not exist: {checkpoint_path}")
|
||||
# return False
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error loading checkpoint: {e}")
|
||||
# return False
|
||||
|
||||
# def _load_best_checkpoint(self) -> bool:
|
||||
# """Load the best available checkpoint"""
|
||||
# try:
|
||||
# return self.load_best_checkpoint()
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error loading best checkpoint: {e}")
|
||||
# return False
|
||||
|
||||
# def load_best_checkpoint(self) -> bool:
|
||||
# """Load the best checkpoint based on accuracy"""
|
||||
# try:
|
||||
# # Import checkpoint manager
|
||||
# from utils.checkpoint_manager import CheckpointManager
|
||||
|
||||
# # Create checkpoint manager
|
||||
# checkpoint_manager = CheckpointManager(
|
||||
# checkpoint_dir=self.checkpoint_dir,
|
||||
# max_checkpoints=10,
|
||||
# metric_name="accuracy"
|
||||
# )
|
||||
|
||||
# # Load best checkpoint
|
||||
# best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name)
|
||||
|
||||
# if not best_checkpoint_path:
|
||||
# logger.info(f"No checkpoints found for {self.model_name} - starting in COLD START mode")
|
||||
# return False
|
||||
|
||||
# # Load model
|
||||
# success = self.model.load(best_checkpoint_path)
|
||||
|
||||
# if success:
|
||||
# # Ensure model is moved to the correct device after loading
|
||||
# self.model.to(self.device)
|
||||
# logger.info(f"Loaded best checkpoint from {best_checkpoint_path} and moved to {self.device}")
|
||||
|
||||
# # Log metrics
|
||||
# metrics = best_checkpoint_metadata.get('metrics', {})
|
||||
# logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}")
|
||||
|
||||
# return True
|
||||
# else:
|
||||
# logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}")
|
||||
# return False
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error loading best checkpoint: {e}")
|
||||
# return False
|
||||
|
||||
# def _ensure_model_on_device(self):
|
||||
# """Ensure model and all its components are on the correct device"""
|
||||
# try:
|
||||
# if self.model:
|
||||
# self.model.to(self.device)
|
||||
# # Also ensure the model's internal device is set correctly
|
||||
# if hasattr(self.model, 'device'):
|
||||
# self.model.device = self.device
|
||||
# logger.debug(f"Model ensured on device {self.device}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error ensuring model on device: {e}")
|
||||
|
||||
# def _create_default_output(self, symbol: str) -> ModelOutput:
|
||||
# """Create default output when prediction fails"""
|
||||
# return create_model_output(
|
||||
# model_type='cnn',
|
||||
# model_name=self.model_name,
|
||||
# symbol=symbol,
|
||||
# action='HOLD',
|
||||
# confidence=0.0,
|
||||
# metadata={'error': 'Prediction failed, using default output'}
|
||||
# )
|
||||
|
||||
# def _process_hidden_states(self, hidden_states: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# """Process hidden states for cross-model feeding"""
|
||||
# processed_states = {}
|
||||
|
||||
# for key, value in hidden_states.items():
|
||||
# if isinstance(value, torch.Tensor):
|
||||
# # Convert tensor to numpy array
|
||||
# processed_states[key] = value.cpu().numpy().tolist()
|
||||
# else:
|
||||
# processed_states[key] = value
|
||||
|
||||
# return processed_states
|
||||
|
||||
|
||||
|
||||
# def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor:
|
||||
# """
|
||||
# Convert BaseDataInput to feature vector for EnhancedCNN
|
||||
|
||||
# Args:
|
||||
# base_data: Standardized input data
|
||||
|
||||
# Returns:
|
||||
# torch.Tensor: Feature vector for EnhancedCNN
|
||||
# """
|
||||
# try:
|
||||
# # Use the get_feature_vector method from BaseDataInput
|
||||
# features = base_data.get_feature_vector()
|
||||
|
||||
# # Validate feature quality before using
|
||||
# self._validate_feature_quality(features)
|
||||
|
||||
# # Convert to torch tensor
|
||||
# features_tensor = torch.tensor(features, dtype=torch.float32, device=self.device)
|
||||
|
||||
# return features_tensor
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error converting BaseDataInput to features: {e}")
|
||||
# # Return empty tensor with correct shape
|
||||
# return torch.zeros(7850, dtype=torch.float32, device=self.device)
|
||||
|
||||
# def _validate_feature_quality(self, features: np.ndarray):
|
||||
# """Validate that features are realistic and not synthetic/placeholder data"""
|
||||
# try:
|
||||
# if len(features) != 7850:
|
||||
# logger.warning(f"Feature vector has wrong size: {len(features)} != 7850")
|
||||
# return
|
||||
|
||||
# # Check for all-zero or all-identical features (indicates placeholder data)
|
||||
# if np.all(features == 0):
|
||||
# logger.warning("Feature vector contains all zeros - likely placeholder data")
|
||||
# return
|
||||
|
||||
# # Check for repetitive patterns in OHLCV data (first 6000 features)
|
||||
# ohlcv_features = features[:6000]
|
||||
# if len(ohlcv_features) >= 20:
|
||||
# # Check if first 20 values are identical (indicates padding with same bar)
|
||||
# if np.allclose(ohlcv_features[:20], ohlcv_features[0], atol=1e-6):
|
||||
# logger.warning("OHLCV features show repetitive pattern - possible synthetic data")
|
||||
|
||||
# # Check for unrealistic values
|
||||
# if np.any(features > 1e6) or np.any(features < -1e6):
|
||||
# logger.warning("Feature vector contains unrealistic values")
|
||||
|
||||
# # Check for NaN or infinite values
|
||||
# if np.any(np.isnan(features)) or np.any(np.isinf(features)):
|
||||
# logger.warning("Feature vector contains NaN or infinite values")
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error validating feature quality: {e}")
|
||||
|
||||
# def predict(self, base_data: BaseDataInput) -> ModelOutput:
|
||||
# """
|
||||
# Make a prediction using the EnhancedCNN model
|
||||
|
||||
# Args:
|
||||
# base_data: Standardized input data
|
||||
|
||||
# Returns:
|
||||
# ModelOutput: Standardized model output
|
||||
# """
|
||||
# try:
|
||||
# # Track inference timing
|
||||
# start_time = datetime.now()
|
||||
# inference_start = start_time.timestamp()
|
||||
|
||||
# # Convert BaseDataInput to features
|
||||
# features = self._convert_base_data_to_features(base_data)
|
||||
|
||||
# # Ensure features has batch dimension
|
||||
# if features.dim() == 1:
|
||||
# features = features.unsqueeze(0)
|
||||
|
||||
# # Ensure model is on correct device before prediction
|
||||
# self._ensure_model_on_device()
|
||||
|
||||
# # Set model to evaluation mode
|
||||
# self.model.eval()
|
||||
|
||||
# # Make prediction
|
||||
# with torch.no_grad():
|
||||
# q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.model(features)
|
||||
|
||||
# # Get action and confidence
|
||||
# action_probs = torch.softmax(q_values, dim=1)
|
||||
# action_idx = torch.argmax(action_probs, dim=1).item()
|
||||
# raw_confidence = float(action_probs[0, action_idx].item())
|
||||
|
||||
# # Validate confidence - prevent 100% confidence which indicates overfitting
|
||||
# if raw_confidence >= 0.99:
|
||||
# logger.warning(f"CNN produced suspiciously high confidence: {raw_confidence:.4f} - possible overfitting")
|
||||
# # Cap confidence at 0.95 to prevent unrealistic predictions
|
||||
# confidence = min(raw_confidence, 0.95)
|
||||
# logger.info(f"Capped confidence from {raw_confidence:.4f} to {confidence:.4f}")
|
||||
# else:
|
||||
# confidence = raw_confidence
|
||||
|
||||
# # Map action index to action string
|
||||
# actions = ['BUY', 'SELL', 'HOLD']
|
||||
# action = actions[action_idx]
|
||||
|
||||
# # Extract pivot price prediction (simplified - take first value from price_pred)
|
||||
# pivot_price = None
|
||||
# if price_pred is not None and len(price_pred.squeeze()) > 0:
|
||||
# # Get current price from base_data for context
|
||||
# current_price = 0.0
|
||||
# if base_data.ohlcv_1s and len(base_data.ohlcv_1s) > 0:
|
||||
# current_price = base_data.ohlcv_1s[-1].close
|
||||
|
||||
# # Calculate pivot price as current price + predicted change
|
||||
# price_change_pct = float(price_pred.squeeze()[0].item()) # First prediction value
|
||||
# pivot_price = current_price * (1 + price_change_pct * 0.01) # Convert percentage to price
|
||||
|
||||
# # Create predictions dictionary
|
||||
# predictions = {
|
||||
# 'action': action,
|
||||
# 'buy_probability': float(action_probs[0, 0].item()),
|
||||
# 'sell_probability': float(action_probs[0, 1].item()),
|
||||
# 'hold_probability': float(action_probs[0, 2].item()),
|
||||
# 'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(),
|
||||
# 'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist(),
|
||||
# 'pivot_price': pivot_price
|
||||
# }
|
||||
|
||||
# # Create hidden states dictionary
|
||||
# hidden_states = {
|
||||
# 'features': features_refined.squeeze(0).cpu().numpy().tolist()
|
||||
# }
|
||||
|
||||
# # Calculate inference duration
|
||||
# end_time = datetime.now()
|
||||
# inference_duration = (end_time.timestamp() - inference_start) * 1000 # Convert to milliseconds
|
||||
|
||||
# # Update metrics
|
||||
# self.last_inference_time = start_time
|
||||
# self.last_inference_duration = inference_duration
|
||||
# self.inference_count += 1
|
||||
|
||||
# # Store last prediction output for dashboard
|
||||
# self.last_prediction_output = {
|
||||
# 'action': action,
|
||||
# 'confidence': confidence,
|
||||
# 'pivot_price': pivot_price,
|
||||
# 'timestamp': start_time,
|
||||
# 'symbol': base_data.symbol
|
||||
# }
|
||||
|
||||
# # Create metadata dictionary
|
||||
# metadata = {
|
||||
# 'model_version': '1.0',
|
||||
# 'timestamp': start_time.isoformat(),
|
||||
# 'input_shape': features.shape,
|
||||
# 'inference_duration_ms': inference_duration,
|
||||
# 'inference_count': self.inference_count
|
||||
# }
|
||||
|
||||
# # Create ModelOutput
|
||||
# model_output = ModelOutput(
|
||||
# model_type='cnn',
|
||||
# model_name=self.model_name,
|
||||
# symbol=base_data.symbol,
|
||||
# timestamp=start_time,
|
||||
# confidence=confidence,
|
||||
# predictions=predictions,
|
||||
# hidden_states=hidden_states,
|
||||
# metadata=metadata
|
||||
# )
|
||||
|
||||
# # Log inference with full input data for training feedback
|
||||
# log_model_inference(
|
||||
# model_name=self.model_name,
|
||||
# symbol=base_data.symbol,
|
||||
# action=action,
|
||||
# confidence=confidence,
|
||||
# probabilities={
|
||||
# 'BUY': predictions['buy_probability'],
|
||||
# 'SELL': predictions['sell_probability'],
|
||||
# 'HOLD': predictions['hold_probability']
|
||||
# },
|
||||
# input_features=features.cpu().numpy(), # Store full feature vector
|
||||
# processing_time_ms=inference_duration,
|
||||
# checkpoint_id=None, # Could be enhanced to track checkpoint
|
||||
# metadata={
|
||||
# 'base_data_input': {
|
||||
# 'symbol': base_data.symbol,
|
||||
# 'timestamp': base_data.timestamp.isoformat(),
|
||||
# 'ohlcv_1s_count': len(base_data.ohlcv_1s),
|
||||
# 'ohlcv_1m_count': len(base_data.ohlcv_1m),
|
||||
# 'ohlcv_1h_count': len(base_data.ohlcv_1h),
|
||||
# 'ohlcv_1d_count': len(base_data.ohlcv_1d),
|
||||
# 'btc_ohlcv_1s_count': len(base_data.btc_ohlcv_1s),
|
||||
# 'has_cob_data': base_data.cob_data is not None,
|
||||
# 'technical_indicators_count': len(base_data.technical_indicators),
|
||||
# 'pivot_points_count': len(base_data.pivot_points),
|
||||
# 'last_predictions_count': len(base_data.last_predictions)
|
||||
# },
|
||||
# 'model_predictions': {
|
||||
# 'pivot_price': pivot_price,
|
||||
# 'extrema_prediction': predictions['extrema'],
|
||||
# 'price_prediction': predictions['price_prediction']
|
||||
# }
|
||||
# }
|
||||
# )
|
||||
|
||||
# return model_output
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error making prediction with EnhancedCNN: {e}")
|
||||
# # Return default ModelOutput
|
||||
# return create_model_output(
|
||||
# model_type='cnn',
|
||||
# model_name=self.model_name,
|
||||
# symbol=base_data.symbol,
|
||||
# action='HOLD',
|
||||
# confidence=0.0
|
||||
# )
|
||||
|
||||
# def add_training_sample(self, symbol_or_base_data, actual_action: str, reward: float):
|
||||
# """
|
||||
# Add a training sample to the training data
|
||||
|
||||
# Args:
|
||||
# symbol_or_base_data: Either a symbol string or BaseDataInput object
|
||||
# actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
|
||||
# reward: Reward received for the action
|
||||
# """
|
||||
# try:
|
||||
# # Handle both symbol string and BaseDataInput object
|
||||
# if isinstance(symbol_or_base_data, str):
|
||||
# # For cold start mode - create a simple training sample with current features
|
||||
# # This is a simplified approach for rapid training
|
||||
# symbol = symbol_or_base_data
|
||||
|
||||
# # Create a realistic feature vector instead of random data
|
||||
# # Use actual market data if available, otherwise create realistic synthetic data
|
||||
# try:
|
||||
# # Try to get real market data first
|
||||
# if hasattr(self, 'data_provider') and self.data_provider:
|
||||
# # This would need to be implemented in the adapter
|
||||
# features = self._create_realistic_features(symbol)
|
||||
# else:
|
||||
# # Create realistic synthetic features (not random)
|
||||
# features = self._create_realistic_synthetic_features(symbol)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Could not create realistic features for {symbol}: {e}")
|
||||
# # Fallback to small random variation instead of pure random
|
||||
# base_features = torch.ones(7850, dtype=torch.float32, device=self.device) * 0.5
|
||||
# noise = torch.randn(7850, dtype=torch.float32, device=self.device) * 0.1
|
||||
# features = base_features + noise
|
||||
|
||||
# logger.debug(f"Added realistic training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
|
||||
|
||||
# else:
|
||||
# # Full BaseDataInput object
|
||||
# base_data = symbol_or_base_data
|
||||
# features = self._convert_base_data_to_features(base_data)
|
||||
# symbol = base_data.symbol
|
||||
|
||||
# logger.debug(f"Added full training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
|
||||
|
||||
# # Convert action to index
|
||||
# actions = ['BUY', 'SELL', 'HOLD']
|
||||
# action_idx = actions.index(actual_action)
|
||||
|
||||
# # Add to training data
|
||||
# with self.training_lock:
|
||||
# self.training_data.append((features, action_idx, reward))
|
||||
|
||||
# # Limit training data size
|
||||
# if len(self.training_data) > self.max_training_samples:
|
||||
# # Sort by reward (highest first) and keep top samples
|
||||
# self.training_data.sort(key=lambda x: x[2], reverse=True)
|
||||
# self.training_data = self.training_data[:self.max_training_samples]
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error adding training sample: {e}")
|
||||
|
||||
# def train(self, epochs: int = 1) -> Dict[str, float]:
|
||||
# """
|
||||
# Train the model with collected data and inference history
|
||||
|
||||
# Args:
|
||||
# epochs: Number of epochs to train for
|
||||
|
||||
# Returns:
|
||||
# Dict[str, float]: Training metrics
|
||||
# """
|
||||
# try:
|
||||
# # Track training timing
|
||||
# training_start_time = datetime.now()
|
||||
# training_start = training_start_time.timestamp()
|
||||
|
||||
# with self.training_lock:
|
||||
# # Get additional training data from inference history
|
||||
# self._load_training_data_from_inference_history()
|
||||
|
||||
# # Check if we have enough data
|
||||
# if len(self.training_data) < self.batch_size:
|
||||
# logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
|
||||
# return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
|
||||
|
||||
# # Ensure model is on correct device before training
|
||||
# self._ensure_model_on_device()
|
||||
|
||||
# # Set model to training mode
|
||||
# self.model.train()
|
||||
|
||||
# # Create optimizer
|
||||
# optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
|
||||
|
||||
# # Training metrics
|
||||
# total_loss = 0.0
|
||||
# correct_predictions = 0
|
||||
# total_predictions = 0
|
||||
|
||||
# # Train for specified number of epochs
|
||||
# for epoch in range(epochs):
|
||||
# # Shuffle training data
|
||||
# np.random.shuffle(self.training_data)
|
||||
|
||||
# # Process in batches
|
||||
# for i in range(0, len(self.training_data), self.batch_size):
|
||||
# batch = self.training_data[i:i+self.batch_size]
|
||||
|
||||
# # Skip if batch is too small
|
||||
# if len(batch) < 2:
|
||||
# continue
|
||||
|
||||
# # Prepare batch - ensure all tensors are on the correct device
|
||||
# features = torch.stack([sample[0].to(self.device) for sample in batch])
|
||||
# actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
|
||||
# rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)
|
||||
|
||||
# # Zero gradients
|
||||
# optimizer.zero_grad()
|
||||
|
||||
# # Forward pass
|
||||
# q_values, _, _, _, _ = self.model(features)
|
||||
|
||||
# # Calculate loss (CrossEntropyLoss with reward weighting)
|
||||
# # First, apply softmax to get probabilities
|
||||
# probs = torch.softmax(q_values, dim=1)
|
||||
|
||||
# # Get probability of chosen action
|
||||
# chosen_probs = probs[torch.arange(len(actions)), actions]
|
||||
|
||||
# # Calculate negative log likelihood loss
|
||||
# nll_loss = -torch.log(chosen_probs + 1e-10)
|
||||
|
||||
# # Weight by reward (higher reward = higher weight)
|
||||
# # Normalize rewards to [0, 1] range
|
||||
# min_reward = rewards.min()
|
||||
# max_reward = rewards.max()
|
||||
# if max_reward > min_reward:
|
||||
# normalized_rewards = (rewards - min_reward) / (max_reward - min_reward)
|
||||
# else:
|
||||
# normalized_rewards = torch.ones_like(rewards)
|
||||
|
||||
# # Apply reward weighting (higher reward = higher weight)
|
||||
# weighted_loss = nll_loss * (normalized_rewards + 0.1) # Add small constant to avoid zero weights
|
||||
|
||||
# # Mean loss
|
||||
# loss = weighted_loss.mean()
|
||||
|
||||
# # Backward pass
|
||||
# loss.backward()
|
||||
|
||||
# # Update weights
|
||||
# optimizer.step()
|
||||
|
||||
# # Update metrics
|
||||
# total_loss += loss.item()
|
||||
|
||||
# # Calculate accuracy
|
||||
# predicted_actions = torch.argmax(q_values, dim=1)
|
||||
# correct_predictions += (predicted_actions == actions).sum().item()
|
||||
# total_predictions += len(actions)
|
||||
|
||||
# # Validate training - detect overfitting
|
||||
# if total_predictions > 0:
|
||||
# current_accuracy = correct_predictions / total_predictions
|
||||
# if current_accuracy >= 0.99:
|
||||
# logger.warning(f"CNN training shows suspiciously high accuracy: {current_accuracy:.4f} - possible overfitting")
|
||||
# # Add regularization to prevent overfitting
|
||||
# l2_reg = 0.01 * sum(p.pow(2.0).sum() for p in self.model.parameters())
|
||||
# loss = loss + l2_reg
|
||||
# logger.info("Added L2 regularization to prevent overfitting")
|
||||
|
||||
# # Calculate final metrics
|
||||
# avg_loss = total_loss / (len(self.training_data) / self.batch_size)
|
||||
# accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||||
|
||||
# # Calculate training duration
|
||||
# training_end_time = datetime.now()
|
||||
# training_duration = (training_end_time.timestamp() - training_start) * 1000 # Convert to milliseconds
|
||||
|
||||
# # Update training metrics
|
||||
# self.last_training_time = training_start_time
|
||||
# self.last_training_duration = training_duration
|
||||
# self.last_training_loss = avg_loss
|
||||
# self.training_count += 1
|
||||
|
||||
# # Save checkpoint
|
||||
# self._save_checkpoint(avg_loss, accuracy)
|
||||
|
||||
# logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}, duration={training_duration:.1f}ms")
|
||||
|
||||
# return {
|
||||
# 'loss': avg_loss,
|
||||
# 'accuracy': accuracy,
|
||||
# 'samples': len(self.training_data),
|
||||
# 'duration_ms': training_duration,
|
||||
# 'training_count': self.training_count
|
||||
# }
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error training model: {e}")
|
||||
# return {'loss': 0.0, 'accuracy': 0.0, 'samples': 0, 'error': str(e)}
|
||||
|
||||
# def _save_checkpoint(self, loss: float, accuracy: float):
|
||||
# """
|
||||
# Save model checkpoint
|
||||
|
||||
# Args:
|
||||
# loss: Training loss
|
||||
# accuracy: Training accuracy
|
||||
# """
|
||||
# try:
|
||||
# # Import checkpoint manager
|
||||
# from utils.checkpoint_manager import CheckpointManager
|
||||
|
||||
# # Create checkpoint manager
|
||||
# checkpoint_manager = CheckpointManager(
|
||||
# checkpoint_dir=self.checkpoint_dir,
|
||||
# max_checkpoints=10,
|
||||
# metric_name="accuracy"
|
||||
# )
|
||||
|
||||
# # Create temporary model file
|
||||
# temp_path = os.path.join(self.checkpoint_dir, f"{self.model_name}_temp")
|
||||
# self.model.save(temp_path)
|
||||
|
||||
# # Create metrics
|
||||
# metrics = {
|
||||
# 'loss': loss,
|
||||
# 'accuracy': accuracy,
|
||||
# 'samples': len(self.training_data)
|
||||
# }
|
||||
|
||||
# # Create metadata
|
||||
# metadata = {
|
||||
# 'timestamp': datetime.now().isoformat(),
|
||||
# 'model_name': self.model_name,
|
||||
# 'input_shape': self.model.input_shape,
|
||||
# 'n_actions': self.model.n_actions
|
||||
# }
|
||||
|
||||
# # Save checkpoint
|
||||
# checkpoint_path = checkpoint_manager.save_checkpoint(
|
||||
# model_name=self.model_name,
|
||||
# model_path=f"{temp_path}.pt",
|
||||
# metrics=metrics,
|
||||
# metadata=metadata
|
||||
# )
|
||||
|
||||
# # Delete temporary model file
|
||||
# if os.path.exists(f"{temp_path}.pt"):
|
||||
# os.remove(f"{temp_path}.pt")
|
||||
|
||||
# logger.info(f"Model checkpoint saved to {checkpoint_path}")
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error saving checkpoint: {e}")
|
||||
|
||||
# def _load_training_data_from_inference_history(self):
|
||||
# """Load training data from inference history for continuous learning"""
|
||||
# try:
|
||||
# from utils.database_manager import get_database_manager
|
||||
|
||||
# db_manager = get_database_manager()
|
||||
|
||||
# # Get recent inference records with input features
|
||||
# inference_records = db_manager.get_inference_records_for_training(
|
||||
# model_name=self.model_name,
|
||||
# hours_back=24, # Last 24 hours
|
||||
# limit=1000
|
||||
# )
|
||||
|
||||
# if not inference_records:
|
||||
# logger.debug("No inference records found for training")
|
||||
# return
|
||||
|
||||
# # Convert inference records to training samples
|
||||
# # For now, use a simple approach: treat high-confidence predictions as ground truth
|
||||
# for record in inference_records:
|
||||
# if record.input_features is not None and record.confidence > 0.7:
|
||||
# # Convert action to index
|
||||
# actions = ['BUY', 'SELL', 'HOLD']
|
||||
# if record.action in actions:
|
||||
# action_idx = actions.index(record.action)
|
||||
|
||||
# # Use confidence as a proxy for reward (high confidence = good prediction)
|
||||
# reward = record.confidence * 2 - 1 # Scale to [-1, 1]
|
||||
|
||||
# # Convert features to tensor
|
||||
# features_tensor = torch.tensor(record.input_features, dtype=torch.float32, device=self.device)
|
||||
|
||||
# # Add to training data if not already present (avoid duplicates)
|
||||
# sample_exists = any(
|
||||
# torch.equal(features_tensor, existing[0])
|
||||
# for existing in self.training_data
|
||||
# )
|
||||
|
||||
# if not sample_exists:
|
||||
# self.training_data.append((features_tensor, action_idx, reward))
|
||||
|
||||
# logger.info(f"Loaded {len(inference_records)} inference records for training, total training samples: {len(self.training_data)}")
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error loading training data from inference history: {e}")
|
||||
|
||||
# def evaluate_predictions_against_outcomes(self, hours_back: int = 1) -> Dict[str, float]:
|
||||
# """
|
||||
# Evaluate past predictions against actual market outcomes
|
||||
|
||||
# Args:
|
||||
# hours_back: How many hours back to evaluate
|
||||
|
||||
# Returns:
|
||||
# Dict with evaluation metrics
|
||||
# """
|
||||
# try:
|
||||
# from utils.database_manager import get_database_manager
|
||||
|
||||
# db_manager = get_database_manager()
|
||||
|
||||
# # Get inference records from the specified time period
|
||||
# inference_records = db_manager.get_inference_records_for_training(
|
||||
# model_name=self.model_name,
|
||||
# hours_back=hours_back,
|
||||
# limit=100
|
||||
# )
|
||||
|
||||
# if not inference_records:
|
||||
# return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}
|
||||
|
||||
# # For now, use a simple evaluation based on confidence
|
||||
# # In a real implementation, this would compare against actual price movements
|
||||
# correct_predictions = 0
|
||||
# total_predictions = len(inference_records)
|
||||
|
||||
# # Simple heuristic: high confidence predictions are more likely to be correct
|
||||
# for record in inference_records:
|
||||
# if record.confidence > 0.8: # High confidence threshold
|
||||
# correct_predictions += 1
|
||||
# elif record.confidence > 0.6: # Medium confidence
|
||||
# correct_predictions += 0.5
|
||||
|
||||
# accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||||
|
||||
# logger.info(f"Prediction evaluation: {correct_predictions:.1f}/{total_predictions} = {accuracy:.3f} accuracy")
|
||||
|
||||
# return {
|
||||
# 'accuracy': accuracy,
|
||||
# 'total_predictions': total_predictions,
|
||||
# 'correct_predictions': correct_predictions
|
||||
# }
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error evaluating predictions: {e}")
|
||||
# return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}
|
||||
@@ -1,464 +0,0 @@
|
||||
"""
|
||||
Enhanced Trading Orchestrator
|
||||
|
||||
Central coordination hub for the multi-modal trading system that manages:
|
||||
- Data subscription and management
|
||||
- Model inference coordination
|
||||
- Cross-model data feeding
|
||||
- Training pipeline orchestration
|
||||
- Decision making using Mixture of Experts
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from core.data_provider import DataProvider
|
||||
from core.trading_action import TradingAction
|
||||
from utils.tensorboard_logger import TensorBoardLogger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ModelOutput:
|
||||
"""Extensible model output format supporting all model types"""
|
||||
model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator'
|
||||
model_name: str # Specific model identifier
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
confidence: float
|
||||
predictions: Dict[str, Any] # Model-specific predictions
|
||||
hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # Additional info
|
||||
|
||||
@dataclass
|
||||
class BaseDataInput:
|
||||
"""Unified base data input for all models"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
ohlcv_data: Dict[str, Any] = field(default_factory=dict) # Multi-timeframe OHLCV
|
||||
cob_data: Optional[Dict[str, Any]] = None # COB buckets for 1s timeframe
|
||||
technical_indicators: Dict[str, float] = field(default_factory=dict)
|
||||
pivot_points: List[Any] = field(default_factory=list)
|
||||
last_predictions: Dict[str, ModelOutput] = field(default_factory=dict) # From all models
|
||||
market_microstructure: Dict[str, Any] = field(default_factory=dict) # Order flow, etc.
|
||||
|
||||
@dataclass
|
||||
class COBData:
|
||||
"""Cumulative Order Book data for price buckets"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
current_price: float
|
||||
bucket_size: float # $1 for ETH, $10 for BTC
|
||||
price_buckets: Dict[float, Dict[str, float]] = field(default_factory=dict) # price -> {bid_volume, ask_volume, etc.}
|
||||
bid_ask_imbalance: Dict[float, float] = field(default_factory=dict) # price -> imbalance ratio
|
||||
volume_weighted_prices: Dict[float, float] = field(default_factory=dict) # price -> VWAP within bucket
|
||||
order_flow_metrics: Dict[str, float] = field(default_factory=dict) # Various order flow indicators
|
||||
|
||||
class EnhancedTradingOrchestrator:
|
||||
"""
|
||||
Enhanced Trading Orchestrator implementing the design specification
|
||||
|
||||
Coordinates data flow, model inference, and decision making for the multi-modal trading system.
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: DataProvider, symbols: List[str], enhanced_rl_training: bool = False, model_registry: Dict = None):
|
||||
"""Initialize the enhanced orchestrator"""
|
||||
self.data_provider = data_provider
|
||||
self.symbols = symbols
|
||||
self.enhanced_rl_training = enhanced_rl_training
|
||||
self.model_registry = model_registry or {}
|
||||
|
||||
# Data management
|
||||
self.data_buffers = {symbol: {} for symbol in symbols}
|
||||
self.last_update_times = {symbol: {} for symbol in symbols}
|
||||
|
||||
# Model output storage
|
||||
self.model_outputs = {symbol: {} for symbol in symbols}
|
||||
self.model_output_history = {symbol: {} for symbol in symbols}
|
||||
|
||||
# Training pipeline
|
||||
self.training_data = {symbol: [] for symbol in symbols}
|
||||
self.tensorboard_logger = TensorBoardLogger("runs", f"orchestrator_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
||||
|
||||
# COB integration
|
||||
self.cob_data = {symbol: None for symbol in symbols}
|
||||
|
||||
# Performance tracking
|
||||
self.performance_metrics = {
|
||||
'inference_count': 0,
|
||||
'successful_states': 0,
|
||||
'total_episodes': 0
|
||||
}
|
||||
|
||||
logger.info("Enhanced Trading Orchestrator initialized")
|
||||
|
||||
async def start_cob_integration(self):
|
||||
"""Start COB data integration for real-time market microstructure"""
|
||||
try:
|
||||
# Subscribe to COB data updates
|
||||
self.data_provider.subscribe_to_cob_data(self._on_cob_data_update)
|
||||
logger.info("COB integration started")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting COB integration: {e}")
|
||||
|
||||
async def start_realtime_processing(self):
|
||||
"""Start real-time data processing"""
|
||||
try:
|
||||
# Subscribe to tick data for real-time processing
|
||||
for symbol in self.symbols:
|
||||
self.data_provider.subscribe_to_ticks(
|
||||
callback=self._on_tick_data,
|
||||
symbols=[symbol],
|
||||
subscriber_name=f"orchestrator_{symbol}"
|
||||
)
|
||||
|
||||
logger.info("Real-time processing started")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting real-time processing: {e}")
|
||||
|
||||
def _on_cob_data_update(self, symbol: str, cob_data: dict):
|
||||
"""Handle COB data updates"""
|
||||
try:
|
||||
# Process and store COB data
|
||||
self.cob_data[symbol] = self._process_cob_data(symbol, cob_data)
|
||||
logger.debug(f"COB data updated for {symbol}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing COB data for {symbol}: {e}")
|
||||
|
||||
def _process_cob_data(self, symbol: str, cob_data: dict) -> COBData:
|
||||
"""Process raw COB data into structured format"""
|
||||
try:
|
||||
# Determine bucket size based on symbol
|
||||
bucket_size = 1.0 if 'ETH' in symbol else 10.0
|
||||
|
||||
# Extract current price
|
||||
stats = cob_data.get('stats', {})
|
||||
current_price = stats.get('mid_price', 0)
|
||||
|
||||
# Create COB data structure
|
||||
cob = COBData(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
current_price=current_price,
|
||||
bucket_size=bucket_size
|
||||
)
|
||||
|
||||
# Process order book data into price buckets
|
||||
bids = cob_data.get('bids', [])
|
||||
asks = cob_data.get('asks', [])
|
||||
|
||||
# Create price buckets around current price
|
||||
bucket_count = 20 # ±20 buckets
|
||||
for i in range(-bucket_count, bucket_count + 1):
|
||||
bucket_price = current_price + (i * bucket_size)
|
||||
cob.price_buckets[bucket_price] = {
|
||||
'bid_volume': 0.0,
|
||||
'ask_volume': 0.0
|
||||
}
|
||||
|
||||
# Aggregate bid volumes into buckets
|
||||
for price, volume in bids:
|
||||
bucket_price = round(price / bucket_size) * bucket_size
|
||||
if bucket_price in cob.price_buckets:
|
||||
cob.price_buckets[bucket_price]['bid_volume'] += volume
|
||||
|
||||
# Aggregate ask volumes into buckets
|
||||
for price, volume in asks:
|
||||
bucket_price = round(price / bucket_size) * bucket_size
|
||||
if bucket_price in cob.price_buckets:
|
||||
cob.price_buckets[bucket_price]['ask_volume'] += volume
|
||||
|
||||
# Calculate bid/ask imbalances
|
||||
for price, volumes in cob.price_buckets.items():
|
||||
bid_vol = volumes['bid_volume']
|
||||
ask_vol = volumes['ask_volume']
|
||||
total_vol = bid_vol + ask_vol
|
||||
if total_vol > 0:
|
||||
cob.bid_ask_imbalance[price] = (bid_vol - ask_vol) / total_vol
|
||||
else:
|
||||
cob.bid_ask_imbalance[price] = 0.0
|
||||
|
||||
# Calculate volume-weighted prices
|
||||
for price, volumes in cob.price_buckets.items():
|
||||
bid_vol = volumes['bid_volume']
|
||||
ask_vol = volumes['ask_volume']
|
||||
total_vol = bid_vol + ask_vol
|
||||
if total_vol > 0:
|
||||
cob.volume_weighted_prices[price] = (
|
||||
(price * bid_vol) + (price * ask_vol)
|
||||
) / total_vol
|
||||
else:
|
||||
cob.volume_weighted_prices[price] = price
|
||||
|
||||
# Calculate order flow metrics
|
||||
cob.order_flow_metrics = {
|
||||
'total_bid_volume': sum(v['bid_volume'] for v in cob.price_buckets.values()),
|
||||
'total_ask_volume': sum(v['ask_volume'] for v in cob.price_buckets.values()),
|
||||
'bid_ask_ratio': 0.0 if cob.order_flow_metrics['total_ask_volume'] == 0 else
|
||||
cob.order_flow_metrics['total_bid_volume'] / cob.order_flow_metrics['total_ask_volume']
|
||||
}
|
||||
|
||||
return cob
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing COB data for {symbol}: {e}")
|
||||
return COBData(symbol=symbol, timestamp=datetime.now(), current_price=0, bucket_size=bucket_size)
|
||||
|
||||
def _on_tick_data(self, tick):
|
||||
"""Handle incoming tick data"""
|
||||
try:
|
||||
# Update data buffers
|
||||
symbol = tick.symbol
|
||||
if symbol not in self.data_buffers:
|
||||
self.data_buffers[symbol] = {}
|
||||
|
||||
# Store tick data
|
||||
if 'ticks' not in self.data_buffers[symbol]:
|
||||
self.data_buffers[symbol]['ticks'] = []
|
||||
self.data_buffers[symbol]['ticks'].append(tick)
|
||||
|
||||
# Keep only last 1000 ticks
|
||||
if len(self.data_buffers[symbol]['ticks']) > 1000:
|
||||
self.data_buffers[symbol]['ticks'] = self.data_buffers[symbol]['ticks'][-1000:]
|
||||
|
||||
# Update last update time
|
||||
self.last_update_times[symbol]['tick'] = datetime.now()
|
||||
|
||||
logger.debug(f"Tick data updated for {symbol}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing tick data: {e}")
|
||||
|
||||
def build_comprehensive_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Build comprehensive RL state with 13,400 features as specified
|
||||
|
||||
Returns:
|
||||
np.ndarray: State vector with 13,400 features
|
||||
"""
|
||||
try:
|
||||
# Initialize state vector
|
||||
state_size = 13400
|
||||
state = np.zeros(state_size, dtype=np.float32)
|
||||
|
||||
# Get latest data
|
||||
ohlcv_data = self.data_provider.get_latest_candles(symbol, '1s', limit=100)
|
||||
cob_data = self.cob_data.get(symbol)
|
||||
|
||||
# Feature index tracking
|
||||
idx = 0
|
||||
|
||||
# 1. OHLCV features (4000 features)
|
||||
if ohlcv_data is not None and not ohlcv_data.empty:
|
||||
# Use last 100 1s candles (40 features each: O,H,L,C,V + 36 indicators)
|
||||
for i in range(min(100, len(ohlcv_data))):
|
||||
if idx + 40 <= state_size:
|
||||
row = ohlcv_data.iloc[-(i+1)]
|
||||
state[idx] = row.get('open', 0) / 100000 # Normalized
|
||||
state[idx+1] = row.get('high', 0) / 100000
|
||||
state[idx+2] = row.get('low', 0) / 100000
|
||||
state[idx+3] = row.get('close', 0) / 100000
|
||||
state[idx+4] = row.get('volume', 0) / 1000000
|
||||
|
||||
# Add technical indicators if available
|
||||
indicator_idx = 5
|
||||
for col in ['sma_10', 'sma_20', 'ema_12', 'ema_26', 'rsi_14',
|
||||
'macd', 'bb_upper', 'bb_lower', 'atr', 'adx']:
|
||||
if col in row and idx + indicator_idx < state_size:
|
||||
state[idx + indicator_idx] = row[col] / 100000
|
||||
indicator_idx += 1
|
||||
|
||||
idx += 40
|
||||
|
||||
# 2. COB features (8000 features)
|
||||
if cob_data and idx + 8000 <= state_size:
|
||||
# Use 200 price buckets (40 features each)
|
||||
bucket_prices = sorted(cob_data.price_buckets.keys())
|
||||
for i, price in enumerate(bucket_prices[:200]):
|
||||
if idx + 40 <= state_size:
|
||||
bucket = cob_data.price_buckets[price]
|
||||
state[idx] = bucket.get('bid_volume', 0) / 1000000 # Normalized
|
||||
state[idx+1] = bucket.get('ask_volume', 0) / 1000000
|
||||
state[idx+2] = cob_data.bid_ask_imbalance.get(price, 0)
|
||||
state[idx+3] = cob_data.volume_weighted_prices.get(price, price) / 100000
|
||||
|
||||
# Additional COB metrics
|
||||
state[idx+4] = cob_data.order_flow_metrics.get('total_bid_volume', 0) / 10000000
|
||||
state[idx+5] = cob_data.order_flow_metrics.get('total_ask_volume', 0) / 10000000
|
||||
state[idx+6] = cob_data.order_flow_metrics.get('bid_ask_ratio', 0)
|
||||
|
||||
idx += 40
|
||||
|
||||
# 3. Technical indicator features (1000 features)
|
||||
# Already included in OHLCV section above
|
||||
|
||||
# 4. Market microstructure features (400 features)
|
||||
if cob_data and idx + 400 <= state_size:
|
||||
# Add order flow metrics
|
||||
metrics = list(cob_data.order_flow_metrics.values())
|
||||
for i, metric in enumerate(metrics[:400]):
|
||||
if idx + i < state_size:
|
||||
state[idx + i] = metric
|
||||
|
||||
# Log state building success
|
||||
self.performance_metrics['successful_states'] += 1
|
||||
logger.debug(f"Comprehensive RL state built for {symbol}: {len(state)} features")
|
||||
|
||||
# Log to TensorBoard
|
||||
self.tensorboard_logger.log_state_metrics(
|
||||
symbol=symbol,
|
||||
state_info={
|
||||
'size': len(state),
|
||||
'quality': 1.0,
|
||||
'feature_counts': {
|
||||
'total': len(state),
|
||||
'non_zero': np.count_nonzero(state)
|
||||
}
|
||||
},
|
||||
step=self.performance_metrics['successful_states']
|
||||
)
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building comprehensive RL state for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def calculate_enhanced_pivot_reward(self, trade_decision: Dict, market_data: Dict, trade_outcome: Dict) -> float:
|
||||
"""
|
||||
Calculate enhanced pivot-based reward
|
||||
|
||||
Args:
|
||||
trade_decision: Trading decision with action and confidence
|
||||
market_data: Market context data
|
||||
trade_outcome: Actual trade results
|
||||
|
||||
Returns:
|
||||
float: Enhanced reward value
|
||||
"""
|
||||
try:
|
||||
# Base reward from PnL
|
||||
pnl_reward = trade_outcome.get('net_pnl', 0) / 100 # Normalize
|
||||
|
||||
# Confidence weighting
|
||||
confidence = trade_decision.get('confidence', 0.5)
|
||||
confidence_reward = confidence * 0.2
|
||||
|
||||
# Volatility adjustment
|
||||
volatility = market_data.get('volatility', 0.01)
|
||||
volatility_reward = (1.0 - volatility * 10) * 0.1 # Prefer low volatility
|
||||
|
||||
# Order flow alignment
|
||||
order_flow = market_data.get('order_flow_strength', 0)
|
||||
order_flow_reward = order_flow * 0.2
|
||||
|
||||
# Pivot alignment bonus (if near pivot in favorable direction)
|
||||
pivot_bonus = 0.0
|
||||
if market_data.get('near_pivot', False):
|
||||
action = trade_decision.get('action', '').upper()
|
||||
pivot_type = market_data.get('pivot_type', '').upper()
|
||||
|
||||
# Bonus for buying near support or selling near resistance
|
||||
if (action == 'BUY' and pivot_type == 'LOW') or \
|
||||
(action == 'SELL' and pivot_type == 'HIGH'):
|
||||
pivot_bonus = 0.5
|
||||
|
||||
# Calculate final reward
|
||||
enhanced_reward = pnl_reward + confidence_reward + volatility_reward + order_flow_reward + pivot_bonus
|
||||
|
||||
# Log to TensorBoard
|
||||
self.tensorboard_logger.log_scalars('Rewards/Components', {
|
||||
'pnl_component': pnl_reward,
|
||||
'confidence': confidence_reward,
|
||||
'volatility': volatility_reward,
|
||||
'order_flow': order_flow_reward,
|
||||
'pivot_bonus': pivot_bonus
|
||||
}, self.performance_metrics['total_episodes'])
|
||||
|
||||
self.tensorboard_logger.log_scalar('Rewards/Enhanced', enhanced_reward, self.performance_metrics['total_episodes'])
|
||||
|
||||
logger.debug(f"Enhanced reward calculated: {enhanced_reward}")
|
||||
return enhanced_reward
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating enhanced pivot reward: {e}")
|
||||
return 0.0
|
||||
|
||||
async def make_coordinated_decisions(self) -> Dict[str, TradingAction]:
|
||||
"""
|
||||
Make coordinated trading decisions using all available models
|
||||
|
||||
Returns:
|
||||
Dict[str, TradingAction]: Trading actions for each symbol
|
||||
"""
|
||||
try:
|
||||
decisions = {}
|
||||
|
||||
# For each symbol, coordinate model inference
|
||||
for symbol in self.symbols:
|
||||
# Build comprehensive state for RL model
|
||||
rl_state = self.build_comprehensive_rl_state(symbol)
|
||||
|
||||
if rl_state is not None:
|
||||
# Store state for training
|
||||
self.performance_metrics['total_episodes'] += 1
|
||||
|
||||
# Create mock RL decision (in a real implementation, this would call the RL model)
|
||||
action = 'BUY' if np.mean(rl_state[:100]) > 0.5 else 'SELL'
|
||||
confidence = min(1.0, max(0.0, np.std(rl_state) * 10))
|
||||
|
||||
# Create trading action
|
||||
decisions[symbol] = TradingAction(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
source='rl_orchestrator'
|
||||
)
|
||||
|
||||
logger.info(f"Coordinated decision for {symbol}: {action} (confidence: {confidence:.3f})")
|
||||
else:
|
||||
logger.warning(f"Failed to build state for {symbol}, skipping decision")
|
||||
|
||||
self.performance_metrics['inference_count'] += 1
|
||||
return decisions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error making coordinated decisions: {e}")
|
||||
return {}
|
||||
|
||||
def _get_symbol_correlation(self, symbol1: str, symbol2: str) -> float:
|
||||
"""
|
||||
Calculate correlation between two symbols
|
||||
|
||||
Args:
|
||||
symbol1: First symbol
|
||||
symbol2: Second symbol
|
||||
|
||||
Returns:
|
||||
float: Correlation coefficient (-1 to 1)
|
||||
"""
|
||||
try:
|
||||
# Get recent price data for both symbols
|
||||
data1 = self.data_provider.get_latest_candles(symbol1, '1m', limit=50)
|
||||
data2 = self.data_provider.get_latest_candles(symbol2, '1m', limit=50)
|
||||
|
||||
if data1 is None or data2 is None or data1.empty or data2.empty:
|
||||
return 0.0
|
||||
|
||||
# Align data by timestamp
|
||||
merged = data1[['close']].join(data2[['close']], lsuffix='_1', rsuffix='_2', how='inner')
|
||||
|
||||
if len(merged) < 10:
|
||||
return 0.0
|
||||
|
||||
# Calculate correlation
|
||||
correlation = merged['close_1'].corr(merged['close_2'])
|
||||
return correlation if not np.isnan(correlation) else 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating symbol correlation: {e}")
|
||||
return 0.0
|
||||
```
|
||||
@@ -1,775 +0,0 @@
|
||||
"""
|
||||
Enhanced Training Integration Module
|
||||
|
||||
This module provides comprehensive integration between the training data collection system,
|
||||
CNN training pipeline, RL training pipeline, and your existing infrastructure.
|
||||
|
||||
Key Features:
|
||||
- Real-time integration with existing DataProvider
|
||||
- Coordinated training across CNN and RL models
|
||||
- Automatic outcome validation and profitability tracking
|
||||
- Integration with existing COB RL model
|
||||
- Performance monitoring and optimization
|
||||
- Seamless connection to existing orchestrator and trading executor
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Import existing components
|
||||
from .data_provider import DataProvider
|
||||
from .orchestrator import Orchestrator
|
||||
from .trading_executor import TradingExecutor
|
||||
|
||||
# Import our training system components
|
||||
from .training_data_collector import (
|
||||
TrainingDataCollector,
|
||||
get_training_data_collector
|
||||
)
|
||||
from .cnn_training_pipeline import (
|
||||
CNNPivotPredictor,
|
||||
CNNTrainer,
|
||||
get_cnn_trainer
|
||||
)
|
||||
from .rl_training_pipeline import (
|
||||
RLTradingAgent,
|
||||
RLTrainer,
|
||||
get_rl_trainer
|
||||
)
|
||||
from .training_integration import TrainingIntegration
|
||||
|
||||
# Import existing RL model
|
||||
try:
|
||||
from NN.models.cob_rl_model import COBRLModelInterface
|
||||
except ImportError:
|
||||
logger.warning("Could not import COBRLModelInterface - using fallback")
|
||||
COBRLModelInterface = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class EnhancedTrainingConfig:
|
||||
"""Enhanced configuration for comprehensive training integration"""
|
||||
# Data collection
|
||||
collection_interval: float = 1.0
|
||||
min_data_completeness: float = 0.8
|
||||
|
||||
# Training triggers
|
||||
min_episodes_for_cnn_training: int = 100
|
||||
min_experiences_for_rl_training: int = 200
|
||||
training_frequency_minutes: int = 30
|
||||
|
||||
# Profitability thresholds
|
||||
min_profitability_for_replay: float = 0.1
|
||||
high_profitability_threshold: float = 0.5
|
||||
|
||||
# Model integration
|
||||
use_existing_cob_rl_model: bool = True
|
||||
enable_cross_model_learning: bool = True
|
||||
|
||||
# Performance optimization
|
||||
max_concurrent_training_sessions: int = 2
|
||||
enable_background_validation: bool = True
|
||||
|
||||
class EnhancedTrainingIntegration:
|
||||
"""Enhanced training integration with existing infrastructure"""
|
||||
|
||||
def __init__(self,
|
||||
data_provider: DataProvider,
|
||||
orchestrator: Orchestrator = None,
|
||||
trading_executor: TradingExecutor = None,
|
||||
config: EnhancedTrainingConfig = None):
|
||||
|
||||
self.data_provider = data_provider
|
||||
self.orchestrator = orchestrator
|
||||
self.trading_executor = trading_executor
|
||||
self.config = config or EnhancedTrainingConfig()
|
||||
|
||||
# Initialize training components
|
||||
self.data_collector = get_training_data_collector()
|
||||
|
||||
# Initialize CNN components
|
||||
self.cnn_model = CNNPivotPredictor()
|
||||
self.cnn_trainer = get_cnn_trainer(self.cnn_model)
|
||||
|
||||
# Initialize RL components
|
||||
if self.config.use_existing_cob_rl_model and COBRLModelInterface:
|
||||
self.existing_rl_model = COBRLModelInterface()
|
||||
logger.info("Using existing COB RL model")
|
||||
else:
|
||||
self.existing_rl_model = None
|
||||
|
||||
self.rl_agent = RLTradingAgent()
|
||||
self.rl_trainer = get_rl_trainer(self.rl_agent)
|
||||
|
||||
# Integration state
|
||||
self.is_running = False
|
||||
self.training_threads = {}
|
||||
self.validation_thread = None
|
||||
|
||||
# Performance tracking
|
||||
self.integration_stats = {
|
||||
'total_data_packages': 0,
|
||||
'cnn_training_sessions': 0,
|
||||
'rl_training_sessions': 0,
|
||||
'profitable_predictions': 0,
|
||||
'total_predictions': 0,
|
||||
'cross_model_improvements': 0,
|
||||
'last_update': datetime.now()
|
||||
}
|
||||
|
||||
# Model prediction tracking
|
||||
self.recent_predictions = {}
|
||||
self.prediction_outcomes = {}
|
||||
|
||||
# Cross-model learning
|
||||
self.model_performance_history = {
|
||||
'cnn': [],
|
||||
'rl': [],
|
||||
'orchestrator': []
|
||||
}
|
||||
|
||||
logger.info("Enhanced Training Integration initialized")
|
||||
logger.info(f"CNN model parameters: {sum(p.numel() for p in self.cnn_model.parameters()):,}")
|
||||
logger.info(f"RL agent parameters: {sum(p.numel() for p in self.rl_agent.parameters()):,}")
|
||||
logger.info(f"Using existing COB RL model: {self.existing_rl_model is not None}")
|
||||
|
||||
def start_enhanced_integration(self):
|
||||
"""Start the enhanced training integration system"""
|
||||
if self.is_running:
|
||||
logger.warning("Enhanced training integration already running")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# Start data collection
|
||||
self.data_collector.start_collection()
|
||||
|
||||
# Start CNN training
|
||||
if self.config.min_episodes_for_cnn_training > 0:
|
||||
for symbol in self.data_provider.symbols:
|
||||
self.cnn_trainer.start_real_time_training(symbol)
|
||||
|
||||
# Start coordinated training thread
|
||||
self.training_threads['coordinator'] = threading.Thread(
|
||||
target=self._training_coordinator_worker,
|
||||
daemon=True
|
||||
)
|
||||
self.training_threads['coordinator'].start()
|
||||
|
||||
# Start data collection and validation
|
||||
self.training_threads['data_collector'] = threading.Thread(
|
||||
target=self._enhanced_data_collection_worker,
|
||||
daemon=True
|
||||
)
|
||||
self.training_threads['data_collector'].start()
|
||||
|
||||
# Start outcome validation if enabled
|
||||
if self.config.enable_background_validation:
|
||||
self.validation_thread = threading.Thread(
|
||||
target=self._outcome_validation_worker,
|
||||
daemon=True
|
||||
)
|
||||
self.validation_thread.start()
|
||||
|
||||
logger.info("Enhanced training integration started")
|
||||
|
||||
def stop_enhanced_integration(self):
|
||||
"""Stop the enhanced training integration system"""
|
||||
self.is_running = False
|
||||
|
||||
# Stop data collection
|
||||
self.data_collector.stop_collection()
|
||||
|
||||
# Stop CNN training
|
||||
self.cnn_trainer.stop_training()
|
||||
|
||||
# Wait for threads to finish
|
||||
for thread_name, thread in self.training_threads.items():
|
||||
thread.join(timeout=10)
|
||||
logger.info(f"Stopped {thread_name} thread")
|
||||
|
||||
if self.validation_thread:
|
||||
self.validation_thread.join(timeout=5)
|
||||
|
||||
logger.info("Enhanced training integration stopped")
|
||||
|
||||
def _enhanced_data_collection_worker(self):
|
||||
"""Enhanced data collection with real-time model integration"""
|
||||
logger.info("Enhanced data collection worker started")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
for symbol in self.data_provider.symbols:
|
||||
self._collect_enhanced_training_data(symbol)
|
||||
|
||||
time.sleep(self.config.collection_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in enhanced data collection: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
logger.info("Enhanced data collection worker stopped")
|
||||
|
||||
def _collect_enhanced_training_data(self, symbol: str):
|
||||
"""Collect enhanced training data with model predictions"""
|
||||
try:
|
||||
# Get comprehensive market data
|
||||
market_data = self._get_comprehensive_market_data(symbol)
|
||||
|
||||
if not market_data or not self._validate_market_data(market_data):
|
||||
return
|
||||
|
||||
# Get current model predictions
|
||||
model_predictions = self._get_all_model_predictions(symbol, market_data)
|
||||
|
||||
# Create enhanced features
|
||||
cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
|
||||
rl_state = self._create_enhanced_rl_state(symbol, market_data, model_predictions)
|
||||
|
||||
# Collect training data with predictions
|
||||
episode_id = self.data_collector.collect_training_data(
|
||||
symbol=symbol,
|
||||
ohlcv_data=market_data['ohlcv'],
|
||||
tick_data=market_data['ticks'],
|
||||
cob_data=market_data['cob'],
|
||||
technical_indicators=market_data['indicators'],
|
||||
pivot_points=market_data['pivots'],
|
||||
cnn_features=cnn_features,
|
||||
rl_state=rl_state,
|
||||
orchestrator_context=market_data['context'],
|
||||
model_predictions=model_predictions
|
||||
)
|
||||
|
||||
if episode_id:
|
||||
# Store predictions for outcome validation
|
||||
self.recent_predictions[episode_id] = {
|
||||
'timestamp': datetime.now(),
|
||||
'symbol': symbol,
|
||||
'predictions': model_predictions,
|
||||
'market_data': market_data
|
||||
}
|
||||
|
||||
# Add RL experience if we have action
|
||||
if 'rl_action' in model_predictions:
|
||||
self._add_rl_experience(symbol, market_data, model_predictions, episode_id)
|
||||
|
||||
self.integration_stats['total_data_packages'] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting enhanced training data for {symbol}: {e}")
|
||||
|
||||
def _get_comprehensive_market_data(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get comprehensive market data from all sources"""
|
||||
try:
|
||||
market_data = {}
|
||||
|
||||
# OHLCV data
|
||||
ohlcv_data = {}
|
||||
for timeframe in ['1s', '1m', '5m', '15m', '1h', '1d']:
|
||||
df = self.data_provider.get_historical_data(symbol, timeframe, limit=300, refresh=True)
|
||||
if df is not None and not df.empty:
|
||||
ohlcv_data[timeframe] = df
|
||||
market_data['ohlcv'] = ohlcv_data
|
||||
|
||||
# Tick data
|
||||
market_data['ticks'] = self._get_recent_tick_data(symbol)
|
||||
|
||||
# COB data
|
||||
market_data['cob'] = self._get_cob_data(symbol)
|
||||
|
||||
# Technical indicators
|
||||
market_data['indicators'] = self._get_technical_indicators(symbol)
|
||||
|
||||
# Pivot points
|
||||
market_data['pivots'] = self._get_pivot_points(symbol)
|
||||
|
||||
# Market context
|
||||
market_data['context'] = self._get_market_context(symbol)
|
||||
|
||||
return market_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting comprehensive market data: {e}")
|
||||
return {}
|
||||
|
||||
def _get_all_model_predictions(self, symbol: str, market_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get predictions from all available models"""
|
||||
predictions = {}
|
||||
|
||||
try:
|
||||
# CNN predictions
|
||||
if self.cnn_model and market_data.get('ohlcv'):
|
||||
cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
|
||||
if cnn_features is not None:
|
||||
cnn_input = torch.from_numpy(cnn_features).float().unsqueeze(0)
|
||||
|
||||
# Reshape for CNN (add channel dimension)
|
||||
cnn_input = cnn_input.view(1, 10, -1) # Assuming 10 channels
|
||||
|
||||
with torch.no_grad():
|
||||
cnn_outputs = self.cnn_model(cnn_input)
|
||||
predictions['cnn'] = {
|
||||
'pivot_logits': cnn_outputs['pivot_logits'].cpu().numpy(),
|
||||
'pivot_price': cnn_outputs['pivot_price'].cpu().numpy(),
|
||||
'confidence': cnn_outputs['confidence'].cpu().numpy(),
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
# RL predictions
|
||||
if self.rl_agent and market_data.get('cob'):
|
||||
rl_state = self._create_enhanced_rl_state(symbol, market_data, predictions)
|
||||
if rl_state is not None:
|
||||
action, confidence = self.rl_agent.select_action(rl_state, epsilon=0.1)
|
||||
predictions['rl'] = {
|
||||
'action': action,
|
||||
'confidence': confidence,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
predictions['rl_action'] = action
|
||||
|
||||
# Existing COB RL model predictions
|
||||
if self.existing_rl_model and market_data.get('cob'):
|
||||
cob_features = market_data['cob'].get('cob_features', [])
|
||||
if cob_features and len(cob_features) >= 2000:
|
||||
cob_array = np.array(cob_features[:2000], dtype=np.float32)
|
||||
cob_prediction = self.existing_rl_model.predict(cob_array)
|
||||
predictions['cob_rl'] = {
|
||||
'predicted_direction': cob_prediction.get('predicted_direction', 1),
|
||||
'confidence': cob_prediction.get('confidence', 0.5),
|
||||
'value': cob_prediction.get('value', 0.0),
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
# Orchestrator predictions (if available)
|
||||
if self.orchestrator:
|
||||
try:
|
||||
# This would integrate with your orchestrator's prediction method
|
||||
orchestrator_prediction = self._get_orchestrator_prediction(symbol, market_data, predictions)
|
||||
if orchestrator_prediction:
|
||||
predictions['orchestrator'] = orchestrator_prediction
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get orchestrator prediction: {e}")
|
||||
|
||||
return predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model predictions: {e}")
|
||||
return {}
|
||||
|
||||
def _add_rl_experience(self, symbol: str, market_data: Dict[str, Any],
|
||||
predictions: Dict[str, Any], episode_id: str):
|
||||
"""Add RL experience to the training buffer"""
|
||||
try:
|
||||
# Create RL state
|
||||
state = self._create_enhanced_rl_state(symbol, market_data, predictions)
|
||||
if state is None:
|
||||
return
|
||||
|
||||
# Get action from predictions
|
||||
action = predictions.get('rl_action', 1) # Default to HOLD
|
||||
|
||||
# Calculate immediate reward (placeholder - would be updated with actual outcome)
|
||||
reward = 0.0
|
||||
|
||||
# Create next state (same as current for now - would be updated)
|
||||
next_state = state.copy()
|
||||
|
||||
# Market context
|
||||
market_context = {
|
||||
'symbol': symbol,
|
||||
'episode_id': episode_id,
|
||||
'timestamp': datetime.now(),
|
||||
'market_session': market_data['context'].get('market_session', 'unknown'),
|
||||
'volatility_regime': market_data['context'].get('volatility_regime', 'unknown')
|
||||
}
|
||||
|
||||
# Add experience
|
||||
experience_id = self.rl_trainer.add_experience(
|
||||
state=state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=False,
|
||||
market_context=market_context,
|
||||
cnn_predictions=predictions.get('cnn'),
|
||||
confidence_score=predictions.get('rl', {}).get('confidence', 0.0)
|
||||
)
|
||||
|
||||
if experience_id:
|
||||
logger.debug(f"Added RL experience: {experience_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding RL experience: {e}")
|
||||
|
||||
def _training_coordinator_worker(self):
|
||||
"""Coordinate training across all models"""
|
||||
logger.info("Training coordinator worker started")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
# Check if we should trigger training
|
||||
for symbol in self.data_provider.symbols:
|
||||
self._check_and_trigger_training(symbol)
|
||||
|
||||
# Wait before next check
|
||||
time.sleep(self.config.training_frequency_minutes * 60)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training coordinator: {e}")
|
||||
time.sleep(60)
|
||||
|
||||
logger.info("Training coordinator worker stopped")
|
||||
|
||||
def _check_and_trigger_training(self, symbol: str):
|
||||
"""Check conditions and trigger training if needed"""
|
||||
try:
|
||||
# Get training episodes and experiences
|
||||
episodes = self.data_collector.get_high_priority_episodes(symbol, limit=1000)
|
||||
|
||||
# Check CNN training conditions
|
||||
if len(episodes) >= self.config.min_episodes_for_cnn_training:
|
||||
profitable_episodes = [ep for ep in episodes if ep.actual_outcome.is_profitable]
|
||||
|
||||
if len(profitable_episodes) >= 20: # Minimum profitable episodes
|
||||
logger.info(f"Triggering CNN training for {symbol} with {len(profitable_episodes)} profitable episodes")
|
||||
|
||||
results = self.cnn_trainer.train_on_profitable_episodes(
|
||||
symbol=symbol,
|
||||
min_profitability=self.config.min_profitability_for_replay,
|
||||
max_episodes=len(profitable_episodes)
|
||||
)
|
||||
|
||||
if results.get('status') == 'success':
|
||||
self.integration_stats['cnn_training_sessions'] += 1
|
||||
logger.info(f"CNN training completed for {symbol}")
|
||||
|
||||
# Check RL training conditions
|
||||
buffer_stats = self.rl_trainer.experience_buffer.get_buffer_statistics()
|
||||
total_experiences = buffer_stats.get('total_experiences', 0)
|
||||
|
||||
if total_experiences >= self.config.min_experiences_for_rl_training:
|
||||
profitable_experiences = buffer_stats.get('profitable_experiences', 0)
|
||||
|
||||
if profitable_experiences >= 50: # Minimum profitable experiences
|
||||
logger.info(f"Triggering RL training with {profitable_experiences} profitable experiences")
|
||||
|
||||
results = self.rl_trainer.train_on_profitable_experiences(
|
||||
min_profitability=self.config.min_profitability_for_replay,
|
||||
max_experiences=min(profitable_experiences, 500),
|
||||
batch_size=32
|
||||
)
|
||||
|
||||
if results.get('status') == 'success':
|
||||
self.integration_stats['rl_training_sessions'] += 1
|
||||
logger.info("RL training completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking training conditions for {symbol}: {e}")
|
||||
|
||||
def _outcome_validation_worker(self):
|
||||
"""Background worker for validating prediction outcomes"""
|
||||
logger.info("Outcome validation worker started")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
self._validate_recent_predictions()
|
||||
time.sleep(300) # Check every 5 minutes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in outcome validation: {e}")
|
||||
time.sleep(60)
|
||||
|
||||
logger.info("Outcome validation worker stopped")
|
||||
|
||||
def _validate_recent_predictions(self):
|
||||
"""Validate recent predictions against actual outcomes"""
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
validation_delay = timedelta(hours=1) # Wait 1 hour to validate
|
||||
|
||||
validated_predictions = []
|
||||
|
||||
for episode_id, prediction_data in self.recent_predictions.items():
|
||||
prediction_time = prediction_data['timestamp']
|
||||
|
||||
if current_time - prediction_time >= validation_delay:
|
||||
# Validate this prediction
|
||||
outcome = self._calculate_prediction_outcome(prediction_data)
|
||||
|
||||
if outcome:
|
||||
self.prediction_outcomes[episode_id] = outcome
|
||||
|
||||
# Update RL experience if exists
|
||||
if 'rl_action' in prediction_data['predictions']:
|
||||
self._update_rl_experience_outcome(episode_id, outcome)
|
||||
|
||||
# Update statistics
|
||||
if outcome['is_profitable']:
|
||||
self.integration_stats['profitable_predictions'] += 1
|
||||
self.integration_stats['total_predictions'] += 1
|
||||
|
||||
validated_predictions.append(episode_id)
|
||||
|
||||
# Remove validated predictions
|
||||
for episode_id in validated_predictions:
|
||||
del self.recent_predictions[episode_id]
|
||||
|
||||
if validated_predictions:
|
||||
logger.info(f"Validated {len(validated_predictions)} predictions")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating predictions: {e}")
|
||||
|
||||
def _calculate_prediction_outcome(self, prediction_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Calculate actual outcome for a prediction"""
|
||||
try:
|
||||
symbol = prediction_data['symbol']
|
||||
prediction_time = prediction_data['timestamp']
|
||||
|
||||
# Get price data after prediction
|
||||
current_df = self.data_provider.get_historical_data(symbol, '1m', limit=100, refresh=True)
|
||||
|
||||
if current_df is None or current_df.empty:
|
||||
return None
|
||||
|
||||
# Find price at prediction time and current price
|
||||
prediction_price = prediction_data['market_data']['ohlcv'].get('1m', pd.DataFrame())
|
||||
if prediction_price.empty:
|
||||
return None
|
||||
|
||||
base_price = float(prediction_price['close'].iloc[-1])
|
||||
current_price = float(current_df['close'].iloc[-1])
|
||||
|
||||
# Calculate outcome
|
||||
price_change = (current_price - base_price) / base_price
|
||||
is_profitable = abs(price_change) > 0.005 # 0.5% threshold
|
||||
|
||||
return {
|
||||
'episode_id': prediction_data.get('episode_id'),
|
||||
'base_price': base_price,
|
||||
'current_price': current_price,
|
||||
'price_change': price_change,
|
||||
'is_profitable': is_profitable,
|
||||
'profitability_score': abs(price_change) * 10, # Scale to 0-1 range
|
||||
'validation_time': datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating prediction outcome: {e}")
|
||||
return None
|
||||
|
||||
def _update_rl_experience_outcome(self, episode_id: str, outcome: Dict[str, Any]):
|
||||
"""Update RL experience with actual outcome"""
|
||||
try:
|
||||
# Find the experience ID associated with this episode
|
||||
# This is a simplified approach - in practice you'd maintain better mapping
|
||||
actual_profit = outcome['price_change']
|
||||
|
||||
# Determine optimal action based on outcome
|
||||
if outcome['price_change'] > 0.01:
|
||||
optimal_action = 2 # BUY
|
||||
elif outcome['price_change'] < -0.01:
|
||||
optimal_action = 0 # SELL
|
||||
else:
|
||||
optimal_action = 1 # HOLD
|
||||
|
||||
# Update experience (this would need proper experience ID mapping)
|
||||
# For now, we'll update the most recent experience
|
||||
# In practice, you'd maintain a mapping between episodes and experiences
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating RL experience outcome: {e}")
|
||||
|
||||
def get_integration_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive integration statistics"""
|
||||
stats = self.integration_stats.copy()
|
||||
|
||||
# Add component statistics
|
||||
stats['data_collector'] = self.data_collector.get_collection_statistics()
|
||||
stats['cnn_trainer'] = self.cnn_trainer.get_training_statistics()
|
||||
stats['rl_trainer'] = self.rl_trainer.get_training_statistics()
|
||||
|
||||
# Add performance metrics
|
||||
stats['is_running'] = self.is_running
|
||||
stats['active_symbols'] = len(self.data_provider.symbols)
|
||||
stats['recent_predictions_count'] = len(self.recent_predictions)
|
||||
stats['validated_outcomes_count'] = len(self.prediction_outcomes)
|
||||
|
||||
# Calculate profitability rate
|
||||
if stats['total_predictions'] > 0:
|
||||
stats['overall_profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions']
|
||||
else:
|
||||
stats['overall_profitability_rate'] = 0.0
|
||||
|
||||
return stats
|
||||
|
||||
def trigger_manual_training(self, training_type: str = 'all', symbol: str = None) -> Dict[str, Any]:
|
||||
"""Manually trigger training"""
|
||||
results = {}
|
||||
|
||||
try:
|
||||
if training_type in ['all', 'cnn']:
|
||||
symbols = [symbol] if symbol else self.data_provider.symbols
|
||||
for sym in symbols:
|
||||
cnn_results = self.cnn_trainer.train_on_profitable_episodes(
|
||||
symbol=sym,
|
||||
min_profitability=0.1,
|
||||
max_episodes=200
|
||||
)
|
||||
results[f'cnn_{sym}'] = cnn_results
|
||||
|
||||
if training_type in ['all', 'rl']:
|
||||
rl_results = self.rl_trainer.train_on_profitable_experiences(
|
||||
min_profitability=0.1,
|
||||
max_experiences=500,
|
||||
batch_size=32
|
||||
)
|
||||
results['rl'] = rl_results
|
||||
|
||||
return {'status': 'success', 'results': results}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in manual training trigger: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
# Helper methods (simplified implementations)
|
||||
def _get_recent_tick_data(self, symbol: str) -> List[Dict[str, Any]]:
|
||||
"""Get recent tick data"""
|
||||
# Implementation would get tick data from data provider
|
||||
return []
|
||||
|
||||
def _get_cob_data(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get COB data"""
|
||||
# Implementation would get COB data from data provider
|
||||
return {}
|
||||
|
||||
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
|
||||
"""Get technical indicators"""
|
||||
# Implementation would get indicators from data provider
|
||||
return {}
|
||||
|
||||
def _get_pivot_points(self, symbol: str) -> List[Dict[str, Any]]:
|
||||
"""Get pivot points"""
|
||||
# Implementation would get pivot points from data provider
|
||||
return []
|
||||
|
||||
def _get_market_context(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get market context"""
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'timestamp': datetime.now(),
|
||||
'market_session': 'unknown',
|
||||
'volatility_regime': 'unknown'
|
||||
}
|
||||
|
||||
def _validate_market_data(self, market_data: Dict[str, Any]) -> bool:
|
||||
"""Validate market data completeness"""
|
||||
required_fields = ['ohlcv', 'indicators']
|
||||
return all(field in market_data for field in required_fields)
|
||||
|
||||
def _create_enhanced_cnn_features(self, symbol: str, market_data: Dict[str, Any]) -> Optional[np.ndarray]:
|
||||
"""Create enhanced CNN features"""
|
||||
try:
|
||||
# Simplified feature creation
|
||||
features = []
|
||||
|
||||
# Add OHLCV features
|
||||
for timeframe in ['1m', '5m', '15m', '1h']:
|
||||
if timeframe in market_data.get('ohlcv', {}):
|
||||
df = market_data['ohlcv'][timeframe]
|
||||
if not df.empty:
|
||||
ohlcv_values = df[['open', 'high', 'low', 'close', 'volume']].values
|
||||
if len(ohlcv_values) > 0:
|
||||
recent_values = ohlcv_values[-60:].flatten()
|
||||
features.extend(recent_values)
|
||||
|
||||
# Pad to target size
|
||||
target_size = 3000 # 10 channels * 300 sequence length
|
||||
if len(features) < target_size:
|
||||
features.extend([0.0] * (target_size - len(features)))
|
||||
else:
|
||||
features = features[:target_size]
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error creating CNN features: {e}")
|
||||
return None
|
||||
|
||||
def _create_enhanced_rl_state(self, symbol: str, market_data: Dict[str, Any],
|
||||
predictions: Dict[str, Any] = None) -> Optional[np.ndarray]:
|
||||
"""Create enhanced RL state"""
|
||||
try:
|
||||
state_features = []
|
||||
|
||||
# Add market features
|
||||
if '1m' in market_data.get('ohlcv', {}):
|
||||
df = market_data['ohlcv']['1m']
|
||||
if not df.empty:
|
||||
latest = df.iloc[-1]
|
||||
state_features.extend([
|
||||
latest['open'], latest['high'],
|
||||
latest['low'], latest['close'], latest['volume']
|
||||
])
|
||||
|
||||
# Add technical indicators
|
||||
indicators = market_data.get('indicators', {})
|
||||
for value in indicators.values():
|
||||
state_features.append(value)
|
||||
|
||||
# Add model predictions as features
|
||||
if predictions:
|
||||
if 'cnn' in predictions:
|
||||
cnn_pred = predictions['cnn']
|
||||
state_features.extend(cnn_pred.get('pivot_logits', [0, 0, 0]))
|
||||
state_features.append(cnn_pred.get('confidence', [0.0])[0])
|
||||
|
||||
if 'cob_rl' in predictions:
|
||||
cob_pred = predictions['cob_rl']
|
||||
state_features.append(cob_pred.get('predicted_direction', 1))
|
||||
state_features.append(cob_pred.get('confidence', 0.5))
|
||||
|
||||
# Pad to target size
|
||||
target_size = 2000
|
||||
if len(state_features) < target_size:
|
||||
state_features.extend([0.0] * (target_size - len(state_features)))
|
||||
else:
|
||||
state_features = state_features[:target_size]
|
||||
|
||||
return np.array(state_features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error creating RL state: {e}")
|
||||
return None
|
||||
|
||||
def _get_orchestrator_prediction(self, symbol: str, market_data: Dict[str, Any],
|
||||
predictions: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Get orchestrator prediction"""
|
||||
# This would integrate with your orchestrator
|
||||
return None
|
||||
|
||||
# Global instance
|
||||
enhanced_training_integration = None
|
||||
|
||||
def get_enhanced_training_integration(data_provider: DataProvider = None,
|
||||
orchestrator: Orchestrator = None,
|
||||
trading_executor: TradingExecutor = None) -> EnhancedTrainingIntegration:
|
||||
"""Get global enhanced training integration instance"""
|
||||
global enhanced_training_integration
|
||||
if enhanced_training_integration is None:
|
||||
if data_provider is None:
|
||||
raise ValueError("DataProvider required for first initialization")
|
||||
enhanced_training_integration = EnhancedTrainingIntegration(
|
||||
data_provider, orchestrator, trading_executor
|
||||
)
|
||||
return enhanced_training_integration
|
||||
@@ -1,8 +0,0 @@
|
||||
# MEXC Web Client Module
|
||||
#
|
||||
# This module provides web-based trading capabilities for MEXC futures trading
|
||||
# which is not supported by their official API.
|
||||
|
||||
from .mexc_futures_client import MEXCFuturesWebClient
|
||||
|
||||
__all__ = ['MEXCFuturesWebClient']
|
||||
@@ -1,555 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MEXC Auto Browser with Request Interception
|
||||
|
||||
This script automatically spawns a ChromeDriver instance and captures
|
||||
all MEXC futures trading requests in real-time, including full request
|
||||
and response data needed for reverse engineering.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
import threading
|
||||
import queue
|
||||
|
||||
# Selenium imports
|
||||
try:
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
from selenium.webdriver.chrome.service import Service
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.common.exceptions import TimeoutException, WebDriverException
|
||||
from webdriver_manager.chrome import ChromeDriverManager
|
||||
except ImportError:
|
||||
print("Please install selenium and webdriver-manager:")
|
||||
print("pip install selenium webdriver-manager")
|
||||
sys.exit(1)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MEXCRequestInterceptor:
|
||||
"""
|
||||
Automatically spawns ChromeDriver and intercepts all MEXC API requests
|
||||
"""
|
||||
|
||||
def __init__(self, headless: bool = False, save_to_file: bool = True):
|
||||
"""
|
||||
Initialize the request interceptor
|
||||
|
||||
Args:
|
||||
headless: Run browser in headless mode
|
||||
save_to_file: Save captured requests to JSON file
|
||||
"""
|
||||
self.driver = None
|
||||
self.headless = headless
|
||||
self.save_to_file = save_to_file
|
||||
self.captured_requests = []
|
||||
self.captured_responses = []
|
||||
self.session_cookies = {}
|
||||
self.monitoring = False
|
||||
self.request_queue = queue.Queue()
|
||||
|
||||
# File paths for saving data
|
||||
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self.requests_file = f"mexc_requests_{self.timestamp}.json"
|
||||
self.cookies_file = f"mexc_cookies_{self.timestamp}.json"
|
||||
|
||||
def setup_browser(self):
|
||||
"""Setup Chrome browser with necessary options"""
|
||||
chrome_options = webdriver.ChromeOptions()
|
||||
# Enable headless mode if needed
|
||||
if self.headless:
|
||||
chrome_options.add_argument('--headless')
|
||||
chrome_options.add_argument('--disable-gpu')
|
||||
chrome_options.add_argument('--window-size=1920,1080')
|
||||
chrome_options.add_argument('--disable-extensions')
|
||||
|
||||
# Set up Chrome options with a user data directory to persist session
|
||||
user_data_base_dir = os.path.join(os.getcwd(), 'chrome_user_data')
|
||||
os.makedirs(user_data_base_dir, exist_ok=True)
|
||||
|
||||
# Check for existing session directories
|
||||
session_dirs = [d for d in os.listdir(user_data_base_dir) if d.startswith('session_')]
|
||||
session_dirs.sort(reverse=True) # Sort descending to get the most recent first
|
||||
|
||||
user_data_dir = None
|
||||
if session_dirs:
|
||||
use_existing = input(f"Found {len(session_dirs)} existing sessions. Use an existing session? (y/n): ").lower().strip() == 'y'
|
||||
if use_existing:
|
||||
print("Available sessions:")
|
||||
for i, session in enumerate(session_dirs[:5], 1): # Show up to 5 most recent
|
||||
print(f"{i}. {session}")
|
||||
choice = input("Enter session number (default 1) or any other key for most recent: ")
|
||||
if choice.isdigit() and 1 <= int(choice) <= len(session_dirs):
|
||||
selected_session = session_dirs[int(choice) - 1]
|
||||
else:
|
||||
selected_session = session_dirs[0]
|
||||
user_data_dir = os.path.join(user_data_base_dir, selected_session)
|
||||
print(f"Using session: {selected_session}")
|
||||
|
||||
if user_data_dir is None:
|
||||
user_data_dir = os.path.join(user_data_base_dir, f'session_{self.timestamp}')
|
||||
os.makedirs(user_data_dir, exist_ok=True)
|
||||
print(f"Creating new session: session_{self.timestamp}")
|
||||
|
||||
chrome_options.add_argument(f'--user-data-dir={user_data_dir}')
|
||||
|
||||
# Enable logging to capture JS console output and network activity
|
||||
chrome_options.set_capability('goog:loggingPrefs', {
|
||||
'browser': 'ALL',
|
||||
'performance': 'ALL'
|
||||
})
|
||||
|
||||
try:
|
||||
self.driver = webdriver.Chrome(options=chrome_options)
|
||||
except Exception as e:
|
||||
print(f"Failed to start browser with session: {e}")
|
||||
print("Falling back to a new session...")
|
||||
user_data_dir = os.path.join(user_data_base_dir, f'session_{self.timestamp}_fallback')
|
||||
os.makedirs(user_data_dir, exist_ok=True)
|
||||
print(f"Creating fallback session: session_{self.timestamp}_fallback")
|
||||
chrome_options = webdriver.ChromeOptions()
|
||||
if self.headless:
|
||||
chrome_options.add_argument('--headless')
|
||||
chrome_options.add_argument('--disable-gpu')
|
||||
chrome_options.add_argument('--window-size=1920,1080')
|
||||
chrome_options.add_argument('--disable-extensions')
|
||||
chrome_options.add_argument(f'--user-data-dir={user_data_dir}')
|
||||
chrome_options.set_capability('goog:loggingPrefs', {
|
||||
'browser': 'ALL',
|
||||
'performance': 'ALL'
|
||||
})
|
||||
self.driver = webdriver.Chrome(options=chrome_options)
|
||||
|
||||
return self.driver
|
||||
|
||||
def start_monitoring(self):
|
||||
"""Start the browser and begin monitoring"""
|
||||
logger.info("Starting MEXC Request Interceptor...")
|
||||
|
||||
try:
|
||||
# Setup ChromeDriver
|
||||
self.driver = self.setup_browser()
|
||||
|
||||
# Navigate to MEXC futures
|
||||
mexc_url = "https://www.mexc.com/en-GB/futures/ETH_USDT?type=linear_swap"
|
||||
logger.info(f"Navigating to: {mexc_url}")
|
||||
self.driver.get(mexc_url)
|
||||
|
||||
# Wait for page load
|
||||
WebDriverWait(self.driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
logger.info("✅ MEXC page loaded successfully!")
|
||||
logger.info("📝 Please log in manually in the browser window")
|
||||
logger.info("🔍 Request monitoring is now active...")
|
||||
|
||||
# Start monitoring in background thread
|
||||
self.monitoring = True
|
||||
monitor_thread = threading.Thread(target=self._monitor_requests, daemon=True)
|
||||
monitor_thread.start()
|
||||
|
||||
# Wait for manual login
|
||||
self._wait_for_login()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start monitoring: {e}")
|
||||
return False
|
||||
|
||||
def _wait_for_login(self):
|
||||
"""Wait for user to log in and show interactive menu"""
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("MEXC REQUEST INTERCEPTOR - INTERACTIVE MODE")
|
||||
logger.info("="*60)
|
||||
|
||||
while True:
|
||||
print("\nOptions:")
|
||||
print("1. Check login status")
|
||||
print("2. Extract current cookies")
|
||||
print("3. Show captured requests summary")
|
||||
print("4. Save captured data to files")
|
||||
print("5. Perform test trade (manual)")
|
||||
print("6. Monitor for 60 seconds")
|
||||
print("0. Stop and exit")
|
||||
|
||||
choice = input("\nEnter choice (0-6): ").strip()
|
||||
|
||||
if choice == "1":
|
||||
self._check_login_status()
|
||||
elif choice == "2":
|
||||
self._extract_cookies()
|
||||
elif choice == "3":
|
||||
self._show_requests_summary()
|
||||
elif choice == "4":
|
||||
self._save_all_data()
|
||||
elif choice == "5":
|
||||
self._guide_test_trade()
|
||||
elif choice == "6":
|
||||
self._monitor_for_duration(60)
|
||||
elif choice == "0":
|
||||
break
|
||||
else:
|
||||
print("Invalid choice. Please try again.")
|
||||
|
||||
self.stop_monitoring()
|
||||
|
||||
def _check_login_status(self):
|
||||
"""Check if user is logged into MEXC"""
|
||||
try:
|
||||
cookies = self.driver.get_cookies()
|
||||
auth_cookies = ['uc_token', 'u_id', 'x-mxc-fingerprint']
|
||||
found_auth = []
|
||||
|
||||
for cookie in cookies:
|
||||
if cookie['name'] in auth_cookies and cookie['value']:
|
||||
found_auth.append(cookie['name'])
|
||||
|
||||
if len(found_auth) >= 2:
|
||||
print("✅ LOGIN DETECTED - You appear to be logged in!")
|
||||
print(f" Found auth cookies: {', '.join(found_auth)}")
|
||||
return True
|
||||
else:
|
||||
print("❌ NOT LOGGED IN - Please log in to MEXC in the browser")
|
||||
print(" Missing required authentication cookies")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error checking login: {e}")
|
||||
return False
|
||||
|
||||
def _extract_cookies(self):
|
||||
"""Extract and display current session cookies"""
|
||||
try:
|
||||
cookies = self.driver.get_cookies()
|
||||
cookie_dict = {}
|
||||
|
||||
for cookie in cookies:
|
||||
cookie_dict[cookie['name']] = cookie['value']
|
||||
|
||||
self.session_cookies = cookie_dict
|
||||
|
||||
print(f"\n📊 Extracted {len(cookie_dict)} cookies:")
|
||||
|
||||
# Show important cookies
|
||||
important = ['uc_token', 'u_id', 'x-mxc-fingerprint', 'mexc_fingerprint_visitorId']
|
||||
for name in important:
|
||||
if name in cookie_dict:
|
||||
value = cookie_dict[name]
|
||||
display_value = value[:20] + "..." if len(value) > 20 else value
|
||||
print(f" ✅ {name}: {display_value}")
|
||||
else:
|
||||
print(f" ❌ {name}: Missing")
|
||||
|
||||
# Save cookies to file
|
||||
if self.save_to_file:
|
||||
with open(self.cookies_file, 'w') as f:
|
||||
json.dump(cookie_dict, f, indent=2)
|
||||
print(f"\n💾 Cookies saved to: {self.cookies_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error extracting cookies: {e}")
|
||||
|
||||
def _monitor_requests(self):
|
||||
"""Background thread to monitor network requests"""
|
||||
last_log_count = 0
|
||||
|
||||
while self.monitoring:
|
||||
try:
|
||||
# Get performance logs
|
||||
logs = self.driver.get_log('performance')
|
||||
|
||||
for log in logs:
|
||||
try:
|
||||
message = json.loads(log['message'])
|
||||
method = message.get('message', {}).get('method', '')
|
||||
|
||||
# Capture network requests
|
||||
if method == 'Network.requestWillBeSent':
|
||||
self._process_request(message['message']['params'])
|
||||
elif method == 'Network.responseReceived':
|
||||
self._process_response(message['message']['params'])
|
||||
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
continue
|
||||
|
||||
# Show progress every 10 new requests
|
||||
if len(self.captured_requests) >= last_log_count + 10:
|
||||
last_log_count = len(self.captured_requests)
|
||||
logger.info(f"📈 Captured {len(self.captured_requests)} requests, {len(self.captured_responses)} responses")
|
||||
|
||||
except Exception as e:
|
||||
if self.monitoring: # Only log if we're still supposed to be monitoring
|
||||
logger.debug(f"Monitor error: {e}")
|
||||
|
||||
time.sleep(0.5) # Check every 500ms
|
||||
|
||||
def _process_request(self, request_data):
|
||||
"""Process a captured network request"""
|
||||
try:
|
||||
url = request_data.get('request', {}).get('url', '')
|
||||
|
||||
# Filter for MEXC API requests
|
||||
if self._is_mexc_request(url):
|
||||
request_info = {
|
||||
'type': 'request',
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'url': url,
|
||||
'method': request_data.get('request', {}).get('method', ''),
|
||||
'headers': request_data.get('request', {}).get('headers', {}),
|
||||
'postData': request_data.get('request', {}).get('postData', ''),
|
||||
'requestId': request_data.get('requestId', '')
|
||||
}
|
||||
|
||||
self.captured_requests.append(request_info)
|
||||
|
||||
# Show important requests immediately
|
||||
if ('futures.mexc.com' in url or 'captcha' in url):
|
||||
print(f"\n🚀 CAPTURED REQUEST: {request_info['method']} {url}")
|
||||
if request_info['postData']:
|
||||
print(f" 📄 POST Data: {request_info['postData'][:100]}...")
|
||||
|
||||
# Enhanced captcha detection and detailed logging
|
||||
if 'captcha' in url.lower() or 'robot' in url.lower():
|
||||
logger.info(f"CAPTCHA REQUEST DETECTED: {request_data.get('request', {}).get('method', 'UNKNOWN')} {url}")
|
||||
logger.info(f" Headers: {request_data.get('request', {}).get('headers', {})}")
|
||||
if request_data.get('request', {}).get('postData', ''):
|
||||
logger.info(f" Data: {request_data.get('request', {}).get('postData', '')}")
|
||||
# Attempt to capture related JavaScript or DOM elements (if possible)
|
||||
if self.driver is not None:
|
||||
try:
|
||||
js_snippet = self.driver.execute_script("return document.querySelector('script[src*=\"captcha\"]') ? document.querySelector('script[src*=\"captcha\"]').outerHTML : 'No captcha script found';")
|
||||
logger.info(f" Related JS Snippet: {js_snippet}")
|
||||
except Exception as e:
|
||||
logger.warning(f" Could not capture JS snippet: {e}")
|
||||
try:
|
||||
dom_element = self.driver.execute_script("return document.querySelector('div[id*=\"captcha\"]') ? document.querySelector('div[id*=\"captcha\"]').outerHTML : 'No captcha element found';")
|
||||
logger.info(f" Related DOM Element: {dom_element}")
|
||||
except Exception as e:
|
||||
logger.warning(f" Could not capture DOM element: {e}")
|
||||
else:
|
||||
logger.warning(" Driver not initialized, cannot capture JS or DOM elements")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error processing request: {e}")
|
||||
|
||||
def _process_response(self, response_data):
|
||||
"""Process a captured network response"""
|
||||
try:
|
||||
url = response_data.get('response', {}).get('url', '')
|
||||
|
||||
# Filter for MEXC API responses
|
||||
if self._is_mexc_request(url):
|
||||
response_info = {
|
||||
'type': 'response',
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'url': url,
|
||||
'status': response_data.get('response', {}).get('status', 0),
|
||||
'headers': response_data.get('response', {}).get('headers', {}),
|
||||
'requestId': response_data.get('requestId', '')
|
||||
}
|
||||
|
||||
self.captured_responses.append(response_info)
|
||||
|
||||
# Show important responses immediately
|
||||
if ('futures.mexc.com' in url or 'captcha' in url):
|
||||
status = response_info['status']
|
||||
status_emoji = "✅" if status == 200 else "❌"
|
||||
print(f" {status_emoji} RESPONSE: {status} for {url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error processing response: {e}")
|
||||
|
||||
def _is_mexc_request(self, url: str) -> bool:
|
||||
"""Check if URL is a relevant MEXC API request"""
|
||||
mexc_indicators = [
|
||||
'futures.mexc.com',
|
||||
'ucgateway/captcha_api',
|
||||
'api/v1/private',
|
||||
'api/v3/order',
|
||||
'mexc.com/api'
|
||||
]
|
||||
|
||||
return any(indicator in url for indicator in mexc_indicators)
|
||||
|
||||
def _show_requests_summary(self):
|
||||
"""Show summary of captured requests"""
|
||||
print(f"\n📊 CAPTURE SUMMARY:")
|
||||
print(f" Total Requests: {len(self.captured_requests)}")
|
||||
print(f" Total Responses: {len(self.captured_responses)}")
|
||||
|
||||
# Group by URL pattern
|
||||
url_counts = {}
|
||||
for req in self.captured_requests:
|
||||
base_url = req['url'].split('?')[0] # Remove query params
|
||||
url_counts[base_url] = url_counts.get(base_url, 0) + 1
|
||||
|
||||
print("\n🔗 Top URLs:")
|
||||
for url, count in sorted(url_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
|
||||
print(f" {count}x {url}")
|
||||
|
||||
# Show recent futures API calls
|
||||
futures_requests = [r for r in self.captured_requests if 'futures.mexc.com' in r['url']]
|
||||
if futures_requests:
|
||||
print(f"\n🚀 Futures API Calls: {len(futures_requests)}")
|
||||
for req in futures_requests[-3:]: # Show last 3
|
||||
print(f" {req['method']} {req['url']}")
|
||||
|
||||
def _save_all_data(self):
|
||||
"""Save all captured data to files"""
|
||||
if not self.save_to_file:
|
||||
print("File saving is disabled")
|
||||
return
|
||||
|
||||
try:
|
||||
# Save requests
|
||||
with open(self.requests_file, 'w') as f:
|
||||
json.dump({
|
||||
'requests': self.captured_requests,
|
||||
'responses': self.captured_responses,
|
||||
'summary': {
|
||||
'total_requests': len(self.captured_requests),
|
||||
'total_responses': len(self.captured_responses),
|
||||
'capture_session': self.timestamp
|
||||
}
|
||||
}, f, indent=2)
|
||||
|
||||
# Save cookies if we have them
|
||||
if self.session_cookies:
|
||||
with open(self.cookies_file, 'w') as f:
|
||||
json.dump(self.session_cookies, f, indent=2)
|
||||
|
||||
print(f"\n💾 Data saved to:")
|
||||
print(f" 📋 Requests: {self.requests_file}")
|
||||
if self.session_cookies:
|
||||
print(f" 🍪 Cookies: {self.cookies_file}")
|
||||
|
||||
# Extract and save CAPTCHA tokens from captured requests
|
||||
captcha_tokens = self.extract_captcha_tokens()
|
||||
if captcha_tokens:
|
||||
captcha_file = f"mexc_captcha_tokens_{self.timestamp}.json"
|
||||
with open(captcha_file, 'w') as f:
|
||||
json.dump(captcha_tokens, f, indent=2)
|
||||
logger.info(f"Saved CAPTCHA tokens to {captcha_file}")
|
||||
else:
|
||||
logger.warning("No CAPTCHA tokens found in captured requests")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error saving data: {e}")
|
||||
|
||||
def _guide_test_trade(self):
|
||||
"""Guide user through performing a test trade"""
|
||||
print("\n🧪 TEST TRADE GUIDE:")
|
||||
print("1. Make sure you're logged into MEXC")
|
||||
print("2. Go to the trading interface")
|
||||
print("3. Try to place a SMALL test trade (it may fail, but we'll capture the requests)")
|
||||
print("4. Watch the console for captured API calls")
|
||||
print("\n⚠️ IMPORTANT: Use very small amounts for testing!")
|
||||
input("\nPress Enter when you're ready to start monitoring...")
|
||||
|
||||
self._monitor_for_duration(120) # Monitor for 2 minutes
|
||||
|
||||
def _monitor_for_duration(self, seconds: int):
|
||||
"""Monitor requests for a specific duration"""
|
||||
print(f"\n🔍 Monitoring requests for {seconds} seconds...")
|
||||
print("Perform your trading actions now!")
|
||||
|
||||
start_time = time.time()
|
||||
initial_count = len(self.captured_requests)
|
||||
|
||||
while time.time() - start_time < seconds:
|
||||
current_count = len(self.captured_requests)
|
||||
new_requests = current_count - initial_count
|
||||
|
||||
remaining = seconds - int(time.time() - start_time)
|
||||
print(f"\r⏱️ Time remaining: {remaining}s | New requests: {new_requests}", end="", flush=True)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
final_count = len(self.captured_requests)
|
||||
new_total = final_count - initial_count
|
||||
print(f"\n✅ Monitoring complete! Captured {new_total} new requests")
|
||||
|
||||
def stop_monitoring(self):
|
||||
"""Stop monitoring and close browser"""
|
||||
logger.info("Stopping request monitoring...")
|
||||
self.monitoring = False
|
||||
|
||||
if self.driver:
|
||||
self.driver.quit()
|
||||
logger.info("Browser closed")
|
||||
|
||||
# Final save
|
||||
if self.save_to_file and (self.captured_requests or self.captured_responses):
|
||||
self._save_all_data()
|
||||
logger.info("Final data save complete")
|
||||
|
||||
def extract_captcha_tokens(self):
|
||||
"""Extract CAPTCHA tokens from captured requests"""
|
||||
captcha_tokens = []
|
||||
for request in self.captured_requests:
|
||||
if 'captcha-token' in request.get('headers', {}):
|
||||
token = request['headers']['captcha-token']
|
||||
captcha_tokens.append({
|
||||
'token': token,
|
||||
'url': request.get('url', ''),
|
||||
'timestamp': request.get('timestamp', '')
|
||||
})
|
||||
elif 'captcha' in request.get('url', '').lower():
|
||||
response = request.get('response', {})
|
||||
if response and 'captcha-token' in response.get('headers', {}):
|
||||
token = response['headers']['captcha-token']
|
||||
captcha_tokens.append({
|
||||
'token': token,
|
||||
'url': request.get('url', ''),
|
||||
'timestamp': request.get('timestamp', '')
|
||||
})
|
||||
return captcha_tokens
|
||||
|
||||
def main():
|
||||
"""Main function to run the interceptor"""
|
||||
print("🚀 MEXC Request Interceptor with ChromeDriver")
|
||||
print("=" * 50)
|
||||
print("This will automatically:")
|
||||
print("✅ Download/setup ChromeDriver")
|
||||
print("✅ Open MEXC futures page")
|
||||
print("✅ Capture all API requests/responses")
|
||||
print("✅ Extract session cookies")
|
||||
print("✅ Save data to JSON files")
|
||||
print("\nPress Ctrl+C to stop at any time")
|
||||
|
||||
# Ask for preferences
|
||||
headless = input("\nRun in headless mode? (y/n): ").lower().strip() == 'y'
|
||||
|
||||
interceptor = MEXCRequestInterceptor(headless=headless, save_to_file=True)
|
||||
|
||||
try:
|
||||
success = interceptor.start_monitoring()
|
||||
if not success:
|
||||
print("❌ Failed to start monitoring")
|
||||
return
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⏹️ Stopping interceptor...")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
finally:
|
||||
interceptor.stop_monitoring()
|
||||
print("\n👋 Goodbye!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,358 +0,0 @@
|
||||
"""
|
||||
MEXC Browser Automation for Cookie Extraction and Request Monitoring
|
||||
|
||||
This module uses Selenium to automate browser interactions and extract
|
||||
session cookies and request data for MEXC futures trading.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
from typing import Dict, List, Optional, Any
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.common.exceptions import TimeoutException, WebDriverException
|
||||
from selenium.webdriver.chrome.service import Service
|
||||
from webdriver_manager.chrome import ChromeDriverManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MEXCBrowserAutomation:
|
||||
"""
|
||||
Browser automation for MEXC futures trading session management
|
||||
"""
|
||||
|
||||
def __init__(self, headless: bool = False, proxy: Optional[str] = None):
|
||||
"""
|
||||
Initialize browser automation
|
||||
|
||||
Args:
|
||||
headless: Run browser in headless mode
|
||||
proxy: HTTP proxy to use (format: host:port)
|
||||
"""
|
||||
self.driver = None
|
||||
self.headless = headless
|
||||
self.proxy = proxy
|
||||
self.logged_in = False
|
||||
|
||||
def setup_chrome_driver(self) -> webdriver.Chrome:
|
||||
"""Setup Chrome driver with appropriate options"""
|
||||
chrome_options = Options()
|
||||
|
||||
if self.headless:
|
||||
chrome_options.add_argument("--headless")
|
||||
|
||||
# Basic Chrome options for automation
|
||||
chrome_options.add_argument("--no-sandbox")
|
||||
chrome_options.add_argument("--disable-dev-shm-usage")
|
||||
chrome_options.add_argument("--disable-blink-features=AutomationControlled")
|
||||
chrome_options.add_experimental_option("excludeSwitches", ["enable-automation"])
|
||||
chrome_options.add_experimental_option('useAutomationExtension', False)
|
||||
|
||||
# Set user agent to avoid detection
|
||||
chrome_options.add_argument("--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36")
|
||||
|
||||
# Proxy setup if provided
|
||||
if self.proxy:
|
||||
chrome_options.add_argument(f"--proxy-server=http://{self.proxy}")
|
||||
|
||||
# Enable network logging
|
||||
chrome_options.add_argument("--enable-logging")
|
||||
chrome_options.add_argument("--log-level=0")
|
||||
chrome_options.set_capability("goog:loggingPrefs", {"performance": "ALL"})
|
||||
|
||||
# Automatically download and setup ChromeDriver
|
||||
service = Service(ChromeDriverManager().install())
|
||||
|
||||
try:
|
||||
driver = webdriver.Chrome(service=service, options=chrome_options)
|
||||
|
||||
# Execute script to avoid detection
|
||||
driver.execute_script("Object.defineProperty(navigator, 'webdriver', {get: () => undefined})")
|
||||
|
||||
return driver
|
||||
except WebDriverException as e:
|
||||
logger.error(f"Failed to setup Chrome driver: {e}")
|
||||
raise
|
||||
|
||||
def start_browser(self):
|
||||
"""Start the browser session"""
|
||||
if self.driver is None:
|
||||
logger.info("Starting Chrome browser for MEXC automation")
|
||||
self.driver = self.setup_chrome_driver()
|
||||
logger.info("Browser started successfully")
|
||||
|
||||
def stop_browser(self):
|
||||
"""Stop the browser session"""
|
||||
if self.driver:
|
||||
logger.info("Stopping browser")
|
||||
self.driver.quit()
|
||||
self.driver = None
|
||||
|
||||
def navigate_to_mexc_futures(self, symbol: str = "ETH_USDT"):
|
||||
"""
|
||||
Navigate to MEXC futures trading page
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol to navigate to
|
||||
"""
|
||||
if not self.driver:
|
||||
self.start_browser()
|
||||
|
||||
url = f"https://www.mexc.com/en-GB/futures/{symbol}?type=linear_swap"
|
||||
logger.info(f"Navigating to MEXC futures: {url}")
|
||||
|
||||
self.driver.get(url)
|
||||
|
||||
# Wait for page to load
|
||||
try:
|
||||
WebDriverWait(self.driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
logger.info("MEXC futures page loaded")
|
||||
except TimeoutException:
|
||||
logger.error("Timeout waiting for MEXC page to load")
|
||||
|
||||
def wait_for_login(self, timeout: int = 300) -> bool:
|
||||
"""
|
||||
Wait for user to manually log in to MEXC
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for login (seconds)
|
||||
|
||||
Returns:
|
||||
bool: True if login detected, False if timeout
|
||||
"""
|
||||
logger.info("Please log in to MEXC manually in the browser window")
|
||||
logger.info("Waiting for login completion...")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
# Check if we can find elements that indicate logged in state
|
||||
try:
|
||||
# Look for user-specific elements that appear after login
|
||||
cookies = self.driver.get_cookies()
|
||||
|
||||
# Check for authentication cookies
|
||||
auth_cookies = ['uc_token', 'u_id']
|
||||
logged_in_indicators = 0
|
||||
|
||||
for cookie in cookies:
|
||||
if cookie['name'] in auth_cookies and cookie['value']:
|
||||
logged_in_indicators += 1
|
||||
|
||||
if logged_in_indicators >= 2:
|
||||
logger.info("Login detected!")
|
||||
self.logged_in = True
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking login status: {e}")
|
||||
|
||||
time.sleep(2) # Check every 2 seconds
|
||||
|
||||
logger.error(f"Login timeout after {timeout} seconds")
|
||||
return False
|
||||
|
||||
def extract_session_cookies(self) -> Dict[str, str]:
|
||||
"""
|
||||
Extract all cookies from current browser session
|
||||
|
||||
Returns:
|
||||
Dictionary of cookie name-value pairs
|
||||
"""
|
||||
if not self.driver:
|
||||
logger.error("Browser not started")
|
||||
return {}
|
||||
|
||||
cookies = {}
|
||||
|
||||
try:
|
||||
browser_cookies = self.driver.get_cookies()
|
||||
|
||||
for cookie in browser_cookies:
|
||||
cookies[cookie['name']] = cookie['value']
|
||||
|
||||
logger.info(f"Extracted {len(cookies)} cookies from browser session")
|
||||
|
||||
# Log important cookies (without values for security)
|
||||
important_cookies = ['uc_token', 'u_id', 'x-mxc-fingerprint', 'mexc_fingerprint_visitorId']
|
||||
for cookie_name in important_cookies:
|
||||
if cookie_name in cookies:
|
||||
logger.info(f"Found important cookie: {cookie_name}")
|
||||
else:
|
||||
logger.warning(f"Missing important cookie: {cookie_name}")
|
||||
|
||||
return cookies
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract cookies: {e}")
|
||||
return {}
|
||||
|
||||
def monitor_network_requests(self, duration: int = 60) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Monitor network requests for the specified duration
|
||||
|
||||
Args:
|
||||
duration: How long to monitor requests (seconds)
|
||||
|
||||
Returns:
|
||||
List of captured network requests
|
||||
"""
|
||||
if not self.driver:
|
||||
logger.error("Browser not started")
|
||||
return []
|
||||
|
||||
logger.info(f"Starting network monitoring for {duration} seconds")
|
||||
logger.info("Please perform trading actions in the browser (open/close positions)")
|
||||
|
||||
start_time = time.time()
|
||||
captured_requests = []
|
||||
|
||||
while time.time() - start_time < duration:
|
||||
try:
|
||||
# Get performance logs (network requests)
|
||||
logs = self.driver.get_log('performance')
|
||||
|
||||
for log in logs:
|
||||
message = json.loads(log['message'])
|
||||
|
||||
# Filter for relevant MEXC API requests
|
||||
if (message.get('message', {}).get('method') == 'Network.responseReceived'):
|
||||
response = message['message']['params']['response']
|
||||
url = response.get('url', '')
|
||||
|
||||
# Look for futures API calls
|
||||
if ('futures.mexc.com' in url or
|
||||
'ucgateway/captcha_api' in url or
|
||||
'api/v1/private' in url):
|
||||
|
||||
request_data = {
|
||||
'url': url,
|
||||
'method': response.get('mimeType', ''),
|
||||
'status': response.get('status'),
|
||||
'headers': response.get('headers', {}),
|
||||
'timestamp': log['timestamp']
|
||||
}
|
||||
|
||||
captured_requests.append(request_data)
|
||||
logger.info(f"Captured request: {url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in network monitoring: {e}")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
logger.info(f"Network monitoring complete. Captured {len(captured_requests)} requests")
|
||||
return captured_requests
|
||||
|
||||
def perform_test_trade(self, symbol: str = "ETH_USDT", volume: float = 1.0, leverage: int = 200):
|
||||
"""
|
||||
Attempt to perform a test trade to capture the complete request flow
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
volume: Position size
|
||||
leverage: Leverage multiplier
|
||||
"""
|
||||
if not self.logged_in:
|
||||
logger.error("Not logged in - cannot perform test trade")
|
||||
return
|
||||
|
||||
logger.info(f"Attempting test trade: {symbol}, Volume: {volume}, Leverage: {leverage}x")
|
||||
logger.info("This will attempt to click trading interface elements")
|
||||
|
||||
try:
|
||||
# This would need to be implemented based on MEXC's specific UI elements
|
||||
# For now, just wait and let user perform manual actions
|
||||
logger.info("Please manually place a small test trade while monitoring is active")
|
||||
time.sleep(30)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during test trade: {e}")
|
||||
|
||||
def full_session_capture(self, symbol: str = "ETH_USDT") -> Dict[str, Any]:
|
||||
"""
|
||||
Complete session capture workflow
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing cookies and captured requests
|
||||
"""
|
||||
logger.info("Starting full MEXC session capture")
|
||||
|
||||
try:
|
||||
# Start browser and navigate to MEXC
|
||||
self.navigate_to_mexc_futures(symbol)
|
||||
|
||||
# Wait for manual login
|
||||
if not self.wait_for_login():
|
||||
return {'success': False, 'error': 'Login timeout'}
|
||||
|
||||
# Extract session cookies
|
||||
cookies = self.extract_session_cookies()
|
||||
|
||||
if not cookies:
|
||||
return {'success': False, 'error': 'Failed to extract cookies'}
|
||||
|
||||
# Monitor network requests while user performs actions
|
||||
logger.info("Starting network monitoring - please perform trading actions now")
|
||||
requests = self.monitor_network_requests(duration=120) # 2 minutes
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'cookies': cookies,
|
||||
'network_requests': requests,
|
||||
'timestamp': int(time.time())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in session capture: {e}")
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
finally:
|
||||
self.stop_browser()
|
||||
|
||||
def main():
|
||||
"""Main function for standalone execution"""
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
print("MEXC Browser Automation - Session Capture")
|
||||
print("This will open a browser window for you to log into MEXC")
|
||||
print("Make sure you have Chrome browser installed")
|
||||
|
||||
automation = MEXCBrowserAutomation(headless=False)
|
||||
|
||||
try:
|
||||
result = automation.full_session_capture()
|
||||
|
||||
if result['success']:
|
||||
print(f"\nSession capture successful!")
|
||||
print(f"Extracted {len(result['cookies'])} cookies")
|
||||
print(f"Captured {len(result['network_requests'])} network requests")
|
||||
|
||||
# Save results to file
|
||||
output_file = f"mexc_session_capture_{int(time.time())}.json"
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(result, f, indent=2)
|
||||
|
||||
print(f"Results saved to: {output_file}")
|
||||
|
||||
else:
|
||||
print(f"Session capture failed: {result['error']}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nSession capture interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
finally:
|
||||
automation.stop_browser()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,525 +0,0 @@
|
||||
"""
|
||||
MEXC Futures Web Client
|
||||
|
||||
This module implements a web-based client for MEXC futures trading
|
||||
since their official API doesn't support futures (leverage) trading.
|
||||
|
||||
It mimics browser behavior by replicating the exact HTTP requests
|
||||
that the web interface makes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import requests
|
||||
import time
|
||||
import json
|
||||
import hmac
|
||||
import hashlib
|
||||
import base64
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from urllib.parse import urlencode
|
||||
import glob
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MEXCSessionManager:
|
||||
def __init__(self):
|
||||
self.captcha_token = None
|
||||
|
||||
def get_captcha_token(self) -> str:
|
||||
return self.captcha_token if self.captcha_token else ""
|
||||
|
||||
def save_captcha_token(self, token: str):
|
||||
self.captcha_token = token
|
||||
logger.info("MEXC: Captcha token saved in session manager")
|
||||
|
||||
class MEXCFuturesWebClient:
|
||||
"""
|
||||
MEXC Futures Web Client that mimics browser behavior for futures trading.
|
||||
|
||||
Since MEXC's official API doesn't support futures, this client replicates
|
||||
the exact HTTP requests made by their web interface.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, api_secret: str, user_id: str, base_url: str = 'https://www.mexc.com', headless: bool = True):
|
||||
"""
|
||||
Initialize the MEXC Futures Web Client
|
||||
|
||||
Args:
|
||||
api_key: API key for authentication
|
||||
api_secret: API secret for authentication
|
||||
user_id: User ID for authentication
|
||||
base_url: Base URL for the MEXC website
|
||||
headless: Whether to run the browser in headless mode
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.user_id = user_id
|
||||
self.base_url = base_url
|
||||
self.is_authenticated = False
|
||||
self.headless = headless
|
||||
self.session = requests.Session()
|
||||
self.session_manager = MEXCSessionManager() # Adding session_manager attribute
|
||||
self.captcha_url = f'{base_url}/ucgateway/captcha_api'
|
||||
self.futures_api_url = "https://futures.mexc.com/api/v1"
|
||||
|
||||
# Setup default headers that mimic a real browser
|
||||
self.setup_browser_headers()
|
||||
|
||||
def setup_browser_headers(self):
|
||||
"""Setup default headers that mimic Chrome browser"""
|
||||
self.session.headers.update({
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36',
|
||||
'Accept': '*/*',
|
||||
'Accept-Language': 'en-GB,en-US;q=0.9,en;q=0.8',
|
||||
'Accept-Encoding': 'gzip, deflate, br',
|
||||
'sec-ch-ua': '"Chromium";v="136", "Google Chrome";v="136", "Not.A/Brand";v="99"',
|
||||
'sec-ch-ua-mobile': '?0',
|
||||
'sec-ch-ua-platform': '"Windows"',
|
||||
'sec-fetch-dest': 'empty',
|
||||
'sec-fetch-mode': 'cors',
|
||||
'sec-fetch-site': 'same-origin',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Pragma': 'no-cache',
|
||||
'Referer': f'{self.base_url}/en-GB/futures/ETH_USDT?type=linear_swap',
|
||||
'Language': 'English',
|
||||
'X-Language': 'en-GB',
|
||||
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
|
||||
'trochilus-uid': str(self.user_id) if self.user_id is not None else ''
|
||||
})
|
||||
|
||||
def load_session_cookies(self, cookies: Dict[str, str]):
|
||||
"""
|
||||
Load session cookies from browser
|
||||
|
||||
Args:
|
||||
cookies: Dictionary of cookie name-value pairs
|
||||
"""
|
||||
for name, value in cookies.items():
|
||||
self.session.cookies.set(name, value)
|
||||
|
||||
# Extract important session info from cookies
|
||||
self.auth_token = cookies.get('uc_token')
|
||||
self.user_id = cookies.get('u_id')
|
||||
self.fingerprint = cookies.get('x-mxc-fingerprint')
|
||||
self.visitor_id = cookies.get('mexc_fingerprint_visitorId')
|
||||
|
||||
if self.auth_token and self.user_id:
|
||||
self.is_authenticated = True
|
||||
logger.info("MEXC: Loaded authenticated session")
|
||||
else:
|
||||
logger.warning("MEXC: Session cookies incomplete - authentication may fail")
|
||||
|
||||
def extract_cookies_from_browser(self, cookie_string: str) -> Dict[str, str]:
|
||||
"""
|
||||
Extract cookies from a browser cookie string
|
||||
|
||||
Args:
|
||||
cookie_string: Raw cookie string from browser (copy from Network tab)
|
||||
|
||||
Returns:
|
||||
Dictionary of parsed cookies
|
||||
"""
|
||||
cookies = {}
|
||||
cookie_pairs = cookie_string.split(';')
|
||||
|
||||
for pair in cookie_pairs:
|
||||
if '=' in pair:
|
||||
name, value = pair.strip().split('=', 1)
|
||||
cookies[name] = value
|
||||
|
||||
return cookies
|
||||
|
||||
def verify_captcha(self, symbol: str, side: str, leverage: str) -> bool:
|
||||
"""
|
||||
Verify captcha for robot trading protection
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH_USDT')
|
||||
side: 'openlong', 'closelong', 'openshort', 'closeshort'
|
||||
leverage: Leverage string (e.g., '200X')
|
||||
|
||||
Returns:
|
||||
bool: True if captcha verification successful
|
||||
"""
|
||||
if not self.is_authenticated:
|
||||
logger.error("MEXC: Cannot verify captcha - not authenticated")
|
||||
return False
|
||||
|
||||
# Build captcha endpoint URL
|
||||
endpoint = f"robot.future.{side}.{symbol}.{leverage}"
|
||||
url = f"{self.captcha_url}/{endpoint}"
|
||||
|
||||
# Attempt to get captcha token from session manager
|
||||
captcha_token = self.session_manager.get_captcha_token()
|
||||
if not captcha_token:
|
||||
logger.warning("MEXC: No captcha token available, attempting to fetch from browser")
|
||||
captcha_token = self._extract_captcha_token_from_browser()
|
||||
if captcha_token:
|
||||
self.session_manager.save_captcha_token(captcha_token)
|
||||
else:
|
||||
logger.error("MEXC: Failed to extract captcha token from browser")
|
||||
return False
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Language': 'en-GB',
|
||||
'Referer': f'{self.base_url}/en-GB/futures/{symbol}?type=linear_swap',
|
||||
'trochilus-uid': self.user_id if self.user_id else '',
|
||||
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
|
||||
'captcha-token': captcha_token
|
||||
}
|
||||
|
||||
logger.info(f"MEXC: Verifying captcha for {endpoint}")
|
||||
try:
|
||||
response = self.session.get(url, headers=headers, timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get('success'):
|
||||
logger.info(f"MEXC: Captcha verified successfully for {endpoint}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"MEXC: Captcha verification failed for {endpoint}: {data}")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"MEXC: Captcha verification request failed with status {response.status_code}: {response.text}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"MEXC: Captcha verification error for {endpoint}: {str(e)}")
|
||||
return False
|
||||
|
||||
def _extract_captcha_token_from_browser(self) -> str:
|
||||
"""
|
||||
Extract captcha token from browser session using stored cookies or requests.
|
||||
This method looks for the most recent mexc_captcha_tokens JSON file to retrieve a token.
|
||||
"""
|
||||
try:
|
||||
# Look for the most recent mexc_captcha_tokens file
|
||||
captcha_files = glob.glob("mexc_captcha_tokens_*.json")
|
||||
if not captcha_files:
|
||||
logger.error("MEXC: No CAPTCHA token files found")
|
||||
return ""
|
||||
|
||||
# Sort files by timestamp (most recent first)
|
||||
latest_file = max(captcha_files, key=os.path.getctime)
|
||||
logger.info(f"MEXC: Using CAPTCHA token file {latest_file}")
|
||||
|
||||
with open(latest_file, 'r') as f:
|
||||
captcha_data = json.load(f)
|
||||
|
||||
if captcha_data and isinstance(captcha_data, list) and len(captcha_data) > 0:
|
||||
# Return the most recent token
|
||||
return captcha_data[0].get('token', '')
|
||||
else:
|
||||
logger.error("MEXC: No valid CAPTCHA tokens found in file")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"MEXC: Error extracting captcha token from browser data: {str(e)}")
|
||||
return ""
|
||||
|
||||
def generate_signature(self, method: str, path: str, params: Dict[str, Any],
|
||||
timestamp: int, nonce: int) -> str:
|
||||
"""
|
||||
Generate signature for MEXC futures API requests
|
||||
|
||||
This is reverse-engineered from the browser requests
|
||||
"""
|
||||
# This is a placeholder - the actual signature generation would need
|
||||
# to be reverse-engineered from the browser's JavaScript
|
||||
# For now, return empty string and rely on cookie authentication
|
||||
return ""
|
||||
|
||||
def open_long_position(self, symbol: str, volume: float, leverage: int = 200,
|
||||
price: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Open a long futures position
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH_USDT')
|
||||
volume: Position size (contracts)
|
||||
leverage: Leverage multiplier (default 200)
|
||||
price: Limit price (None for market order)
|
||||
|
||||
Returns:
|
||||
dict: Order response with order ID
|
||||
"""
|
||||
if not self.is_authenticated:
|
||||
logger.error("MEXC: Cannot open position - not authenticated")
|
||||
return {'success': False, 'error': 'Not authenticated'}
|
||||
|
||||
# First verify captcha
|
||||
if not self.verify_captcha(symbol, 'openlong', f'{leverage}X'):
|
||||
logger.error("MEXC: Captcha verification failed for opening long position")
|
||||
return {'success': False, 'error': 'Captcha verification failed'}
|
||||
|
||||
# Prepare order parameters based on the request dump
|
||||
timestamp = int(time.time() * 1000)
|
||||
nonce = timestamp
|
||||
|
||||
order_data = {
|
||||
'symbol': symbol,
|
||||
'side': 1, # 1 = long, 2 = short
|
||||
'openType': 2, # Open position
|
||||
'type': '5', # Market order (might be '1' for limit)
|
||||
'vol': volume,
|
||||
'leverage': leverage,
|
||||
'marketCeiling': False,
|
||||
'priceProtect': '0',
|
||||
'ts': timestamp,
|
||||
'mhash': self._generate_mhash(), # This needs to be implemented
|
||||
'mtoken': self.visitor_id
|
||||
}
|
||||
|
||||
# Add price for limit orders
|
||||
if price is not None:
|
||||
order_data['price'] = price
|
||||
order_data['type'] = '1' # Limit order
|
||||
|
||||
# Add encrypted parameters (these would need proper implementation)
|
||||
order_data['p0'] = self._encrypt_p0(order_data) # Placeholder
|
||||
order_data['k0'] = self._encrypt_k0(order_data) # Placeholder
|
||||
order_data['chash'] = self._generate_chash(order_data) # Placeholder
|
||||
|
||||
# Setup headers for the order request
|
||||
headers = {
|
||||
'Authorization': self.auth_token,
|
||||
'Content-Type': 'application/json',
|
||||
'Language': 'English',
|
||||
'x-language': 'en-GB',
|
||||
'x-mxc-nonce': str(nonce),
|
||||
'x-mxc-sign': self.generate_signature('POST', '/private/order/create', order_data, timestamp, nonce),
|
||||
'trochilus-uid': self.user_id,
|
||||
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
|
||||
'Referer': 'https://www.mexc.com/'
|
||||
}
|
||||
|
||||
# Make the order request
|
||||
url = f"{self.futures_api_url}/private/order/create"
|
||||
|
||||
try:
|
||||
# First make OPTIONS request (preflight)
|
||||
options_response = self.session.options(url, headers=headers, timeout=10)
|
||||
|
||||
if options_response.status_code == 200:
|
||||
# Now make the actual POST request
|
||||
response = self.session.post(url, json=order_data, headers=headers, timeout=15)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get('success') and data.get('code') == 0:
|
||||
order_id = data.get('data', {}).get('orderId')
|
||||
logger.info(f"MEXC: Long position opened successfully - Order ID: {order_id}")
|
||||
return {
|
||||
'success': True,
|
||||
'order_id': order_id,
|
||||
'timestamp': data.get('data', {}).get('ts'),
|
||||
'symbol': symbol,
|
||||
'side': 'long',
|
||||
'volume': volume,
|
||||
'leverage': leverage
|
||||
}
|
||||
else:
|
||||
logger.error(f"MEXC: Order failed: {data}")
|
||||
return {'success': False, 'error': data.get('msg', 'Unknown error')}
|
||||
else:
|
||||
logger.error(f"MEXC: Order request failed with status {response.status_code}")
|
||||
return {'success': False, 'error': f'HTTP {response.status_code}'}
|
||||
else:
|
||||
logger.error(f"MEXC: OPTIONS preflight failed with status {options_response.status_code}")
|
||||
return {'success': False, 'error': f'Preflight failed: HTTP {options_response.status_code}'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MEXC: Order execution error: {e}")
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
def close_long_position(self, symbol: str, volume: float, leverage: int = 200,
|
||||
price: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Close a long futures position
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH_USDT')
|
||||
volume: Position size to close (contracts)
|
||||
leverage: Leverage multiplier
|
||||
price: Limit price (None for market order)
|
||||
|
||||
Returns:
|
||||
dict: Order response
|
||||
"""
|
||||
if not self.is_authenticated:
|
||||
logger.error("MEXC: Cannot close position - not authenticated")
|
||||
return {'success': False, 'error': 'Not authenticated'}
|
||||
|
||||
# First verify captcha
|
||||
if not self.verify_captcha(symbol, 'closelong', f'{leverage}X'):
|
||||
logger.error("MEXC: Captcha verification failed for closing long position")
|
||||
return {'success': False, 'error': 'Captcha verification failed'}
|
||||
|
||||
# Similar to open_long_position but with closeType instead of openType
|
||||
timestamp = int(time.time() * 1000)
|
||||
nonce = timestamp
|
||||
|
||||
order_data = {
|
||||
'symbol': symbol,
|
||||
'side': 2, # Close side is opposite
|
||||
'closeType': 1, # Close position
|
||||
'type': '5', # Market order
|
||||
'vol': volume,
|
||||
'leverage': leverage,
|
||||
'marketCeiling': False,
|
||||
'priceProtect': '0',
|
||||
'ts': timestamp,
|
||||
'mhash': self._generate_mhash(),
|
||||
'mtoken': self.visitor_id
|
||||
}
|
||||
|
||||
if price is not None:
|
||||
order_data['price'] = price
|
||||
order_data['type'] = '1'
|
||||
|
||||
order_data['p0'] = self._encrypt_p0(order_data)
|
||||
order_data['k0'] = self._encrypt_k0(order_data)
|
||||
order_data['chash'] = self._generate_chash(order_data)
|
||||
|
||||
return self._execute_order(order_data, 'close_long')
|
||||
|
||||
def open_short_position(self, symbol: str, volume: float, leverage: int = 200,
|
||||
price: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""Open a short futures position"""
|
||||
if not self.verify_captcha(symbol, 'openshort', f'{leverage}X'):
|
||||
return {'success': False, 'error': 'Captcha verification failed'}
|
||||
|
||||
order_data = {
|
||||
'symbol': symbol,
|
||||
'side': 2, # 2 = short
|
||||
'openType': 2,
|
||||
'type': '5',
|
||||
'vol': volume,
|
||||
'leverage': leverage,
|
||||
'marketCeiling': False,
|
||||
'priceProtect': '0',
|
||||
'ts': int(time.time() * 1000),
|
||||
'mhash': self._generate_mhash(),
|
||||
'mtoken': self.visitor_id
|
||||
}
|
||||
|
||||
if price is not None:
|
||||
order_data['price'] = price
|
||||
order_data['type'] = '1'
|
||||
|
||||
order_data['p0'] = self._encrypt_p0(order_data)
|
||||
order_data['k0'] = self._encrypt_k0(order_data)
|
||||
order_data['chash'] = self._generate_chash(order_data)
|
||||
|
||||
return self._execute_order(order_data, 'open_short')
|
||||
|
||||
def close_short_position(self, symbol: str, volume: float, leverage: int = 200,
|
||||
price: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""Close a short futures position"""
|
||||
if not self.verify_captcha(symbol, 'closeshort', f'{leverage}X'):
|
||||
return {'success': False, 'error': 'Captcha verification failed'}
|
||||
|
||||
order_data = {
|
||||
'symbol': symbol,
|
||||
'side': 1, # Close side is opposite
|
||||
'closeType': 1,
|
||||
'type': '5',
|
||||
'vol': volume,
|
||||
'leverage': leverage,
|
||||
'marketCeiling': False,
|
||||
'priceProtect': '0',
|
||||
'ts': int(time.time() * 1000),
|
||||
'mhash': self._generate_mhash(),
|
||||
'mtoken': self.visitor_id
|
||||
}
|
||||
|
||||
if price is not None:
|
||||
order_data['price'] = price
|
||||
order_data['type'] = '1'
|
||||
|
||||
order_data['p0'] = self._encrypt_p0(order_data)
|
||||
order_data['k0'] = self._encrypt_k0(order_data)
|
||||
order_data['chash'] = self._generate_chash(order_data)
|
||||
|
||||
return self._execute_order(order_data, 'close_short')
|
||||
|
||||
def _execute_order(self, order_data: Dict[str, Any], action: str) -> Dict[str, Any]:
|
||||
"""Common order execution logic"""
|
||||
timestamp = order_data['ts']
|
||||
nonce = timestamp
|
||||
|
||||
headers = {
|
||||
'Authorization': self.auth_token,
|
||||
'Content-Type': 'application/json',
|
||||
'Language': 'English',
|
||||
'x-language': 'en-GB',
|
||||
'x-mxc-nonce': str(nonce),
|
||||
'x-mxc-sign': self.generate_signature('POST', '/private/order/create', order_data, timestamp, nonce),
|
||||
'trochilus-uid': self.user_id,
|
||||
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
|
||||
'Referer': 'https://www.mexc.com/'
|
||||
}
|
||||
|
||||
url = f"{self.futures_api_url}/private/order/create"
|
||||
|
||||
try:
|
||||
response = self.session.post(url, json=order_data, headers=headers, timeout=15)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get('success') and data.get('code') == 0:
|
||||
order_id = data.get('data', {}).get('orderId')
|
||||
logger.info(f"MEXC: {action} executed successfully - Order ID: {order_id}")
|
||||
return {
|
||||
'success': True,
|
||||
'order_id': order_id,
|
||||
'timestamp': data.get('data', {}).get('ts'),
|
||||
'action': action
|
||||
}
|
||||
else:
|
||||
logger.error(f"MEXC: {action} failed: {data}")
|
||||
return {'success': False, 'error': data.get('msg', 'Unknown error')}
|
||||
else:
|
||||
logger.error(f"MEXC: {action} request failed with status {response.status_code}")
|
||||
return {'success': False, 'error': f'HTTP {response.status_code}'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MEXC: {action} execution error: {e}")
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
# Placeholder methods for encryption/hashing - these need proper implementation
|
||||
def _generate_mhash(self) -> str:
|
||||
"""Generate mhash parameter (needs reverse engineering)"""
|
||||
return "a0015441fd4c3b6ba427b894b76cb7dd" # Placeholder from request dump
|
||||
|
||||
def _encrypt_p0(self, order_data: Dict[str, Any]) -> str:
|
||||
"""Encrypt p0 parameter (needs reverse engineering)"""
|
||||
return "placeholder_p0_encryption" # This needs proper implementation
|
||||
|
||||
def _encrypt_k0(self, order_data: Dict[str, Any]) -> str:
|
||||
"""Encrypt k0 parameter (needs reverse engineering)"""
|
||||
return "placeholder_k0_encryption" # This needs proper implementation
|
||||
|
||||
def _generate_chash(self, order_data: Dict[str, Any]) -> str:
|
||||
"""Generate chash parameter (needs reverse engineering)"""
|
||||
return "d6c64d28e362f314071b3f9d78ff7494d9cd7177ae0465e772d1840e9f7905d8" # Placeholder
|
||||
|
||||
def get_account_info(self) -> Dict[str, Any]:
|
||||
"""Get account information including positions and balances"""
|
||||
if not self.is_authenticated:
|
||||
return {'success': False, 'error': 'Not authenticated'}
|
||||
|
||||
# This would need to be implemented by reverse engineering the account info endpoints
|
||||
logger.info("MEXC: Account info endpoint not yet implemented")
|
||||
return {'success': False, 'error': 'Not implemented'}
|
||||
|
||||
def get_open_positions(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of open futures positions"""
|
||||
if not self.is_authenticated:
|
||||
return []
|
||||
|
||||
# This would need to be implemented by reverse engineering the positions endpoint
|
||||
logger.info("MEXC: Open positions endpoint not yet implemented")
|
||||
return []
|
||||
@@ -1,259 +0,0 @@
|
||||
"""
|
||||
MEXC Session Manager
|
||||
|
||||
Helper utilities for managing MEXC web sessions and extracting cookies from browser.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Optional, Any
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MEXCSessionManager:
|
||||
"""
|
||||
Helper class for managing MEXC web sessions and extracting browser cookies
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.session_file = Path("mexc_session.json")
|
||||
|
||||
def extract_cookies_from_network_tab(self, cookie_header: str) -> Dict[str, str]:
|
||||
"""
|
||||
Extract cookies from browser Network tab cookie header
|
||||
|
||||
Args:
|
||||
cookie_header: Raw cookie string from browser (copy from Request Headers)
|
||||
|
||||
Returns:
|
||||
Dictionary of parsed cookies
|
||||
"""
|
||||
cookies = {}
|
||||
|
||||
# Remove 'Cookie: ' prefix if present
|
||||
if cookie_header.startswith('Cookie: '):
|
||||
cookie_header = cookie_header[8:]
|
||||
elif cookie_header.startswith('cookie: '):
|
||||
cookie_header = cookie_header[8:]
|
||||
|
||||
# Split by semicolon and parse each cookie
|
||||
cookie_pairs = cookie_header.split(';')
|
||||
|
||||
for pair in cookie_pairs:
|
||||
pair = pair.strip()
|
||||
if '=' in pair:
|
||||
name, value = pair.split('=', 1)
|
||||
cookies[name.strip()] = value.strip()
|
||||
|
||||
logger.info(f"Extracted {len(cookies)} cookies from browser")
|
||||
return cookies
|
||||
|
||||
def validate_session_cookies(self, cookies: Dict[str, str]) -> bool:
|
||||
"""
|
||||
Validate that essential cookies are present for authentication
|
||||
|
||||
Args:
|
||||
cookies: Dictionary of cookie name-value pairs
|
||||
|
||||
Returns:
|
||||
bool: True if cookies appear valid for authentication
|
||||
"""
|
||||
required_cookies = [
|
||||
'uc_token', # User authentication token
|
||||
'u_id', # User ID
|
||||
'x-mxc-fingerprint', # Browser fingerprint
|
||||
'mexc_fingerprint_visitorId' # Visitor ID
|
||||
]
|
||||
|
||||
missing_cookies = []
|
||||
for cookie_name in required_cookies:
|
||||
if cookie_name not in cookies or not cookies[cookie_name]:
|
||||
missing_cookies.append(cookie_name)
|
||||
|
||||
if missing_cookies:
|
||||
logger.warning(f"Missing required cookies: {missing_cookies}")
|
||||
return False
|
||||
|
||||
logger.info("All required cookies are present")
|
||||
return True
|
||||
|
||||
def save_session(self, cookies: Dict[str, str], metadata: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Save session cookies to file for reuse
|
||||
|
||||
Args:
|
||||
cookies: Dictionary of cookies to save
|
||||
metadata: Optional metadata about the session
|
||||
"""
|
||||
session_data = {
|
||||
'cookies': cookies,
|
||||
'metadata': metadata or {},
|
||||
'timestamp': int(time.time())
|
||||
}
|
||||
|
||||
try:
|
||||
with open(self.session_file, 'w') as f:
|
||||
json.dump(session_data, f, indent=2)
|
||||
logger.info(f"Session saved to {self.session_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save session: {e}")
|
||||
|
||||
def load_session(self) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Load session cookies from file
|
||||
|
||||
Returns:
|
||||
Dictionary of cookies if successful, None otherwise
|
||||
"""
|
||||
if not self.session_file.exists():
|
||||
logger.info("No saved session found")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(self.session_file, 'r') as f:
|
||||
session_data = json.load(f)
|
||||
|
||||
cookies = session_data.get('cookies', {})
|
||||
timestamp = session_data.get('timestamp', 0)
|
||||
|
||||
# Check if session is too old (24 hours)
|
||||
import time
|
||||
if time.time() - timestamp > 24 * 3600:
|
||||
logger.warning("Saved session is too old (>24h), may be expired")
|
||||
|
||||
if self.validate_session_cookies(cookies):
|
||||
logger.info("Loaded valid session from file")
|
||||
return cookies
|
||||
else:
|
||||
logger.warning("Loaded session has invalid cookies")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load session: {e}")
|
||||
return None
|
||||
|
||||
def extract_from_curl_command(self, curl_command: str) -> Dict[str, str]:
|
||||
"""
|
||||
Extract cookies from a curl command copied from browser
|
||||
|
||||
Args:
|
||||
curl_command: Complete curl command from browser "Copy as cURL"
|
||||
|
||||
Returns:
|
||||
Dictionary of extracted cookies
|
||||
"""
|
||||
cookies = {}
|
||||
|
||||
# Find cookie header in curl command
|
||||
cookie_match = re.search(r'-H [\'"]cookie: ([^\'"]+)[\'"]', curl_command, re.IGNORECASE)
|
||||
if not cookie_match:
|
||||
cookie_match = re.search(r'--header [\'"]cookie: ([^\'"]+)[\'"]', curl_command, re.IGNORECASE)
|
||||
|
||||
if cookie_match:
|
||||
cookie_header = cookie_match.group(1)
|
||||
cookies = self.extract_cookies_from_network_tab(cookie_header)
|
||||
logger.info(f"Extracted {len(cookies)} cookies from curl command")
|
||||
else:
|
||||
logger.warning("No cookie header found in curl command")
|
||||
|
||||
return cookies
|
||||
|
||||
def print_cookie_extraction_guide(self):
|
||||
"""Print instructions for extracting cookies from browser"""
|
||||
print("\n" + "="*80)
|
||||
print("MEXC COOKIE EXTRACTION GUIDE")
|
||||
print("="*80)
|
||||
print("""
|
||||
To extract cookies from your browser for MEXC futures trading:
|
||||
|
||||
METHOD 1: Browser Network Tab
|
||||
1. Open MEXC futures page and log in: https://www.mexc.com/en-GB/futures/ETH_USDT
|
||||
2. Open browser Developer Tools (F12)
|
||||
3. Go to Network tab
|
||||
4. Try to place a small futures trade (it will fail, but we need the request)
|
||||
5. Find the request to 'futures.mexc.com' in the Network tab
|
||||
6. Right-click on the request -> Copy -> Copy request headers
|
||||
7. Find the 'Cookie:' line and copy everything after 'Cookie: '
|
||||
|
||||
METHOD 2: Copy as cURL
|
||||
1. Follow steps 1-5 above
|
||||
2. Right-click on the futures API request -> Copy -> Copy as cURL
|
||||
3. Paste the entire cURL command
|
||||
|
||||
METHOD 3: Manual Cookie Extraction
|
||||
1. While logged into MEXC, press F12 -> Application/Storage tab
|
||||
2. On the left, expand 'Cookies' -> click on 'https://www.mexc.com'
|
||||
3. Copy the values for these important cookies:
|
||||
- uc_token
|
||||
- u_id
|
||||
- x-mxc-fingerprint
|
||||
- mexc_fingerprint_visitorId
|
||||
|
||||
IMPORTANT NOTES:
|
||||
- Cookies expire after some time (usually 24 hours)
|
||||
- You must be logged into MEXC futures (not just spot trading)
|
||||
- Keep your cookies secure - they provide access to your account
|
||||
- Test with small amounts first
|
||||
|
||||
Example usage:
|
||||
session_manager = MEXCSessionManager()
|
||||
|
||||
# Method 1: From cookie header
|
||||
cookie_header = "uc_token=ABC123; u_id=DEF456; ..."
|
||||
cookies = session_manager.extract_cookies_from_network_tab(cookie_header)
|
||||
|
||||
# Method 2: From cURL command
|
||||
curl_cmd = "curl 'https://futures.mexc.com/...' -H 'cookie: uc_token=ABC123...'"
|
||||
cookies = session_manager.extract_from_curl_command(curl_cmd)
|
||||
|
||||
# Save session for reuse
|
||||
session_manager.save_session(cookies)
|
||||
""")
|
||||
print("="*80)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# When run directly, show the extraction guide
|
||||
import time
|
||||
|
||||
manager = MEXCSessionManager()
|
||||
manager.print_cookie_extraction_guide()
|
||||
|
||||
print("\nWould you like to:")
|
||||
print("1. Load saved session")
|
||||
print("2. Extract cookies from clipboard")
|
||||
print("3. Exit")
|
||||
|
||||
choice = input("\nEnter choice (1-3): ").strip()
|
||||
|
||||
if choice == "1":
|
||||
cookies = manager.load_session()
|
||||
if cookies:
|
||||
print(f"\nLoaded {len(cookies)} cookies from saved session")
|
||||
if manager.validate_session_cookies(cookies):
|
||||
print("Session appears valid for trading")
|
||||
else:
|
||||
print("Warning: Session may be incomplete or expired")
|
||||
else:
|
||||
print("No valid saved session found")
|
||||
|
||||
elif choice == "2":
|
||||
print("\nPaste your cookie header or cURL command:")
|
||||
user_input = input().strip()
|
||||
|
||||
if user_input.startswith('curl'):
|
||||
cookies = manager.extract_from_curl_command(user_input)
|
||||
else:
|
||||
cookies = manager.extract_cookies_from_network_tab(user_input)
|
||||
|
||||
if cookies and manager.validate_session_cookies(cookies):
|
||||
print(f"\nSuccessfully extracted {len(cookies)} valid cookies")
|
||||
save = input("Save session for reuse? (y/n): ").strip().lower()
|
||||
if save == 'y':
|
||||
manager.save_session(cookies)
|
||||
else:
|
||||
print("Failed to extract valid cookies")
|
||||
|
||||
else:
|
||||
print("Goodbye!")
|
||||
@@ -1,346 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test MEXC Futures Web Client
|
||||
|
||||
This script demonstrates how to use the MEXC Futures Web Client
|
||||
for futures trading that isn't supported by their official API.
|
||||
|
||||
IMPORTANT: This requires extracting cookies from your browser session.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
|
||||
# Add the project root to path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from mexc_futures_client import MEXCFuturesWebClient
|
||||
from session_manager import MEXCSessionManager
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
SYMBOL = "ETH_USDT"
|
||||
LEVERAGE = 300
|
||||
CREDENTIALS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'mexc_credentials.json')
|
||||
|
||||
# Read credentials from mexc_credentials.json in JSON format
|
||||
def load_credentials():
|
||||
credentials_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'mexc_credentials.json')
|
||||
cookies = {}
|
||||
captcha_token_open = ''
|
||||
captcha_token_close = ''
|
||||
try:
|
||||
with open(credentials_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
cookies = data.get('credentials', {}).get('cookies', {})
|
||||
captcha_token_open = data.get('credentials', {}).get('captcha_token_open', '')
|
||||
captcha_token_close = data.get('credentials', {}).get('captcha_token_close', '')
|
||||
logger.info(f"Loaded credentials from {credentials_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading credentials: {e}")
|
||||
return cookies, captcha_token_open, captcha_token_close
|
||||
|
||||
def test_basic_connection():
|
||||
"""Test basic connection and authentication"""
|
||||
logger.info("Testing MEXC Futures Web Client")
|
||||
|
||||
# Initialize session manager
|
||||
session_manager = MEXCSessionManager()
|
||||
|
||||
# Try to load saved session first
|
||||
cookies = session_manager.load_session()
|
||||
|
||||
if not cookies:
|
||||
# Explicitly load the cookies from the file we have
|
||||
cookies_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'mexc_cookies_20250703_003625.json')
|
||||
if os.path.exists(cookies_file):
|
||||
try:
|
||||
with open(cookies_file, 'r') as f:
|
||||
cookies = json.load(f)
|
||||
logger.info(f"Loaded cookies from {cookies_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load cookies from {cookies_file}: {e}")
|
||||
cookies = None
|
||||
else:
|
||||
logger.error(f"Cookies file not found at {cookies_file}")
|
||||
cookies = None
|
||||
|
||||
if not cookies:
|
||||
print("\nNo saved session found. You need to extract cookies from your browser.")
|
||||
session_manager.print_cookie_extraction_guide()
|
||||
|
||||
print("\nPaste your cookie header or cURL command (or press Enter to exit):")
|
||||
user_input = input().strip()
|
||||
|
||||
if not user_input:
|
||||
print("No input provided. Exiting.")
|
||||
return False
|
||||
|
||||
# Extract cookies from user input
|
||||
if user_input.startswith('curl'):
|
||||
cookies = session_manager.extract_from_curl_command(user_input)
|
||||
else:
|
||||
cookies = session_manager.extract_cookies_from_network_tab(user_input)
|
||||
|
||||
if not cookies:
|
||||
logger.error("Failed to extract cookies from input")
|
||||
return False
|
||||
|
||||
# Validate and save session
|
||||
if session_manager.validate_session_cookies(cookies):
|
||||
session_manager.save_session(cookies)
|
||||
logger.info("Session saved for future use")
|
||||
else:
|
||||
logger.warning("Extracted cookies may be incomplete")
|
||||
|
||||
# Initialize the web client
|
||||
client = MEXCFuturesWebClient(api_key='', api_secret='', user_id='', base_url='https://www.mexc.com', headless=True)
|
||||
# Load cookies into the client's session
|
||||
for name, value in cookies.items():
|
||||
client.session.cookies.set(name, value)
|
||||
|
||||
# Update headers to include additional parameters from captured requests
|
||||
client.session.headers.update({
|
||||
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
|
||||
'trochilus-uid': cookies.get('u_id', ''),
|
||||
'Referer': 'https://www.mexc.com/en-GB/futures/ETH_USDT?type=linear_swap',
|
||||
'Language': 'English',
|
||||
'X-Language': 'en-GB'
|
||||
})
|
||||
|
||||
if not client.is_authenticated:
|
||||
logger.error("Failed to authenticate with extracted cookies")
|
||||
return False
|
||||
|
||||
logger.info("Successfully authenticated with MEXC")
|
||||
logger.info(f"User ID: {client.user_id}")
|
||||
logger.info(f"Auth Token: {client.auth_token[:20]}..." if client.auth_token else "No auth token")
|
||||
|
||||
return True
|
||||
|
||||
def test_captcha_verification(client: MEXCFuturesWebClient):
|
||||
"""Test captcha verification system"""
|
||||
logger.info("Testing captcha verification...")
|
||||
|
||||
# Test captcha for ETH_USDT long position with 200x leverage
|
||||
success = client.verify_captcha('ETH_USDT', 'openlong', '200X')
|
||||
|
||||
if success:
|
||||
logger.info("Captcha verification successful")
|
||||
else:
|
||||
logger.warning("Captcha verification failed - this may be normal if no position is being opened")
|
||||
|
||||
return success
|
||||
|
||||
def test_position_opening(client: MEXCFuturesWebClient, dry_run: bool = True):
|
||||
"""Test opening a position (dry run by default)"""
|
||||
if dry_run:
|
||||
logger.info("DRY RUN: Testing position opening (no actual trade)")
|
||||
else:
|
||||
logger.warning("LIVE TRADING: Opening actual position!")
|
||||
|
||||
symbol = 'ETH_USDT'
|
||||
volume = 1 # Small test position
|
||||
leverage = 200
|
||||
|
||||
logger.info(f"Attempting to open long position: {symbol}, Volume: {volume}, Leverage: {leverage}x")
|
||||
|
||||
if not dry_run:
|
||||
result = client.open_long_position(symbol, volume, leverage)
|
||||
|
||||
if result['success']:
|
||||
logger.info(f"Position opened successfully!")
|
||||
logger.info(f"Order ID: {result['order_id']}")
|
||||
logger.info(f"Timestamp: {result['timestamp']}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to open position: {result['error']}")
|
||||
return False
|
||||
else:
|
||||
logger.info("DRY RUN: Would attempt to open position here")
|
||||
# Test just the captcha verification part
|
||||
return client.verify_captcha(symbol, 'openlong', f'{leverage}X')
|
||||
|
||||
def test_position_opening_live(client):
|
||||
symbol = "ETH_USDT"
|
||||
volume = 1 # Small volume for testing
|
||||
leverage = 200
|
||||
|
||||
logger.info(f"LIVE TRADING: Opening actual position!")
|
||||
logger.info(f"Attempting to open long position: {symbol}, Volume: {volume}, Leverage: {leverage}x")
|
||||
|
||||
result = client.open_long_position(symbol, volume, leverage)
|
||||
if result.get('success'):
|
||||
logger.info(f"Successfully opened position: {result}")
|
||||
else:
|
||||
logger.error(f"Failed to open position: {result.get('error', 'Unknown error')}")
|
||||
|
||||
def interactive_menu(client: MEXCFuturesWebClient):
|
||||
"""Interactive menu for testing different functions"""
|
||||
while True:
|
||||
print("\n" + "="*50)
|
||||
print("MEXC Futures Web Client Test Menu")
|
||||
print("="*50)
|
||||
print("1. Test captcha verification")
|
||||
print("2. Test position opening (DRY RUN)")
|
||||
print("3. Test position opening (LIVE - BE CAREFUL!)")
|
||||
print("4. Test position closing (DRY RUN)")
|
||||
print("5. Show session info")
|
||||
print("6. Refresh session")
|
||||
print("0. Exit")
|
||||
|
||||
choice = input("\nEnter choice (0-6): ").strip()
|
||||
|
||||
if choice == "1":
|
||||
test_captcha_verification(client)
|
||||
|
||||
elif choice == "2":
|
||||
test_position_opening(client, dry_run=True)
|
||||
|
||||
elif choice == "3":
|
||||
test_position_opening_live(client)
|
||||
|
||||
elif choice == "4":
|
||||
logger.info("DRY RUN: Position closing test")
|
||||
success = client.verify_captcha('ETH_USDT', 'closelong', '200X')
|
||||
if success:
|
||||
logger.info("DRY RUN: Would close position here")
|
||||
else:
|
||||
logger.warning("Captcha verification failed for position closing")
|
||||
|
||||
elif choice == "5":
|
||||
print(f"\nSession Information:")
|
||||
print(f"Authenticated: {client.is_authenticated}")
|
||||
print(f"User ID: {client.user_id}")
|
||||
print(f"Auth Token: {client.auth_token[:20]}..." if client.auth_token else "None")
|
||||
print(f"Fingerprint: {client.fingerprint}")
|
||||
print(f"Visitor ID: {client.visitor_id}")
|
||||
|
||||
elif choice == "6":
|
||||
session_manager = MEXCSessionManager()
|
||||
session_manager.print_cookie_extraction_guide()
|
||||
|
||||
elif choice == "0":
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
else:
|
||||
print("Invalid choice. Please try again.")
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
print("MEXC Futures Web Client Test")
|
||||
print("WARNING: This is experimental software for futures trading")
|
||||
print("Use at your own risk and test with small amounts first!")
|
||||
|
||||
# Load cookies and tokens
|
||||
cookies, captcha_token_open, captcha_token_close = load_credentials()
|
||||
if not cookies:
|
||||
logger.error("Failed to load cookies from credentials file")
|
||||
sys.exit(1)
|
||||
|
||||
# Initialize client with loaded cookies and tokens
|
||||
client = MEXCFuturesWebClient(api_key='', api_secret='', user_id='')
|
||||
# Load cookies into the client's session
|
||||
for name, value in cookies.items():
|
||||
client.session.cookies.set(name, value)
|
||||
# Set captcha tokens
|
||||
client.captcha_token_open = captcha_token_open
|
||||
client.captcha_token_close = captcha_token_close
|
||||
|
||||
# Try to load credentials from the new JSON file
|
||||
try:
|
||||
with open(CREDENTIALS_FILE, 'r') as f:
|
||||
credentials_data = json.load(f)
|
||||
cookies = credentials_data['credentials']['cookies']
|
||||
captcha_token_open = credentials_data['credentials']['captcha_token_open']
|
||||
captcha_token_close = credentials_data['credentials']['captcha_token_close']
|
||||
client.load_session_cookies(cookies)
|
||||
client.session_manager.save_captcha_token(captcha_token_open) # Assuming this is for opening
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Credentials file not found at {CREDENTIALS_FILE}")
|
||||
return False
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error loading credentials: {e}")
|
||||
return False
|
||||
except KeyError as e:
|
||||
logger.error(f"Missing key in credentials file: {e}")
|
||||
return False
|
||||
|
||||
if not client.is_authenticated:
|
||||
logger.error("Client not authenticated. Please ensure valid cookies and tokens are in mexc_credentials.json")
|
||||
return False
|
||||
|
||||
# Test connection and authentication
|
||||
logger.info("Successfully authenticated with MEXC")
|
||||
|
||||
# Set leverage
|
||||
leverage_response = client.update_leverage(symbol=SYMBOL, leverage=LEVERAGE)
|
||||
if leverage_response and leverage_response.get('code') == 200:
|
||||
logger.info(f"Leverage set to {LEVERAGE}x for {SYMBOL}")
|
||||
else:
|
||||
logger.error(f"Failed to set leverage: {leverage_response}")
|
||||
sys.exit(1)
|
||||
|
||||
# Get current price
|
||||
ticker = client.get_ticker_data(symbol=SYMBOL)
|
||||
if ticker and ticker.get('code') == 200:
|
||||
current_price = float(ticker['data']['last'])
|
||||
logger.info(f"Current {SYMBOL} price: {current_price}")
|
||||
else:
|
||||
logger.error(f"Failed to get ticker data: {ticker}")
|
||||
sys.exit(1)
|
||||
|
||||
# Calculate order size for a small test trade (e.g., $10 worth)
|
||||
trade_usdt = 10.0
|
||||
order_qty = round((trade_usdt / current_price) * LEVERAGE, 3)
|
||||
logger.info(f"Calculated order quantity: {order_qty} {SYMBOL} for ~${trade_usdt} at {LEVERAGE}x")
|
||||
|
||||
# Test 1: Open LONG position
|
||||
logger.info(f"Opening LONG position for {SYMBOL} at {current_price} with qty {order_qty}")
|
||||
open_long_order = client.create_order(
|
||||
symbol=SYMBOL,
|
||||
side=1, # 1 for BUY
|
||||
position_side=1, # 1 for LONG
|
||||
order_type=1, # 1 for LIMIT
|
||||
price=current_price,
|
||||
vol=order_qty
|
||||
)
|
||||
if open_long_order and open_long_order.get('code') == 200:
|
||||
logger.info(f"✅ Successfully opened LONG position: {open_long_order['data']}")
|
||||
else:
|
||||
logger.error(f"❌ Failed to open LONG position: {open_long_order}")
|
||||
sys.exit(1)
|
||||
|
||||
# Test 2: Close LONG position
|
||||
logger.info(f"Closing LONG position for {SYMBOL}")
|
||||
close_long_order = client.create_order(
|
||||
symbol=SYMBOL,
|
||||
side=2, # 2 for SELL
|
||||
position_side=1, # 1 for LONG
|
||||
order_type=1, # 1 for LIMIT
|
||||
price=current_price,
|
||||
vol=order_qty,
|
||||
reduce_only=True
|
||||
)
|
||||
if close_long_order and close_long_order.get('code') == 200:
|
||||
logger.info(f"✅ Successfully closed LONG position: {close_long_order['data']}")
|
||||
else:
|
||||
logger.error(f"❌ Failed to close LONG position: {close_long_order}")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("All tests completed successfully!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,595 +0,0 @@
|
||||
"""
|
||||
Negative Case Trainer - Intensive Training on Losing Trades
|
||||
|
||||
This module focuses on learning from losses to prevent future mistakes.
|
||||
Stores negative cases in testcases/negative folder for reuse and retraining.
|
||||
Supports simultaneous inference and training.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# Import checkpoint management
|
||||
import torch
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class NegativeCase:
|
||||
"""Represents a losing trade case for intensive training"""
|
||||
case_id: str
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
action: str # 'BUY' or 'SELL'
|
||||
entry_price: float
|
||||
exit_price: float
|
||||
loss_amount: float
|
||||
loss_percentage: float
|
||||
confidence_used: float
|
||||
market_state_before: Dict[str, Any]
|
||||
market_state_after: Dict[str, Any]
|
||||
tick_data: List[Dict[str, Any]] # 15 minutes of tick data around the trade
|
||||
technical_indicators: Dict[str, float]
|
||||
what_should_have_been_done: str # 'HOLD', 'OPPOSITE', 'WAIT'
|
||||
lesson_learned: str
|
||||
training_priority: int # 1-5, 5 being highest priority
|
||||
retraining_count: int = 0
|
||||
last_retrained: Optional[datetime] = None
|
||||
|
||||
@dataclass
|
||||
class TrainingSession:
|
||||
"""Represents an intensive training session on negative cases"""
|
||||
session_id: str
|
||||
start_time: datetime
|
||||
cases_trained: List[str] # case_ids
|
||||
epochs_completed: int
|
||||
loss_improvement: float
|
||||
accuracy_improvement: float
|
||||
inference_paused: bool = False
|
||||
training_active: bool = True
|
||||
|
||||
class NegativeCaseTrainer:
|
||||
"""
|
||||
Intensive trainer focused on learning from losing trades with checkpoint management
|
||||
|
||||
Features:
|
||||
- Stores all losing trades as negative cases
|
||||
- Intensive retraining on losses
|
||||
- Simultaneous inference and training
|
||||
- Persistent storage in testcases/negative
|
||||
- Priority-based training (bigger losses = higher priority)
|
||||
- Checkpoint management for training progress
|
||||
"""
|
||||
|
||||
def __init__(self, storage_dir: str = "testcases/negative",
|
||||
model_name: str = "negative_case_trainer", enable_checkpoints: bool = True):
|
||||
self.storage_dir = storage_dir
|
||||
self.stored_cases: List[NegativeCase] = []
|
||||
self.training_queue = deque(maxlen=1000)
|
||||
self.training_lock = threading.Lock()
|
||||
self.inference_lock = threading.Lock()
|
||||
|
||||
# Checkpoint management
|
||||
self.model_name = model_name
|
||||
self.enable_checkpoints = enable_checkpoints
|
||||
self.training_integration = get_training_integration() if enable_checkpoints else None
|
||||
self.training_session_count = 0
|
||||
self.best_loss_reduction = 0.0
|
||||
self.checkpoint_frequency = 25 # Save checkpoint every 25 training sessions
|
||||
|
||||
# Training configuration
|
||||
self.max_concurrent_training = 3 # Max parallel training sessions
|
||||
self.intensive_training_epochs = 50 # Epochs per negative case
|
||||
self.priority_multiplier = 2.0 # Training time multiplier for high priority cases
|
||||
|
||||
# Simultaneous inference/training control
|
||||
self.inference_active = True
|
||||
self.training_active = False
|
||||
self.current_training_sessions: List[TrainingSession] = []
|
||||
|
||||
# Performance tracking
|
||||
self.total_cases_processed = 0
|
||||
self.total_training_time = 0.0
|
||||
self.accuracy_improvements = []
|
||||
|
||||
# Initialize storage
|
||||
self._initialize_storage()
|
||||
self._load_existing_cases()
|
||||
|
||||
# Load best checkpoint if available
|
||||
if self.enable_checkpoints:
|
||||
self.load_best_checkpoint()
|
||||
|
||||
# Start background training thread
|
||||
self.training_thread = threading.Thread(target=self._background_training_loop, daemon=True)
|
||||
self.training_thread.start()
|
||||
|
||||
logger.info(f"NegativeCaseTrainer initialized with {len(self.stored_cases)} existing cases")
|
||||
logger.info(f"Storage directory: {self.storage_dir}")
|
||||
logger.info(f"Checkpoint management: {enable_checkpoints}, Model name: {model_name}")
|
||||
logger.info("Background training thread started")
|
||||
|
||||
def _initialize_storage(self):
|
||||
"""Initialize storage directories"""
|
||||
try:
|
||||
os.makedirs(self.storage_dir, exist_ok=True)
|
||||
os.makedirs(f"{self.storage_dir}/cases", exist_ok=True)
|
||||
os.makedirs(f"{self.storage_dir}/sessions", exist_ok=True)
|
||||
os.makedirs(f"{self.storage_dir}/models", exist_ok=True)
|
||||
|
||||
# Create index file if it doesn't exist
|
||||
index_file = f"{self.storage_dir}/case_index.json"
|
||||
if not os.path.exists(index_file):
|
||||
with open(index_file, 'w') as f:
|
||||
json.dump({"cases": [], "last_updated": datetime.now().isoformat()}, f)
|
||||
|
||||
logger.info(f"Storage initialized at {self.storage_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing storage: {e}")
|
||||
|
||||
def _load_existing_cases(self):
|
||||
"""Load existing negative cases from storage"""
|
||||
try:
|
||||
index_file = f"{self.storage_dir}/case_index.json"
|
||||
if os.path.exists(index_file):
|
||||
with open(index_file, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
|
||||
for case_info in index_data.get("cases", []):
|
||||
case_file = f"{self.storage_dir}/cases/{case_info['case_id']}.pkl"
|
||||
if os.path.exists(case_file):
|
||||
try:
|
||||
with open(case_file, 'rb') as f:
|
||||
case = pickle.load(f)
|
||||
self.stored_cases.append(case)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading case {case_info['case_id']}: {e}")
|
||||
|
||||
logger.info(f"Loaded {len(self.stored_cases)} existing negative cases")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading existing cases: {e}")
|
||||
|
||||
def add_losing_trade(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Add a losing trade as a negative case for intensive training
|
||||
|
||||
Args:
|
||||
trade_info: Trade information including P&L
|
||||
market_data: Market state and tick data around the trade
|
||||
|
||||
Returns:
|
||||
case_id: Unique identifier for the negative case
|
||||
"""
|
||||
try:
|
||||
# Generate unique case ID
|
||||
case_id = f"loss_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{trade_info['symbol'].replace('/', '')}"
|
||||
|
||||
# Calculate loss metrics
|
||||
loss_amount = abs(trade_info.get('pnl', 0))
|
||||
loss_percentage = (loss_amount / trade_info.get('value', 1)) * 100
|
||||
|
||||
# Determine training priority based on loss size
|
||||
if loss_percentage > 10:
|
||||
priority = 5 # Critical loss
|
||||
elif loss_percentage > 5:
|
||||
priority = 4 # High loss
|
||||
elif loss_percentage > 2:
|
||||
priority = 3 # Medium loss
|
||||
elif loss_percentage > 1:
|
||||
priority = 2 # Small loss
|
||||
else:
|
||||
priority = 1 # Minimal loss
|
||||
|
||||
# Analyze what should have been done
|
||||
what_should_have_been_done = self._analyze_optimal_action(trade_info, market_data)
|
||||
lesson_learned = self._generate_lesson(trade_info, market_data, what_should_have_been_done)
|
||||
|
||||
# Create negative case
|
||||
negative_case = NegativeCase(
|
||||
case_id=case_id,
|
||||
timestamp=trade_info['timestamp'],
|
||||
symbol=trade_info['symbol'],
|
||||
action=trade_info['action'],
|
||||
entry_price=trade_info['price'],
|
||||
exit_price=market_data.get('exit_price', trade_info['price']),
|
||||
loss_amount=loss_amount,
|
||||
loss_percentage=loss_percentage,
|
||||
confidence_used=trade_info.get('confidence', 0.5),
|
||||
market_state_before=market_data.get('state_before', {}),
|
||||
market_state_after=market_data.get('state_after', {}),
|
||||
tick_data=market_data.get('tick_data', []),
|
||||
technical_indicators=market_data.get('technical_indicators', {}),
|
||||
what_should_have_been_done=what_should_have_been_done,
|
||||
lesson_learned=lesson_learned,
|
||||
training_priority=priority
|
||||
)
|
||||
|
||||
# Store the case
|
||||
self._store_case(negative_case)
|
||||
|
||||
# Add to training queue with priority
|
||||
with self.training_lock:
|
||||
self.training_queue.append(negative_case)
|
||||
self.stored_cases.append(negative_case)
|
||||
|
||||
logger.error(f"NEGATIVE CASE ADDED: {case_id} | Loss: ${loss_amount:.2f} ({loss_percentage:.1f}%) | Priority: {priority}")
|
||||
logger.error(f"Lesson: {lesson_learned}")
|
||||
|
||||
return case_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding losing trade: {e}")
|
||||
return ""
|
||||
|
||||
def _analyze_optimal_action(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str:
|
||||
"""Analyze what the optimal action should have been"""
|
||||
try:
|
||||
# Simple analysis based on price movement
|
||||
entry_price = trade_info['price']
|
||||
exit_price = market_data.get('exit_price', entry_price)
|
||||
action = trade_info['action']
|
||||
|
||||
price_change = (exit_price - entry_price) / entry_price
|
||||
|
||||
if action == 'BUY' and price_change < 0:
|
||||
# Bought but price went down
|
||||
if abs(price_change) > 0.005: # >0.5% move
|
||||
return 'SELL' # Should have sold instead
|
||||
else:
|
||||
return 'HOLD' # Should have waited
|
||||
elif action == 'SELL' and price_change > 0:
|
||||
# Sold but price went up
|
||||
if price_change > 0.005: # >0.5% move
|
||||
return 'BUY' # Should have bought instead
|
||||
else:
|
||||
return 'HOLD' # Should have waited
|
||||
else:
|
||||
return 'HOLD' # Should have done nothing
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing optimal action: {e}")
|
||||
return 'HOLD'
|
||||
|
||||
def _generate_lesson(self, trade_info: Dict[str, Any], market_data: Dict[str, Any], optimal_action: str) -> str:
|
||||
"""Generate a lesson learned from the losing trade"""
|
||||
try:
|
||||
action = trade_info['action']
|
||||
symbol = trade_info['symbol']
|
||||
loss_pct = (abs(trade_info.get('pnl', 0)) / trade_info.get('value', 1)) * 100
|
||||
confidence = trade_info.get('confidence', 0.5)
|
||||
|
||||
if optimal_action == 'HOLD':
|
||||
return f"Should have HELD {symbol} instead of {action}. Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss."
|
||||
elif optimal_action == 'BUY' and action == 'SELL':
|
||||
return f"Should have BOUGHT {symbol} instead of SELLING. Market moved opposite to prediction."
|
||||
elif optimal_action == 'SELL' and action == 'BUY':
|
||||
return f"Should have SOLD {symbol} instead of BUYING. Market moved opposite to prediction."
|
||||
else:
|
||||
return f"Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss on {action} {symbol}."
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating lesson: {e}")
|
||||
return "Learn from this loss to improve future decisions."
|
||||
|
||||
def _store_case(self, case: NegativeCase):
|
||||
"""Store negative case to persistent storage"""
|
||||
try:
|
||||
# Store case file
|
||||
case_file = f"{self.storage_dir}/cases/{case.case_id}.pkl"
|
||||
with open(case_file, 'wb') as f:
|
||||
pickle.dump(case, f)
|
||||
|
||||
# Update index
|
||||
index_file = f"{self.storage_dir}/case_index.json"
|
||||
with open(index_file, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
|
||||
# Add case to index
|
||||
case_info = {
|
||||
'case_id': case.case_id,
|
||||
'timestamp': case.timestamp.isoformat(),
|
||||
'symbol': case.symbol,
|
||||
'loss_amount': case.loss_amount,
|
||||
'loss_percentage': case.loss_percentage,
|
||||
'training_priority': case.training_priority,
|
||||
'retraining_count': case.retraining_count
|
||||
}
|
||||
|
||||
index_data['cases'].append(case_info)
|
||||
index_data['last_updated'] = datetime.now().isoformat()
|
||||
|
||||
with open(index_file, 'w') as f:
|
||||
json.dump(index_data, f, indent=2)
|
||||
|
||||
logger.info(f"Stored negative case: {case.case_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing case: {e}")
|
||||
|
||||
def _background_training_loop(self):
|
||||
"""Background loop for intensive training on negative cases"""
|
||||
logger.info("Background training loop started")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Check if we have cases to train on
|
||||
with self.training_lock:
|
||||
if not self.training_queue:
|
||||
time.sleep(5) # Wait for new cases
|
||||
continue
|
||||
|
||||
# Get highest priority case
|
||||
cases_by_priority = sorted(self.training_queue, key=lambda x: x.training_priority, reverse=True)
|
||||
case_to_train = cases_by_priority[0]
|
||||
self.training_queue.remove(case_to_train)
|
||||
|
||||
# Start intensive training session
|
||||
self._start_intensive_training_session(case_to_train)
|
||||
|
||||
# Brief pause between training sessions
|
||||
time.sleep(2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background training loop: {e}")
|
||||
time.sleep(10) # Wait longer on error
|
||||
|
||||
def _start_intensive_training_session(self, case: NegativeCase):
|
||||
"""Start an intensive training session for a negative case"""
|
||||
try:
|
||||
session_id = f"session_{case.case_id}_{int(time.time())}"
|
||||
|
||||
# Create training session
|
||||
session = TrainingSession(
|
||||
session_id=session_id,
|
||||
start_time=datetime.now(),
|
||||
cases_trained=[case.case_id],
|
||||
epochs_completed=0,
|
||||
loss_improvement=0.0,
|
||||
accuracy_improvement=0.0
|
||||
)
|
||||
|
||||
self.current_training_sessions.append(session)
|
||||
self.training_active = True
|
||||
|
||||
logger.warning(f"INTENSIVE TRAINING STARTED: {session_id}")
|
||||
logger.warning(f"Training on loss case: {case.case_id} (Priority: {case.training_priority})")
|
||||
|
||||
# Calculate training epochs based on priority
|
||||
epochs = int(self.intensive_training_epochs * case.training_priority * self.priority_multiplier)
|
||||
|
||||
# Simulate intensive training (replace with actual model training)
|
||||
for epoch in range(epochs):
|
||||
# Pause inference during critical training phases
|
||||
if case.training_priority >= 4 and epoch % 10 == 0:
|
||||
with self.inference_lock:
|
||||
session.inference_paused = True
|
||||
time.sleep(0.1) # Brief pause for critical training
|
||||
session.inference_paused = False
|
||||
|
||||
# Simulate training step
|
||||
session.epochs_completed = epoch + 1
|
||||
|
||||
# Log progress for high priority cases
|
||||
if case.training_priority >= 4 and epoch % 10 == 0:
|
||||
logger.warning(f"Intensive training progress: {epoch}/{epochs} epochs ({case.case_id})")
|
||||
|
||||
time.sleep(0.05) # Simulate training time
|
||||
|
||||
# Update case retraining info
|
||||
case.retraining_count += 1
|
||||
case.last_retrained = datetime.now()
|
||||
|
||||
# Calculate improvements (simulated)
|
||||
session.loss_improvement = np.random.uniform(0.1, 0.5) # 10-50% improvement
|
||||
session.accuracy_improvement = np.random.uniform(0.05, 0.2) # 5-20% improvement
|
||||
|
||||
# Store training session results
|
||||
self._store_training_session(session)
|
||||
|
||||
# Update statistics
|
||||
self.total_cases_processed += 1
|
||||
self.total_training_time += (datetime.now() - session.start_time).total_seconds()
|
||||
self.accuracy_improvements.append(session.accuracy_improvement)
|
||||
|
||||
# Remove from active sessions
|
||||
self.current_training_sessions.remove(session)
|
||||
if not self.current_training_sessions:
|
||||
self.training_active = False
|
||||
|
||||
logger.warning(f"INTENSIVE TRAINING COMPLETED: {session_id}")
|
||||
logger.warning(f"Epochs: {session.epochs_completed} | Loss improvement: {session.loss_improvement:.1%} | Accuracy improvement: {session.accuracy_improvement:.1%}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in intensive training session: {e}")
|
||||
|
||||
def _store_training_session(self, session: TrainingSession):
|
||||
"""Store training session results"""
|
||||
try:
|
||||
session_file = f"{self.storage_dir}/sessions/{session.session_id}.json"
|
||||
session_data = {
|
||||
'session_id': session.session_id,
|
||||
'start_time': session.start_time.isoformat(),
|
||||
'end_time': datetime.now().isoformat(),
|
||||
'cases_trained': session.cases_trained,
|
||||
'epochs_completed': session.epochs_completed,
|
||||
'loss_improvement': session.loss_improvement,
|
||||
'accuracy_improvement': session.accuracy_improvement
|
||||
}
|
||||
|
||||
with open(session_file, 'w') as f:
|
||||
json.dump(session_data, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing training session: {e}")
|
||||
|
||||
def can_inference_proceed(self) -> bool:
|
||||
"""Check if inference can proceed (not blocked by critical training)"""
|
||||
with self.inference_lock:
|
||||
# Check if any critical training is pausing inference
|
||||
for session in self.current_training_sessions:
|
||||
if session.inference_paused:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_training_stats(self) -> Dict[str, Any]:
|
||||
"""Get training statistics"""
|
||||
try:
|
||||
avg_accuracy_improvement = np.mean(self.accuracy_improvements) if self.accuracy_improvements else 0.0
|
||||
|
||||
return {
|
||||
'total_negative_cases': len(self.stored_cases),
|
||||
'cases_in_queue': len(self.training_queue),
|
||||
'total_cases_processed': self.total_cases_processed,
|
||||
'total_training_time': self.total_training_time,
|
||||
'avg_accuracy_improvement': avg_accuracy_improvement,
|
||||
'active_training_sessions': len(self.current_training_sessions),
|
||||
'training_active': self.training_active,
|
||||
'high_priority_cases': len([c for c in self.stored_cases if c.training_priority >= 4]),
|
||||
'storage_directory': self.storage_dir
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting training stats: {e}")
|
||||
return {}
|
||||
|
||||
def get_recent_lessons(self, count: int = 5) -> List[str]:
|
||||
"""Get recent lessons learned from negative cases"""
|
||||
try:
|
||||
recent_cases = sorted(self.stored_cases, key=lambda x: x.timestamp, reverse=True)[:count]
|
||||
return [case.lesson_learned for case in recent_cases]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting recent lessons: {e}")
|
||||
return []
|
||||
|
||||
def retrain_all_cases(self):
|
||||
"""Retrain all stored negative cases (for periodic retraining)"""
|
||||
try:
|
||||
logger.warning("RETRAINING ALL NEGATIVE CASES - This may take a while...")
|
||||
|
||||
with self.training_lock:
|
||||
# Add all stored cases back to training queue
|
||||
for case in self.stored_cases:
|
||||
if case not in self.training_queue:
|
||||
self.training_queue.append(case)
|
||||
|
||||
logger.warning(f"Added {len(self.stored_cases)} cases to retraining queue")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retraining all cases: {e}")
|
||||
|
||||
def load_best_checkpoint(self):
|
||||
"""Load the best checkpoint for this negative case trainer"""
|
||||
try:
|
||||
if not self.enable_checkpoints:
|
||||
return
|
||||
|
||||
result = load_best_checkpoint(self.model_name)
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
checkpoint = torch.load(file_path, map_location='cpu')
|
||||
|
||||
# Load training state
|
||||
if 'training_session_count' in checkpoint:
|
||||
self.training_session_count = checkpoint['training_session_count']
|
||||
if 'best_loss_reduction' in checkpoint:
|
||||
self.best_loss_reduction = checkpoint['best_loss_reduction']
|
||||
if 'total_cases_processed' in checkpoint:
|
||||
self.total_cases_processed = checkpoint['total_cases_processed']
|
||||
if 'total_training_time' in checkpoint:
|
||||
self.total_training_time = checkpoint['total_training_time']
|
||||
if 'accuracy_improvements' in checkpoint:
|
||||
self.accuracy_improvements = checkpoint['accuracy_improvements']
|
||||
|
||||
logger.info(f"Loaded NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
|
||||
logger.info(f"Session: {self.training_session_count}, Best loss reduction: {self.best_loss_reduction:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
|
||||
|
||||
def save_checkpoint(self, loss_improvement: float = 0.0, force_save: bool = False):
|
||||
"""Save checkpoint if performance improved or forced"""
|
||||
try:
|
||||
if not self.enable_checkpoints:
|
||||
return False
|
||||
|
||||
self.training_session_count += 1
|
||||
|
||||
# Update best loss reduction
|
||||
improved = False
|
||||
if loss_improvement > self.best_loss_reduction:
|
||||
self.best_loss_reduction = loss_improvement
|
||||
improved = True
|
||||
|
||||
# Save checkpoint if improved, forced, or at regular intervals
|
||||
should_save = (
|
||||
force_save or
|
||||
improved or
|
||||
self.training_session_count % self.checkpoint_frequency == 0
|
||||
)
|
||||
|
||||
if should_save:
|
||||
# Prepare checkpoint data
|
||||
checkpoint_data = {
|
||||
'training_session_count': self.training_session_count,
|
||||
'best_loss_reduction': self.best_loss_reduction,
|
||||
'total_cases_processed': self.total_cases_processed,
|
||||
'total_training_time': self.total_training_time,
|
||||
'accuracy_improvements': self.accuracy_improvements,
|
||||
'storage_dir': self.storage_dir,
|
||||
'max_concurrent_training': self.max_concurrent_training,
|
||||
'intensive_training_epochs': self.intensive_training_epochs
|
||||
}
|
||||
|
||||
# Create performance metrics for checkpoint manager
|
||||
avg_accuracy_improvement = (
|
||||
sum(self.accuracy_improvements) / len(self.accuracy_improvements)
|
||||
if self.accuracy_improvements else 0.0
|
||||
)
|
||||
|
||||
performance_metrics = {
|
||||
'loss_reduction': self.best_loss_reduction,
|
||||
'avg_accuracy_improvement': avg_accuracy_improvement,
|
||||
'total_cases_processed': self.total_cases_processed,
|
||||
'training_efficiency': (
|
||||
self.total_cases_processed / self.total_training_time
|
||||
if self.total_training_time > 0 else 0.0
|
||||
)
|
||||
}
|
||||
|
||||
# Save using checkpoint manager
|
||||
metadata = save_checkpoint(
|
||||
model=checkpoint_data, # We're saving data dict instead of model
|
||||
model_name=self.model_name,
|
||||
model_type="negative_case_trainer",
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata={
|
||||
'session': self.training_session_count,
|
||||
'cases_processed': self.total_cases_processed,
|
||||
'training_time_hours': self.total_training_time / 3600
|
||||
},
|
||||
force_save=force_save
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f"Saved NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving NegativeCaseTrainer checkpoint: {e}")
|
||||
return False
|
||||
@@ -1,277 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Neural Network Decision Fusion System
|
||||
Central NN that merges all model outputs + market data for final trading decisions
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ModelPrediction:
|
||||
"""Standardized prediction from any model"""
|
||||
model_name: str
|
||||
prediction_type: str # 'price', 'direction', 'action'
|
||||
value: float # -1 to 1 for direction, actual price for price predictions
|
||||
confidence: float # 0 to 1
|
||||
timestamp: datetime
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
@dataclass
|
||||
class MarketContext:
|
||||
"""Current market context for decision fusion"""
|
||||
symbol: str
|
||||
current_price: float
|
||||
price_change_1m: float
|
||||
price_change_5m: float
|
||||
volume_ratio: float
|
||||
volatility: float
|
||||
timestamp: datetime
|
||||
|
||||
@dataclass
|
||||
class FusionDecision:
|
||||
"""Final trading decision from fusion NN"""
|
||||
action: str # 'BUY', 'SELL', 'HOLD'
|
||||
confidence: float # 0 to 1
|
||||
expected_return: float # Expected return percentage
|
||||
risk_score: float # 0 to 1, higher = riskier
|
||||
position_size: float # Recommended position size
|
||||
reasoning: str # Human-readable explanation
|
||||
model_contributions: Dict[str, float] # How much each model contributed
|
||||
timestamp: datetime
|
||||
|
||||
class DecisionFusionNetwork(nn.Module):
|
||||
"""Small NN that fuses model predictions with market context"""
|
||||
|
||||
def __init__(self, input_dim: int = 32, hidden_dim: int = 64):
|
||||
super().__init__()
|
||||
|
||||
self.fusion_layers = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim // 2, 16)
|
||||
)
|
||||
|
||||
# Output heads
|
||||
self.action_head = nn.Linear(16, 3) # BUY, SELL, HOLD
|
||||
self.confidence_head = nn.Linear(16, 1)
|
||||
self.return_head = nn.Linear(16, 1)
|
||||
|
||||
def forward(self, features: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
"""Forward pass through fusion network"""
|
||||
fusion_output = self.fusion_layers(features)
|
||||
|
||||
action_logits = self.action_head(fusion_output)
|
||||
action_probs = F.softmax(action_logits, dim=1)
|
||||
|
||||
confidence = torch.sigmoid(self.confidence_head(fusion_output))
|
||||
expected_return = torch.tanh(self.return_head(fusion_output))
|
||||
|
||||
return {
|
||||
'action_probs': action_probs,
|
||||
'confidence': confidence.squeeze(),
|
||||
'expected_return': expected_return.squeeze()
|
||||
}
|
||||
|
||||
class NeuralDecisionFusion:
|
||||
"""Main NN-based decision fusion system"""
|
||||
|
||||
def __init__(self, training_mode: bool = True):
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.network = DecisionFusionNetwork().to(self.device)
|
||||
self.training_mode = training_mode
|
||||
self.registered_models = {}
|
||||
self.last_predictions = {}
|
||||
|
||||
logger.info(f"Neural Decision Fusion initialized on {self.device}")
|
||||
|
||||
def register_model(self, model_name: str, model_type: str, prediction_format: str):
|
||||
"""Register a model that will provide predictions"""
|
||||
self.registered_models[model_name] = {
|
||||
'type': model_type,
|
||||
'format': prediction_format,
|
||||
'prediction_count': 0
|
||||
}
|
||||
logger.info(f"Registered NN model: {model_name} ({model_type})")
|
||||
|
||||
def add_prediction(self, prediction: ModelPrediction):
|
||||
"""Add a prediction from a registered model"""
|
||||
self.last_predictions[prediction.model_name] = prediction
|
||||
if prediction.model_name in self.registered_models:
|
||||
self.registered_models[prediction.model_name]['prediction_count'] += 1
|
||||
|
||||
logger.debug(f"🔮 {prediction.model_name}: {prediction.value:.3f} "
|
||||
f"(confidence: {prediction.confidence:.3f})")
|
||||
|
||||
def make_decision(self, symbol: str, market_context: MarketContext,
|
||||
min_confidence: float = 0.25) -> Optional[FusionDecision]:
|
||||
"""Make NN-driven trading decision"""
|
||||
try:
|
||||
if len(self.last_predictions) < 1:
|
||||
logger.debug("No NN predictions available")
|
||||
return None
|
||||
|
||||
# Prepare features
|
||||
features = self._prepare_features(market_context)
|
||||
if features is None:
|
||||
return None
|
||||
|
||||
# Run NN inference
|
||||
with torch.no_grad():
|
||||
self.network.eval()
|
||||
features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(self.device)
|
||||
outputs = self.network(features_tensor)
|
||||
|
||||
action_probs = outputs['action_probs'][0].cpu().numpy()
|
||||
confidence = outputs['confidence'].cpu().item()
|
||||
expected_return = outputs['expected_return'].cpu().item()
|
||||
|
||||
# Determine action
|
||||
action_idx = np.argmax(action_probs)
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
action = actions[action_idx]
|
||||
|
||||
# Check confidence threshold
|
||||
if confidence < min_confidence:
|
||||
action = 'HOLD'
|
||||
logger.debug(f"Low NN confidence ({confidence:.3f}), defaulting to HOLD")
|
||||
|
||||
# Calculate position size
|
||||
position_size = self._calculate_position_size(confidence, expected_return)
|
||||
|
||||
# Generate reasoning
|
||||
reasoning = self._generate_reasoning(action, confidence, expected_return, action_probs)
|
||||
|
||||
# Calculate risk score and model contributions
|
||||
risk_score = min(1.0, abs(expected_return) * 5 + (1 - confidence) * 0.5)
|
||||
model_contributions = self._calculate_model_contributions()
|
||||
|
||||
decision = FusionDecision(
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
expected_return=expected_return,
|
||||
risk_score=risk_score,
|
||||
position_size=position_size,
|
||||
reasoning=reasoning,
|
||||
model_contributions=model_contributions,
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
logger.info(f"🧠 NN DECISION: {action} (conf: {confidence:.3f}, "
|
||||
f"return: {expected_return:.3f}, size: {position_size:.4f})")
|
||||
|
||||
return decision
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in NN decision making: {e}")
|
||||
return None
|
||||
|
||||
def _prepare_features(self, context: MarketContext) -> Optional[np.ndarray]:
|
||||
"""Prepare feature vector for NN"""
|
||||
try:
|
||||
features = np.zeros(32)
|
||||
|
||||
# Model predictions (slots 0-15)
|
||||
idx = 0
|
||||
for model_name, prediction in self.last_predictions.items():
|
||||
if idx < 14: # Leave room for other features
|
||||
features[idx] = prediction.value
|
||||
features[idx + 1] = prediction.confidence
|
||||
idx += 2
|
||||
|
||||
# Market context (slots 16-31)
|
||||
features[16] = np.tanh(context.price_change_1m * 100) # 1m change
|
||||
features[17] = np.tanh(context.price_change_5m * 100) # 5m change
|
||||
features[18] = np.tanh(context.volume_ratio - 1) # Volume ratio
|
||||
features[19] = np.tanh(context.volatility * 100) # Volatility
|
||||
features[20] = context.current_price / 10000.0 # Normalized price
|
||||
|
||||
# Time features
|
||||
now = context.timestamp
|
||||
features[21] = now.hour / 24.0
|
||||
features[22] = now.weekday() / 7.0
|
||||
|
||||
# Model agreement features
|
||||
if len(self.last_predictions) >= 2:
|
||||
values = [p.value for p in self.last_predictions.values()]
|
||||
features[23] = np.mean(values) # Average prediction
|
||||
features[24] = np.std(values) # Prediction variance
|
||||
features[25] = len(self.last_predictions) # Model count
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing NN features: {e}")
|
||||
return None
|
||||
|
||||
def _calculate_position_size(self, confidence: float, expected_return: float) -> float:
|
||||
"""Calculate position size based on NN outputs"""
|
||||
base_size = 0.01 # 0.01 ETH base
|
||||
|
||||
# Scale by confidence
|
||||
confidence_multiplier = max(0.1, min(2.0, confidence * 1.5))
|
||||
|
||||
# Scale by expected return
|
||||
return_multiplier = 1.0 + abs(expected_return) * 0.5
|
||||
|
||||
final_size = base_size * confidence_multiplier * return_multiplier
|
||||
return max(0.001, min(0.05, final_size))
|
||||
|
||||
def _generate_reasoning(self, action: str, confidence: float,
|
||||
expected_return: float, action_probs: np.ndarray) -> str:
|
||||
"""Generate human-readable reasoning"""
|
||||
reasons = []
|
||||
|
||||
if action == 'BUY':
|
||||
reasons.append(f"NN suggests BUY ({action_probs[0]:.1%})")
|
||||
elif action == 'SELL':
|
||||
reasons.append(f"NN suggests SELL ({action_probs[1]:.1%})")
|
||||
else:
|
||||
reasons.append(f"NN suggests HOLD")
|
||||
|
||||
if confidence > 0.7:
|
||||
reasons.append("High confidence")
|
||||
elif confidence > 0.5:
|
||||
reasons.append("Moderate confidence")
|
||||
else:
|
||||
reasons.append("Low confidence")
|
||||
|
||||
if abs(expected_return) > 0.01:
|
||||
direction = "positive" if expected_return > 0 else "negative"
|
||||
reasons.append(f"Expected {direction} return: {expected_return:.2%}")
|
||||
|
||||
reasons.append(f"Based on {len(self.last_predictions)} NN models")
|
||||
|
||||
return " | ".join(reasons)
|
||||
|
||||
def _calculate_model_contributions(self) -> Dict[str, float]:
|
||||
"""Calculate how much each model contributed to the decision"""
|
||||
contributions = {}
|
||||
total_confidence = sum(p.confidence for p in self.last_predictions.values()) if self.last_predictions else 1.0
|
||||
|
||||
if total_confidence > 0:
|
||||
for model_name, prediction in self.last_predictions.items():
|
||||
contributions[model_name] = prediction.confidence / total_confidence
|
||||
|
||||
return contributions
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get NN fusion system status"""
|
||||
return {
|
||||
'device': str(self.device),
|
||||
'training_mode': self.training_mode,
|
||||
'registered_models': len(self.registered_models),
|
||||
'recent_predictions': len(self.last_predictions),
|
||||
'model_parameters': sum(p.numel() for p in self.network.parameters())
|
||||
}
|
||||
Binary file not shown.
@@ -1,649 +0,0 @@
|
||||
"""
|
||||
Real-Time Tick Processing Neural Network Module
|
||||
|
||||
This module acts as a Neural Network DPS (Data Processing System) alternative,
|
||||
processing raw tick data with ultra-low latency and feeding processed features
|
||||
to trading models in real-time.
|
||||
|
||||
Features:
|
||||
- Real-time tick ingestion with volume processing
|
||||
- Neural network feature extraction from tick streams
|
||||
- Ultra-low latency processing (sub-millisecond)
|
||||
- Volume-weighted price analysis
|
||||
- Microstructure pattern detection
|
||||
- Real-time feature streaming to models
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Deque
|
||||
from collections import deque
|
||||
from threading import Thread, Lock
|
||||
import websockets
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TickData:
|
||||
"""Raw tick data structure"""
|
||||
timestamp: datetime
|
||||
price: float
|
||||
volume: float
|
||||
side: str # 'buy' or 'sell'
|
||||
trade_id: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class ProcessedTickFeatures:
|
||||
"""Processed tick features for model consumption"""
|
||||
timestamp: datetime
|
||||
price_features: np.ndarray # Price-based features
|
||||
volume_features: np.ndarray # Volume-based features
|
||||
microstructure_features: np.ndarray # Market microstructure features
|
||||
neural_features: np.ndarray # Neural network extracted features
|
||||
confidence: float # Feature quality confidence
|
||||
|
||||
class TickProcessingNN(nn.Module):
|
||||
"""
|
||||
Neural Network for real-time tick processing
|
||||
Extracts high-level features from raw tick data
|
||||
"""
|
||||
|
||||
def __init__(self, input_size: int = 9, hidden_size: int = 128, output_size: int = 64):
|
||||
super(TickProcessingNN, self).__init__()
|
||||
|
||||
# Tick sequence processing layers
|
||||
self.tick_encoder = nn.Sequential(
|
||||
nn.Linear(input_size, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1)
|
||||
)
|
||||
|
||||
# LSTM for temporal patterns
|
||||
self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, num_layers=2)
|
||||
|
||||
# Attention mechanism for important tick selection
|
||||
self.attention = nn.MultiheadAttention(hidden_size, num_heads=8, batch_first=True)
|
||||
|
||||
# Feature extraction heads
|
||||
self.price_head = nn.Linear(hidden_size, 16) # Price pattern features
|
||||
self.volume_head = nn.Linear(hidden_size, 16) # Volume pattern features
|
||||
self.microstructure_head = nn.Linear(hidden_size, 16) # Microstructure features
|
||||
|
||||
# Final feature fusion
|
||||
self.feature_fusion = nn.Sequential(
|
||||
nn.Linear(48, output_size), # 16+16+16 = 48
|
||||
nn.ReLU(),
|
||||
nn.Linear(output_size, output_size)
|
||||
)
|
||||
|
||||
# Confidence estimation
|
||||
self.confidence_head = nn.Sequential(
|
||||
nn.Linear(output_size, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, tick_sequence: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Process tick sequence and extract features
|
||||
|
||||
Args:
|
||||
tick_sequence: [batch, sequence_length, features]
|
||||
|
||||
Returns:
|
||||
features: [batch, output_size] - extracted features
|
||||
confidence: [batch, 1] - feature confidence
|
||||
"""
|
||||
batch_size, seq_len, _ = tick_sequence.shape
|
||||
|
||||
# Encode each tick
|
||||
encoded = self.tick_encoder(tick_sequence) # [batch, seq_len, hidden_size]
|
||||
|
||||
# LSTM processing for temporal patterns
|
||||
lstm_out, _ = self.lstm(encoded) # [batch, seq_len, hidden_size]
|
||||
|
||||
# Attention to focus on important ticks
|
||||
attended, _ = self.attention(lstm_out, lstm_out, lstm_out) # [batch, seq_len, hidden_size]
|
||||
|
||||
# Use the last attended output
|
||||
final_features = attended[:, -1, :] # [batch, hidden_size]
|
||||
|
||||
# Extract specialized features
|
||||
price_features = self.price_head(final_features)
|
||||
volume_features = self.volume_head(final_features)
|
||||
microstructure_features = self.microstructure_head(final_features)
|
||||
|
||||
# Fuse all features
|
||||
combined_features = torch.cat([price_features, volume_features, microstructure_features], dim=1)
|
||||
final_features = self.feature_fusion(combined_features)
|
||||
|
||||
# Estimate confidence
|
||||
confidence = self.confidence_head(final_features)
|
||||
|
||||
return final_features, confidence
|
||||
|
||||
class RealTimeTickProcessor:
|
||||
"""
|
||||
Real-time tick processing system with neural network feature extraction
|
||||
Acts as a DPS alternative for ultra-low latency tick processing
|
||||
"""
|
||||
|
||||
def __init__(self, symbols: List[str] = None, tick_buffer_size: int = 1000):
|
||||
"""Initialize the real-time tick processor"""
|
||||
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
||||
self.tick_buffer_size = tick_buffer_size
|
||||
|
||||
# Tick storage buffers
|
||||
self.tick_buffers: Dict[str, Deque[TickData]] = {}
|
||||
self.processed_features: Dict[str, Deque[ProcessedTickFeatures]] = {}
|
||||
|
||||
# Initialize buffers for each symbol
|
||||
for symbol in self.symbols:
|
||||
self.tick_buffers[symbol] = deque(maxlen=tick_buffer_size)
|
||||
self.processed_features[symbol] = deque(maxlen=100) # Keep last 100 processed features
|
||||
|
||||
# Neural network for feature extraction
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.tick_nn = TickProcessingNN(input_size=9).to(self.device)
|
||||
self.tick_nn.eval() # Start in evaluation mode
|
||||
|
||||
# Processing parameters
|
||||
self.processing_window = 50 # Number of ticks to process at once
|
||||
self.min_ticks_for_processing = 10 # Minimum ticks before processing
|
||||
|
||||
# Real-time streaming
|
||||
self.streaming = False
|
||||
self.websocket_tasks = {}
|
||||
self.processing_threads = {}
|
||||
|
||||
# Performance tracking
|
||||
self.processing_times = deque(maxlen=1000)
|
||||
self.tick_counts = {symbol: 0 for symbol in self.symbols}
|
||||
|
||||
# Thread safety
|
||||
self.data_lock = Lock()
|
||||
|
||||
# Feature subscribers (models that want real-time features)
|
||||
self.feature_subscribers = []
|
||||
|
||||
logger.info(f"RealTimeTickProcessor initialized for symbols: {self.symbols}")
|
||||
logger.info(f"Neural network device: {self.device}")
|
||||
logger.info(f"Tick buffer size: {tick_buffer_size}")
|
||||
|
||||
def add_feature_subscriber(self, callback):
|
||||
"""Add a callback function to receive processed features"""
|
||||
self.feature_subscribers.append(callback)
|
||||
logger.info(f"Added feature subscriber: {callback.__name__}")
|
||||
|
||||
def remove_feature_subscriber(self, callback):
|
||||
"""Remove a feature subscriber"""
|
||||
if callback in self.feature_subscribers:
|
||||
self.feature_subscribers.remove(callback)
|
||||
logger.info(f"Removed feature subscriber: {callback.__name__}")
|
||||
|
||||
async def start_processing(self):
|
||||
"""Start real-time tick processing"""
|
||||
logger.info("Starting real-time tick processing...")
|
||||
self.streaming = True
|
||||
|
||||
# Start WebSocket streams for each symbol
|
||||
for symbol in self.symbols:
|
||||
task = asyncio.create_task(self._websocket_stream(symbol))
|
||||
self.websocket_tasks[symbol] = task
|
||||
|
||||
# Start processing thread for each symbol
|
||||
thread = Thread(target=self._processing_loop, args=(symbol,), daemon=True)
|
||||
thread.start()
|
||||
self.processing_threads[symbol] = thread
|
||||
|
||||
logger.info("Real-time tick processing started")
|
||||
|
||||
async def stop_processing(self):
|
||||
"""Stop real-time tick processing"""
|
||||
logger.info("Stopping real-time tick processing...")
|
||||
self.streaming = False
|
||||
|
||||
# Cancel WebSocket tasks
|
||||
for symbol, task in self.websocket_tasks.items():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self.websocket_tasks.clear()
|
||||
logger.info("Real-time tick processing stopped")
|
||||
|
||||
async def _websocket_stream(self, symbol: str):
|
||||
"""WebSocket stream for real-time tick data"""
|
||||
binance_symbol = symbol.replace('/', '').lower()
|
||||
url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@trade"
|
||||
|
||||
while self.streaming:
|
||||
try:
|
||||
async with websockets.connect(url) as websocket:
|
||||
logger.info(f"Tick WebSocket connected for {symbol}")
|
||||
|
||||
async for message in websocket:
|
||||
if not self.streaming:
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._process_raw_tick(symbol, data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing tick for {symbol}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for {symbol}: {e}")
|
||||
if self.streaming:
|
||||
logger.info(f"Reconnecting tick WebSocket for {symbol} in 2 seconds...")
|
||||
await asyncio.sleep(2)
|
||||
|
||||
async def _process_raw_tick(self, symbol: str, raw_data: Dict):
|
||||
"""Process raw tick data from WebSocket"""
|
||||
try:
|
||||
# Extract tick information
|
||||
tick = TickData(
|
||||
timestamp=datetime.fromtimestamp(int(raw_data['T']) / 1000),
|
||||
price=float(raw_data['p']),
|
||||
volume=float(raw_data['q']),
|
||||
side='buy' if raw_data['m'] == False else 'sell', # m=true means buyer is market maker (sell)
|
||||
trade_id=raw_data.get('t')
|
||||
)
|
||||
|
||||
# Add to buffer
|
||||
with self.data_lock:
|
||||
self.tick_buffers[symbol].append(tick)
|
||||
self.tick_counts[symbol] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing raw tick for {symbol}: {e}")
|
||||
|
||||
def _processing_loop(self, symbol: str):
|
||||
"""Main processing loop for a symbol"""
|
||||
logger.info(f"Starting processing loop for {symbol}")
|
||||
|
||||
while self.streaming:
|
||||
try:
|
||||
# Check if we have enough ticks to process
|
||||
with self.data_lock:
|
||||
tick_count = len(self.tick_buffers[symbol])
|
||||
|
||||
if tick_count >= self.min_ticks_for_processing:
|
||||
start_time = time.time()
|
||||
|
||||
# Process ticks
|
||||
features = self._extract_neural_features(symbol)
|
||||
|
||||
if features is not None:
|
||||
# Store processed features
|
||||
with self.data_lock:
|
||||
self.processed_features[symbol].append(features)
|
||||
|
||||
# Notify subscribers
|
||||
self._notify_feature_subscribers(symbol, features)
|
||||
|
||||
# Track processing time
|
||||
processing_time = (time.time() - start_time) * 1000 # Convert to ms
|
||||
self.processing_times.append(processing_time)
|
||||
|
||||
if len(self.processing_times) % 100 == 0:
|
||||
avg_time = np.mean(list(self.processing_times))
|
||||
logger.debug(f"RTP: Average processing time: {avg_time:.2f}ms")
|
||||
|
||||
# Small sleep to prevent CPU overload
|
||||
time.sleep(0.001) # 1ms sleep for ultra-low latency
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in processing loop for {symbol}: {e}")
|
||||
time.sleep(0.01) # Longer sleep on error
|
||||
|
||||
def _extract_neural_features(self, symbol: str) -> Optional[ProcessedTickFeatures]:
|
||||
"""Extract neural network features from recent ticks"""
|
||||
try:
|
||||
with self.data_lock:
|
||||
# Get recent ticks
|
||||
recent_ticks = list(self.tick_buffers[symbol])[-self.processing_window:]
|
||||
|
||||
if len(recent_ticks) < self.min_ticks_for_processing:
|
||||
return None
|
||||
|
||||
# Convert ticks to neural network input
|
||||
tick_features = self._ticks_to_features(recent_ticks)
|
||||
|
||||
# Process with neural network
|
||||
with torch.no_grad():
|
||||
tick_tensor = torch.FloatTensor(tick_features).unsqueeze(0).to(self.device)
|
||||
neural_features, confidence = self.tick_nn(tick_tensor)
|
||||
|
||||
neural_features = neural_features.cpu().numpy().flatten()
|
||||
confidence = confidence.cpu().numpy().item()
|
||||
|
||||
# Extract traditional features
|
||||
price_features = self._extract_price_features(recent_ticks)
|
||||
volume_features = self._extract_volume_features(recent_ticks)
|
||||
microstructure_features = self._extract_microstructure_features(recent_ticks)
|
||||
|
||||
# Create processed features object
|
||||
processed = ProcessedTickFeatures(
|
||||
timestamp=recent_ticks[-1].timestamp,
|
||||
price_features=price_features,
|
||||
volume_features=volume_features,
|
||||
microstructure_features=microstructure_features,
|
||||
neural_features=neural_features,
|
||||
confidence=confidence
|
||||
)
|
||||
|
||||
return processed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting neural features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _ticks_to_features(self, ticks: List[TickData]) -> np.ndarray:
|
||||
"""Convert tick data to neural network input features"""
|
||||
features = []
|
||||
|
||||
for i, tick in enumerate(ticks):
|
||||
tick_features = [
|
||||
tick.price,
|
||||
tick.volume,
|
||||
1.0 if tick.side == 'buy' else 0.0, # Buy/sell indicator
|
||||
tick.timestamp.timestamp(), # Timestamp
|
||||
]
|
||||
|
||||
# Add relative features if we have previous ticks
|
||||
if i > 0:
|
||||
prev_tick = ticks[i-1]
|
||||
price_change = (tick.price - prev_tick.price) / prev_tick.price
|
||||
volume_ratio = tick.volume / (prev_tick.volume + 1e-8)
|
||||
time_delta = (tick.timestamp - prev_tick.timestamp).total_seconds()
|
||||
|
||||
tick_features.extend([
|
||||
price_change,
|
||||
volume_ratio,
|
||||
time_delta
|
||||
])
|
||||
else:
|
||||
tick_features.extend([0.0, 1.0, 0.0]) # Default values for first tick
|
||||
|
||||
# Add moving averages if we have enough data
|
||||
if i >= 5:
|
||||
recent_prices = [t.price for t in ticks[max(0, i-4):i+1]]
|
||||
recent_volumes = [t.volume for t in ticks[max(0, i-4):i+1]]
|
||||
|
||||
price_ma = np.mean(recent_prices)
|
||||
volume_ma = np.mean(recent_volumes)
|
||||
|
||||
tick_features.extend([
|
||||
(tick.price - price_ma) / price_ma, # Price deviation from MA
|
||||
(tick.volume - volume_ma) / (volume_ma + 1e-8) # Volume deviation from MA
|
||||
])
|
||||
else:
|
||||
tick_features.extend([0.0, 0.0])
|
||||
|
||||
features.append(tick_features)
|
||||
|
||||
# Pad or truncate to fixed size
|
||||
target_length = self.processing_window
|
||||
if len(features) < target_length:
|
||||
# Pad with zeros
|
||||
padding = [[0.0] * len(features[0])] * (target_length - len(features))
|
||||
features = padding + features
|
||||
elif len(features) > target_length:
|
||||
# Take the most recent ticks
|
||||
features = features[-target_length:]
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
def _extract_price_features(self, ticks: List[TickData]) -> np.ndarray:
|
||||
"""Extract price-based features"""
|
||||
prices = np.array([tick.price for tick in ticks])
|
||||
|
||||
features = [
|
||||
prices[-1], # Current price
|
||||
np.mean(prices), # Average price
|
||||
np.std(prices), # Price volatility
|
||||
np.max(prices), # High
|
||||
np.min(prices), # Low
|
||||
(prices[-1] - prices[0]) / prices[0] if prices[0] != 0 else 0, # Total return
|
||||
]
|
||||
|
||||
# Price momentum features
|
||||
if len(prices) >= 10:
|
||||
short_ma = np.mean(prices[-5:])
|
||||
long_ma = np.mean(prices[-10:])
|
||||
momentum = (short_ma - long_ma) / long_ma if long_ma != 0 else 0
|
||||
features.append(momentum)
|
||||
else:
|
||||
features.append(0.0)
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
def _extract_volume_features(self, ticks: List[TickData]) -> np.ndarray:
|
||||
"""Extract volume-based features"""
|
||||
volumes = np.array([tick.volume for tick in ticks])
|
||||
buy_volumes = np.array([tick.volume for tick in ticks if tick.side == 'buy'])
|
||||
sell_volumes = np.array([tick.volume for tick in ticks if tick.side == 'sell'])
|
||||
|
||||
features = [
|
||||
np.sum(volumes), # Total volume
|
||||
np.mean(volumes), # Average volume
|
||||
np.std(volumes), # Volume volatility
|
||||
np.sum(buy_volumes) if len(buy_volumes) > 0 else 0, # Buy volume
|
||||
np.sum(sell_volumes) if len(sell_volumes) > 0 else 0, # Sell volume
|
||||
]
|
||||
|
||||
# Volume imbalance
|
||||
total_buy = np.sum(buy_volumes) if len(buy_volumes) > 0 else 0
|
||||
total_sell = np.sum(sell_volumes) if len(sell_volumes) > 0 else 0
|
||||
total_volume = total_buy + total_sell
|
||||
|
||||
if total_volume > 0:
|
||||
buy_ratio = total_buy / total_volume
|
||||
volume_imbalance = buy_ratio - 0.5 # -0.5 to 0.5 range
|
||||
else:
|
||||
volume_imbalance = 0.0
|
||||
|
||||
features.append(volume_imbalance)
|
||||
|
||||
# VWAP (Volume Weighted Average Price)
|
||||
if np.sum(volumes) > 0:
|
||||
prices = np.array([tick.price for tick in ticks])
|
||||
vwap = np.sum(prices * volumes) / np.sum(volumes)
|
||||
current_price = ticks[-1].price
|
||||
vwap_deviation = (current_price - vwap) / vwap if vwap != 0 else 0
|
||||
else:
|
||||
vwap_deviation = 0.0
|
||||
|
||||
features.append(vwap_deviation)
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
def _extract_microstructure_features(self, ticks: List[TickData]) -> np.ndarray:
|
||||
"""Extract market microstructure features"""
|
||||
features = []
|
||||
|
||||
# Trade frequency
|
||||
if len(ticks) >= 2:
|
||||
time_deltas = [(ticks[i].timestamp - ticks[i-1].timestamp).total_seconds()
|
||||
for i in range(1, len(ticks))]
|
||||
avg_time_delta = np.mean(time_deltas)
|
||||
trade_frequency = 1.0 / avg_time_delta if avg_time_delta > 0 else 0
|
||||
else:
|
||||
trade_frequency = 0.0
|
||||
|
||||
features.append(trade_frequency)
|
||||
|
||||
# Price impact features
|
||||
prices = [tick.price for tick in ticks]
|
||||
volumes = [tick.volume for tick in ticks]
|
||||
|
||||
if len(prices) >= 3:
|
||||
# Calculate price changes and corresponding volumes
|
||||
price_changes = [(prices[i] - prices[i-1]) / prices[i-1]
|
||||
for i in range(1, len(prices)) if prices[i-1] != 0]
|
||||
corresponding_volumes = volumes[1:len(price_changes)+1]
|
||||
|
||||
if len(price_changes) > 0 and len(corresponding_volumes) > 0:
|
||||
# Simple price impact measure
|
||||
price_impact = np.corrcoef(np.abs(price_changes), corresponding_volumes)[0, 1]
|
||||
if np.isnan(price_impact):
|
||||
price_impact = 0.0
|
||||
else:
|
||||
price_impact = 0.0
|
||||
else:
|
||||
price_impact = 0.0
|
||||
|
||||
features.append(price_impact)
|
||||
|
||||
# Bid-ask spread proxy (using price volatility)
|
||||
if len(prices) >= 5:
|
||||
recent_prices = prices[-5:]
|
||||
spread_proxy = (np.max(recent_prices) - np.min(recent_prices)) / np.mean(recent_prices)
|
||||
else:
|
||||
spread_proxy = 0.0
|
||||
|
||||
features.append(spread_proxy)
|
||||
|
||||
# Order flow imbalance (already calculated in volume features, but different perspective)
|
||||
buy_count = sum(1 for tick in ticks if tick.side == 'buy')
|
||||
sell_count = len(ticks) - buy_count
|
||||
total_trades = len(ticks)
|
||||
|
||||
if total_trades > 0:
|
||||
order_flow_imbalance = (buy_count - sell_count) / total_trades
|
||||
else:
|
||||
order_flow_imbalance = 0.0
|
||||
|
||||
features.append(order_flow_imbalance)
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
def _notify_feature_subscribers(self, symbol: str, features: ProcessedTickFeatures):
|
||||
"""Notify all feature subscribers of new processed features"""
|
||||
for callback in self.feature_subscribers:
|
||||
try:
|
||||
callback(symbol, features)
|
||||
except Exception as e:
|
||||
logger.error(f"Error notifying feature subscriber {callback.__name__}: {e}")
|
||||
|
||||
def get_latest_features(self, symbol: str) -> Optional[ProcessedTickFeatures]:
|
||||
"""Get the latest processed features for a symbol"""
|
||||
with self.data_lock:
|
||||
if symbol in self.processed_features and self.processed_features[symbol]:
|
||||
return self.processed_features[symbol][-1]
|
||||
return None
|
||||
|
||||
def get_processing_stats(self) -> Dict[str, Any]:
|
||||
"""Get processing performance statistics"""
|
||||
stats = {
|
||||
'symbols': self.symbols,
|
||||
'streaming': self.streaming,
|
||||
'tick_counts': dict(self.tick_counts),
|
||||
'buffer_sizes': {symbol: len(self.tick_buffers[symbol]) for symbol in self.symbols},
|
||||
'feature_counts': {symbol: len(self.processed_features[symbol]) for symbol in self.symbols},
|
||||
'subscribers': len(self.feature_subscribers)
|
||||
}
|
||||
|
||||
if self.processing_times:
|
||||
stats['processing_performance'] = {
|
||||
'avg_time_ms': np.mean(list(self.processing_times)),
|
||||
'min_time_ms': np.min(list(self.processing_times)),
|
||||
'max_time_ms': np.max(list(self.processing_times)),
|
||||
'std_time_ms': np.std(list(self.processing_times))
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def train_neural_network(self, training_data: List[Tuple[np.ndarray, np.ndarray]], epochs: int = 100):
|
||||
"""Train the tick processing neural network"""
|
||||
logger.info("Training tick processing neural network...")
|
||||
|
||||
self.tick_nn.train()
|
||||
optimizer = torch.optim.Adam(self.tick_nn.parameters(), lr=0.001)
|
||||
criterion = nn.MSELoss()
|
||||
|
||||
for epoch in range(epochs):
|
||||
total_loss = 0.0
|
||||
|
||||
for batch_features, batch_targets in training_data:
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Convert to tensors
|
||||
features_tensor = torch.FloatTensor(batch_features).to(self.device)
|
||||
targets_tensor = torch.FloatTensor(batch_targets).to(self.device)
|
||||
|
||||
# Forward pass
|
||||
outputs, confidence = self.tick_nn(features_tensor)
|
||||
|
||||
# Calculate loss
|
||||
loss = criterion(outputs, targets_tensor)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if epoch % 10 == 0:
|
||||
avg_loss = total_loss / len(training_data)
|
||||
logger.info(f"Epoch {epoch}/{epochs}, Average Loss: {avg_loss:.6f}")
|
||||
|
||||
self.tick_nn.eval()
|
||||
logger.info("Neural network training completed")
|
||||
|
||||
# Integration with existing orchestrator
|
||||
def integrate_with_orchestrator(orchestrator, tick_processor: RealTimeTickProcessor):
|
||||
"""Integrate tick processor with enhanced orchestrator"""
|
||||
|
||||
def feature_callback(symbol: str, features: ProcessedTickFeatures):
|
||||
"""Callback to feed processed features to orchestrator"""
|
||||
try:
|
||||
# Convert processed features to format expected by orchestrator
|
||||
feature_dict = {
|
||||
'symbol': symbol,
|
||||
'timestamp': features.timestamp,
|
||||
'neural_features': features.neural_features,
|
||||
'price_features': features.price_features,
|
||||
'volume_features': features.volume_features,
|
||||
'microstructure_features': features.microstructure_features,
|
||||
'confidence': features.confidence
|
||||
}
|
||||
|
||||
# Feed to orchestrator's real-time feature processing
|
||||
if hasattr(orchestrator, 'process_realtime_features'):
|
||||
orchestrator.process_realtime_features(feature_dict)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error integrating features with orchestrator: {e}")
|
||||
|
||||
# Add the callback to tick processor
|
||||
tick_processor.add_feature_subscriber(feature_callback)
|
||||
logger.info("Tick processor integrated with orchestrator")
|
||||
|
||||
# Factory function for easy creation
|
||||
def create_realtime_tick_processor(symbols: List[str] = None) -> RealTimeTickProcessor:
|
||||
"""Create and configure a real-time tick processor"""
|
||||
if symbols is None:
|
||||
symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
|
||||
processor = RealTimeTickProcessor(symbols=symbols)
|
||||
logger.info(f"Created RealTimeTickProcessor for symbols: {symbols}")
|
||||
|
||||
return processor
|
||||
@@ -1,453 +0,0 @@
|
||||
"""
|
||||
Retrospective Training System
|
||||
|
||||
This module implements a retrospective training system that:
|
||||
1. Triggers training when trades close with known P&L outcomes
|
||||
2. Uses captured model inputs from trade entry to train models
|
||||
3. Optimizes for profit by learning from profitable vs unprofitable patterns
|
||||
4. Supports simultaneous inference and training without weight reloading
|
||||
5. Implements reinforcement learning with immediate reward feedback
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import queue
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TrainingCase:
|
||||
"""Represents a completed trade case for retrospective training"""
|
||||
case_id: str
|
||||
symbol: str
|
||||
action: str # 'BUY' or 'SELL'
|
||||
entry_price: float
|
||||
exit_price: float
|
||||
entry_time: datetime
|
||||
exit_time: datetime
|
||||
pnl: float
|
||||
fees: float
|
||||
confidence: float
|
||||
model_inputs: Dict[str, Any]
|
||||
market_state: Dict[str, Any]
|
||||
outcome_label: int # 1 for profit, 0 for loss, 2 for breakeven
|
||||
reward_signal: float # Scaled reward for RL training
|
||||
leverage: float = 1.0
|
||||
|
||||
class RetrospectiveTrainer:
|
||||
"""Retrospective training system for real-time model optimization"""
|
||||
|
||||
def __init__(self, orchestrator=None, config: Optional[Dict[str, Any]] = None):
|
||||
"""Initialize the retrospective trainer"""
|
||||
self.orchestrator = orchestrator
|
||||
self.config = config or {}
|
||||
|
||||
# Training configuration
|
||||
self.batch_size = self.config.get('batch_size', 32)
|
||||
self.min_cases_for_training = self.config.get('min_cases_for_training', 5)
|
||||
self.profit_threshold = self.config.get('profit_threshold', 0.0)
|
||||
self.training_frequency = self.config.get('training_frequency_seconds', 120) # 2 minutes
|
||||
self.max_training_cases = self.config.get('max_training_cases', 1000)
|
||||
|
||||
# Training state
|
||||
self.training_queue = queue.Queue()
|
||||
self.completed_cases = deque(maxlen=self.max_training_cases)
|
||||
self.training_stats = {
|
||||
'total_cases': 0,
|
||||
'profitable_cases': 0,
|
||||
'loss_cases': 0,
|
||||
'breakeven_cases': 0,
|
||||
'avg_profit': 0.0,
|
||||
'last_training_time': datetime.now(),
|
||||
'training_sessions': 0,
|
||||
'model_updates': 0
|
||||
}
|
||||
|
||||
# Threading
|
||||
self.training_thread = None
|
||||
self.is_training_active = False
|
||||
self.training_lock = threading.Lock()
|
||||
|
||||
logger.info("RetrospectiveTrainer initialized")
|
||||
logger.info(f"Configuration: batch_size={self.batch_size}, "
|
||||
f"min_cases={self.min_cases_for_training}, "
|
||||
f"training_freq={self.training_frequency}s")
|
||||
|
||||
def add_completed_trade(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> bool:
|
||||
"""Add a completed trade for retrospective training"""
|
||||
try:
|
||||
# Create training case from trade record
|
||||
case = self._create_training_case(trade_record, model_inputs)
|
||||
if case is None:
|
||||
return False
|
||||
|
||||
# Add to completed cases
|
||||
self.completed_cases.append(case)
|
||||
self.training_queue.put(case)
|
||||
|
||||
# Update statistics
|
||||
self.training_stats['total_cases'] += 1
|
||||
if case.outcome_label == 1: # Profit
|
||||
self.training_stats['profitable_cases'] += 1
|
||||
elif case.outcome_label == 0: # Loss
|
||||
self.training_stats['loss_cases'] += 1
|
||||
else: # Breakeven
|
||||
self.training_stats['breakeven_cases'] += 1
|
||||
|
||||
# Calculate running average profit
|
||||
total_pnl = sum(c.pnl for c in self.completed_cases)
|
||||
self.training_stats['avg_profit'] = total_pnl / len(self.completed_cases)
|
||||
|
||||
logger.info(f"RETROSPECTIVE: Added training case {case.case_id} "
|
||||
f"(P&L: ${case.pnl:.3f}, Label: {case.outcome_label})")
|
||||
|
||||
# Trigger training if we have enough cases
|
||||
self._maybe_trigger_training()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding completed trade for retrospective training: {e}")
|
||||
return False
|
||||
|
||||
def _create_training_case(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> Optional[TrainingCase]:
|
||||
"""Create a training case from trade record and model inputs"""
|
||||
try:
|
||||
# Extract trade information
|
||||
symbol = trade_record.get('symbol', 'UNKNOWN')
|
||||
side = trade_record.get('side', 'UNKNOWN')
|
||||
pnl = trade_record.get('pnl', 0.0)
|
||||
fees = trade_record.get('fees', 0.0)
|
||||
confidence = trade_record.get('confidence', 0.0)
|
||||
|
||||
# Calculate net P&L after fees
|
||||
net_pnl = pnl - fees
|
||||
|
||||
# Determine outcome label and reward signal
|
||||
if net_pnl > self.profit_threshold:
|
||||
outcome_label = 1 # Profitable
|
||||
# Scale reward by profit magnitude and confidence
|
||||
reward_signal = min(10.0, net_pnl * confidence * 10) # Amplify for training
|
||||
elif net_pnl < -self.profit_threshold:
|
||||
outcome_label = 0 # Loss
|
||||
# Negative reward scaled by loss magnitude
|
||||
reward_signal = max(-10.0, net_pnl * confidence * 10) # Negative reward
|
||||
else:
|
||||
outcome_label = 2 # Breakeven
|
||||
reward_signal = 0.0
|
||||
|
||||
# Create case ID
|
||||
timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
case_id = f"retro_{timestamp_str}_{symbol.replace('/', '')}_{side}_pnl_{abs(net_pnl):.3f}".replace('.', 'p')
|
||||
|
||||
# Create training case
|
||||
case = TrainingCase(
|
||||
case_id=case_id,
|
||||
symbol=symbol,
|
||||
action=side,
|
||||
entry_price=trade_record.get('entry_price', 0.0),
|
||||
exit_price=trade_record.get('exit_price', 0.0),
|
||||
entry_time=trade_record.get('entry_time', datetime.now()),
|
||||
exit_time=trade_record.get('exit_time', datetime.now()),
|
||||
pnl=net_pnl,
|
||||
fees=fees,
|
||||
confidence=confidence,
|
||||
model_inputs=model_inputs,
|
||||
market_state=model_inputs.get('market_state', {}),
|
||||
outcome_label=outcome_label,
|
||||
reward_signal=reward_signal,
|
||||
leverage=trade_record.get('leverage', 1.0)
|
||||
)
|
||||
|
||||
return case
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating training case: {e}")
|
||||
return None
|
||||
|
||||
def _maybe_trigger_training(self):
|
||||
"""Check if we should trigger a training session"""
|
||||
try:
|
||||
# Check if we have enough cases
|
||||
if len(self.completed_cases) < self.min_cases_for_training:
|
||||
return
|
||||
|
||||
# Check if enough time has passed since last training
|
||||
time_since_last = (datetime.now() - self.training_stats['last_training_time']).total_seconds()
|
||||
if time_since_last < self.training_frequency:
|
||||
return
|
||||
|
||||
# Check if training thread is not already running
|
||||
if self.is_training_active:
|
||||
logger.debug("Training already in progress, skipping trigger")
|
||||
return
|
||||
|
||||
# Start training in background thread
|
||||
self._start_training_session()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking training trigger: {e}")
|
||||
|
||||
def _start_training_session(self):
|
||||
"""Start a training session in background thread"""
|
||||
try:
|
||||
if self.training_thread and self.training_thread.is_alive():
|
||||
logger.debug("Training thread already running")
|
||||
return
|
||||
|
||||
self.training_thread = threading.Thread(
|
||||
target=self._run_training_session,
|
||||
daemon=True,
|
||||
name="RetrospectiveTrainer"
|
||||
)
|
||||
self.training_thread.start()
|
||||
logger.info("RETROSPECTIVE: Started training session")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting training session: {e}")
|
||||
|
||||
def _run_training_session(self):
|
||||
"""Run a complete training session"""
|
||||
try:
|
||||
with self.training_lock:
|
||||
self.is_training_active = True
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(f"RETROSPECTIVE: Training with {len(self.completed_cases)} cases")
|
||||
|
||||
# Train models if orchestrator available
|
||||
training_results = {}
|
||||
if self.orchestrator:
|
||||
training_results = self._train_models()
|
||||
|
||||
# Update statistics
|
||||
self.training_stats['last_training_time'] = datetime.now()
|
||||
self.training_stats['training_sessions'] += 1
|
||||
self.training_stats['model_updates'] += len(training_results)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"RETROSPECTIVE: Training completed in {elapsed_time:.2f}s - {training_results}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in retrospective training session: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
self.is_training_active = False
|
||||
|
||||
def _train_models(self) -> Dict[str, Any]:
|
||||
"""Train available models using retrospective data"""
|
||||
results = {}
|
||||
|
||||
try:
|
||||
# Prepare training data
|
||||
profitable_cases = [c for c in self.completed_cases if c.outcome_label == 1]
|
||||
loss_cases = [c for c in self.completed_cases if c.outcome_label == 0]
|
||||
|
||||
if len(profitable_cases) == 0 and len(loss_cases) == 0:
|
||||
return {'error': 'No labeled cases for training'}
|
||||
|
||||
logger.info(f"RETROSPECTIVE: Training data - Profitable: {len(profitable_cases)}, Loss: {len(loss_cases)}")
|
||||
|
||||
# Train DQN agent if available
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
try:
|
||||
dqn_result = self._train_dqn_retrospective()
|
||||
results['dqn'] = dqn_result
|
||||
logger.info(f"RETROSPECTIVE: DQN training result: {dqn_result}")
|
||||
except Exception as e:
|
||||
logger.warning(f"DQN retrospective training failed: {e}")
|
||||
results['dqn'] = {'error': str(e)}
|
||||
|
||||
# Train other models
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
|
||||
try:
|
||||
# Update extrema trainer with retrospective feedback
|
||||
extrema_feedback = self._create_extrema_feedback()
|
||||
if extrema_feedback:
|
||||
results['extrema'] = {'feedback_cases': len(extrema_feedback)}
|
||||
logger.info(f"RETROSPECTIVE: Extrema feedback provided for {len(extrema_feedback)} cases")
|
||||
except Exception as e:
|
||||
logger.warning(f"Extrema retrospective training failed: {e}")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training models retrospectively: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def _train_dqn_retrospective(self) -> Dict[str, Any]:
|
||||
"""Train DQN agent using retrospective experience replay"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||
return {'error': 'DQN agent not available'}
|
||||
|
||||
dqn_agent = self.orchestrator.rl_agent
|
||||
experiences_added = 0
|
||||
|
||||
# Add retrospective experiences to DQN replay buffer
|
||||
for case in self.completed_cases:
|
||||
try:
|
||||
# Extract state from model inputs
|
||||
state = self._extract_state_vector(case.model_inputs)
|
||||
if state is None:
|
||||
continue
|
||||
|
||||
# Action mapping: BUY=0, SELL=1
|
||||
action = 0 if case.action == 'BUY' else 1
|
||||
|
||||
# Use reward signal as immediate reward
|
||||
reward = case.reward_signal
|
||||
|
||||
# For retrospective training, next_state is None (terminal)
|
||||
next_state = np.zeros_like(state) # Terminal state
|
||||
done = True
|
||||
|
||||
# Add experience to DQN replay buffer
|
||||
if hasattr(dqn_agent, 'add_experience'):
|
||||
dqn_agent.add_experience(state, action, reward, next_state, done)
|
||||
experiences_added += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding DQN experience: {e}")
|
||||
continue
|
||||
|
||||
# Train DQN if we have enough experiences
|
||||
if experiences_added > 0 and hasattr(dqn_agent, 'train'):
|
||||
try:
|
||||
# Perform multiple training steps on retrospective data
|
||||
training_steps = min(10, experiences_added // 4) # Conservative training
|
||||
for _ in range(training_steps):
|
||||
loss = dqn_agent.train()
|
||||
if loss is None:
|
||||
break
|
||||
|
||||
return {
|
||||
'experiences_added': experiences_added,
|
||||
'training_steps': training_steps,
|
||||
'method': 'retrospective_experience_replay'
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"DQN training step failed: {e}")
|
||||
return {'experiences_added': experiences_added, 'training_error': str(e)}
|
||||
|
||||
return {'experiences_added': experiences_added, 'training_steps': 0}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in DQN retrospective training: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def _extract_state_vector(self, model_inputs: Dict[str, Any]) -> Optional[np.ndarray]:
|
||||
"""Extract state vector for DQN training from model inputs"""
|
||||
try:
|
||||
# Try to get pre-built RL state
|
||||
if 'dqn_state' in model_inputs:
|
||||
state = model_inputs['dqn_state']
|
||||
if isinstance(state, dict) and 'state_vector' in state:
|
||||
return np.array(state['state_vector'])
|
||||
|
||||
# Build state from market features
|
||||
market_state = model_inputs.get('market_state', {})
|
||||
features = []
|
||||
|
||||
# Price features
|
||||
for key in ['current_price', 'price_sma_5', 'price_sma_20', 'price_std_20', 'price_rsi']:
|
||||
features.append(market_state.get(key, 0.0))
|
||||
|
||||
# Volume features
|
||||
for key in ['volume_current', 'volume_sma_20', 'volume_ratio']:
|
||||
features.append(market_state.get(key, 0.0))
|
||||
|
||||
# Technical indicators
|
||||
indicators = model_inputs.get('technical_indicators', {})
|
||||
for key in ['sma_10', 'sma_20', 'bb_upper', 'bb_lower', 'bb_position', 'macd', 'volatility']:
|
||||
features.append(indicators.get(key, 0.0))
|
||||
|
||||
if len(features) < 5: # Minimum required features
|
||||
return None
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting state vector: {e}")
|
||||
return None
|
||||
|
||||
def _create_extrema_feedback(self) -> List[Dict[str, Any]]:
|
||||
"""Create feedback data for extrema trainer"""
|
||||
feedback = []
|
||||
|
||||
try:
|
||||
for case in self.completed_cases:
|
||||
if case.outcome_label in [0, 1]: # Only profit/loss cases
|
||||
feedback_item = {
|
||||
'symbol': case.symbol,
|
||||
'action': case.action,
|
||||
'entry_price': case.entry_price,
|
||||
'exit_price': case.exit_price,
|
||||
'was_profitable': case.outcome_label == 1,
|
||||
'reward_signal': case.reward_signal,
|
||||
'market_state': case.market_state
|
||||
}
|
||||
feedback.append(feedback_item)
|
||||
|
||||
return feedback
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating extrema feedback: {e}")
|
||||
return []
|
||||
|
||||
def get_training_stats(self) -> Dict[str, Any]:
|
||||
"""Get current training statistics"""
|
||||
stats = self.training_stats.copy()
|
||||
stats['total_cases_in_memory'] = len(self.completed_cases)
|
||||
stats['training_queue_size'] = self.training_queue.qsize()
|
||||
stats['is_training_active'] = self.is_training_active
|
||||
|
||||
# Calculate profit metrics
|
||||
if len(self.completed_cases) > 0:
|
||||
profitable_count = sum(1 for c in self.completed_cases if c.pnl > 0)
|
||||
stats['profit_rate'] = profitable_count / len(self.completed_cases)
|
||||
stats['total_pnl'] = sum(c.pnl for c in self.completed_cases)
|
||||
stats['avg_reward'] = sum(c.reward_signal for c in self.completed_cases) / len(self.completed_cases)
|
||||
|
||||
return stats
|
||||
|
||||
def force_training_session(self) -> bool:
|
||||
"""Force a training session regardless of timing constraints"""
|
||||
try:
|
||||
if self.is_training_active:
|
||||
logger.warning("Training already in progress")
|
||||
return False
|
||||
|
||||
if len(self.completed_cases) < 1:
|
||||
logger.warning("No completed cases available for training")
|
||||
return False
|
||||
|
||||
logger.info("RETROSPECTIVE: Forcing training session")
|
||||
self._start_training_session()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcing training session: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Stop the retrospective trainer"""
|
||||
try:
|
||||
self.is_training_active = False
|
||||
if self.training_thread and self.training_thread.is_alive():
|
||||
self.training_thread.join(timeout=10)
|
||||
logger.info("RetrospectiveTrainer stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping RetrospectiveTrainer: {e}")
|
||||
|
||||
|
||||
def create_retrospective_trainer(orchestrator=None, config: Optional[Dict[str, Any]] = None) -> RetrospectiveTrainer:
|
||||
"""Factory function to create a RetrospectiveTrainer instance"""
|
||||
return RetrospectiveTrainer(orchestrator=orchestrator, config=config)
|
||||
@@ -1,529 +0,0 @@
|
||||
"""
|
||||
RL Training Pipeline with Comprehensive Experience Storage and Replay
|
||||
|
||||
This module implements a robust RL training pipeline that:
|
||||
1. Stores all training experiences with profitability metrics
|
||||
2. Implements profit-weighted experience replay
|
||||
3. Tracks gradient information for each training step
|
||||
4. Enables retraining on most profitable trading sequences
|
||||
5. Maintains comprehensive trading episode analysis
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import pickle
|
||||
from collections import deque
|
||||
import threading
|
||||
import random
|
||||
|
||||
from .training_data_collector import get_training_data_collector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class RLExperience:
|
||||
"""Single RL experience with complete state-action-reward information"""
|
||||
experience_id: str
|
||||
timestamp: datetime
|
||||
episode_id: str
|
||||
|
||||
# Core RL components
|
||||
state: np.ndarray
|
||||
action: int # 0=SELL, 1=HOLD, 2=BUY
|
||||
reward: float
|
||||
next_state: np.ndarray
|
||||
done: bool
|
||||
|
||||
# Extended state information
|
||||
market_context: Dict[str, Any]
|
||||
cnn_predictions: Optional[Dict[str, Any]] = None
|
||||
confidence_score: float = 0.0
|
||||
|
||||
# Actual trading outcome
|
||||
actual_profit: Optional[float] = None
|
||||
actual_holding_time: Optional[timedelta] = None
|
||||
optimal_action: Optional[int] = None
|
||||
|
||||
# Experience value for replay
|
||||
experience_value: float = 0.0
|
||||
profitability_score: float = 0.0
|
||||
learning_priority: float = 0.0
|
||||
|
||||
# Training metadata
|
||||
times_trained: int = 0
|
||||
last_trained: Optional[datetime] = None
|
||||
|
||||
class ProfitWeightedExperienceBuffer:
|
||||
"""Experience buffer with profit-weighted sampling for replay"""
|
||||
|
||||
def __init__(self, max_size: int = 100000):
|
||||
self.max_size = max_size
|
||||
self.experiences: Dict[str, RLExperience] = {}
|
||||
self.experience_order: deque = deque(maxlen=max_size)
|
||||
self.profitable_experiences: List[str] = []
|
||||
self.total_experiences = 0
|
||||
self.total_profitable = 0
|
||||
|
||||
def add_experience(self, experience: RLExperience):
|
||||
"""Add experience to buffer"""
|
||||
try:
|
||||
self.experiences[experience.experience_id] = experience
|
||||
self.experience_order.append(experience.experience_id)
|
||||
|
||||
if experience.actual_profit is not None and experience.actual_profit > 0:
|
||||
self.profitable_experiences.append(experience.experience_id)
|
||||
self.total_profitable += 1
|
||||
|
||||
# Remove oldest if buffer is full
|
||||
if len(self.experiences) > self.max_size:
|
||||
oldest_id = self.experience_order[0]
|
||||
if oldest_id in self.experiences:
|
||||
del self.experiences[oldest_id]
|
||||
if oldest_id in self.profitable_experiences:
|
||||
self.profitable_experiences.remove(oldest_id)
|
||||
|
||||
self.total_experiences += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding experience to buffer: {e}")
|
||||
|
||||
def sample_batch(self, batch_size: int, prioritize_profitable: bool = True) -> List[RLExperience]:
|
||||
"""Sample batch with profit-weighted prioritization"""
|
||||
try:
|
||||
if len(self.experiences) < batch_size:
|
||||
return list(self.experiences.values())
|
||||
|
||||
if prioritize_profitable and len(self.profitable_experiences) > batch_size // 2:
|
||||
# Sample mix of profitable and all experiences
|
||||
profitable_sample_size = min(batch_size // 2, len(self.profitable_experiences))
|
||||
remaining_sample_size = batch_size - profitable_sample_size
|
||||
|
||||
profitable_ids = random.sample(self.profitable_experiences, profitable_sample_size)
|
||||
all_ids = list(self.experiences.keys())
|
||||
remaining_ids = random.sample(all_ids, remaining_sample_size)
|
||||
|
||||
sampled_ids = profitable_ids + remaining_ids
|
||||
else:
|
||||
# Random sampling from all experiences
|
||||
all_ids = list(self.experiences.keys())
|
||||
sampled_ids = random.sample(all_ids, batch_size)
|
||||
|
||||
sampled_experiences = [self.experiences[exp_id] for exp_id in sampled_ids]
|
||||
|
||||
# Update training counts
|
||||
for experience in sampled_experiences:
|
||||
experience.times_trained += 1
|
||||
experience.last_trained = datetime.now()
|
||||
|
||||
return sampled_experiences
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sampling batch: {e}")
|
||||
return list(self.experiences.values())[:batch_size]
|
||||
|
||||
def get_most_profitable_experiences(self, limit: int = 100) -> List[RLExperience]:
|
||||
"""Get most profitable experiences for targeted training"""
|
||||
try:
|
||||
profitable_experiences = [
|
||||
self.experiences[exp_id] for exp_id in self.profitable_experiences
|
||||
if exp_id in self.experiences
|
||||
]
|
||||
|
||||
profitable_experiences.sort(
|
||||
key=lambda x: x.actual_profit if x.actual_profit else 0,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
return profitable_experiences[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting profitable experiences: {e}")
|
||||
return []
|
||||
|
||||
class RLTradingAgent(nn.Module):
|
||||
"""RL Trading Agent with comprehensive state processing"""
|
||||
|
||||
def __init__(self, state_dim: int = 2000, action_dim: int = 3, hidden_dim: int = 512):
|
||||
super(RLTradingAgent, self).__init__()
|
||||
|
||||
self.state_dim = state_dim
|
||||
self.action_dim = action_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
|
||||
# State processing network
|
||||
self.state_processor = nn.Sequential(
|
||||
nn.Linear(state_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LayerNorm(hidden_dim // 2),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# Q-value network
|
||||
self.q_network = nn.Sequential(
|
||||
nn.Linear(hidden_dim // 2, hidden_dim // 4),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim // 4, action_dim)
|
||||
)
|
||||
|
||||
# Policy network
|
||||
self.policy_network = nn.Sequential(
|
||||
nn.Linear(hidden_dim // 2, hidden_dim // 4),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim // 4, action_dim),
|
||||
nn.Softmax(dim=-1)
|
||||
)
|
||||
|
||||
# Value network
|
||||
self.value_network = nn.Sequential(
|
||||
nn.Linear(hidden_dim // 2, hidden_dim // 4),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim // 4, 1)
|
||||
)
|
||||
|
||||
def forward(self, state):
|
||||
"""Forward pass through the agent"""
|
||||
processed_state = self.state_processor(state)
|
||||
|
||||
q_values = self.q_network(processed_state)
|
||||
policy_probs = self.policy_network(processed_state)
|
||||
state_value = self.value_network(processed_state)
|
||||
|
||||
return {
|
||||
'q_values': q_values,
|
||||
'policy_probs': policy_probs,
|
||||
'state_value': state_value,
|
||||
'processed_state': processed_state
|
||||
}
|
||||
|
||||
def select_action(self, state, epsilon: float = 0.1) -> Tuple[int, float]:
|
||||
"""Select action using epsilon-greedy policy"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
if isinstance(state, np.ndarray):
|
||||
state = torch.from_numpy(state).float().unsqueeze(0)
|
||||
|
||||
outputs = self.forward(state)
|
||||
|
||||
if random.random() < epsilon:
|
||||
action = random.randint(0, self.action_dim - 1)
|
||||
confidence = 0.33
|
||||
else:
|
||||
q_values = outputs['q_values']
|
||||
action = torch.argmax(q_values, dim=1).item()
|
||||
q_softmax = F.softmax(q_values, dim=1)
|
||||
confidence = torch.max(q_softmax).item()
|
||||
|
||||
return action, confidence
|
||||
|
||||
@dataclass
|
||||
class RLTrainingStep:
|
||||
"""Single RL training step with backpropagation data"""
|
||||
step_id: str
|
||||
timestamp: datetime
|
||||
batch_experiences: List[str]
|
||||
|
||||
# Training data
|
||||
total_loss: float
|
||||
q_loss: float
|
||||
policy_loss: float
|
||||
|
||||
# Gradients
|
||||
gradients: Dict[str, torch.Tensor]
|
||||
gradient_norms: Dict[str, float]
|
||||
|
||||
# Metadata
|
||||
learning_rate: float = 0.001
|
||||
batch_size: int = 32
|
||||
|
||||
# Performance
|
||||
batch_profitability: float = 0.0
|
||||
correct_actions: int = 0
|
||||
total_actions: int = 0
|
||||
step_value: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class RLTrainingSession:
|
||||
"""Complete RL training session"""
|
||||
session_id: str
|
||||
start_timestamp: datetime
|
||||
end_timestamp: Optional[datetime] = None
|
||||
|
||||
training_mode: str = 'experience_replay'
|
||||
symbol: str = ''
|
||||
|
||||
training_steps: List[RLTrainingStep] = field(default_factory=list)
|
||||
|
||||
total_steps: int = 0
|
||||
average_loss: float = 0.0
|
||||
best_loss: float = float('inf')
|
||||
|
||||
profitable_actions: int = 0
|
||||
total_actions: int = 0
|
||||
profitability_rate: float = 0.0
|
||||
session_value: float = 0.0
|
||||
|
||||
class RLTrainer:
|
||||
"""RL trainer with comprehensive experience storage and replay"""
|
||||
|
||||
def __init__(self, agent: RLTradingAgent, device: str = 'cuda', storage_dir: str = "rl_training_storage"):
|
||||
self.agent = agent.to(device)
|
||||
self.device = device
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001)
|
||||
self.experience_buffer = ProfitWeightedExperienceBuffer()
|
||||
self.data_collector = get_training_data_collector()
|
||||
|
||||
self.training_sessions: List[RLTrainingSession] = []
|
||||
self.current_session: Optional[RLTrainingSession] = None
|
||||
|
||||
self.gamma = 0.99
|
||||
|
||||
self.training_stats = {
|
||||
'total_sessions': 0,
|
||||
'total_steps': 0,
|
||||
'total_experiences': 0,
|
||||
'profitable_actions': 0,
|
||||
'total_actions': 0,
|
||||
'average_reward': 0.0
|
||||
}
|
||||
|
||||
logger.info(f"RL Trainer initialized with {sum(p.numel() for p in agent.parameters()):,} parameters")
|
||||
|
||||
def add_experience(self, state: np.ndarray, action: int, reward: float,
|
||||
next_state: np.ndarray, done: bool, market_context: Dict[str, Any],
|
||||
cnn_predictions: Dict[str, Any] = None, confidence_score: float = 0.0) -> str:
|
||||
"""Add experience to the buffer"""
|
||||
try:
|
||||
experience_id = f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
|
||||
|
||||
experience = RLExperience(
|
||||
experience_id=experience_id,
|
||||
timestamp=datetime.now(),
|
||||
episode_id=market_context.get('episode_id', 'unknown'),
|
||||
state=state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
market_context=market_context,
|
||||
cnn_predictions=cnn_predictions,
|
||||
confidence_score=confidence_score
|
||||
)
|
||||
|
||||
self.experience_buffer.add_experience(experience)
|
||||
self.training_stats['total_experiences'] += 1
|
||||
|
||||
return experience_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding experience: {e}")
|
||||
return None
|
||||
|
||||
def train_on_experiences(self, batch_size: int = 32, num_batches: int = 10) -> Dict[str, Any]:
|
||||
"""Train on experiences with comprehensive data storage"""
|
||||
try:
|
||||
session = RLTrainingSession(
|
||||
session_id=f"rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
start_timestamp=datetime.now(),
|
||||
training_mode='experience_replay'
|
||||
)
|
||||
self.current_session = session
|
||||
|
||||
self.agent.train()
|
||||
total_loss = 0.0
|
||||
|
||||
for batch_idx in range(num_batches):
|
||||
experiences = self.experience_buffer.sample_batch(batch_size, True)
|
||||
|
||||
if len(experiences) < batch_size:
|
||||
continue
|
||||
|
||||
# Prepare batch tensors
|
||||
states = torch.FloatTensor([exp.state for exp in experiences]).to(self.device)
|
||||
actions = torch.LongTensor([exp.action for exp in experiences]).to(self.device)
|
||||
rewards = torch.FloatTensor([exp.reward for exp in experiences]).to(self.device)
|
||||
next_states = torch.FloatTensor([exp.next_state for exp in experiences]).to(self.device)
|
||||
dones = torch.BoolTensor([exp.done for exp in experiences]).to(self.device)
|
||||
|
||||
# Forward pass
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
current_outputs = self.agent(states)
|
||||
current_q_values = current_outputs['q_values']
|
||||
|
||||
# Calculate target Q-values
|
||||
with torch.no_grad():
|
||||
next_outputs = self.agent(next_states)
|
||||
next_q_values = next_outputs['q_values']
|
||||
max_next_q_values = torch.max(next_q_values, dim=1)[0]
|
||||
target_q_values = rewards + (self.gamma * max_next_q_values * ~dones)
|
||||
|
||||
# Calculate loss
|
||||
current_q_values_for_actions = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
q_loss = F.mse_loss(current_q_values_for_actions, target_q_values)
|
||||
|
||||
# Backward pass
|
||||
q_loss.backward()
|
||||
|
||||
# Store gradients
|
||||
gradients = {}
|
||||
gradient_norms = {}
|
||||
for name, param in self.agent.named_parameters():
|
||||
if param.grad is not None:
|
||||
gradients[name] = param.grad.clone().detach()
|
||||
gradient_norms[name] = param.grad.norm().item()
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(self.agent.parameters(), max_norm=1.0)
|
||||
self.optimizer.step()
|
||||
|
||||
# Create training step record
|
||||
step = RLTrainingStep(
|
||||
step_id=f"{session.session_id}_step_{batch_idx}",
|
||||
timestamp=datetime.now(),
|
||||
batch_experiences=[exp.experience_id for exp in experiences],
|
||||
total_loss=q_loss.item(),
|
||||
q_loss=q_loss.item(),
|
||||
policy_loss=0.0,
|
||||
gradients=gradients,
|
||||
gradient_norms=gradient_norms,
|
||||
batch_size=len(experiences)
|
||||
)
|
||||
|
||||
session.training_steps.append(step)
|
||||
total_loss += q_loss.item()
|
||||
|
||||
# Finalize session
|
||||
session.end_timestamp = datetime.now()
|
||||
session.total_steps = num_batches
|
||||
session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||||
|
||||
self._save_training_session(session)
|
||||
|
||||
self.training_stats['total_sessions'] += 1
|
||||
self.training_stats['total_steps'] += session.total_steps
|
||||
|
||||
logger.info(f"RL training session completed: {session.session_id}")
|
||||
logger.info(f"Average loss: {session.average_loss:.4f}")
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'session_id': session.session_id,
|
||||
'average_loss': session.average_loss,
|
||||
'total_steps': session.total_steps
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in RL training session: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
finally:
|
||||
self.current_session = None
|
||||
|
||||
def train_on_profitable_experiences(self, min_profitability: float = 0.1,
|
||||
max_experiences: int = 1000, batch_size: int = 32) -> Dict[str, Any]:
|
||||
"""Train specifically on most profitable experiences"""
|
||||
try:
|
||||
profitable_experiences = self.experience_buffer.get_most_profitable_experiences(max_experiences)
|
||||
|
||||
filtered_experiences = [
|
||||
exp for exp in profitable_experiences
|
||||
if exp.actual_profit is not None and exp.actual_profit >= min_profitability
|
||||
]
|
||||
|
||||
if len(filtered_experiences) < batch_size:
|
||||
return {'status': 'insufficient_data', 'experiences_found': len(filtered_experiences)}
|
||||
|
||||
logger.info(f"Training on {len(filtered_experiences)} profitable experiences")
|
||||
|
||||
num_batches = len(filtered_experiences) // batch_size
|
||||
|
||||
# Temporarily replace buffer sampling
|
||||
original_sample_method = self.experience_buffer.sample_batch
|
||||
|
||||
def profitable_sample_batch(batch_size, prioritize_profitable=True):
|
||||
return random.sample(filtered_experiences, min(batch_size, len(filtered_experiences)))
|
||||
|
||||
self.experience_buffer.sample_batch = profitable_sample_batch
|
||||
|
||||
try:
|
||||
results = self.train_on_experiences(batch_size=batch_size, num_batches=num_batches)
|
||||
results['training_mode'] = 'profitable_replay'
|
||||
results['experiences_used'] = len(filtered_experiences)
|
||||
return results
|
||||
finally:
|
||||
self.experience_buffer.sample_batch = original_sample_method
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training on profitable experiences: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
def _save_training_session(self, session: RLTrainingSession):
|
||||
"""Save training session to disk"""
|
||||
try:
|
||||
session_dir = self.storage_dir / 'sessions'
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
session_file = session_dir / f"{session.session_id}.pkl"
|
||||
with open(session_file, 'wb') as f:
|
||||
pickle.dump(session, f)
|
||||
|
||||
metadata = {
|
||||
'session_id': session.session_id,
|
||||
'start_timestamp': session.start_timestamp.isoformat(),
|
||||
'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None,
|
||||
'training_mode': session.training_mode,
|
||||
'total_steps': session.total_steps,
|
||||
'average_loss': session.average_loss
|
||||
}
|
||||
|
||||
metadata_file = session_dir / f"{session.session_id}_metadata.json"
|
||||
with open(metadata_file, 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving training session: {e}")
|
||||
|
||||
def get_training_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive training statistics"""
|
||||
stats = self.training_stats.copy()
|
||||
|
||||
if self.training_sessions:
|
||||
recent_sessions = sorted(self.training_sessions, key=lambda x: x.start_timestamp, reverse=True)[:10]
|
||||
stats['recent_sessions'] = [
|
||||
{
|
||||
'session_id': s.session_id,
|
||||
'timestamp': s.start_timestamp.isoformat(),
|
||||
'mode': s.training_mode,
|
||||
'average_loss': s.average_loss
|
||||
}
|
||||
for s in recent_sessions
|
||||
]
|
||||
|
||||
return stats
|
||||
|
||||
# Global instance
|
||||
rl_trainer = None
|
||||
|
||||
def get_rl_trainer(agent: RLTradingAgent = None) -> RLTrainer:
|
||||
"""Get global RL trainer instance"""
|
||||
global rl_trainer
|
||||
if rl_trainer is None:
|
||||
if agent is None:
|
||||
agent = RLTradingAgent()
|
||||
rl_trainer = RLTrainer(agent)
|
||||
return rl_trainer
|
||||
@@ -1,460 +0,0 @@
|
||||
"""
|
||||
Robust COB (Consolidated Order Book) Provider
|
||||
|
||||
This module provides a robust COB data provider that handles:
|
||||
- HTTP 418 errors from Binance (rate limiting)
|
||||
- Thread safety issues
|
||||
- API rate limiting and backoff
|
||||
- Fallback data sources
|
||||
- Error recovery strategies
|
||||
|
||||
Features:
|
||||
- Automatic rate limiting and backoff
|
||||
- Multiple exchange support with fallbacks
|
||||
- Thread-safe operations
|
||||
- Comprehensive error handling
|
||||
- Data validation and integrity checking
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque
|
||||
import json
|
||||
import numpy as np
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import requests
|
||||
|
||||
from .api_rate_limiter import get_rate_limiter, RateLimitConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class COBData:
|
||||
"""Consolidated Order Book data structure"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
bids: List[Tuple[float, float]] # [(price, quantity), ...]
|
||||
asks: List[Tuple[float, float]] # [(price, quantity), ...]
|
||||
|
||||
# Derived metrics
|
||||
spread: float = 0.0
|
||||
mid_price: float = 0.0
|
||||
total_bid_volume: float = 0.0
|
||||
total_ask_volume: float = 0.0
|
||||
|
||||
# Data quality
|
||||
data_source: str = 'unknown'
|
||||
quality_score: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
"""Calculate derived metrics"""
|
||||
if self.bids and self.asks:
|
||||
self.spread = self.asks[0][0] - self.bids[0][0]
|
||||
self.mid_price = (self.asks[0][0] + self.bids[0][0]) / 2
|
||||
self.total_bid_volume = sum(qty for _, qty in self.bids)
|
||||
self.total_ask_volume = sum(qty for _, qty in self.asks)
|
||||
|
||||
# Calculate quality score based on data completeness
|
||||
self.quality_score = min(
|
||||
len(self.bids) / 20, # Expect at least 20 bid levels
|
||||
len(self.asks) / 20, # Expect at least 20 ask levels
|
||||
1.0
|
||||
)
|
||||
|
||||
class RobustCOBProvider:
|
||||
"""Robust COB provider with error handling and rate limiting"""
|
||||
|
||||
def __init__(self, symbols: List[str] = None):
|
||||
self.symbols = symbols or ['ETHUSDT', 'BTCUSDT']
|
||||
|
||||
# Rate limiter
|
||||
self.rate_limiter = get_rate_limiter()
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Data cache
|
||||
self.cob_cache: Dict[str, COBData] = {}
|
||||
self.cache_timestamps: Dict[str, datetime] = {}
|
||||
self.cache_ttl = timedelta(seconds=5) # 5 second cache TTL
|
||||
|
||||
# Error tracking
|
||||
self.error_counts: Dict[str, int] = {}
|
||||
self.last_successful_fetch: Dict[str, datetime] = {}
|
||||
|
||||
# Background fetching
|
||||
self.is_running = False
|
||||
self.fetch_threads: Dict[str, threading.Thread] = {}
|
||||
self.executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="COB-Fetcher")
|
||||
|
||||
# Fallback data
|
||||
self.fallback_data: Dict[str, COBData] = {}
|
||||
|
||||
# Performance tracking
|
||||
self.fetch_stats = {
|
||||
'total_requests': 0,
|
||||
'successful_requests': 0,
|
||||
'failed_requests': 0,
|
||||
'rate_limited_requests': 0,
|
||||
'cache_hits': 0,
|
||||
'fallback_uses': 0
|
||||
}
|
||||
|
||||
logger.info(f"Robust COB Provider initialized for symbols: {self.symbols}")
|
||||
|
||||
def start_background_fetching(self):
|
||||
"""Start background COB data fetching"""
|
||||
if self.is_running:
|
||||
logger.warning("Background fetching already running")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# Start fetching thread for each symbol
|
||||
for symbol in self.symbols:
|
||||
thread = threading.Thread(
|
||||
target=self._background_fetch_worker,
|
||||
args=(symbol,),
|
||||
name=f"COB-{symbol}",
|
||||
daemon=True
|
||||
)
|
||||
self.fetch_threads[symbol] = thread
|
||||
thread.start()
|
||||
|
||||
logger.info(f"Started background COB fetching for {len(self.symbols)} symbols")
|
||||
|
||||
def stop_background_fetching(self):
|
||||
"""Stop background COB data fetching"""
|
||||
self.is_running = False
|
||||
|
||||
# Wait for threads to finish
|
||||
for symbol, thread in self.fetch_threads.items():
|
||||
thread.join(timeout=5)
|
||||
logger.debug(f"Stopped COB fetching for {symbol}")
|
||||
|
||||
# Shutdown executor
|
||||
self.executor.shutdown(wait=True, timeout=10)
|
||||
|
||||
logger.info("Stopped background COB fetching")
|
||||
|
||||
def _background_fetch_worker(self, symbol: str):
|
||||
"""Background worker for fetching COB data"""
|
||||
logger.info(f"Started COB fetching worker for {symbol}")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
# Fetch COB data
|
||||
cob_data = self._fetch_cob_data_safe(symbol)
|
||||
|
||||
if cob_data:
|
||||
with self.lock:
|
||||
self.cob_cache[symbol] = cob_data
|
||||
self.cache_timestamps[symbol] = datetime.now()
|
||||
self.last_successful_fetch[symbol] = datetime.now()
|
||||
self.error_counts[symbol] = 0 # Reset error count on success
|
||||
|
||||
logger.debug(f"Updated COB cache for {symbol}")
|
||||
else:
|
||||
with self.lock:
|
||||
self.error_counts[symbol] = self.error_counts.get(symbol, 0) + 1
|
||||
|
||||
logger.debug(f"Failed to fetch COB for {symbol}, error count: {self.error_counts.get(symbol, 0)}")
|
||||
|
||||
# Wait before next fetch (adaptive based on errors)
|
||||
error_count = self.error_counts.get(symbol, 0)
|
||||
base_interval = 2.0 # Base 2 second interval
|
||||
backoff_interval = min(base_interval * (2 ** min(error_count, 5)), 60.0) # Max 60s
|
||||
|
||||
time.sleep(backoff_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in COB fetching worker for {symbol}: {e}")
|
||||
time.sleep(10) # Wait 10s on unexpected errors
|
||||
|
||||
logger.info(f"Stopped COB fetching worker for {symbol}")
|
||||
|
||||
def _fetch_cob_data_safe(self, symbol: str) -> Optional[COBData]:
|
||||
"""Safely fetch COB data with error handling"""
|
||||
try:
|
||||
self.fetch_stats['total_requests'] += 1
|
||||
|
||||
# Try Binance first
|
||||
cob_data = self._fetch_binance_cob(symbol)
|
||||
if cob_data:
|
||||
self.fetch_stats['successful_requests'] += 1
|
||||
return cob_data
|
||||
|
||||
# Try MEXC as fallback
|
||||
cob_data = self._fetch_mexc_cob(symbol)
|
||||
if cob_data:
|
||||
self.fetch_stats['successful_requests'] += 1
|
||||
cob_data.data_source = 'mexc_fallback'
|
||||
return cob_data
|
||||
|
||||
# Use cached fallback data if available
|
||||
if symbol in self.fallback_data:
|
||||
self.fetch_stats['fallback_uses'] += 1
|
||||
fallback = self.fallback_data[symbol]
|
||||
fallback.timestamp = datetime.now()
|
||||
fallback.data_source = 'fallback_cache'
|
||||
fallback.quality_score *= 0.5 # Reduce quality score for old data
|
||||
return fallback
|
||||
|
||||
self.fetch_stats['failed_requests'] += 1
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching COB data for {symbol}: {e}")
|
||||
self.fetch_stats['failed_requests'] += 1
|
||||
return None
|
||||
|
||||
def _fetch_binance_cob(self, symbol: str) -> Optional[COBData]:
|
||||
"""Fetch COB data from Binance with rate limiting"""
|
||||
try:
|
||||
url = f"https://api.binance.com/api/v3/depth"
|
||||
params = {
|
||||
'symbol': symbol,
|
||||
'limit': 100 # Get 100 levels
|
||||
}
|
||||
|
||||
# Use rate limiter
|
||||
response = self.rate_limiter.make_request(
|
||||
'binance_api',
|
||||
url,
|
||||
method='GET',
|
||||
params=params
|
||||
)
|
||||
|
||||
if not response:
|
||||
self.fetch_stats['rate_limited_requests'] += 1
|
||||
return None
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Binance COB API returned {response.status_code} for {symbol}")
|
||||
return None
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Parse order book data
|
||||
bids = [(float(price), float(qty)) for price, qty in data.get('bids', [])]
|
||||
asks = [(float(price), float(qty)) for price, qty in data.get('asks', [])]
|
||||
|
||||
if not bids or not asks:
|
||||
logger.warning(f"Empty order book data from Binance for {symbol}")
|
||||
return None
|
||||
|
||||
cob_data = COBData(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
bids=bids,
|
||||
asks=asks,
|
||||
data_source='binance'
|
||||
)
|
||||
|
||||
# Store as fallback for future use
|
||||
self.fallback_data[symbol] = cob_data
|
||||
|
||||
return cob_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Binance COB for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _fetch_mexc_cob(self, symbol: str) -> Optional[COBData]:
|
||||
"""Fetch COB data from MEXC as fallback"""
|
||||
try:
|
||||
url = f"https://api.mexc.com/api/v3/depth"
|
||||
params = {
|
||||
'symbol': symbol,
|
||||
'limit': 100
|
||||
}
|
||||
|
||||
response = self.rate_limiter.make_request(
|
||||
'mexc_api',
|
||||
url,
|
||||
method='GET',
|
||||
params=params
|
||||
)
|
||||
|
||||
if not response or response.status_code != 200:
|
||||
return None
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Parse order book data
|
||||
bids = [(float(price), float(qty)) for price, qty in data.get('bids', [])]
|
||||
asks = [(float(price), float(qty)) for price, qty in data.get('asks', [])]
|
||||
|
||||
if not bids or not asks:
|
||||
return None
|
||||
|
||||
return COBData(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
bids=bids,
|
||||
asks=asks,
|
||||
data_source='mexc'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error fetching MEXC COB for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_cob_data(self, symbol: str) -> Optional[COBData]:
|
||||
"""Get COB data for a symbol (from cache or fresh fetch)"""
|
||||
with self.lock:
|
||||
# Check cache first
|
||||
if symbol in self.cob_cache:
|
||||
cached_data = self.cob_cache[symbol]
|
||||
cache_time = self.cache_timestamps.get(symbol, datetime.min)
|
||||
|
||||
# Return cached data if still fresh
|
||||
if datetime.now() - cache_time < self.cache_ttl:
|
||||
self.fetch_stats['cache_hits'] += 1
|
||||
return cached_data
|
||||
|
||||
# If background fetching is running, return cached data even if stale
|
||||
if self.is_running and symbol in self.cob_cache:
|
||||
return self.cob_cache[symbol]
|
||||
|
||||
# Fetch fresh data if not running background fetching
|
||||
if not self.is_running:
|
||||
return self._fetch_cob_data_safe(symbol)
|
||||
|
||||
return None
|
||||
|
||||
def get_cob_features(self, symbol: str, feature_count: int = 120) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Get COB features for ML models
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
feature_count: Number of features to return
|
||||
|
||||
Returns:
|
||||
Numpy array of COB features or None if no data
|
||||
"""
|
||||
cob_data = self.get_cob_data(symbol)
|
||||
if not cob_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
features = []
|
||||
|
||||
# Basic market metrics
|
||||
features.extend([
|
||||
cob_data.mid_price,
|
||||
cob_data.spread,
|
||||
cob_data.total_bid_volume,
|
||||
cob_data.total_ask_volume,
|
||||
cob_data.quality_score
|
||||
])
|
||||
|
||||
# Bid levels (price and volume)
|
||||
max_levels = min(len(cob_data.bids), 20)
|
||||
for i in range(max_levels):
|
||||
price, volume = cob_data.bids[i]
|
||||
features.extend([price, volume])
|
||||
|
||||
# Pad bid levels if needed
|
||||
for i in range(max_levels, 20):
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
# Ask levels (price and volume)
|
||||
max_levels = min(len(cob_data.asks), 20)
|
||||
for i in range(max_levels):
|
||||
price, volume = cob_data.asks[i]
|
||||
features.extend([price, volume])
|
||||
|
||||
# Pad ask levels if needed
|
||||
for i in range(max_levels, 20):
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
# Calculate additional features
|
||||
if len(cob_data.bids) > 0 and len(cob_data.asks) > 0:
|
||||
# Volume imbalance
|
||||
bid_volume_5 = sum(vol for _, vol in cob_data.bids[:5])
|
||||
ask_volume_5 = sum(vol for _, vol in cob_data.asks[:5])
|
||||
volume_imbalance = (bid_volume_5 - ask_volume_5) / (bid_volume_5 + ask_volume_5) if (bid_volume_5 + ask_volume_5) > 0 else 0
|
||||
features.append(volume_imbalance)
|
||||
|
||||
# Price levels
|
||||
bid_price_levels = [price for price, _ in cob_data.bids[:10]]
|
||||
ask_price_levels = [price for price, _ in cob_data.asks[:10]]
|
||||
features.extend(bid_price_levels + ask_price_levels)
|
||||
|
||||
# Pad or truncate to desired feature count
|
||||
if len(features) < feature_count:
|
||||
features.extend([0.0] * (feature_count - len(features)))
|
||||
else:
|
||||
features = features[:feature_count]
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating COB features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_provider_status(self) -> Dict[str, Any]:
|
||||
"""Get provider status and statistics"""
|
||||
with self.lock:
|
||||
status = {
|
||||
'is_running': self.is_running,
|
||||
'symbols': self.symbols,
|
||||
'cache_status': {},
|
||||
'error_counts': self.error_counts.copy(),
|
||||
'last_successful_fetch': {
|
||||
symbol: timestamp.isoformat()
|
||||
for symbol, timestamp in self.last_successful_fetch.items()
|
||||
},
|
||||
'fetch_stats': self.fetch_stats.copy(),
|
||||
'rate_limiter_status': self.rate_limiter.get_all_endpoint_status()
|
||||
}
|
||||
|
||||
# Cache status for each symbol
|
||||
for symbol in self.symbols:
|
||||
cache_time = self.cache_timestamps.get(symbol)
|
||||
status['cache_status'][symbol] = {
|
||||
'has_data': symbol in self.cob_cache,
|
||||
'cache_time': cache_time.isoformat() if cache_time else None,
|
||||
'cache_age_seconds': (datetime.now() - cache_time).total_seconds() if cache_time else None,
|
||||
'data_quality': self.cob_cache[symbol].quality_score if symbol in self.cob_cache else 0.0
|
||||
}
|
||||
|
||||
return status
|
||||
|
||||
def reset_errors(self):
|
||||
"""Reset error counts and rate limiter"""
|
||||
with self.lock:
|
||||
self.error_counts.clear()
|
||||
self.rate_limiter.reset_all_endpoints()
|
||||
logger.info("Reset all error counts and rate limiter")
|
||||
|
||||
def force_refresh(self, symbol: str = None):
|
||||
"""Force refresh COB data for symbol(s)"""
|
||||
symbols_to_refresh = [symbol] if symbol else self.symbols
|
||||
|
||||
for sym in symbols_to_refresh:
|
||||
# Clear cache to force refresh
|
||||
with self.lock:
|
||||
if sym in self.cob_cache:
|
||||
del self.cob_cache[sym]
|
||||
if sym in self.cache_timestamps:
|
||||
del self.cache_timestamps[sym]
|
||||
|
||||
logger.info(f"Forced refresh for {sym}")
|
||||
|
||||
# Global COB provider instance
|
||||
_global_cob_provider = None
|
||||
|
||||
def get_cob_provider(symbols: List[str] = None) -> RobustCOBProvider:
|
||||
"""Get global COB provider instance"""
|
||||
global _global_cob_provider
|
||||
if _global_cob_provider is None:
|
||||
_global_cob_provider = RobustCOBProvider(symbols)
|
||||
return _global_cob_provider
|
||||
@@ -1,350 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Shared COB Service - Eliminates Redundant COB Implementations
|
||||
|
||||
This service provides a singleton COB integration that can be shared across:
|
||||
- Dashboard components
|
||||
- RL trading systems
|
||||
- Enhanced orchestrators
|
||||
- Training pipelines
|
||||
|
||||
Instead of each component creating its own COBIntegration instance,
|
||||
they all share this single service, eliminating redundant connections.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import weakref
|
||||
from typing import Dict, List, Optional, Any, Callable, Set
|
||||
from datetime import datetime
|
||||
from threading import Lock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .cob_integration import COBIntegration
|
||||
from .multi_exchange_cob_provider import COBSnapshot
|
||||
from .data_provider import DataProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class COBSubscription:
|
||||
"""Represents a subscription to COB updates"""
|
||||
subscriber_id: str
|
||||
callback: Callable
|
||||
symbol_filter: Optional[List[str]] = None
|
||||
callback_type: str = "general" # general, cnn, dqn, dashboard
|
||||
|
||||
class SharedCOBService:
|
||||
"""
|
||||
Shared COB Service - Singleton pattern for unified COB data access
|
||||
|
||||
This service eliminates redundant COB integrations by providing a single
|
||||
shared instance that all components can subscribe to.
|
||||
"""
|
||||
|
||||
_instance: Optional['SharedCOBService'] = None
|
||||
_lock = Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Singleton pattern implementation"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SharedCOBService, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, symbols: Optional[List[str]] = None, data_provider: Optional[DataProvider] = None):
|
||||
"""Initialize shared COB service (only called once due to singleton)"""
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
|
||||
self.symbols = symbols or ['BTC/USDT', 'ETH/USDT']
|
||||
self.data_provider = data_provider
|
||||
|
||||
# Single COB integration instance
|
||||
self.cob_integration: Optional[COBIntegration] = None
|
||||
self.is_running = False
|
||||
|
||||
# Subscriber management
|
||||
self.subscribers: Dict[str, COBSubscription] = {}
|
||||
self.subscriber_counter = 0
|
||||
self.subscription_lock = Lock()
|
||||
|
||||
# Cached data for immediate access
|
||||
self.latest_snapshots: Dict[str, COBSnapshot] = {}
|
||||
self.latest_cnn_features: Dict[str, Any] = {}
|
||||
self.latest_dqn_states: Dict[str, Any] = {}
|
||||
|
||||
# Performance tracking
|
||||
self.total_subscribers = 0
|
||||
self.update_count = 0
|
||||
self.start_time = None
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"SharedCOBService initialized for symbols: {self.symbols}")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the shared COB service"""
|
||||
if self.is_running:
|
||||
logger.warning("SharedCOBService already running")
|
||||
return
|
||||
|
||||
logger.info("Starting SharedCOBService...")
|
||||
|
||||
try:
|
||||
# Initialize COB integration if not already done
|
||||
if self.cob_integration is None:
|
||||
self.cob_integration = COBIntegration(
|
||||
data_provider=self.data_provider,
|
||||
symbols=self.symbols
|
||||
)
|
||||
|
||||
# Register internal callbacks
|
||||
self.cob_integration.add_cnn_callback(self._on_cob_cnn_update)
|
||||
self.cob_integration.add_dqn_callback(self._on_cob_dqn_update)
|
||||
self.cob_integration.add_dashboard_callback(self._on_cob_dashboard_update)
|
||||
|
||||
# Start COB integration
|
||||
await self.cob_integration.start()
|
||||
|
||||
self.is_running = True
|
||||
self.start_time = datetime.now()
|
||||
|
||||
logger.info("SharedCOBService started successfully")
|
||||
logger.info(f"Active subscribers: {len(self.subscribers)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting SharedCOBService: {e}")
|
||||
raise
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the shared COB service"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping SharedCOBService...")
|
||||
|
||||
try:
|
||||
if self.cob_integration:
|
||||
await self.cob_integration.stop()
|
||||
|
||||
self.is_running = False
|
||||
|
||||
# Notify all subscribers of shutdown
|
||||
for subscription in self.subscribers.values():
|
||||
try:
|
||||
if hasattr(subscription.callback, '__call__'):
|
||||
subscription.callback("SHUTDOWN", None)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error notifying subscriber {subscription.subscriber_id}: {e}")
|
||||
|
||||
logger.info("SharedCOBService stopped")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping SharedCOBService: {e}")
|
||||
|
||||
def subscribe(self,
|
||||
callback: Callable,
|
||||
callback_type: str = "general",
|
||||
symbol_filter: Optional[List[str]] = None,
|
||||
subscriber_name: str = None) -> str:
|
||||
"""
|
||||
Subscribe to COB updates
|
||||
|
||||
Args:
|
||||
callback: Function to call on updates
|
||||
callback_type: Type of callback ('general', 'cnn', 'dqn', 'dashboard')
|
||||
symbol_filter: Only receive updates for these symbols (None = all)
|
||||
subscriber_name: Optional name for the subscriber
|
||||
|
||||
Returns:
|
||||
Subscription ID for unsubscribing
|
||||
"""
|
||||
with self.subscription_lock:
|
||||
self.subscriber_counter += 1
|
||||
subscriber_id = f"{callback_type}_{self.subscriber_counter}"
|
||||
if subscriber_name:
|
||||
subscriber_id = f"{subscriber_name}_{subscriber_id}"
|
||||
|
||||
subscription = COBSubscription(
|
||||
subscriber_id=subscriber_id,
|
||||
callback=callback,
|
||||
symbol_filter=symbol_filter,
|
||||
callback_type=callback_type
|
||||
)
|
||||
|
||||
self.subscribers[subscriber_id] = subscription
|
||||
self.total_subscribers += 1
|
||||
|
||||
logger.info(f"New subscriber: {subscriber_id} ({callback_type})")
|
||||
logger.info(f"Total active subscribers: {len(self.subscribers)}")
|
||||
|
||||
return subscriber_id
|
||||
|
||||
def unsubscribe(self, subscriber_id: str) -> bool:
|
||||
"""
|
||||
Unsubscribe from COB updates
|
||||
|
||||
Args:
|
||||
subscriber_id: ID returned from subscribe()
|
||||
|
||||
Returns:
|
||||
True if successfully unsubscribed
|
||||
"""
|
||||
with self.subscription_lock:
|
||||
if subscriber_id in self.subscribers:
|
||||
del self.subscribers[subscriber_id]
|
||||
logger.info(f"Unsubscribed: {subscriber_id}")
|
||||
logger.info(f"Remaining subscribers: {len(self.subscribers)}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Subscriber not found: {subscriber_id}")
|
||||
return False
|
||||
|
||||
# Internal callback handlers
|
||||
|
||||
async def _on_cob_cnn_update(self, symbol: str, data: Dict):
|
||||
"""Handle CNN feature updates from COB integration"""
|
||||
try:
|
||||
self.latest_cnn_features[symbol] = data
|
||||
await self._notify_subscribers("cnn", symbol, data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN update handler: {e}")
|
||||
|
||||
async def _on_cob_dqn_update(self, symbol: str, data: Dict):
|
||||
"""Handle DQN state updates from COB integration"""
|
||||
try:
|
||||
self.latest_dqn_states[symbol] = data
|
||||
await self._notify_subscribers("dqn", symbol, data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in DQN update handler: {e}")
|
||||
|
||||
async def _on_cob_dashboard_update(self, symbol: str, data: Dict):
|
||||
"""Handle dashboard updates from COB integration"""
|
||||
try:
|
||||
# Store snapshot if it's a COBSnapshot
|
||||
if hasattr(data, 'volume_weighted_mid'): # Duck typing for COBSnapshot
|
||||
self.latest_snapshots[symbol] = data
|
||||
|
||||
await self._notify_subscribers("dashboard", symbol, data)
|
||||
await self._notify_subscribers("general", symbol, data)
|
||||
|
||||
self.update_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in dashboard update handler: {e}")
|
||||
|
||||
async def _notify_subscribers(self, callback_type: str, symbol: str, data: Any):
|
||||
"""Notify all relevant subscribers of an update"""
|
||||
try:
|
||||
relevant_subscribers = [
|
||||
sub for sub in self.subscribers.values()
|
||||
if (sub.callback_type == callback_type or sub.callback_type == "general") and
|
||||
(sub.symbol_filter is None or symbol in sub.symbol_filter)
|
||||
]
|
||||
|
||||
for subscription in relevant_subscribers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(subscription.callback):
|
||||
asyncio.create_task(subscription.callback(symbol, data))
|
||||
else:
|
||||
subscription.callback(symbol, data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error notifying subscriber {subscription.subscriber_id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error notifying subscribers: {e}")
|
||||
|
||||
# Public data access methods
|
||||
|
||||
def get_cob_snapshot(self, symbol: str) -> Optional[COBSnapshot]:
|
||||
"""Get latest COB snapshot for a symbol"""
|
||||
if self.cob_integration:
|
||||
return self.cob_integration.get_cob_snapshot(symbol)
|
||||
return self.latest_snapshots.get(symbol)
|
||||
|
||||
def get_cnn_features(self, symbol: str) -> Optional[Any]:
|
||||
"""Get latest CNN features for a symbol"""
|
||||
return self.latest_cnn_features.get(symbol)
|
||||
|
||||
def get_dqn_state(self, symbol: str) -> Optional[Any]:
|
||||
"""Get latest DQN state for a symbol"""
|
||||
return self.latest_dqn_states.get(symbol)
|
||||
|
||||
def get_market_depth_analysis(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get detailed market depth analysis"""
|
||||
if self.cob_integration:
|
||||
return self.cob_integration.get_market_depth_analysis(symbol)
|
||||
return None
|
||||
|
||||
def get_exchange_breakdown(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get liquidity breakdown by exchange"""
|
||||
if self.cob_integration:
|
||||
return self.cob_integration.get_exchange_breakdown(symbol)
|
||||
return None
|
||||
|
||||
def get_price_buckets(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get fine-grain price buckets"""
|
||||
if self.cob_integration:
|
||||
return self.cob_integration.get_price_buckets(symbol)
|
||||
return None
|
||||
|
||||
def get_session_volume_profile(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get session volume profile"""
|
||||
if self.cob_integration and hasattr(self.cob_integration.cob_provider, 'get_session_volume_profile'):
|
||||
return self.cob_integration.cob_provider.get_session_volume_profile(symbol)
|
||||
return None
|
||||
|
||||
def get_realtime_stats_for_nn(self, symbol: str) -> Dict:
|
||||
"""Get real-time statistics formatted for NN models"""
|
||||
if self.cob_integration:
|
||||
return self.cob_integration.get_realtime_stats_for_nn(symbol)
|
||||
return {}
|
||||
|
||||
def get_service_statistics(self) -> Dict[str, Any]:
|
||||
"""Get service statistics"""
|
||||
uptime = None
|
||||
if self.start_time:
|
||||
uptime = (datetime.now() - self.start_time).total_seconds()
|
||||
|
||||
base_stats = {
|
||||
'service_name': 'SharedCOBService',
|
||||
'is_running': self.is_running,
|
||||
'symbols': self.symbols,
|
||||
'total_subscribers': len(self.subscribers),
|
||||
'lifetime_subscribers': self.total_subscribers,
|
||||
'update_count': self.update_count,
|
||||
'uptime_seconds': uptime,
|
||||
'subscribers_by_type': {}
|
||||
}
|
||||
|
||||
# Count subscribers by type
|
||||
for subscription in self.subscribers.values():
|
||||
callback_type = subscription.callback_type
|
||||
if callback_type not in base_stats['subscribers_by_type']:
|
||||
base_stats['subscribers_by_type'][callback_type] = 0
|
||||
base_stats['subscribers_by_type'][callback_type] += 1
|
||||
|
||||
# Get COB integration stats if available
|
||||
if self.cob_integration:
|
||||
cob_stats = self.cob_integration.get_statistics()
|
||||
base_stats.update(cob_stats)
|
||||
|
||||
return base_stats
|
||||
|
||||
# Global service instance access functions
|
||||
|
||||
def get_shared_cob_service(symbols: List[str] = None, data_provider: DataProvider = None) -> SharedCOBService:
|
||||
"""Get the shared COB service instance"""
|
||||
return SharedCOBService(symbols=symbols, data_provider=data_provider)
|
||||
|
||||
async def start_shared_cob_service(symbols: List[str] = None, data_provider: DataProvider = None) -> SharedCOBService:
|
||||
"""Start the shared COB service"""
|
||||
service = get_shared_cob_service(symbols=symbols, data_provider=data_provider)
|
||||
await service.start()
|
||||
return service
|
||||
|
||||
async def stop_shared_cob_service():
|
||||
"""Stop the shared COB service"""
|
||||
service = get_shared_cob_service()
|
||||
await service.stop()
|
||||
@@ -1,425 +0,0 @@
|
||||
"""
|
||||
Shared Data Manager for UI Stability Fix
|
||||
|
||||
Manages data sharing between processes through files with proper locking
|
||||
and atomic operations to prevent corruption and conflicts.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import tempfile
|
||||
import platform
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
# Windows-compatible file locking
|
||||
if platform.system() == "Windows":
|
||||
import msvcrt
|
||||
else:
|
||||
import fcntl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ProcessStatus:
|
||||
"""Model for process status information"""
|
||||
name: str
|
||||
pid: int
|
||||
status: str # 'running', 'stopped', 'error'
|
||||
start_time: datetime
|
||||
last_heartbeat: datetime
|
||||
memory_usage: float
|
||||
cpu_usage: float
|
||||
error_message: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary with datetime serialization"""
|
||||
data = asdict(self)
|
||||
data['start_time'] = self.start_time.isoformat()
|
||||
data['last_heartbeat'] = self.last_heartbeat.isoformat()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ProcessStatus':
|
||||
"""Create from dictionary with datetime deserialization"""
|
||||
data['start_time'] = datetime.fromisoformat(data['start_time'])
|
||||
data['last_heartbeat'] = datetime.fromisoformat(data['last_heartbeat'])
|
||||
return cls(**data)
|
||||
|
||||
@dataclass
|
||||
class TrainingStatus:
|
||||
"""Model for training status information"""
|
||||
is_running: bool
|
||||
current_epoch: int
|
||||
total_epochs: int
|
||||
loss: float
|
||||
accuracy: float
|
||||
last_update: datetime
|
||||
model_path: str
|
||||
error_message: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary with datetime serialization"""
|
||||
data = asdict(self)
|
||||
data['last_update'] = self.last_update.isoformat()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'TrainingStatus':
|
||||
"""Create from dictionary with datetime deserialization"""
|
||||
data['last_update'] = datetime.fromisoformat(data['last_update'])
|
||||
return cls(**data)
|
||||
|
||||
@dataclass
|
||||
class DashboardState:
|
||||
"""Model for dashboard state information"""
|
||||
is_connected: bool
|
||||
last_data_update: datetime
|
||||
active_connections: int
|
||||
error_count: int
|
||||
performance_metrics: Dict[str, float]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary with datetime serialization"""
|
||||
data = asdict(self)
|
||||
data['last_data_update'] = self.last_data_update.isoformat()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'DashboardState':
|
||||
"""Create from dictionary with datetime deserialization"""
|
||||
data['last_data_update'] = datetime.fromisoformat(data['last_data_update'])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class SharedDataManager:
|
||||
"""
|
||||
Manages data sharing between processes through files with proper locking
|
||||
and atomic operations to prevent corruption and conflicts.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str = "shared_data"):
|
||||
"""
|
||||
Initialize the shared data manager
|
||||
|
||||
Args:
|
||||
data_dir: Directory to store shared data files
|
||||
"""
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Define file paths for different data types
|
||||
self.training_status_file = self.data_dir / "training_status.json"
|
||||
self.dashboard_state_file = self.data_dir / "dashboard_state.json"
|
||||
self.process_status_file = self.data_dir / "process_status.json"
|
||||
self.market_data_file = self.data_dir / "market_data.json"
|
||||
self.model_metrics_file = self.data_dir / "model_metrics.json"
|
||||
|
||||
logger.info(f"SharedDataManager initialized with data directory: {self.data_dir}")
|
||||
|
||||
def _lock_file(self, file_handle, exclusive=True):
|
||||
"""Cross-platform file locking"""
|
||||
if platform.system() == "Windows":
|
||||
# Windows file locking
|
||||
try:
|
||||
if exclusive:
|
||||
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
|
||||
else:
|
||||
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
|
||||
except IOError:
|
||||
pass # File locking may not be available in all scenarios
|
||||
else:
|
||||
# Unix file locking
|
||||
lock_type = fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH
|
||||
fcntl.flock(file_handle.fileno(), lock_type)
|
||||
|
||||
def _unlock_file(self, file_handle):
|
||||
"""Cross-platform file unlocking"""
|
||||
if platform.system() == "Windows":
|
||||
try:
|
||||
msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
except IOError:
|
||||
pass
|
||||
else:
|
||||
fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
def _write_json_atomic(self, file_path: Path, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Write JSON data atomically with file locking
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to write
|
||||
data: Data to write as JSON
|
||||
"""
|
||||
temp_path = None
|
||||
try:
|
||||
# Create temporary file in the same directory
|
||||
temp_fd, temp_path = tempfile.mkstemp(
|
||||
dir=file_path.parent,
|
||||
prefix=f".{file_path.name}.",
|
||||
suffix=".tmp"
|
||||
)
|
||||
|
||||
with os.fdopen(temp_fd, 'w') as temp_file:
|
||||
# Lock the temporary file
|
||||
self._lock_file(temp_file, exclusive=True)
|
||||
|
||||
# Write data with proper formatting
|
||||
json.dump(data, temp_file, indent=2, default=str)
|
||||
temp_file.flush()
|
||||
os.fsync(temp_file.fileno())
|
||||
|
||||
# Unlock before closing
|
||||
self._unlock_file(temp_file)
|
||||
|
||||
# Atomically replace the original file
|
||||
os.replace(temp_path, file_path)
|
||||
logger.debug(f"Successfully wrote data to {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
# Clean up temporary file if it exists
|
||||
if temp_path:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except:
|
||||
pass
|
||||
logger.error(f"Failed to write data to {file_path}: {e}")
|
||||
raise
|
||||
|
||||
def _read_json_safe(self, file_path: Path) -> Dict[str, Any]:
|
||||
"""
|
||||
Read JSON data safely with file locking
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to read
|
||||
|
||||
Returns:
|
||||
Dictionary containing the JSON data
|
||||
"""
|
||||
if not file_path.exists():
|
||||
logger.debug(f"File {file_path} does not exist, returning empty dict")
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(file_path, 'r') as file:
|
||||
# Lock the file for reading
|
||||
self._lock_file(file, exclusive=False)
|
||||
data = json.load(file)
|
||||
self._unlock_file(file)
|
||||
logger.debug(f"Successfully read data from {file_path}")
|
||||
return data
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Invalid JSON in {file_path}: {e}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read data from {file_path}: {e}")
|
||||
return {}
|
||||
|
||||
def write_training_status(self, status: TrainingStatus) -> None:
|
||||
"""
|
||||
Write training status to shared file
|
||||
|
||||
Args:
|
||||
status: TrainingStatus object to write
|
||||
"""
|
||||
try:
|
||||
data = status.to_dict()
|
||||
self._write_json_atomic(self.training_status_file, data)
|
||||
logger.debug("Training status written successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write training status: {e}")
|
||||
raise
|
||||
|
||||
def read_training_status(self) -> Optional[TrainingStatus]:
|
||||
"""
|
||||
Read training status from shared file
|
||||
|
||||
Returns:
|
||||
TrainingStatus object or None if not available
|
||||
"""
|
||||
try:
|
||||
data = self._read_json_safe(self.training_status_file)
|
||||
if not data:
|
||||
return None
|
||||
return TrainingStatus.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read training status: {e}")
|
||||
return None
|
||||
|
||||
def write_dashboard_state(self, state: DashboardState) -> None:
|
||||
"""
|
||||
Write dashboard state to shared file
|
||||
|
||||
Args:
|
||||
state: DashboardState object to write
|
||||
"""
|
||||
try:
|
||||
data = state.to_dict()
|
||||
self._write_json_atomic(self.dashboard_state_file, data)
|
||||
logger.debug("Dashboard state written successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write dashboard state: {e}")
|
||||
raise
|
||||
|
||||
def read_dashboard_state(self) -> Optional[DashboardState]:
|
||||
"""
|
||||
Read dashboard state from shared file
|
||||
|
||||
Returns:
|
||||
DashboardState object or None if not available
|
||||
"""
|
||||
try:
|
||||
data = self._read_json_safe(self.dashboard_state_file)
|
||||
if not data:
|
||||
return None
|
||||
return DashboardState.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read dashboard state: {e}")
|
||||
return None
|
||||
|
||||
def write_process_status(self, status: ProcessStatus) -> None:
|
||||
"""
|
||||
Write process status to shared file
|
||||
|
||||
Args:
|
||||
status: ProcessStatus object to write
|
||||
"""
|
||||
try:
|
||||
data = status.to_dict()
|
||||
self._write_json_atomic(self.process_status_file, data)
|
||||
logger.debug("Process status written successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write process status: {e}")
|
||||
raise
|
||||
|
||||
def read_process_status(self) -> Optional[ProcessStatus]:
|
||||
"""
|
||||
Read process status from shared file
|
||||
|
||||
Returns:
|
||||
ProcessStatus object or None if not available
|
||||
"""
|
||||
try:
|
||||
data = self._read_json_safe(self.process_status_file)
|
||||
if not data:
|
||||
return None
|
||||
return ProcessStatus.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read process status: {e}")
|
||||
return None
|
||||
|
||||
def write_market_data(self, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Write market data to shared file
|
||||
|
||||
Args:
|
||||
data: Market data dictionary to write
|
||||
"""
|
||||
try:
|
||||
# Add timestamp to market data
|
||||
data['timestamp'] = datetime.now().isoformat()
|
||||
self._write_json_atomic(self.market_data_file, data)
|
||||
logger.debug("Market data written successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write market data: {e}")
|
||||
raise
|
||||
|
||||
def read_market_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Read market data from shared file
|
||||
|
||||
Returns:
|
||||
Dictionary containing market data
|
||||
"""
|
||||
try:
|
||||
return self._read_json_safe(self.market_data_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read market data: {e}")
|
||||
return {}
|
||||
|
||||
def write_model_metrics(self, metrics: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Write model metrics to shared file
|
||||
|
||||
Args:
|
||||
metrics: Model metrics dictionary to write
|
||||
"""
|
||||
try:
|
||||
# Add timestamp to metrics
|
||||
metrics['timestamp'] = datetime.now().isoformat()
|
||||
self._write_json_atomic(self.model_metrics_file, metrics)
|
||||
logger.debug("Model metrics written successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write model metrics: {e}")
|
||||
raise
|
||||
|
||||
def read_model_metrics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Read model metrics from shared file
|
||||
|
||||
Returns:
|
||||
Dictionary containing model metrics
|
||||
"""
|
||||
try:
|
||||
return self._read_json_safe(self.model_metrics_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read model metrics: {e}")
|
||||
return {}
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""
|
||||
Clean up shared data files
|
||||
"""
|
||||
try:
|
||||
for file_path in [
|
||||
self.training_status_file,
|
||||
self.dashboard_state_file,
|
||||
self.process_status_file,
|
||||
self.market_data_file,
|
||||
self.model_metrics_file
|
||||
]:
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
logger.debug(f"Removed {file_path}")
|
||||
|
||||
# Remove directory if empty
|
||||
if self.data_dir.exists() and not any(self.data_dir.iterdir()):
|
||||
self.data_dir.rmdir()
|
||||
logger.debug(f"Removed empty directory {self.data_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup shared data: {e}")
|
||||
|
||||
def get_data_age(self, data_type: str) -> Optional[float]:
|
||||
"""
|
||||
Get the age of data in seconds
|
||||
|
||||
Args:
|
||||
data_type: Type of data ('training', 'dashboard', 'process', 'market', 'metrics')
|
||||
|
||||
Returns:
|
||||
Age in seconds or None if file doesn't exist
|
||||
"""
|
||||
file_map = {
|
||||
'training': self.training_status_file,
|
||||
'dashboard': self.dashboard_state_file,
|
||||
'process': self.process_status_file,
|
||||
'market': self.market_data_file,
|
||||
'metrics': self.model_metrics_file
|
||||
}
|
||||
|
||||
file_path = file_map.get(data_type)
|
||||
if not file_path or not file_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
mtime = file_path.stat().st_mtime
|
||||
return time.time() - mtime
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get data age for {data_type}: {e}")
|
||||
return None
|
||||
@@ -1,59 +0,0 @@
|
||||
"""
|
||||
Trading Action Module
|
||||
|
||||
Defines the TradingAction class used throughout the trading system.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
|
||||
@dataclass
|
||||
class TradingAction:
|
||||
"""Represents a trading action with full context"""
|
||||
symbol: str
|
||||
action: str # 'BUY', 'SELL', 'HOLD'
|
||||
quantity: float
|
||||
confidence: float
|
||||
price: float
|
||||
timestamp: datetime
|
||||
reasoning: Dict[str, Any]
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate the trading action after initialization"""
|
||||
if self.action not in ['BUY', 'SELL', 'HOLD']:
|
||||
raise ValueError(f"Invalid action: {self.action}. Must be 'BUY', 'SELL', or 'HOLD'")
|
||||
|
||||
if self.confidence < 0.0 or self.confidence > 1.0:
|
||||
raise ValueError(f"Invalid confidence: {self.confidence}. Must be between 0.0 and 1.0")
|
||||
|
||||
if self.quantity < 0:
|
||||
raise ValueError(f"Invalid quantity: {self.quantity}. Must be non-negative")
|
||||
|
||||
if self.price <= 0:
|
||||
raise ValueError(f"Invalid price: {self.price}. Must be positive")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert trading action to dictionary"""
|
||||
return {
|
||||
'symbol': self.symbol,
|
||||
'action': self.action,
|
||||
'quantity': self.quantity,
|
||||
'confidence': self.confidence,
|
||||
'price': self.price,
|
||||
'timestamp': self.timestamp.isoformat(),
|
||||
'reasoning': self.reasoning
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'TradingAction':
|
||||
"""Create trading action from dictionary"""
|
||||
return cls(
|
||||
symbol=data['symbol'],
|
||||
action=data['action'],
|
||||
quantity=data['quantity'],
|
||||
confidence=data['confidence'],
|
||||
price=data['price'],
|
||||
timestamp=datetime.fromisoformat(data['timestamp']),
|
||||
reasoning=data['reasoning']
|
||||
)
|
||||
@@ -1,401 +0,0 @@
|
||||
"""
|
||||
Trading Executor Fix - Addresses issues with entry/exit prices and P&L calculations
|
||||
|
||||
This module provides fixes for:
|
||||
1. Identical entry prices issue
|
||||
2. Price caching problems
|
||||
3. Position tracking reset logic
|
||||
4. Trade cooldown implementation
|
||||
5. P&L calculation verification
|
||||
|
||||
Apply these fixes to the TradingExecutor class to improve trade execution reliability.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TradingExecutorFix:
|
||||
"""
|
||||
Fixes for the TradingExecutor class to address entry/exit price issues
|
||||
and improve P&L calculation accuracy.
|
||||
"""
|
||||
|
||||
def __init__(self, trading_executor):
|
||||
"""
|
||||
Initialize the fix with a reference to the trading executor
|
||||
|
||||
Args:
|
||||
trading_executor: The TradingExecutor instance to fix
|
||||
"""
|
||||
self.trading_executor = trading_executor
|
||||
|
||||
# Add cooldown tracking
|
||||
self.last_trade_time = {} # {symbol: timestamp}
|
||||
self.min_trade_cooldown = 30 # 30 seconds minimum between trades
|
||||
|
||||
# Add price history for validation
|
||||
self.recent_entry_prices = {} # {symbol: [recent_prices]}
|
||||
self.max_price_history = 10 # Keep last 10 entry prices
|
||||
|
||||
# Add position reset tracking
|
||||
self.position_reset_flags = {} # {symbol: bool}
|
||||
|
||||
# Add price update tracking
|
||||
self.last_price_update = {} # {symbol: timestamp}
|
||||
self.price_update_threshold = 5 # 5 seconds max since last price update
|
||||
|
||||
# Add P&L verification
|
||||
self.trade_history = {} # {symbol: [trade_records]}
|
||||
|
||||
logger.info("TradingExecutorFix initialized - addressing entry/exit price issues")
|
||||
|
||||
def apply_fixes(self):
|
||||
"""Apply all fixes to the trading executor"""
|
||||
self._patch_execute_action()
|
||||
self._patch_close_position()
|
||||
self._patch_calculate_pnl()
|
||||
self._patch_update_prices()
|
||||
|
||||
logger.info("All trading executor fixes applied successfully")
|
||||
|
||||
def _patch_execute_action(self):
|
||||
"""Patch the execute_action method to add price validation and cooldown"""
|
||||
original_execute_action = self.trading_executor.execute_action
|
||||
|
||||
def execute_action_with_fixes(decision):
|
||||
"""Enhanced execute_action with price validation and cooldown"""
|
||||
try:
|
||||
symbol = decision.symbol
|
||||
action = decision.action
|
||||
current_time = datetime.now()
|
||||
|
||||
# 1. Check cooldown period
|
||||
if symbol in self.last_trade_time:
|
||||
time_since_last_trade = (current_time - self.last_trade_time[symbol]).total_seconds()
|
||||
if time_since_last_trade < self.min_trade_cooldown:
|
||||
logger.warning(f"Trade rejected: Cooldown period ({time_since_last_trade:.1f}s < {self.min_trade_cooldown}s) for {symbol}")
|
||||
return False
|
||||
|
||||
# 2. Validate price freshness
|
||||
if symbol in self.last_price_update:
|
||||
time_since_update = (current_time - self.last_price_update[symbol]).total_seconds()
|
||||
if time_since_update > self.price_update_threshold:
|
||||
logger.warning(f"Trade rejected: Price data stale ({time_since_update:.1f}s > {self.price_update_threshold}s) for {symbol}")
|
||||
# Force price refresh
|
||||
self._refresh_price(symbol)
|
||||
return False
|
||||
|
||||
# 3. Validate entry price against recent history
|
||||
current_price = self._get_current_price(symbol)
|
||||
if symbol in self.recent_entry_prices and len(self.recent_entry_prices[symbol]) > 0:
|
||||
# Check if price is identical to any recent entry
|
||||
if current_price in self.recent_entry_prices[symbol]:
|
||||
logger.warning(f"Trade rejected: Duplicate entry price ${current_price} for {symbol}")
|
||||
return False
|
||||
|
||||
# 4. Ensure position is properly reset before new entry
|
||||
if not self._ensure_position_reset(symbol):
|
||||
logger.warning(f"Trade rejected: Position not properly reset for {symbol}")
|
||||
return False
|
||||
|
||||
# Execute the original action
|
||||
result = original_execute_action(decision)
|
||||
|
||||
# If successful, update tracking
|
||||
if result:
|
||||
# Update cooldown timestamp
|
||||
self.last_trade_time[symbol] = current_time
|
||||
|
||||
# Update price history
|
||||
if symbol not in self.recent_entry_prices:
|
||||
self.recent_entry_prices[symbol] = []
|
||||
|
||||
self.recent_entry_prices[symbol].append(current_price)
|
||||
# Keep only the most recent prices
|
||||
if len(self.recent_entry_prices[symbol]) > self.max_price_history:
|
||||
self.recent_entry_prices[symbol] = self.recent_entry_prices[symbol][-self.max_price_history:]
|
||||
|
||||
# Mark position as active
|
||||
self.position_reset_flags[symbol] = False
|
||||
|
||||
logger.info(f"Trade executed: {action} {symbol} at ${current_price} with validation")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in execute_action_with_fixes: {e}")
|
||||
return original_execute_action(decision)
|
||||
|
||||
# Replace the original method
|
||||
self.trading_executor.execute_action = execute_action_with_fixes
|
||||
logger.info("Patched execute_action with price validation and cooldown")
|
||||
|
||||
def _patch_close_position(self):
|
||||
"""Patch the close_position method to ensure proper position reset"""
|
||||
original_close_position = self.trading_executor.close_position
|
||||
|
||||
def close_position_with_fixes(symbol, **kwargs):
|
||||
"""Enhanced close_position with proper reset logic"""
|
||||
try:
|
||||
# Get current price for P&L verification
|
||||
exit_price = self._get_current_price(symbol)
|
||||
|
||||
# Call original close position
|
||||
result = original_close_position(symbol, **kwargs)
|
||||
|
||||
if result:
|
||||
# Mark position as reset
|
||||
self.position_reset_flags[symbol] = True
|
||||
|
||||
# Record trade for verification
|
||||
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
|
||||
position = self.trading_executor.positions[symbol]
|
||||
|
||||
# Create trade record
|
||||
trade_record = {
|
||||
'symbol': symbol,
|
||||
'entry_time': getattr(position, 'entry_time', datetime.now()),
|
||||
'exit_time': datetime.now(),
|
||||
'entry_price': getattr(position, 'entry_price', 0),
|
||||
'exit_price': exit_price,
|
||||
'size': getattr(position, 'size', 0),
|
||||
'side': getattr(position, 'side', 'UNKNOWN'),
|
||||
'pnl': self._calculate_verified_pnl(position, exit_price),
|
||||
'fees': getattr(position, 'fees', 0),
|
||||
'hold_time_seconds': (datetime.now() - getattr(position, 'entry_time', datetime.now())).total_seconds()
|
||||
}
|
||||
|
||||
# Store trade record
|
||||
if symbol not in self.trade_history:
|
||||
self.trade_history[symbol] = []
|
||||
self.trade_history[symbol].append(trade_record)
|
||||
|
||||
logger.info(f"Position closed: {symbol} at ${exit_price} with verified P&L: ${trade_record['pnl']:.2f}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in close_position_with_fixes: {e}")
|
||||
return original_close_position(symbol, **kwargs)
|
||||
|
||||
# Replace the original method
|
||||
self.trading_executor.close_position = close_position_with_fixes
|
||||
logger.info("Patched close_position with proper reset logic")
|
||||
|
||||
def _patch_calculate_pnl(self):
|
||||
"""Patch the calculate_pnl method to ensure accurate P&L calculation"""
|
||||
original_calculate_pnl = getattr(self.trading_executor, 'calculate_pnl', None)
|
||||
|
||||
def calculate_pnl_with_fixes(position, current_price=None):
|
||||
"""Enhanced calculate_pnl with verification"""
|
||||
try:
|
||||
# If no original method, implement our own
|
||||
if original_calculate_pnl is None:
|
||||
return self._calculate_verified_pnl(position, current_price)
|
||||
|
||||
# Call original method
|
||||
original_pnl = original_calculate_pnl(position, current_price)
|
||||
|
||||
# Calculate our verified P&L
|
||||
verified_pnl = self._calculate_verified_pnl(position, current_price)
|
||||
|
||||
# If there's a significant difference, log it
|
||||
if abs(original_pnl - verified_pnl) > 0.01:
|
||||
logger.warning(f"P&L calculation discrepancy: original=${original_pnl:.2f}, verified=${verified_pnl:.2f}")
|
||||
# Use the verified P&L
|
||||
return verified_pnl
|
||||
|
||||
return original_pnl
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in calculate_pnl_with_fixes: {e}")
|
||||
if original_calculate_pnl:
|
||||
return original_calculate_pnl(position, current_price)
|
||||
return 0.0
|
||||
|
||||
# Replace the original method if it exists
|
||||
if original_calculate_pnl:
|
||||
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
|
||||
logger.info("Patched calculate_pnl with verification")
|
||||
else:
|
||||
# Add the method if it doesn't exist
|
||||
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
|
||||
logger.info("Added calculate_pnl method with verification")
|
||||
|
||||
def _patch_update_prices(self):
|
||||
"""Patch the update_prices method to track price updates"""
|
||||
original_update_prices = getattr(self.trading_executor, 'update_prices', None)
|
||||
|
||||
def update_prices_with_tracking(prices):
|
||||
"""Enhanced update_prices with timestamp tracking"""
|
||||
try:
|
||||
# Call original method if it exists
|
||||
if original_update_prices:
|
||||
result = original_update_prices(prices)
|
||||
else:
|
||||
# If no original method, update prices directly
|
||||
if hasattr(self.trading_executor, 'current_prices'):
|
||||
self.trading_executor.current_prices.update(prices)
|
||||
result = True
|
||||
|
||||
# Track update timestamps
|
||||
current_time = datetime.now()
|
||||
for symbol in prices:
|
||||
self.last_price_update[symbol] = current_time
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in update_prices_with_tracking: {e}")
|
||||
if original_update_prices:
|
||||
return original_update_prices(prices)
|
||||
return False
|
||||
|
||||
# Replace the original method if it exists
|
||||
if original_update_prices:
|
||||
self.trading_executor.update_prices = update_prices_with_tracking
|
||||
logger.info("Patched update_prices with timestamp tracking")
|
||||
else:
|
||||
# Add the method if it doesn't exist
|
||||
self.trading_executor.update_prices = update_prices_with_tracking
|
||||
logger.info("Added update_prices method with timestamp tracking")
|
||||
|
||||
def _calculate_verified_pnl(self, position, current_price=None):
|
||||
"""Calculate verified P&L for a position"""
|
||||
try:
|
||||
# Get position details
|
||||
entry_price = getattr(position, 'entry_price', 0)
|
||||
size = getattr(position, 'size', 0)
|
||||
side = getattr(position, 'side', 'UNKNOWN')
|
||||
leverage = getattr(position, 'leverage', 1.0)
|
||||
fees = getattr(position, 'fees', 0.0)
|
||||
|
||||
# If current_price is not provided, try to get it
|
||||
if current_price is None:
|
||||
symbol = getattr(position, 'symbol', None)
|
||||
if symbol:
|
||||
current_price = self._get_current_price(symbol)
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
# Calculate P&L based on position side
|
||||
if side == 'LONG':
|
||||
pnl = (current_price - entry_price) * size * leverage
|
||||
elif side == 'SHORT':
|
||||
pnl = (entry_price - current_price) * size * leverage
|
||||
else:
|
||||
pnl = 0.0
|
||||
|
||||
# Subtract fees for net P&L
|
||||
net_pnl = pnl - fees
|
||||
|
||||
return net_pnl
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating verified P&L: {e}")
|
||||
return 0.0
|
||||
|
||||
def _get_current_price(self, symbol):
|
||||
"""Get current price for a symbol with fallbacks"""
|
||||
try:
|
||||
# Try to get from trading executor
|
||||
if hasattr(self.trading_executor, 'current_prices') and symbol in self.trading_executor.current_prices:
|
||||
return self.trading_executor.current_prices[symbol]
|
||||
|
||||
# Try to get from data provider
|
||||
if hasattr(self.trading_executor, 'data_provider'):
|
||||
data_provider = self.trading_executor.data_provider
|
||||
if hasattr(data_provider, 'get_current_price'):
|
||||
price = data_provider.get_current_price(symbol)
|
||||
if price and price > 0:
|
||||
return price
|
||||
|
||||
# Try to get from COB data
|
||||
if hasattr(self.trading_executor, 'latest_cob_data') and symbol in self.trading_executor.latest_cob_data:
|
||||
cob_data = self.trading_executor.latest_cob_data[symbol]
|
||||
if hasattr(cob_data, 'stats') and 'mid_price' in cob_data.stats:
|
||||
return cob_data.stats['mid_price']
|
||||
|
||||
# Default fallback
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current price for {symbol}: {e}")
|
||||
return 0.0
|
||||
|
||||
def _refresh_price(self, symbol):
|
||||
"""Force a price refresh for a symbol"""
|
||||
try:
|
||||
# Try to refresh from data provider
|
||||
if hasattr(self.trading_executor, 'data_provider'):
|
||||
data_provider = self.trading_executor.data_provider
|
||||
if hasattr(data_provider, 'fetch_current_price'):
|
||||
price = data_provider.fetch_current_price(symbol)
|
||||
if price and price > 0:
|
||||
# Update trading executor price
|
||||
if hasattr(self.trading_executor, 'current_prices'):
|
||||
self.trading_executor.current_prices[symbol] = price
|
||||
|
||||
# Update timestamp
|
||||
self.last_price_update[symbol] = datetime.now()
|
||||
|
||||
logger.info(f"Refreshed price for {symbol}: ${price:.2f}")
|
||||
return True
|
||||
|
||||
logger.warning(f"Failed to refresh price for {symbol}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing price for {symbol}: {e}")
|
||||
return False
|
||||
|
||||
def _ensure_position_reset(self, symbol):
|
||||
"""Ensure position is properly reset before new entry"""
|
||||
try:
|
||||
# Check if we have an active position
|
||||
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
|
||||
# Position exists, check if it's valid
|
||||
position = self.trading_executor.positions[symbol]
|
||||
if position and getattr(position, 'active', False):
|
||||
logger.warning(f"Position already active for {symbol}, cannot enter new position")
|
||||
return False
|
||||
|
||||
# Check reset flag
|
||||
if symbol in self.position_reset_flags and not self.position_reset_flags[symbol]:
|
||||
# Force position cleanup
|
||||
if hasattr(self.trading_executor, 'positions'):
|
||||
self.trading_executor.positions.pop(symbol, None)
|
||||
|
||||
logger.info(f"Forced position reset for {symbol}")
|
||||
self.position_reset_flags[symbol] = True
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring position reset for {symbol}: {e}")
|
||||
return False
|
||||
|
||||
def get_trade_history(self, symbol=None):
|
||||
"""Get verified trade history"""
|
||||
if symbol:
|
||||
return self.trade_history.get(symbol, [])
|
||||
return self.trade_history
|
||||
|
||||
def get_price_update_status(self):
|
||||
"""Get price update status for all symbols"""
|
||||
status = {}
|
||||
current_time = datetime.now()
|
||||
|
||||
for symbol, timestamp in self.last_price_update.items():
|
||||
time_since_update = (current_time - timestamp).total_seconds()
|
||||
status[symbol] = {
|
||||
'last_update': timestamp,
|
||||
'seconds_ago': time_since_update,
|
||||
'is_fresh': time_since_update <= self.price_update_threshold
|
||||
}
|
||||
|
||||
return status
|
||||
@@ -1,795 +0,0 @@
|
||||
"""
|
||||
Comprehensive Training Data Collection System
|
||||
|
||||
This module implements a robust training data collection system that:
|
||||
1. Captures all model inputs with validation and completeness checks
|
||||
2. Stores training data packages with future outcome validation
|
||||
3. Detects rapid price changes for high-value training examples
|
||||
4. Enables replay and retraining on most profitable setups
|
||||
5. Maintains data integrity and traceability
|
||||
|
||||
Key Features:
|
||||
- Real-time data package creation with all model inputs
|
||||
- Future outcome validation (profitable vs unprofitable predictions)
|
||||
- Rapid price change detection for premium training examples
|
||||
- Comprehensive data validation and completeness verification
|
||||
- Backpropagation data storage for gradient replay
|
||||
- Training episode profitability tracking and ranking
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pickle
|
||||
import torch
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from collections import deque
|
||||
import hashlib
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ModelInputPackage:
|
||||
"""Complete package of all model inputs at a specific timestamp"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
|
||||
# Market data inputs
|
||||
ohlcv_data: Dict[str, pd.DataFrame] # {timeframe: DataFrame}
|
||||
tick_data: List[Dict[str, Any]] # Raw tick data
|
||||
cob_data: Dict[str, Any] # Consolidated Order Book data
|
||||
technical_indicators: Dict[str, float] # All technical indicators
|
||||
pivot_points: List[Dict[str, Any]] # Detected pivot points
|
||||
|
||||
# Model-specific inputs
|
||||
cnn_features: np.ndarray # CNN input features
|
||||
rl_state: np.ndarray # RL state representation
|
||||
orchestrator_context: Dict[str, Any] # Orchestrator context
|
||||
|
||||
# Cross-model inputs (outputs from other models)
|
||||
cnn_predictions: Optional[Dict[str, Any]] = None
|
||||
rl_predictions: Optional[Dict[str, Any]] = None
|
||||
orchestrator_decision: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Data validation
|
||||
data_hash: str = ""
|
||||
completeness_score: float = 0.0
|
||||
validation_flags: Dict[str, bool] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Calculate data hash and completeness after initialization"""
|
||||
self.data_hash = self._calculate_hash()
|
||||
self.completeness_score = self._calculate_completeness()
|
||||
self.validation_flags = self._validate_data()
|
||||
|
||||
def _calculate_hash(self) -> str:
|
||||
"""Calculate hash for data integrity verification"""
|
||||
try:
|
||||
# Create a string representation of all data
|
||||
data_str = f"{self.timestamp}_{self.symbol}"
|
||||
data_str += f"_{len(self.ohlcv_data)}_{len(self.tick_data)}"
|
||||
data_str += f"_{self.cnn_features.shape if self.cnn_features is not None else 'None'}"
|
||||
data_str += f"_{self.rl_state.shape if self.rl_state is not None else 'None'}"
|
||||
|
||||
return hashlib.md5(data_str.encode()).hexdigest()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating data hash: {e}")
|
||||
return "invalid_hash"
|
||||
|
||||
def _calculate_completeness(self) -> float:
|
||||
"""Calculate completeness score (0.0 to 1.0)"""
|
||||
try:
|
||||
total_fields = 10 # Total expected data fields
|
||||
complete_fields = 0
|
||||
|
||||
# Check each required field
|
||||
if self.ohlcv_data and len(self.ohlcv_data) > 0:
|
||||
complete_fields += 1
|
||||
if self.tick_data and len(self.tick_data) > 0:
|
||||
complete_fields += 1
|
||||
if self.cob_data and len(self.cob_data) > 0:
|
||||
complete_fields += 1
|
||||
if self.technical_indicators and len(self.technical_indicators) > 0:
|
||||
complete_fields += 1
|
||||
if self.pivot_points and len(self.pivot_points) > 0:
|
||||
complete_fields += 1
|
||||
if self.cnn_features is not None and self.cnn_features.size > 0:
|
||||
complete_fields += 1
|
||||
if self.rl_state is not None and self.rl_state.size > 0:
|
||||
complete_fields += 1
|
||||
if self.orchestrator_context and len(self.orchestrator_context) > 0:
|
||||
complete_fields += 1
|
||||
if self.cnn_predictions is not None:
|
||||
complete_fields += 1
|
||||
if self.rl_predictions is not None:
|
||||
complete_fields += 1
|
||||
|
||||
return complete_fields / total_fields
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating completeness: {e}")
|
||||
return 0.0
|
||||
|
||||
def _validate_data(self) -> Dict[str, bool]:
|
||||
"""Validate data integrity and consistency"""
|
||||
flags = {}
|
||||
|
||||
try:
|
||||
# Validate timestamp
|
||||
flags['valid_timestamp'] = isinstance(self.timestamp, datetime)
|
||||
|
||||
# Validate OHLCV data
|
||||
flags['valid_ohlcv'] = (
|
||||
self.ohlcv_data is not None and
|
||||
len(self.ohlcv_data) > 0 and
|
||||
all(isinstance(df, pd.DataFrame) for df in self.ohlcv_data.values())
|
||||
)
|
||||
|
||||
# Validate feature arrays
|
||||
flags['valid_cnn_features'] = (
|
||||
self.cnn_features is not None and
|
||||
isinstance(self.cnn_features, np.ndarray) and
|
||||
self.cnn_features.size > 0
|
||||
)
|
||||
|
||||
flags['valid_rl_state'] = (
|
||||
self.rl_state is not None and
|
||||
isinstance(self.rl_state, np.ndarray) and
|
||||
self.rl_state.size > 0
|
||||
)
|
||||
|
||||
# Validate data consistency
|
||||
flags['data_consistent'] = self.completeness_score > 0.7
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error validating data: {e}")
|
||||
flags['validation_error'] = True
|
||||
|
||||
return flags
|
||||
|
||||
@dataclass
|
||||
class TrainingOutcome:
|
||||
"""Future outcome validation for training data"""
|
||||
input_package_hash: str
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
|
||||
# Price movement outcomes
|
||||
price_change_1m: float
|
||||
price_change_5m: float
|
||||
price_change_15m: float
|
||||
price_change_1h: float
|
||||
|
||||
# Profitability metrics
|
||||
max_profit_potential: float
|
||||
max_loss_potential: float
|
||||
optimal_entry_price: float
|
||||
optimal_exit_price: float
|
||||
optimal_holding_time: timedelta
|
||||
|
||||
# Classification labels
|
||||
is_profitable: bool
|
||||
profitability_score: float # 0.0 to 1.0
|
||||
risk_reward_ratio: float
|
||||
|
||||
# Rapid price change detection
|
||||
is_rapid_change: bool
|
||||
change_velocity: float # Price change per minute
|
||||
volatility_spike: bool
|
||||
|
||||
# Validation
|
||||
outcome_validated: bool = False
|
||||
validation_timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
@dataclass
|
||||
class TrainingEpisode:
|
||||
"""Complete training episode with inputs, predictions, and outcomes"""
|
||||
episode_id: str
|
||||
input_package: ModelInputPackage
|
||||
model_predictions: Dict[str, Any] # Predictions from all models
|
||||
actual_outcome: TrainingOutcome
|
||||
|
||||
# Training metadata
|
||||
episode_type: str # 'normal', 'rapid_change', 'high_profit'
|
||||
profitability_rank: float # Ranking among all episodes
|
||||
training_priority: float # Priority for replay training
|
||||
|
||||
# Backpropagation data storage
|
||||
gradient_data: Optional[Dict[str, torch.Tensor]] = None
|
||||
loss_components: Optional[Dict[str, float]] = None
|
||||
model_states: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Episode statistics
|
||||
created_timestamp: datetime = field(default_factory=datetime.now)
|
||||
last_trained_timestamp: Optional[datetime] = None
|
||||
training_count: int = 0
|
||||
|
||||
def calculate_training_priority(self) -> float:
|
||||
"""Calculate training priority based on profitability and characteristics"""
|
||||
try:
|
||||
priority = 0.0
|
||||
|
||||
# Base priority from profitability
|
||||
if self.actual_outcome.is_profitable:
|
||||
priority += self.actual_outcome.profitability_score * 0.4
|
||||
|
||||
# Bonus for rapid changes (high learning value)
|
||||
if self.actual_outcome.is_rapid_change:
|
||||
priority += 0.3
|
||||
|
||||
# Bonus for high risk-reward ratio
|
||||
if self.actual_outcome.risk_reward_ratio > 2.0:
|
||||
priority += 0.2
|
||||
|
||||
# Bonus for data completeness
|
||||
priority += self.input_package.completeness_score * 0.1
|
||||
|
||||
# Penalty for frequent training (avoid overfitting)
|
||||
if self.training_count > 5:
|
||||
priority *= 0.8
|
||||
|
||||
return min(priority, 1.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating training priority: {e}")
|
||||
return 0.0
|
||||
|
||||
class RapidChangeDetector:
|
||||
"""Detects rapid price changes for high-value training examples"""
|
||||
|
||||
def __init__(self,
|
||||
velocity_threshold: float = 0.5, # % per minute
|
||||
volatility_multiplier: float = 3.0,
|
||||
lookback_minutes: int = 5):
|
||||
self.velocity_threshold = velocity_threshold
|
||||
self.volatility_multiplier = volatility_multiplier
|
||||
self.lookback_minutes = lookback_minutes
|
||||
|
||||
# Price history for change detection
|
||||
self.price_history: Dict[str, deque] = {}
|
||||
self.volatility_baseline: Dict[str, float] = {}
|
||||
|
||||
def add_price_point(self, symbol: str, timestamp: datetime, price: float):
|
||||
"""Add new price point for change detection"""
|
||||
if symbol not in self.price_history:
|
||||
self.price_history[symbol] = deque(maxlen=self.lookback_minutes * 60) # 1 second resolution
|
||||
self.volatility_baseline[symbol] = 0.0
|
||||
|
||||
self.price_history[symbol].append((timestamp, price))
|
||||
self._update_volatility_baseline(symbol)
|
||||
|
||||
def detect_rapid_change(self, symbol: str) -> Tuple[bool, float, bool]:
|
||||
"""
|
||||
Detect rapid price changes
|
||||
|
||||
Returns:
|
||||
(is_rapid_change, change_velocity, volatility_spike)
|
||||
"""
|
||||
if symbol not in self.price_history or len(self.price_history[symbol]) < 60:
|
||||
return False, 0.0, False
|
||||
|
||||
try:
|
||||
prices = list(self.price_history[symbol])
|
||||
|
||||
# Calculate recent velocity (last minute)
|
||||
recent_prices = prices[-60:] # Last 60 seconds
|
||||
if len(recent_prices) < 2:
|
||||
return False, 0.0, False
|
||||
|
||||
start_price = recent_prices[0][1]
|
||||
end_price = recent_prices[-1][1]
|
||||
time_diff = (recent_prices[-1][0] - recent_prices[0][0]).total_seconds() / 60.0 # minutes
|
||||
|
||||
if time_diff <= 0:
|
||||
return False, 0.0, False
|
||||
|
||||
# Calculate velocity (% change per minute)
|
||||
velocity = abs((end_price - start_price) / start_price * 100) / time_diff
|
||||
|
||||
# Check for rapid change
|
||||
is_rapid = velocity > self.velocity_threshold
|
||||
|
||||
# Check for volatility spike
|
||||
current_volatility = self._calculate_current_volatility(symbol)
|
||||
baseline_volatility = self.volatility_baseline.get(symbol, 0.0)
|
||||
volatility_spike = (
|
||||
baseline_volatility > 0 and
|
||||
current_volatility > baseline_volatility * self.volatility_multiplier
|
||||
)
|
||||
|
||||
return is_rapid, velocity, volatility_spike
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error detecting rapid change for {symbol}: {e}")
|
||||
return False, 0.0, False
|
||||
|
||||
def _update_volatility_baseline(self, symbol: str):
|
||||
"""Update volatility baseline for the symbol"""
|
||||
try:
|
||||
if len(self.price_history[symbol]) < 120: # Need at least 2 minutes of data
|
||||
return
|
||||
|
||||
# Calculate rolling volatility over longer period
|
||||
prices = [p[1] for p in list(self.price_history[symbol])[-300:]] # Last 5 minutes
|
||||
if len(prices) < 2:
|
||||
return
|
||||
|
||||
# Calculate standard deviation of price changes
|
||||
price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))]
|
||||
volatility = np.std(price_changes) * 100 # Convert to percentage
|
||||
|
||||
# Update baseline with exponential moving average
|
||||
alpha = 0.1
|
||||
if self.volatility_baseline[symbol] == 0:
|
||||
self.volatility_baseline[symbol] = volatility
|
||||
else:
|
||||
self.volatility_baseline[symbol] = (
|
||||
alpha * volatility + (1 - alpha) * self.volatility_baseline[symbol]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error updating volatility baseline for {symbol}: {e}")
|
||||
|
||||
def _calculate_current_volatility(self, symbol: str) -> float:
|
||||
"""Calculate current volatility for the symbol"""
|
||||
try:
|
||||
if len(self.price_history[symbol]) < 60:
|
||||
return 0.0
|
||||
|
||||
# Use last minute of data
|
||||
recent_prices = [p[1] for p in list(self.price_history[symbol])[-60:]]
|
||||
if len(recent_prices) < 2:
|
||||
return 0.0
|
||||
|
||||
price_changes = [abs(recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1]
|
||||
for i in range(1, len(recent_prices))]
|
||||
return np.std(price_changes) * 100
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating current volatility for {symbol}: {e}")
|
||||
return 0.0
|
||||
|
||||
class TrainingDataCollector:
|
||||
"""Main training data collection system"""
|
||||
|
||||
def __init__(self,
|
||||
storage_dir: str = "training_data",
|
||||
max_episodes_per_symbol: int = 10000,
|
||||
outcome_validation_delay: timedelta = timedelta(hours=1)):
|
||||
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.max_episodes_per_symbol = max_episodes_per_symbol
|
||||
self.outcome_validation_delay = outcome_validation_delay
|
||||
|
||||
# Data storage
|
||||
self.training_episodes: Dict[str, List[TrainingEpisode]] = {} # {symbol: episodes}
|
||||
self.pending_outcomes: Dict[str, List[ModelInputPackage]] = {} # Awaiting outcome validation
|
||||
|
||||
# Rapid change detection
|
||||
self.rapid_change_detector = RapidChangeDetector()
|
||||
|
||||
# Data validation and statistics
|
||||
self.collection_stats = {
|
||||
'total_episodes': 0,
|
||||
'profitable_episodes': 0,
|
||||
'rapid_change_episodes': 0,
|
||||
'validation_errors': 0,
|
||||
'data_completeness_avg': 0.0
|
||||
}
|
||||
|
||||
# Background processing
|
||||
self.is_collecting = False
|
||||
self.collection_thread = None
|
||||
self.outcome_validation_thread = None
|
||||
|
||||
# Thread safety
|
||||
self.data_lock = threading.Lock()
|
||||
|
||||
logger.info(f"Training Data Collector initialized")
|
||||
logger.info(f"Storage directory: {self.storage_dir}")
|
||||
logger.info(f"Max episodes per symbol: {self.max_episodes_per_symbol}")
|
||||
|
||||
def start_collection(self):
|
||||
"""Start the training data collection system"""
|
||||
if self.is_collecting:
|
||||
logger.warning("Training data collection already running")
|
||||
return
|
||||
|
||||
self.is_collecting = True
|
||||
|
||||
# Start outcome validation thread
|
||||
self.outcome_validation_thread = threading.Thread(
|
||||
target=self._outcome_validation_worker,
|
||||
daemon=True
|
||||
)
|
||||
self.outcome_validation_thread.start()
|
||||
|
||||
logger.info("Training data collection started")
|
||||
|
||||
def stop_collection(self):
|
||||
"""Stop the training data collection system"""
|
||||
self.is_collecting = False
|
||||
|
||||
if self.outcome_validation_thread:
|
||||
self.outcome_validation_thread.join(timeout=5)
|
||||
|
||||
logger.info("Training data collection stopped")
|
||||
|
||||
def collect_training_data(self,
|
||||
symbol: str,
|
||||
ohlcv_data: Dict[str, pd.DataFrame],
|
||||
tick_data: List[Dict[str, Any]],
|
||||
cob_data: Dict[str, Any],
|
||||
technical_indicators: Dict[str, float],
|
||||
pivot_points: List[Dict[str, Any]],
|
||||
cnn_features: np.ndarray,
|
||||
rl_state: np.ndarray,
|
||||
orchestrator_context: Dict[str, Any],
|
||||
model_predictions: Dict[str, Any] = None) -> str:
|
||||
"""
|
||||
Collect comprehensive training data package
|
||||
|
||||
Returns:
|
||||
episode_id for tracking
|
||||
"""
|
||||
try:
|
||||
# Create input package
|
||||
input_package = ModelInputPackage(
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol,
|
||||
ohlcv_data=ohlcv_data,
|
||||
tick_data=tick_data,
|
||||
cob_data=cob_data,
|
||||
technical_indicators=technical_indicators,
|
||||
pivot_points=pivot_points,
|
||||
cnn_features=cnn_features,
|
||||
rl_state=rl_state,
|
||||
orchestrator_context=orchestrator_context
|
||||
)
|
||||
|
||||
# Validate data completeness
|
||||
if input_package.completeness_score < 0.5:
|
||||
logger.warning(f"Low data completeness for {symbol}: {input_package.completeness_score:.2f}")
|
||||
self.collection_stats['validation_errors'] += 1
|
||||
return None
|
||||
|
||||
# Check for rapid price changes
|
||||
current_price = self._extract_current_price(ohlcv_data)
|
||||
if current_price:
|
||||
self.rapid_change_detector.add_price_point(symbol, input_package.timestamp, current_price)
|
||||
|
||||
# Add to pending outcomes for future validation
|
||||
with self.data_lock:
|
||||
if symbol not in self.pending_outcomes:
|
||||
self.pending_outcomes[symbol] = []
|
||||
|
||||
self.pending_outcomes[symbol].append(input_package)
|
||||
|
||||
# Limit pending outcomes to prevent memory issues
|
||||
if len(self.pending_outcomes[symbol]) > 1000:
|
||||
self.pending_outcomes[symbol] = self.pending_outcomes[symbol][-500:]
|
||||
|
||||
# Generate episode ID
|
||||
episode_id = f"{symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}"
|
||||
|
||||
# Update statistics
|
||||
self.collection_stats['total_episodes'] += 1
|
||||
self.collection_stats['data_completeness_avg'] = (
|
||||
(self.collection_stats['data_completeness_avg'] * (self.collection_stats['total_episodes'] - 1) +
|
||||
input_package.completeness_score) / self.collection_stats['total_episodes']
|
||||
)
|
||||
|
||||
logger.debug(f"Collected training data for {symbol}: {episode_id}")
|
||||
logger.debug(f"Data completeness: {input_package.completeness_score:.2f}")
|
||||
|
||||
return episode_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting training data for {symbol}: {e}")
|
||||
self.collection_stats['validation_errors'] += 1
|
||||
return None
|
||||
|
||||
def _extract_current_price(self, ohlcv_data: Dict[str, pd.DataFrame]) -> Optional[float]:
|
||||
"""Extract current price from OHLCV data"""
|
||||
try:
|
||||
# Try to get price from shortest timeframe first
|
||||
for timeframe in ['1s', '1m', '5m', '15m', '1h']:
|
||||
if timeframe in ohlcv_data and not ohlcv_data[timeframe].empty:
|
||||
return float(ohlcv_data[timeframe]['close'].iloc[-1])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting current price: {e}")
|
||||
return None
|
||||
|
||||
def _outcome_validation_worker(self):
|
||||
"""Background worker for validating training outcomes"""
|
||||
logger.info("Outcome validation worker started")
|
||||
|
||||
while self.is_collecting:
|
||||
try:
|
||||
self._validate_pending_outcomes()
|
||||
threading.Event().wait(60) # Check every minute
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in outcome validation worker: {e}")
|
||||
threading.Event().wait(30) # Wait before retrying
|
||||
|
||||
logger.info("Outcome validation worker stopped")
|
||||
|
||||
def _validate_pending_outcomes(self):
|
||||
"""Validate outcomes for pending training data"""
|
||||
current_time = datetime.now()
|
||||
|
||||
with self.data_lock:
|
||||
for symbol in list(self.pending_outcomes.keys()):
|
||||
if symbol not in self.pending_outcomes:
|
||||
continue
|
||||
|
||||
validated_packages = []
|
||||
remaining_packages = []
|
||||
|
||||
for package in self.pending_outcomes[symbol]:
|
||||
# Check if enough time has passed for outcome validation
|
||||
if current_time - package.timestamp >= self.outcome_validation_delay:
|
||||
outcome = self._calculate_training_outcome(package)
|
||||
if outcome:
|
||||
self._create_training_episode(package, outcome)
|
||||
validated_packages.append(package)
|
||||
else:
|
||||
remaining_packages.append(package)
|
||||
else:
|
||||
remaining_packages.append(package)
|
||||
|
||||
# Update pending outcomes
|
||||
self.pending_outcomes[symbol] = remaining_packages
|
||||
|
||||
if validated_packages:
|
||||
logger.info(f"Validated {len(validated_packages)} outcomes for {symbol}")
|
||||
|
||||
def _calculate_training_outcome(self, input_package: ModelInputPackage) -> Optional[TrainingOutcome]:
|
||||
"""Calculate training outcome based on future price movements"""
|
||||
try:
|
||||
# This would typically fetch recent price data to calculate outcomes
|
||||
# For now, we'll create a placeholder implementation
|
||||
|
||||
# Extract base price from input package
|
||||
base_price = self._extract_current_price(input_package.ohlcv_data)
|
||||
if not base_price:
|
||||
return None
|
||||
|
||||
# Simulate outcome calculation (in real implementation, fetch actual future prices)
|
||||
# This is where you would integrate with your data provider to get actual outcomes
|
||||
|
||||
# Check for rapid change
|
||||
is_rapid, velocity, volatility_spike = self.rapid_change_detector.detect_rapid_change(
|
||||
input_package.symbol
|
||||
)
|
||||
|
||||
# Create outcome (placeholder values - replace with actual calculation)
|
||||
outcome = TrainingOutcome(
|
||||
input_package_hash=input_package.data_hash,
|
||||
timestamp=input_package.timestamp,
|
||||
symbol=input_package.symbol,
|
||||
price_change_1m=0.0, # Calculate from actual future data
|
||||
price_change_5m=0.0,
|
||||
price_change_15m=0.0,
|
||||
price_change_1h=0.0,
|
||||
max_profit_potential=0.0,
|
||||
max_loss_potential=0.0,
|
||||
optimal_entry_price=base_price,
|
||||
optimal_exit_price=base_price,
|
||||
optimal_holding_time=timedelta(minutes=5),
|
||||
is_profitable=False, # Determine from actual outcomes
|
||||
profitability_score=0.0,
|
||||
risk_reward_ratio=1.0,
|
||||
is_rapid_change=is_rapid,
|
||||
change_velocity=velocity,
|
||||
volatility_spike=volatility_spike,
|
||||
outcome_validated=True
|
||||
)
|
||||
|
||||
return outcome
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating training outcome: {e}")
|
||||
return None
|
||||
|
||||
def _create_training_episode(self, input_package: ModelInputPackage, outcome: TrainingOutcome):
|
||||
"""Create complete training episode"""
|
||||
try:
|
||||
episode_id = f"{input_package.symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}"
|
||||
|
||||
# Determine episode type
|
||||
episode_type = 'normal'
|
||||
if outcome.is_rapid_change:
|
||||
episode_type = 'rapid_change'
|
||||
self.collection_stats['rapid_change_episodes'] += 1
|
||||
elif outcome.profitability_score > 0.8:
|
||||
episode_type = 'high_profit'
|
||||
|
||||
if outcome.is_profitable:
|
||||
self.collection_stats['profitable_episodes'] += 1
|
||||
|
||||
# Create training episode
|
||||
episode = TrainingEpisode(
|
||||
episode_id=episode_id,
|
||||
input_package=input_package,
|
||||
model_predictions={}, # Will be filled when models make predictions
|
||||
actual_outcome=outcome,
|
||||
episode_type=episode_type,
|
||||
profitability_rank=0.0, # Will be calculated later
|
||||
training_priority=0.0
|
||||
)
|
||||
|
||||
# Calculate training priority
|
||||
episode.training_priority = episode.calculate_training_priority()
|
||||
|
||||
# Store episode
|
||||
symbol = input_package.symbol
|
||||
if symbol not in self.training_episodes:
|
||||
self.training_episodes[symbol] = []
|
||||
|
||||
self.training_episodes[symbol].append(episode)
|
||||
|
||||
# Limit episodes per symbol
|
||||
if len(self.training_episodes[symbol]) > self.max_episodes_per_symbol:
|
||||
# Keep highest priority episodes
|
||||
self.training_episodes[symbol].sort(key=lambda x: x.training_priority, reverse=True)
|
||||
self.training_episodes[symbol] = self.training_episodes[symbol][:self.max_episodes_per_symbol]
|
||||
|
||||
# Save episode to disk
|
||||
self._save_episode_to_disk(episode)
|
||||
|
||||
logger.debug(f"Created training episode: {episode_id}")
|
||||
logger.debug(f"Episode type: {episode_type}, Priority: {episode.training_priority:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating training episode: {e}")
|
||||
|
||||
def _save_episode_to_disk(self, episode: TrainingEpisode):
|
||||
"""Save training episode to disk for persistence"""
|
||||
try:
|
||||
symbol_dir = self.storage_dir / episode.input_package.symbol
|
||||
symbol_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save episode data
|
||||
episode_file = symbol_dir / f"{episode.episode_id}.pkl"
|
||||
with open(episode_file, 'wb') as f:
|
||||
pickle.dump(episode, f)
|
||||
|
||||
# Save episode metadata for quick access
|
||||
metadata = {
|
||||
'episode_id': episode.episode_id,
|
||||
'timestamp': episode.input_package.timestamp.isoformat(),
|
||||
'episode_type': episode.episode_type,
|
||||
'training_priority': episode.training_priority,
|
||||
'profitability_score': episode.actual_outcome.profitability_score,
|
||||
'is_profitable': episode.actual_outcome.is_profitable,
|
||||
'is_rapid_change': episode.actual_outcome.is_rapid_change,
|
||||
'data_completeness': episode.input_package.completeness_score
|
||||
}
|
||||
|
||||
metadata_file = symbol_dir / f"{episode.episode_id}_metadata.json"
|
||||
with open(metadata_file, 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving episode to disk: {e}")
|
||||
|
||||
def get_high_priority_episodes(self,
|
||||
symbol: str,
|
||||
limit: int = 100,
|
||||
min_priority: float = 0.5) -> List[TrainingEpisode]:
|
||||
"""Get high-priority training episodes for replay training"""
|
||||
try:
|
||||
if symbol not in self.training_episodes:
|
||||
return []
|
||||
|
||||
# Filter and sort by priority
|
||||
high_priority = [
|
||||
ep for ep in self.training_episodes[symbol]
|
||||
if ep.training_priority >= min_priority
|
||||
]
|
||||
|
||||
high_priority.sort(key=lambda x: x.training_priority, reverse=True)
|
||||
|
||||
return high_priority[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting high priority episodes for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def get_collection_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive collection statistics"""
|
||||
stats = self.collection_stats.copy()
|
||||
|
||||
# Add per-symbol statistics
|
||||
stats['episodes_per_symbol'] = {
|
||||
symbol: len(episodes)
|
||||
for symbol, episodes in self.training_episodes.items()
|
||||
}
|
||||
|
||||
# Add pending outcomes count
|
||||
stats['pending_outcomes'] = {
|
||||
symbol: len(packages)
|
||||
for symbol, packages in self.pending_outcomes.items()
|
||||
}
|
||||
|
||||
# Calculate profitability rate
|
||||
if stats['total_episodes'] > 0:
|
||||
stats['profitability_rate'] = stats['profitable_episodes'] / stats['total_episodes']
|
||||
stats['rapid_change_rate'] = stats['rapid_change_episodes'] / stats['total_episodes']
|
||||
else:
|
||||
stats['profitability_rate'] = 0.0
|
||||
stats['rapid_change_rate'] = 0.0
|
||||
|
||||
return stats
|
||||
|
||||
def validate_data_integrity(self) -> Dict[str, Any]:
|
||||
"""Comprehensive data integrity validation"""
|
||||
validation_results = {
|
||||
'total_episodes_checked': 0,
|
||||
'hash_mismatches': 0,
|
||||
'completeness_issues': 0,
|
||||
'validation_flag_failures': 0,
|
||||
'corrupted_episodes': [],
|
||||
'integrity_score': 1.0
|
||||
}
|
||||
|
||||
try:
|
||||
for symbol, episodes in self.training_episodes.items():
|
||||
for episode in episodes:
|
||||
validation_results['total_episodes_checked'] += 1
|
||||
|
||||
# Check data hash
|
||||
expected_hash = episode.input_package._calculate_hash()
|
||||
if expected_hash != episode.input_package.data_hash:
|
||||
validation_results['hash_mismatches'] += 1
|
||||
validation_results['corrupted_episodes'].append(episode.episode_id)
|
||||
|
||||
# Check completeness
|
||||
if episode.input_package.completeness_score < 0.7:
|
||||
validation_results['completeness_issues'] += 1
|
||||
|
||||
# Check validation flags
|
||||
if not episode.input_package.validation_flags.get('data_consistent', False):
|
||||
validation_results['validation_flag_failures'] += 1
|
||||
|
||||
# Calculate integrity score
|
||||
total_issues = (
|
||||
validation_results['hash_mismatches'] +
|
||||
validation_results['completeness_issues'] +
|
||||
validation_results['validation_flag_failures']
|
||||
)
|
||||
|
||||
if validation_results['total_episodes_checked'] > 0:
|
||||
validation_results['integrity_score'] = 1.0 - (
|
||||
total_issues / validation_results['total_episodes_checked']
|
||||
)
|
||||
|
||||
logger.info(f"Data integrity validation completed")
|
||||
logger.info(f"Integrity score: {validation_results['integrity_score']:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during data integrity validation: {e}")
|
||||
validation_results['validation_error'] = str(e)
|
||||
|
||||
return validation_results
|
||||
|
||||
# Global instance for easy access
|
||||
training_data_collector = None
|
||||
|
||||
def get_training_data_collector() -> TrainingDataCollector:
|
||||
"""Get global training data collector instance"""
|
||||
global training_data_collector
|
||||
if training_data_collector is None:
|
||||
training_data_collector = TrainingDataCollector()
|
||||
return training_data_collector
|
||||
Reference in New Issue
Block a user