BIG CLEANUP

This commit is contained in:
Dobromir Popov
2025-08-08 14:58:55 +03:00
parent e39e9ee95a
commit 2b0d2679c6
162 changed files with 455 additions and 42814 deletions

View File

View File

@ -1,402 +0,0 @@
"""
API Rate Limiter and Error Handler
This module provides robust rate limiting and error handling for API requests,
specifically designed to handle Binance's aggressive rate limiting (HTTP 418 errors)
and other exchange API limitations.
Features:
- Exponential backoff for rate limiting
- IP rotation and proxy support
- Request queuing and throttling
- Error recovery strategies
- Thread-safe operations
"""
import asyncio
import logging
import time
import random
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass, field
from collections import deque
import threading
from concurrent.futures import ThreadPoolExecutor
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
logger = logging.getLogger(__name__)
@dataclass
class RateLimitConfig:
"""Configuration for rate limiting"""
requests_per_second: float = 0.5 # Very conservative for Binance
requests_per_minute: int = 20
requests_per_hour: int = 1000
# Backoff configuration
initial_backoff: float = 1.0
max_backoff: float = 300.0 # 5 minutes max
backoff_multiplier: float = 2.0
# Error handling
max_retries: int = 3
retry_delay: float = 5.0
# IP blocking detection
block_detection_threshold: int = 3 # 3 consecutive 418s = blocked
block_recovery_time: int = 3600 # 1 hour recovery time
@dataclass
class APIEndpoint:
"""API endpoint configuration"""
name: str
base_url: str
rate_limit: RateLimitConfig
last_request_time: float = 0.0
request_count_minute: int = 0
request_count_hour: int = 0
consecutive_errors: int = 0
blocked_until: Optional[datetime] = None
# Request history for rate limiting
request_history: deque = field(default_factory=lambda: deque(maxlen=3600)) # 1 hour history
class APIRateLimiter:
"""Thread-safe API rate limiter with error handling"""
def __init__(self, config: RateLimitConfig = None):
self.config = config or RateLimitConfig()
# Thread safety
self.lock = threading.RLock()
# Endpoint tracking
self.endpoints: Dict[str, APIEndpoint] = {}
# Global rate limiting
self.global_request_history = deque(maxlen=3600)
self.global_blocked_until: Optional[datetime] = None
# Request session with retry strategy
self.session = self._create_session()
# Background cleanup thread
self.cleanup_thread = None
self.is_running = False
logger.info("API Rate Limiter initialized")
logger.info(f"Rate limits: {self.config.requests_per_second}/s, {self.config.requests_per_minute}/m")
def _create_session(self) -> requests.Session:
"""Create requests session with retry strategy"""
session = requests.Session()
# Retry strategy
retry_strategy = Retry(
total=self.config.max_retries,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["HEAD", "GET", "OPTIONS"]
)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("http://", adapter)
session.mount("https://", adapter)
# Headers to appear more legitimate
session.headers.update({
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Accept': 'application/json',
'Accept-Language': 'en-US,en;q=0.9',
'Accept-Encoding': 'gzip, deflate, br',
'Connection': 'keep-alive',
'Upgrade-Insecure-Requests': '1',
})
return session
def register_endpoint(self, name: str, base_url: str, rate_limit: RateLimitConfig = None):
"""Register an API endpoint for rate limiting"""
with self.lock:
self.endpoints[name] = APIEndpoint(
name=name,
base_url=base_url,
rate_limit=rate_limit or self.config
)
logger.info(f"Registered endpoint: {name} -> {base_url}")
def start_background_cleanup(self):
"""Start background cleanup thread"""
if self.is_running:
return
self.is_running = True
self.cleanup_thread = threading.Thread(target=self._cleanup_worker, daemon=True)
self.cleanup_thread.start()
logger.info("Started background cleanup thread")
def stop_background_cleanup(self):
"""Stop background cleanup thread"""
self.is_running = False
if self.cleanup_thread:
self.cleanup_thread.join(timeout=5)
logger.info("Stopped background cleanup thread")
def _cleanup_worker(self):
"""Background worker to clean up old request history"""
while self.is_running:
try:
current_time = time.time()
cutoff_time = current_time - 3600 # 1 hour ago
with self.lock:
# Clean global history
while (self.global_request_history and
self.global_request_history[0] < cutoff_time):
self.global_request_history.popleft()
# Clean endpoint histories
for endpoint in self.endpoints.values():
while (endpoint.request_history and
endpoint.request_history[0] < cutoff_time):
endpoint.request_history.popleft()
# Reset counters
endpoint.request_count_minute = len([
t for t in endpoint.request_history
if t > current_time - 60
])
endpoint.request_count_hour = len(endpoint.request_history)
time.sleep(60) # Clean every minute
except Exception as e:
logger.error(f"Error in cleanup worker: {e}")
time.sleep(30)
def can_make_request(self, endpoint_name: str) -> tuple[bool, float]:
"""
Check if we can make a request to the endpoint
Returns:
(can_make_request, wait_time_seconds)
"""
with self.lock:
current_time = time.time()
# Check global blocking
if self.global_blocked_until and datetime.now() < self.global_blocked_until:
wait_time = (self.global_blocked_until - datetime.now()).total_seconds()
return False, wait_time
# Get endpoint
endpoint = self.endpoints.get(endpoint_name)
if not endpoint:
logger.warning(f"Unknown endpoint: {endpoint_name}")
return False, 60.0
# Check endpoint blocking
if endpoint.blocked_until and datetime.now() < endpoint.blocked_until:
wait_time = (endpoint.blocked_until - datetime.now()).total_seconds()
return False, wait_time
# Check rate limits
config = endpoint.rate_limit
# Per-second rate limit
time_since_last = current_time - endpoint.last_request_time
if time_since_last < (1.0 / config.requests_per_second):
wait_time = (1.0 / config.requests_per_second) - time_since_last
return False, wait_time
# Per-minute rate limit
minute_requests = len([
t for t in endpoint.request_history
if t > current_time - 60
])
if minute_requests >= config.requests_per_minute:
return False, 60.0
# Per-hour rate limit
if len(endpoint.request_history) >= config.requests_per_hour:
return False, 3600.0
return True, 0.0
def make_request(self, endpoint_name: str, url: str, method: str = 'GET',
**kwargs) -> Optional[requests.Response]:
"""
Make a rate-limited request with error handling
Args:
endpoint_name: Name of the registered endpoint
url: Full URL to request
method: HTTP method
**kwargs: Additional arguments for requests
Returns:
Response object or None if failed
"""
with self.lock:
endpoint = self.endpoints.get(endpoint_name)
if not endpoint:
logger.error(f"Unknown endpoint: {endpoint_name}")
return None
# Check if we can make the request
can_request, wait_time = self.can_make_request(endpoint_name)
if not can_request:
logger.debug(f"Rate limited for {endpoint_name}, waiting {wait_time:.2f}s")
time.sleep(min(wait_time, 30)) # Cap wait time
return None
# Record request attempt
current_time = time.time()
endpoint.last_request_time = current_time
endpoint.request_history.append(current_time)
self.global_request_history.append(current_time)
# Add jitter to avoid thundering herd
jitter = random.uniform(0.1, 0.5)
time.sleep(jitter)
# Make the request (outside of lock to avoid blocking other threads)
try:
# Set timeout
kwargs.setdefault('timeout', 10)
# Make request
response = self.session.request(method, url, **kwargs)
# Handle response
with self.lock:
if response.status_code == 200:
# Success - reset error counter
endpoint.consecutive_errors = 0
return response
elif response.status_code == 418:
# Binance "I'm a teapot" - rate limited/blocked
endpoint.consecutive_errors += 1
logger.warning(f"HTTP 418 (rate limited) for {endpoint_name}, consecutive errors: {endpoint.consecutive_errors}")
if endpoint.consecutive_errors >= endpoint.rate_limit.block_detection_threshold:
# We're likely IP blocked
block_time = datetime.now() + timedelta(seconds=endpoint.rate_limit.block_recovery_time)
endpoint.blocked_until = block_time
logger.error(f"Endpoint {endpoint_name} blocked until {block_time}")
return None
elif response.status_code == 429:
# Too many requests
endpoint.consecutive_errors += 1
logger.warning(f"HTTP 429 (too many requests) for {endpoint_name}")
# Implement exponential backoff
backoff_time = min(
endpoint.rate_limit.initial_backoff * (endpoint.rate_limit.backoff_multiplier ** endpoint.consecutive_errors),
endpoint.rate_limit.max_backoff
)
block_time = datetime.now() + timedelta(seconds=backoff_time)
endpoint.blocked_until = block_time
logger.warning(f"Backing off {endpoint_name} for {backoff_time:.2f}s")
return None
else:
# Other error
endpoint.consecutive_errors += 1
logger.warning(f"HTTP {response.status_code} for {endpoint_name}: {response.text[:200]}")
return None
except requests.exceptions.RequestException as e:
with self.lock:
endpoint.consecutive_errors += 1
logger.error(f"Request exception for {endpoint_name}: {e}")
return None
except Exception as e:
with self.lock:
endpoint.consecutive_errors += 1
logger.error(f"Unexpected error for {endpoint_name}: {e}")
return None
def get_endpoint_status(self, endpoint_name: str) -> Dict[str, Any]:
"""Get status information for an endpoint"""
with self.lock:
endpoint = self.endpoints.get(endpoint_name)
if not endpoint:
return {'error': 'Unknown endpoint'}
current_time = time.time()
return {
'name': endpoint.name,
'base_url': endpoint.base_url,
'consecutive_errors': endpoint.consecutive_errors,
'blocked_until': endpoint.blocked_until.isoformat() if endpoint.blocked_until else None,
'requests_last_minute': len([t for t in endpoint.request_history if t > current_time - 60]),
'requests_last_hour': len(endpoint.request_history),
'last_request_time': endpoint.last_request_time,
'can_make_request': self.can_make_request(endpoint_name)[0]
}
def get_all_endpoint_status(self) -> Dict[str, Dict[str, Any]]:
"""Get status for all endpoints"""
return {name: self.get_endpoint_status(name) for name in self.endpoints.keys()}
def reset_endpoint(self, endpoint_name: str):
"""Reset an endpoint's error state"""
with self.lock:
endpoint = self.endpoints.get(endpoint_name)
if endpoint:
endpoint.consecutive_errors = 0
endpoint.blocked_until = None
logger.info(f"Reset endpoint: {endpoint_name}")
def reset_all_endpoints(self):
"""Reset all endpoints' error states"""
with self.lock:
for endpoint in self.endpoints.values():
endpoint.consecutive_errors = 0
endpoint.blocked_until = None
self.global_blocked_until = None
logger.info("Reset all endpoints")
# Global rate limiter instance
_global_rate_limiter = None
def get_rate_limiter() -> APIRateLimiter:
"""Get global rate limiter instance"""
global _global_rate_limiter
if _global_rate_limiter is None:
_global_rate_limiter = APIRateLimiter()
_global_rate_limiter.start_background_cleanup()
# Register common endpoints
_global_rate_limiter.register_endpoint(
'binance_api',
'https://api.binance.com',
RateLimitConfig(
requests_per_second=0.2, # Very conservative
requests_per_minute=10,
requests_per_hour=500
)
)
_global_rate_limiter.register_endpoint(
'mexc_api',
'https://api.mexc.com',
RateLimitConfig(
requests_per_second=0.5,
requests_per_minute=20,
requests_per_hour=1000
)
)
return _global_rate_limiter

View File

@ -1,442 +0,0 @@
"""
Async Handler for UI Stability Fix
Properly handles all async operations in the dashboard with single event loop management,
proper exception handling, and timeout support to prevent async/await errors.
"""
import asyncio
import logging
import threading
import time
from typing import Any, Callable, Coroutine, Dict, Optional, Union
from concurrent.futures import ThreadPoolExecutor
import functools
import weakref
logger = logging.getLogger(__name__)
class AsyncOperationError(Exception):
"""Exception raised for async operation errors"""
pass
class AsyncHandler:
"""
Centralized async operation handler with single event loop management
and proper exception handling for async operations.
"""
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None):
"""
Initialize the async handler
Args:
loop: Optional event loop to use. If None, creates a new one.
"""
self._loop = loop
self._thread = None
self._executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="AsyncHandler")
self._running = False
self._callbacks = weakref.WeakSet()
self._timeout_default = 30.0 # Default timeout for operations
# Start the event loop in a separate thread if not provided
if self._loop is None:
self._start_event_loop_thread()
logger.info("AsyncHandler initialized with event loop management")
def _start_event_loop_thread(self):
"""Start the event loop in a separate thread"""
def run_event_loop():
"""Run the event loop in a separate thread"""
try:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._running = True
logger.debug("Event loop started in separate thread")
self._loop.run_forever()
except Exception as e:
logger.error(f"Error in event loop thread: {e}")
finally:
self._running = False
logger.debug("Event loop thread stopped")
self._thread = threading.Thread(target=run_event_loop, daemon=True, name="AsyncHandler-EventLoop")
self._thread.start()
# Wait for the loop to be ready
timeout = 5.0
start_time = time.time()
while not self._running and (time.time() - start_time) < timeout:
time.sleep(0.1)
if not self._running:
raise AsyncOperationError("Failed to start event loop within timeout")
def is_running(self) -> bool:
"""Check if the async handler is running"""
return self._running and self._loop is not None and not self._loop.is_closed()
def run_async_safely(self, coro: Coroutine, timeout: Optional[float] = None) -> Any:
"""
Run an async coroutine safely with proper error handling and timeout
Args:
coro: The coroutine to run
timeout: Timeout in seconds (uses default if None)
Returns:
The result of the coroutine
Raises:
AsyncOperationError: If the operation fails or times out
"""
if not self.is_running():
raise AsyncOperationError("AsyncHandler is not running")
timeout = timeout or self._timeout_default
try:
# Schedule the coroutine on the event loop
future = asyncio.run_coroutine_threadsafe(
asyncio.wait_for(coro, timeout=timeout),
self._loop
)
# Wait for the result with timeout
result = future.result(timeout=timeout + 1.0) # Add buffer to future timeout
logger.debug("Async operation completed successfully")
return result
except asyncio.TimeoutError:
logger.error(f"Async operation timed out after {timeout} seconds")
raise AsyncOperationError(f"Operation timed out after {timeout} seconds")
except Exception as e:
logger.error(f"Async operation failed: {e}")
raise AsyncOperationError(f"Async operation failed: {e}")
def schedule_coroutine(self, coro: Coroutine, callback: Optional[Callable] = None) -> None:
"""
Schedule a coroutine to run asynchronously without waiting for result
Args:
coro: The coroutine to schedule
callback: Optional callback to call with the result
"""
if not self.is_running():
logger.warning("Cannot schedule coroutine: AsyncHandler is not running")
return
async def wrapped_coro():
"""Wrapper to handle exceptions and callbacks"""
try:
result = await coro
if callback:
try:
callback(result)
except Exception as e:
logger.error(f"Error in coroutine callback: {e}")
return result
except Exception as e:
logger.error(f"Error in scheduled coroutine: {e}")
if callback:
try:
callback(None) # Call callback with None on error
except Exception as cb_e:
logger.error(f"Error in error callback: {cb_e}")
try:
asyncio.run_coroutine_threadsafe(wrapped_coro(), self._loop)
logger.debug("Coroutine scheduled successfully")
except Exception as e:
logger.error(f"Failed to schedule coroutine: {e}")
def create_task_safely(self, coro: Coroutine, name: Optional[str] = None) -> Optional[asyncio.Task]:
"""
Create an asyncio task safely with proper error handling
Args:
coro: The coroutine to create a task for
name: Optional name for the task
Returns:
The created task or None if failed
"""
if not self.is_running():
logger.warning("Cannot create task: AsyncHandler is not running")
return None
async def create_task():
"""Create the task in the event loop"""
try:
task = asyncio.create_task(coro, name=name)
logger.debug(f"Task created: {name or 'unnamed'}")
return task
except Exception as e:
logger.error(f"Failed to create task {name}: {e}")
return None
try:
future = asyncio.run_coroutine_threadsafe(create_task(), self._loop)
return future.result(timeout=5.0)
except Exception as e:
logger.error(f"Failed to create task {name}: {e}")
return None
async def handle_orchestrator_connection(self, orchestrator) -> bool:
"""
Handle orchestrator connection with proper async patterns
Args:
orchestrator: The orchestrator instance to connect to
Returns:
True if connection successful, False otherwise
"""
try:
logger.info("Connecting to orchestrator...")
# Add decision callback if orchestrator supports it
if hasattr(orchestrator, 'add_decision_callback'):
await orchestrator.add_decision_callback(self._handle_trading_decision)
logger.info("Decision callback added to orchestrator")
# Start COB integration if available
if hasattr(orchestrator, 'start_cob_integration'):
await orchestrator.start_cob_integration()
logger.info("COB integration started")
# Start continuous trading if available
if hasattr(orchestrator, 'start_continuous_trading'):
await orchestrator.start_continuous_trading()
logger.info("Continuous trading started")
logger.info("Successfully connected to orchestrator")
return True
except Exception as e:
logger.error(f"Failed to connect to orchestrator: {e}")
return False
async def handle_cob_integration(self, cob_integration) -> bool:
"""
Handle COB integration startup with proper async patterns
Args:
cob_integration: The COB integration instance
Returns:
True if startup successful, False otherwise
"""
try:
logger.info("Starting COB integration...")
if hasattr(cob_integration, 'start'):
await cob_integration.start()
logger.info("COB integration started successfully")
return True
else:
logger.warning("COB integration does not have start method")
return False
except Exception as e:
logger.error(f"Failed to start COB integration: {e}")
return False
async def _handle_trading_decision(self, decision: Dict[str, Any]) -> None:
"""
Handle trading decision with proper async patterns
Args:
decision: The trading decision dictionary
"""
try:
logger.debug(f"Handling trading decision: {decision.get('action', 'UNKNOWN')}")
# Process the decision (this would be customized based on needs)
# For now, just log it
symbol = decision.get('symbol', 'UNKNOWN')
action = decision.get('action', 'HOLD')
confidence = decision.get('confidence', 0.0)
logger.info(f"Trading decision processed: {action} {symbol} (confidence: {confidence:.2f})")
except Exception as e:
logger.error(f"Error handling trading decision: {e}")
def run_in_executor(self, func: Callable, *args, **kwargs) -> Any:
"""
Run a blocking function in the thread pool executor
Args:
func: The function to run
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
The result of the function
"""
if not self.is_running():
raise AsyncOperationError("AsyncHandler is not running")
try:
# Create a partial function with the arguments
partial_func = functools.partial(func, *args, **kwargs)
# Create a coroutine that runs the function in executor
async def run_in_executor_coro():
return await self._loop.run_in_executor(self._executor, partial_func)
# Run the coroutine
future = asyncio.run_coroutine_threadsafe(run_in_executor_coro(), self._loop)
result = future.result(timeout=self._timeout_default)
logger.debug("Executor function completed successfully")
return result
except Exception as e:
logger.error(f"Error running function in executor: {e}")
raise AsyncOperationError(f"Executor function failed: {e}")
def add_periodic_task(self, coro_func: Callable[[], Coroutine], interval: float, name: Optional[str] = None) -> Optional[asyncio.Task]:
"""
Add a periodic task that runs at specified intervals
Args:
coro_func: Function that returns a coroutine to run periodically
interval: Interval in seconds between runs
name: Optional name for the task
Returns:
The created task or None if failed
"""
async def periodic_runner():
"""Run the coroutine periodically"""
task_name = name or "periodic_task"
logger.info(f"Starting periodic task: {task_name} (interval: {interval}s)")
try:
while True:
try:
coro = coro_func()
await coro
logger.debug(f"Periodic task {task_name} completed")
except Exception as e:
logger.error(f"Error in periodic task {task_name}: {e}")
await asyncio.sleep(interval)
except asyncio.CancelledError:
logger.info(f"Periodic task {task_name} cancelled")
raise
except Exception as e:
logger.error(f"Fatal error in periodic task {task_name}: {e}")
return self.create_task_safely(periodic_runner(), name=f"periodic_{name}")
def stop(self) -> None:
"""Stop the async handler and clean up resources"""
try:
logger.info("Stopping AsyncHandler...")
if self._loop and not self._loop.is_closed():
# Cancel all tasks
if self._loop.is_running():
asyncio.run_coroutine_threadsafe(self._cancel_all_tasks(), self._loop)
# Stop the event loop
self._loop.call_soon_threadsafe(self._loop.stop)
# Shutdown executor
if self._executor:
self._executor.shutdown(wait=True)
# Wait for thread to finish
if self._thread and self._thread.is_alive():
self._thread.join(timeout=5.0)
self._running = False
logger.info("AsyncHandler stopped successfully")
except Exception as e:
logger.error(f"Error stopping AsyncHandler: {e}")
async def _cancel_all_tasks(self) -> None:
"""Cancel all running tasks"""
try:
tasks = [task for task in asyncio.all_tasks(self._loop) if not task.done()]
if tasks:
logger.info(f"Cancelling {len(tasks)} running tasks")
for task in tasks:
task.cancel()
# Wait for tasks to be cancelled
await asyncio.gather(*tasks, return_exceptions=True)
logger.debug("All tasks cancelled")
except Exception as e:
logger.error(f"Error cancelling tasks: {e}")
def __enter__(self):
"""Context manager entry"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit"""
self.stop()
class AsyncContextManager:
"""
Context manager for async operations that ensures proper cleanup
"""
def __init__(self, async_handler: AsyncHandler):
self.async_handler = async_handler
self.active_tasks = []
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Cancel any active tasks
for task in self.active_tasks:
if not task.done():
task.cancel()
def create_task(self, coro: Coroutine, name: Optional[str] = None) -> Optional[asyncio.Task]:
"""Create a task and track it for cleanup"""
task = self.async_handler.create_task_safely(coro, name)
if task:
self.active_tasks.append(task)
return task
def create_async_handler(loop: Optional[asyncio.AbstractEventLoop] = None) -> AsyncHandler:
"""
Factory function to create an AsyncHandler instance
Args:
loop: Optional event loop to use
Returns:
AsyncHandler instance
"""
return AsyncHandler(loop=loop)
def run_async_safely(coro: Coroutine, timeout: Optional[float] = None) -> Any:
"""
Convenience function to run a coroutine safely with a temporary AsyncHandler
Args:
coro: The coroutine to run
timeout: Timeout in seconds
Returns:
The result of the coroutine
"""
with AsyncHandler() as handler:
return handler.run_async_safely(coro, timeout=timeout)

View File

@ -1,952 +0,0 @@
"""
Bookmap Order Book Data Provider
This module integrates with Bookmap to gather:
- Current Order Book (COB) data
- Session Volume Profile (SVP) data
- Order book sweeps and momentum trades detection
- Real-time order size heatmap matrix (last 10 minutes)
- Level 2 market depth analysis
The data is processed and fed to CNN and DQN networks for enhanced trading decisions.
"""
import asyncio
import json
import logging
import time
import websockets
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any, Callable
from collections import deque, defaultdict
from dataclasses import dataclass
from threading import Thread, Lock
import requests
logger = logging.getLogger(__name__)
@dataclass
class OrderBookLevel:
"""Represents a single order book level"""
price: float
size: float
orders: int
side: str # 'bid' or 'ask'
timestamp: datetime
@dataclass
class OrderBookSnapshot:
"""Complete order book snapshot"""
symbol: str
timestamp: datetime
bids: List[OrderBookLevel]
asks: List[OrderBookLevel]
spread: float
mid_price: float
@dataclass
class VolumeProfileLevel:
"""Volume profile level data"""
price: float
volume: float
buy_volume: float
sell_volume: float
trades_count: int
vwap: float
@dataclass
class OrderFlowSignal:
"""Order flow signal detection"""
timestamp: datetime
signal_type: str # 'sweep', 'absorption', 'iceberg', 'momentum'
price: float
volume: float
confidence: float
description: str
class BookmapDataProvider:
"""
Real-time order book data provider using Bookmap-style analysis
Features:
- Level 2 order book monitoring
- Order flow detection (sweeps, absorptions)
- Volume profile analysis
- Order size heatmap generation
- Market microstructure analysis
"""
def __init__(self, symbols: List[str] = None, depth_levels: int = 20):
"""
Initialize Bookmap data provider
Args:
symbols: List of symbols to monitor
depth_levels: Number of order book levels to track
"""
self.symbols = symbols or ['ETHUSDT', 'BTCUSDT']
self.depth_levels = depth_levels
self.is_streaming = False
# Order book data storage
self.order_books: Dict[str, OrderBookSnapshot] = {}
self.order_book_history: Dict[str, deque] = {}
self.volume_profiles: Dict[str, List[VolumeProfileLevel]] = {}
# Heatmap data (10-minute rolling window)
self.heatmap_window = timedelta(minutes=10)
self.order_heatmaps: Dict[str, deque] = {}
self.price_levels: Dict[str, List[float]] = {}
# Order flow detection
self.flow_signals: Dict[str, deque] = {}
self.sweep_threshold = 0.8 # Minimum confidence for sweep detection
self.absorption_threshold = 0.7 # Minimum confidence for absorption
# Market microstructure metrics
self.bid_ask_spreads: Dict[str, deque] = {}
self.order_book_imbalances: Dict[str, deque] = {}
self.liquidity_metrics: Dict[str, Dict] = {}
# WebSocket connections
self.websocket_tasks: Dict[str, asyncio.Task] = {}
self.data_lock = Lock()
# Callbacks for CNN/DQN integration
self.cnn_callbacks: List[Callable] = []
self.dqn_callbacks: List[Callable] = []
# Performance tracking
self.update_counts = defaultdict(int)
self.last_update_times = {}
# Initialize data structures
for symbol in self.symbols:
self.order_book_history[symbol] = deque(maxlen=1000)
self.order_heatmaps[symbol] = deque(maxlen=600) # 10 min at 1s intervals
self.flow_signals[symbol] = deque(maxlen=500)
self.bid_ask_spreads[symbol] = deque(maxlen=1000)
self.order_book_imbalances[symbol] = deque(maxlen=1000)
self.liquidity_metrics[symbol] = {
'total_bid_size': 0.0,
'total_ask_size': 0.0,
'weighted_mid': 0.0,
'liquidity_ratio': 1.0
}
logger.info(f"BookmapDataProvider initialized for {len(self.symbols)} symbols")
logger.info(f"Tracking {depth_levels} order book levels per side")
def add_cnn_callback(self, callback: Callable[[str, Dict], None]):
"""Add callback for CNN model updates"""
self.cnn_callbacks.append(callback)
logger.info(f"Added CNN callback: {len(self.cnn_callbacks)} total")
def add_dqn_callback(self, callback: Callable[[str, Dict], None]):
"""Add callback for DQN model updates"""
self.dqn_callbacks.append(callback)
logger.info(f"Added DQN callback: {len(self.dqn_callbacks)} total")
async def start_streaming(self):
"""Start real-time order book streaming"""
if self.is_streaming:
logger.warning("Bookmap streaming already active")
return
self.is_streaming = True
logger.info("Starting Bookmap order book streaming")
# Start order book streams for each symbol
for symbol in self.symbols:
# Order book depth stream
depth_task = asyncio.create_task(self._stream_order_book_depth(symbol))
self.websocket_tasks[f"{symbol}_depth"] = depth_task
# Trade stream for order flow analysis
trade_task = asyncio.create_task(self._stream_trades(symbol))
self.websocket_tasks[f"{symbol}_trades"] = trade_task
# Start analysis threads
analysis_task = asyncio.create_task(self._continuous_analysis())
self.websocket_tasks["analysis"] = analysis_task
logger.info(f"Started streaming for {len(self.symbols)} symbols")
async def stop_streaming(self):
"""Stop order book streaming"""
if not self.is_streaming:
return
logger.info("Stopping Bookmap streaming")
self.is_streaming = False
# Cancel all tasks
for name, task in self.websocket_tasks.items():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self.websocket_tasks.clear()
logger.info("Bookmap streaming stopped")
async def _stream_order_book_depth(self, symbol: str):
"""Stream order book depth data"""
binance_symbol = symbol.lower()
url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@depth20@100ms"
while self.is_streaming:
try:
async with websockets.connect(url) as websocket:
logger.info(f"Order book depth WebSocket connected for {symbol}")
async for message in websocket:
if not self.is_streaming:
break
try:
data = json.loads(message)
await self._process_depth_update(symbol, data)
except Exception as e:
logger.warning(f"Error processing depth for {symbol}: {e}")
except Exception as e:
logger.error(f"Depth WebSocket error for {symbol}: {e}")
if self.is_streaming:
await asyncio.sleep(2)
async def _stream_trades(self, symbol: str):
"""Stream trade data for order flow analysis"""
binance_symbol = symbol.lower()
url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@trade"
while self.is_streaming:
try:
async with websockets.connect(url) as websocket:
logger.info(f"Trade WebSocket connected for {symbol}")
async for message in websocket:
if not self.is_streaming:
break
try:
data = json.loads(message)
await self._process_trade_update(symbol, data)
except Exception as e:
logger.warning(f"Error processing trade for {symbol}: {e}")
except Exception as e:
logger.error(f"Trade WebSocket error for {symbol}: {e}")
if self.is_streaming:
await asyncio.sleep(2)
async def _process_depth_update(self, symbol: str, data: Dict):
"""Process order book depth update"""
try:
timestamp = datetime.now()
# Parse bids and asks
bids = []
asks = []
for bid_data in data.get('bids', []):
price = float(bid_data[0])
size = float(bid_data[1])
bids.append(OrderBookLevel(
price=price,
size=size,
orders=1, # Binance doesn't provide order count
side='bid',
timestamp=timestamp
))
for ask_data in data.get('asks', []):
price = float(ask_data[0])
size = float(ask_data[1])
asks.append(OrderBookLevel(
price=price,
size=size,
orders=1,
side='ask',
timestamp=timestamp
))
# Sort order book levels
bids.sort(key=lambda x: x.price, reverse=True)
asks.sort(key=lambda x: x.price)
# Calculate spread and mid price
if bids and asks:
best_bid = bids[0].price
best_ask = asks[0].price
spread = best_ask - best_bid
mid_price = (best_bid + best_ask) / 2
else:
spread = 0.0
mid_price = 0.0
# Create order book snapshot
snapshot = OrderBookSnapshot(
symbol=symbol,
timestamp=timestamp,
bids=bids,
asks=asks,
spread=spread,
mid_price=mid_price
)
with self.data_lock:
self.order_books[symbol] = snapshot
self.order_book_history[symbol].append(snapshot)
# Update liquidity metrics
self._update_liquidity_metrics(symbol, snapshot)
# Update order book imbalance
self._calculate_order_book_imbalance(symbol, snapshot)
# Update heatmap data
self._update_order_heatmap(symbol, snapshot)
# Update counters
self.update_counts[f"{symbol}_depth"] += 1
self.last_update_times[f"{symbol}_depth"] = timestamp
except Exception as e:
logger.error(f"Error processing depth update for {symbol}: {e}")
async def _process_trade_update(self, symbol: str, data: Dict):
"""Process trade data for order flow analysis"""
try:
timestamp = datetime.fromtimestamp(int(data['T']) / 1000)
price = float(data['p'])
quantity = float(data['q'])
is_buyer_maker = data['m']
# Analyze for order flow signals
await self._analyze_order_flow(symbol, timestamp, price, quantity, is_buyer_maker)
# Update volume profile
self._update_volume_profile(symbol, price, quantity, is_buyer_maker)
self.update_counts[f"{symbol}_trades"] += 1
except Exception as e:
logger.error(f"Error processing trade for {symbol}: {e}")
def _update_liquidity_metrics(self, symbol: str, snapshot: OrderBookSnapshot):
"""Update liquidity metrics from order book snapshot"""
try:
total_bid_size = sum(level.size for level in snapshot.bids)
total_ask_size = sum(level.size for level in snapshot.asks)
# Calculate weighted mid price
if snapshot.bids and snapshot.asks:
bid_weight = total_bid_size / (total_bid_size + total_ask_size)
ask_weight = total_ask_size / (total_bid_size + total_ask_size)
weighted_mid = (snapshot.bids[0].price * ask_weight +
snapshot.asks[0].price * bid_weight)
else:
weighted_mid = snapshot.mid_price
# Liquidity ratio (bid/ask balance)
if total_ask_size > 0:
liquidity_ratio = total_bid_size / total_ask_size
else:
liquidity_ratio = 1.0
self.liquidity_metrics[symbol] = {
'total_bid_size': total_bid_size,
'total_ask_size': total_ask_size,
'weighted_mid': weighted_mid,
'liquidity_ratio': liquidity_ratio,
'spread_bps': (snapshot.spread / snapshot.mid_price) * 10000 if snapshot.mid_price > 0 else 0
}
except Exception as e:
logger.error(f"Error updating liquidity metrics for {symbol}: {e}")
def _calculate_order_book_imbalance(self, symbol: str, snapshot: OrderBookSnapshot):
"""Calculate order book imbalance ratio"""
try:
if not snapshot.bids or not snapshot.asks:
return
# Calculate imbalance for top N levels
n_levels = min(5, len(snapshot.bids), len(snapshot.asks))
total_bid_size = sum(snapshot.bids[i].size for i in range(n_levels))
total_ask_size = sum(snapshot.asks[i].size for i in range(n_levels))
if total_bid_size + total_ask_size > 0:
imbalance = (total_bid_size - total_ask_size) / (total_bid_size + total_ask_size)
else:
imbalance = 0.0
self.order_book_imbalances[symbol].append({
'timestamp': snapshot.timestamp,
'imbalance': imbalance,
'bid_size': total_bid_size,
'ask_size': total_ask_size
})
except Exception as e:
logger.error(f"Error calculating imbalance for {symbol}: {e}")
def _update_order_heatmap(self, symbol: str, snapshot: OrderBookSnapshot):
"""Update order size heatmap matrix"""
try:
# Create heatmap entry
heatmap_entry = {
'timestamp': snapshot.timestamp,
'mid_price': snapshot.mid_price,
'levels': {}
}
# Add bid levels
for level in snapshot.bids:
price_offset = level.price - snapshot.mid_price
heatmap_entry['levels'][price_offset] = {
'side': 'bid',
'size': level.size,
'price': level.price
}
# Add ask levels
for level in snapshot.asks:
price_offset = level.price - snapshot.mid_price
heatmap_entry['levels'][price_offset] = {
'side': 'ask',
'size': level.size,
'price': level.price
}
self.order_heatmaps[symbol].append(heatmap_entry)
# Clean old entries (keep 10 minutes)
cutoff_time = snapshot.timestamp - self.heatmap_window
while (self.order_heatmaps[symbol] and
self.order_heatmaps[symbol][0]['timestamp'] < cutoff_time):
self.order_heatmaps[symbol].popleft()
except Exception as e:
logger.error(f"Error updating heatmap for {symbol}: {e}")
def _update_volume_profile(self, symbol: str, price: float, quantity: float, is_buyer_maker: bool):
"""Update volume profile with new trade"""
try:
# Initialize if not exists
if symbol not in self.volume_profiles:
self.volume_profiles[symbol] = []
# Find or create price level
price_level = None
for level in self.volume_profiles[symbol]:
if abs(level.price - price) < 0.01: # Price tolerance
price_level = level
break
if not price_level:
price_level = VolumeProfileLevel(
price=price,
volume=0.0,
buy_volume=0.0,
sell_volume=0.0,
trades_count=0,
vwap=price
)
self.volume_profiles[symbol].append(price_level)
# Update volume profile
volume = price * quantity
old_total = price_level.volume
price_level.volume += volume
price_level.trades_count += 1
if is_buyer_maker:
price_level.sell_volume += volume
else:
price_level.buy_volume += volume
# Update VWAP
if price_level.volume > 0:
price_level.vwap = ((price_level.vwap * old_total) + (price * volume)) / price_level.volume
except Exception as e:
logger.error(f"Error updating volume profile for {symbol}: {e}")
async def _analyze_order_flow(self, symbol: str, timestamp: datetime, price: float,
quantity: float, is_buyer_maker: bool):
"""Analyze order flow for sweep and absorption patterns"""
try:
# Get recent order book data
if symbol not in self.order_book_history or not self.order_book_history[symbol]:
return
recent_snapshots = list(self.order_book_history[symbol])[-10:] # Last 10 snapshots
# Check for order book sweeps
sweep_signal = self._detect_order_sweep(symbol, recent_snapshots, price, quantity, is_buyer_maker)
if sweep_signal:
self.flow_signals[symbol].append(sweep_signal)
await self._notify_flow_signal(symbol, sweep_signal)
# Check for absorption patterns
absorption_signal = self._detect_absorption(symbol, recent_snapshots, price, quantity)
if absorption_signal:
self.flow_signals[symbol].append(absorption_signal)
await self._notify_flow_signal(symbol, absorption_signal)
# Check for momentum trades
momentum_signal = self._detect_momentum_trade(symbol, price, quantity, is_buyer_maker)
if momentum_signal:
self.flow_signals[symbol].append(momentum_signal)
await self._notify_flow_signal(symbol, momentum_signal)
except Exception as e:
logger.error(f"Error analyzing order flow for {symbol}: {e}")
def _detect_order_sweep(self, symbol: str, snapshots: List[OrderBookSnapshot],
price: float, quantity: float, is_buyer_maker: bool) -> Optional[OrderFlowSignal]:
"""Detect order book sweep patterns"""
try:
if len(snapshots) < 2:
return None
before_snapshot = snapshots[-2]
after_snapshot = snapshots[-1]
# Check if multiple levels were consumed
if is_buyer_maker: # Sell order, check ask side
levels_consumed = 0
total_consumed_size = 0
for level in before_snapshot.asks[:5]: # Check top 5 levels
if level.price <= price:
levels_consumed += 1
total_consumed_size += level.size
if levels_consumed >= 2 and total_consumed_size > quantity * 1.5:
confidence = min(0.9, levels_consumed / 5.0 + 0.3)
return OrderFlowSignal(
timestamp=datetime.now(),
signal_type='sweep',
price=price,
volume=quantity * price,
confidence=confidence,
description=f"Sell sweep: {levels_consumed} levels, {total_consumed_size:.2f} size"
)
else: # Buy order, check bid side
levels_consumed = 0
total_consumed_size = 0
for level in before_snapshot.bids[:5]:
if level.price >= price:
levels_consumed += 1
total_consumed_size += level.size
if levels_consumed >= 2 and total_consumed_size > quantity * 1.5:
confidence = min(0.9, levels_consumed / 5.0 + 0.3)
return OrderFlowSignal(
timestamp=datetime.now(),
signal_type='sweep',
price=price,
volume=quantity * price,
confidence=confidence,
description=f"Buy sweep: {levels_consumed} levels, {total_consumed_size:.2f} size"
)
return None
except Exception as e:
logger.error(f"Error detecting sweep for {symbol}: {e}")
return None
def _detect_absorption(self, symbol: str, snapshots: List[OrderBookSnapshot],
price: float, quantity: float) -> Optional[OrderFlowSignal]:
"""Detect absorption patterns where large orders are absorbed without price movement"""
try:
if len(snapshots) < 3:
return None
# Check if large order was absorbed with minimal price impact
volume_threshold = 10000 # $10K minimum for absorption
price_impact_threshold = 0.001 # 0.1% max price impact
trade_value = price * quantity
if trade_value < volume_threshold:
return None
# Calculate price impact
price_before = snapshots[-3].mid_price
price_after = snapshots[-1].mid_price
price_impact = abs(price_after - price_before) / price_before
if price_impact < price_impact_threshold:
confidence = min(0.8, (trade_value / 50000) * 0.5 + 0.3) # Scale with size
return OrderFlowSignal(
timestamp=datetime.now(),
signal_type='absorption',
price=price,
volume=trade_value,
confidence=confidence,
description=f"Absorption: ${trade_value:.0f} with {price_impact*100:.3f}% impact"
)
return None
except Exception as e:
logger.error(f"Error detecting absorption for {symbol}: {e}")
return None
def _detect_momentum_trade(self, symbol: str, price: float, quantity: float,
is_buyer_maker: bool) -> Optional[OrderFlowSignal]:
"""Detect momentum trades based on size and direction"""
try:
trade_value = price * quantity
momentum_threshold = 25000 # $25K minimum for momentum classification
if trade_value < momentum_threshold:
return None
# Calculate confidence based on trade size
confidence = min(0.9, trade_value / 100000 * 0.6 + 0.3)
direction = "sell" if is_buyer_maker else "buy"
return OrderFlowSignal(
timestamp=datetime.now(),
signal_type='momentum',
price=price,
volume=trade_value,
confidence=confidence,
description=f"Large {direction}: ${trade_value:.0f}"
)
except Exception as e:
logger.error(f"Error detecting momentum for {symbol}: {e}")
return None
async def _notify_flow_signal(self, symbol: str, signal: OrderFlowSignal):
"""Notify CNN and DQN models of order flow signals"""
try:
signal_data = {
'signal_type': signal.signal_type,
'price': signal.price,
'volume': signal.volume,
'confidence': signal.confidence,
'timestamp': signal.timestamp,
'description': signal.description
}
# Notify CNN callbacks
for callback in self.cnn_callbacks:
try:
callback(symbol, signal_data)
except Exception as e:
logger.warning(f"Error in CNN callback: {e}")
# Notify DQN callbacks
for callback in self.dqn_callbacks:
try:
callback(symbol, signal_data)
except Exception as e:
logger.warning(f"Error in DQN callback: {e}")
except Exception as e:
logger.error(f"Error notifying flow signal: {e}")
async def _continuous_analysis(self):
"""Continuous analysis of market microstructure"""
while self.is_streaming:
try:
await asyncio.sleep(1) # Analyze every second
for symbol in self.symbols:
# Generate CNN features
cnn_features = self.get_cnn_features(symbol)
if cnn_features is not None:
for callback in self.cnn_callbacks:
try:
callback(symbol, {'features': cnn_features, 'type': 'orderbook'})
except Exception as e:
logger.warning(f"Error in CNN feature callback: {e}")
# Generate DQN state features
dqn_features = self.get_dqn_state_features(symbol)
if dqn_features is not None:
for callback in self.dqn_callbacks:
try:
callback(symbol, {'state': dqn_features, 'type': 'orderbook'})
except Exception as e:
logger.warning(f"Error in DQN state callback: {e}")
except Exception as e:
logger.error(f"Error in continuous analysis: {e}")
await asyncio.sleep(5)
def get_cnn_features(self, symbol: str) -> Optional[np.ndarray]:
"""Generate CNN input features from order book data"""
try:
if symbol not in self.order_books:
return None
snapshot = self.order_books[symbol]
features = []
# Order book features (40 features: 20 levels x 2 sides)
for i in range(min(20, len(snapshot.bids))):
bid = snapshot.bids[i]
features.append(bid.size)
features.append(bid.price - snapshot.mid_price) # Price offset
# Pad if not enough bid levels
while len(features) < 40:
features.extend([0.0, 0.0])
for i in range(min(20, len(snapshot.asks))):
ask = snapshot.asks[i]
features.append(ask.size)
features.append(ask.price - snapshot.mid_price) # Price offset
# Pad if not enough ask levels
while len(features) < 80:
features.extend([0.0, 0.0])
# Liquidity metrics (10 features)
metrics = self.liquidity_metrics.get(symbol, {})
features.extend([
metrics.get('total_bid_size', 0.0),
metrics.get('total_ask_size', 0.0),
metrics.get('liquidity_ratio', 1.0),
metrics.get('spread_bps', 0.0),
snapshot.spread,
metrics.get('weighted_mid', snapshot.mid_price) - snapshot.mid_price,
len(snapshot.bids),
len(snapshot.asks),
snapshot.mid_price,
time.time() % 86400 # Time of day
])
# Order book imbalance features (5 features)
if self.order_book_imbalances[symbol]:
latest_imbalance = self.order_book_imbalances[symbol][-1]
features.extend([
latest_imbalance['imbalance'],
latest_imbalance['bid_size'],
latest_imbalance['ask_size'],
latest_imbalance['bid_size'] + latest_imbalance['ask_size'],
abs(latest_imbalance['imbalance'])
])
else:
features.extend([0.0, 0.0, 0.0, 0.0, 0.0])
# Flow signal features (5 features)
recent_signals = [s for s in self.flow_signals[symbol]
if (datetime.now() - s.timestamp).seconds < 60]
sweep_count = sum(1 for s in recent_signals if s.signal_type == 'sweep')
absorption_count = sum(1 for s in recent_signals if s.signal_type == 'absorption')
momentum_count = sum(1 for s in recent_signals if s.signal_type == 'momentum')
max_confidence = max([s.confidence for s in recent_signals], default=0.0)
total_flow_volume = sum(s.volume for s in recent_signals)
features.extend([
sweep_count,
absorption_count,
momentum_count,
max_confidence,
total_flow_volume
])
return np.array(features, dtype=np.float32)
except Exception as e:
logger.error(f"Error generating CNN features for {symbol}: {e}")
return None
def get_dqn_state_features(self, symbol: str) -> Optional[np.ndarray]:
"""Generate DQN state features from order book data"""
try:
if symbol not in self.order_books:
return None
snapshot = self.order_books[symbol]
state_features = []
# Normalized order book state (20 features)
total_bid_size = sum(level.size for level in snapshot.bids[:10])
total_ask_size = sum(level.size for level in snapshot.asks[:10])
total_size = total_bid_size + total_ask_size
if total_size > 0:
for i in range(min(10, len(snapshot.bids))):
state_features.append(snapshot.bids[i].size / total_size)
# Pad bids
while len(state_features) < 10:
state_features.append(0.0)
for i in range(min(10, len(snapshot.asks))):
state_features.append(snapshot.asks[i].size / total_size)
# Pad asks
while len(state_features) < 20:
state_features.append(0.0)
else:
state_features.extend([0.0] * 20)
# Market state indicators (10 features)
metrics = self.liquidity_metrics.get(symbol, {})
# Normalize spread as percentage
spread_pct = (snapshot.spread / snapshot.mid_price) if snapshot.mid_price > 0 else 0
# Liquidity imbalance
liquidity_ratio = metrics.get('liquidity_ratio', 1.0)
liquidity_imbalance = (liquidity_ratio - 1) / (liquidity_ratio + 1)
# Recent flow signals strength
recent_signals = [s for s in self.flow_signals[symbol]
if (datetime.now() - s.timestamp).seconds < 30]
flow_strength = sum(s.confidence for s in recent_signals) / max(len(recent_signals), 1)
# Price volatility (from recent snapshots)
if len(self.order_book_history[symbol]) >= 10:
recent_prices = [s.mid_price for s in list(self.order_book_history[symbol])[-10:]]
price_volatility = np.std(recent_prices) / np.mean(recent_prices) if recent_prices else 0
else:
price_volatility = 0
state_features.extend([
spread_pct * 10000, # Spread in basis points
liquidity_imbalance,
flow_strength,
price_volatility * 100, # Volatility as percentage
min(len(snapshot.bids), 20) / 20, # Book depth ratio
min(len(snapshot.asks), 20) / 20,
sweep_count / 10 if 'sweep_count' in locals() else 0, # From CNN features
absorption_count / 5 if 'absorption_count' in locals() else 0,
momentum_count / 5 if 'momentum_count' in locals() else 0,
(datetime.now().hour * 60 + datetime.now().minute) / 1440 # Time of day normalized
])
return np.array(state_features, dtype=np.float32)
except Exception as e:
logger.error(f"Error generating DQN features for {symbol}: {e}")
return None
def get_order_heatmap_matrix(self, symbol: str, levels: int = 40) -> Optional[np.ndarray]:
"""Generate order size heatmap matrix for dashboard visualization"""
try:
if symbol not in self.order_heatmaps or not self.order_heatmaps[symbol]:
return None
# Create price levels around current mid price
current_snapshot = self.order_books.get(symbol)
if not current_snapshot:
return None
mid_price = current_snapshot.mid_price
price_step = mid_price * 0.0001 # 1 basis point steps
# Create matrix: time x price levels
time_window = min(600, len(self.order_heatmaps[symbol])) # 10 minutes max
heatmap_matrix = np.zeros((time_window, levels))
# Fill matrix with order sizes
for t, entry in enumerate(list(self.order_heatmaps[symbol])[-time_window:]):
for price_offset, level_data in entry['levels'].items():
# Convert price offset to matrix index
level_idx = int((price_offset + (levels/2) * price_step) / price_step)
if 0 <= level_idx < levels:
size_weight = 1.0 if level_data['side'] == 'bid' else -1.0
heatmap_matrix[t, level_idx] = level_data['size'] * size_weight
return heatmap_matrix
except Exception as e:
logger.error(f"Error generating heatmap matrix for {symbol}: {e}")
return None
def get_volume_profile_data(self, symbol: str) -> Optional[List[Dict]]:
"""Get session volume profile data"""
try:
if symbol not in self.volume_profiles:
return None
profile_data = []
for level in sorted(self.volume_profiles[symbol], key=lambda x: x.price):
profile_data.append({
'price': level.price,
'volume': level.volume,
'buy_volume': level.buy_volume,
'sell_volume': level.sell_volume,
'trades_count': level.trades_count,
'vwap': level.vwap,
'net_volume': level.buy_volume - level.sell_volume
})
return profile_data
except Exception as e:
logger.error(f"Error getting volume profile for {symbol}: {e}")
return None
def get_current_order_book(self, symbol: str) -> Optional[Dict]:
"""Get current order book snapshot"""
try:
if symbol not in self.order_books:
return None
snapshot = self.order_books[symbol]
return {
'timestamp': snapshot.timestamp.isoformat(),
'symbol': symbol,
'mid_price': snapshot.mid_price,
'spread': snapshot.spread,
'bids': [{'price': l.price, 'size': l.size} for l in snapshot.bids[:20]],
'asks': [{'price': l.price, 'size': l.size} for l in snapshot.asks[:20]],
'liquidity_metrics': self.liquidity_metrics.get(symbol, {}),
'recent_signals': [
{
'type': s.signal_type,
'price': s.price,
'volume': s.volume,
'confidence': s.confidence,
'timestamp': s.timestamp.isoformat()
}
for s in list(self.flow_signals[symbol])[-5:] # Last 5 signals
]
}
except Exception as e:
logger.error(f"Error getting order book for {symbol}: {e}")
return None
def get_statistics(self) -> Dict[str, Any]:
"""Get provider statistics"""
return {
'symbols': self.symbols,
'is_streaming': self.is_streaming,
'update_counts': dict(self.update_counts),
'last_update_times': {k: v.isoformat() if isinstance(v, datetime) else v
for k, v in self.last_update_times.items()},
'order_books_active': len(self.order_books),
'flow_signals_total': sum(len(signals) for signals in self.flow_signals.values()),
'cnn_callbacks': len(self.cnn_callbacks),
'dqn_callbacks': len(self.dqn_callbacks),
'websocket_tasks': len(self.websocket_tasks)
}

File diff suppressed because it is too large Load Diff

View File

@ -1,785 +0,0 @@
"""
CNN Training Pipeline with Comprehensive Data Storage and Replay
This module implements a robust CNN training pipeline that:
1. Integrates with the comprehensive training data collection system
2. Stores all backpropagation data for gradient replay
3. Enables retraining on most profitable setups
4. Maintains training episode profitability tracking
5. Supports both real-time and batch training modes
Key Features:
- Integration with TrainingDataCollector for data validation
- Gradient and loss storage for each training step
- Profitable episode prioritization and replay
- Comprehensive training metrics and validation
- Real-time pivot point prediction with outcome tracking
"""
import asyncio
import logging
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any, Callable
from dataclasses import dataclass, field
import json
import pickle
from collections import deque, defaultdict
import threading
from concurrent.futures import ThreadPoolExecutor
from .training_data_collector import (
TrainingDataCollector,
TrainingEpisode,
ModelInputPackage,
get_training_data_collector
)
logger = logging.getLogger(__name__)
@dataclass
class CNNTrainingStep:
"""Single CNN training step with complete backpropagation data"""
step_id: str
timestamp: datetime
episode_id: str
# Input data
input_features: torch.Tensor
target_labels: torch.Tensor
# Forward pass results
model_outputs: Dict[str, torch.Tensor]
predictions: Dict[str, Any]
confidence_scores: torch.Tensor
# Loss components
total_loss: float
pivot_prediction_loss: float
confidence_loss: float
regularization_loss: float
# Backpropagation data
gradients: Dict[str, torch.Tensor] # Gradients for each parameter
gradient_norms: Dict[str, float] # Gradient norms for monitoring
# Model state
model_state_dict: Optional[Dict[str, torch.Tensor]] = None
optimizer_state: Optional[Dict[str, Any]] = None
# Training metadata
learning_rate: float = 0.001
batch_size: int = 32
epoch: int = 0
# Profitability tracking
actual_profitability: Optional[float] = None
prediction_accuracy: Optional[float] = None
training_value: float = 0.0 # Value of this training step for replay
@dataclass
class CNNTrainingSession:
"""Complete CNN training session with multiple steps"""
session_id: str
start_timestamp: datetime
end_timestamp: Optional[datetime] = None
# Session configuration
training_mode: str = 'real_time' # 'real_time', 'batch', 'replay'
symbol: str = ''
# Training steps
training_steps: List[CNNTrainingStep] = field(default_factory=list)
# Session metrics
total_steps: int = 0
average_loss: float = 0.0
best_loss: float = float('inf')
convergence_achieved: bool = False
# Profitability metrics
profitable_predictions: int = 0
total_predictions: int = 0
profitability_rate: float = 0.0
# Session value for replay prioritization
session_value: float = 0.0
class CNNPivotPredictor(nn.Module):
"""CNN model for pivot point prediction with comprehensive output"""
def __init__(self,
input_channels: int = 10, # Multiple timeframes
sequence_length: int = 300, # 300 bars
hidden_dim: int = 256,
num_pivot_classes: int = 3, # high, low, none
dropout_rate: float = 0.2):
super(CNNPivotPredictor, self).__init__()
self.input_channels = input_channels
self.sequence_length = sequence_length
self.hidden_dim = hidden_dim
# Convolutional layers for pattern extraction
self.conv_layers = nn.Sequential(
# First conv block
nn.Conv1d(input_channels, 64, kernel_size=7, padding=3),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(dropout_rate),
# Second conv block
nn.Conv1d(64, 128, kernel_size=5, padding=2),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(dropout_rate),
# Third conv block
nn.Conv1d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(dropout_rate),
)
# LSTM for temporal dependencies
self.lstm = nn.LSTM(
input_size=256,
hidden_size=hidden_dim,
num_layers=2,
batch_first=True,
dropout=dropout_rate,
bidirectional=True
)
# Attention mechanism
self.attention = nn.MultiheadAttention(
embed_dim=hidden_dim * 2, # Bidirectional LSTM
num_heads=8,
dropout=dropout_rate,
batch_first=True
)
# Output heads
self.pivot_classifier = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, num_pivot_classes)
)
self.pivot_price_regressor = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, 1)
)
self.confidence_head = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid()
)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize weights with proper scaling"""
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv1d):
torch.nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
def forward(self, x):
"""
Forward pass through CNN pivot predictor
Args:
x: Input tensor [batch_size, input_channels, sequence_length]
Returns:
Dict containing predictions and hidden states
"""
batch_size = x.size(0)
# Convolutional feature extraction
conv_features = self.conv_layers(x) # [batch, 256, sequence_length]
# Prepare for LSTM (transpose to [batch, sequence, features])
lstm_input = conv_features.transpose(1, 2) # [batch, sequence_length, 256]
# LSTM processing
lstm_output, (hidden, cell) = self.lstm(lstm_input) # [batch, sequence_length, hidden_dim*2]
# Attention mechanism
attended_output, attention_weights = self.attention(
lstm_output, lstm_output, lstm_output
)
# Use the last timestep for predictions
final_features = attended_output[:, -1, :] # [batch, hidden_dim*2]
# Generate predictions
pivot_logits = self.pivot_classifier(final_features)
pivot_price = self.pivot_price_regressor(final_features)
confidence = self.confidence_head(final_features)
return {
'pivot_logits': pivot_logits,
'pivot_price': pivot_price,
'confidence': confidence,
'hidden_states': final_features,
'attention_weights': attention_weights,
'conv_features': conv_features,
'lstm_output': lstm_output
}
class CNNTrainingDataset(Dataset):
"""Dataset for CNN training with training episodes"""
def __init__(self, training_episodes: List[TrainingEpisode]):
self.episodes = training_episodes
self.valid_episodes = self._validate_episodes()
def _validate_episodes(self) -> List[TrainingEpisode]:
"""Validate and filter episodes for training"""
valid = []
for episode in self.episodes:
try:
# Check if episode has required data
if (episode.input_package.cnn_features is not None and
episode.actual_outcome.outcome_validated):
valid.append(episode)
except Exception as e:
logger.warning(f"Invalid episode {episode.episode_id}: {e}")
logger.info(f"Validated {len(valid)}/{len(self.episodes)} episodes for training")
return valid
def __len__(self):
return len(self.valid_episodes)
def __getitem__(self, idx):
episode = self.valid_episodes[idx]
# Extract features
features = torch.from_numpy(episode.input_package.cnn_features).float()
# Create labels from actual outcomes
pivot_class = self._determine_pivot_class(episode.actual_outcome)
pivot_price = episode.actual_outcome.optimal_exit_price
confidence_target = episode.actual_outcome.profitability_score
return {
'features': features,
'pivot_class': torch.tensor(pivot_class, dtype=torch.long),
'pivot_price': torch.tensor(pivot_price, dtype=torch.float),
'confidence_target': torch.tensor(confidence_target, dtype=torch.float),
'episode_id': episode.episode_id,
'profitability': episode.actual_outcome.profitability_score
}
def _determine_pivot_class(self, outcome) -> int:
"""Determine pivot class from outcome"""
if outcome.price_change_15m > 0.5: # Significant upward movement
return 0 # High pivot
elif outcome.price_change_15m < -0.5: # Significant downward movement
return 1 # Low pivot
else:
return 2 # No significant pivot
class CNNTrainer:
"""CNN trainer with comprehensive data storage and replay capabilities"""
def __init__(self,
model: CNNPivotPredictor,
device: str = 'cuda',
learning_rate: float = 0.001,
storage_dir: str = "cnn_training_storage"):
self.model = model.to(device)
self.device = device
self.learning_rate = learning_rate
# Storage
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(parents=True, exist_ok=True)
# Optimizer
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=learning_rate,
weight_decay=1e-5
)
# Learning rate scheduler
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode='min', patience=10, factor=0.5
)
# Training data collector
self.data_collector = get_training_data_collector()
# Training sessions storage
self.training_sessions: List[CNNTrainingSession] = []
self.current_session: Optional[CNNTrainingSession] = None
# Training statistics
self.training_stats = {
'total_sessions': 0,
'total_steps': 0,
'best_validation_loss': float('inf'),
'profitable_predictions': 0,
'total_predictions': 0,
'replay_sessions': 0
}
# Background training
self.is_training = False
self.training_thread = None
logger.info(f"CNN Trainer initialized")
logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
logger.info(f"Storage directory: {self.storage_dir}")
def start_real_time_training(self, symbol: str):
"""Start real-time training for a symbol"""
if self.is_training:
logger.warning("CNN training already running")
return
self.is_training = True
self.training_thread = threading.Thread(
target=self._real_time_training_worker,
args=(symbol,),
daemon=True
)
self.training_thread.start()
logger.info(f"Started real-time CNN training for {symbol}")
def stop_training(self):
"""Stop training"""
self.is_training = False
if self.training_thread:
self.training_thread.join(timeout=10)
if self.current_session:
self._finalize_training_session()
logger.info("CNN training stopped")
def _real_time_training_worker(self, symbol: str):
"""Real-time training worker"""
logger.info(f"Real-time CNN training worker started for {symbol}")
while self.is_training:
try:
# Get high-priority episodes for training
episodes = self.data_collector.get_high_priority_episodes(
symbol=symbol,
limit=100,
min_priority=0.3
)
if len(episodes) >= 32: # Minimum batch size
self._train_on_episodes(episodes, training_mode='real_time')
# Wait before next training cycle
threading.Event().wait(300) # Train every 5 minutes
except Exception as e:
logger.error(f"Error in real-time training worker: {e}")
threading.Event().wait(60) # Wait before retrying
logger.info(f"Real-time CNN training worker stopped for {symbol}")
def train_on_profitable_episodes(self,
symbol: str,
min_profitability: float = 0.7,
max_episodes: int = 500) -> Dict[str, Any]:
"""Train specifically on most profitable episodes"""
try:
# Get all episodes for symbol
all_episodes = self.data_collector.training_episodes.get(symbol, [])
# Filter for profitable episodes
profitable_episodes = [
ep for ep in all_episodes
if (ep.actual_outcome.is_profitable and
ep.actual_outcome.profitability_score >= min_profitability)
]
# Sort by profitability and limit
profitable_episodes.sort(
key=lambda x: x.actual_outcome.profitability_score,
reverse=True
)
profitable_episodes = profitable_episodes[:max_episodes]
if len(profitable_episodes) < 10:
logger.warning(f"Insufficient profitable episodes for {symbol}: {len(profitable_episodes)}")
return {'status': 'insufficient_data', 'episodes_found': len(profitable_episodes)}
# Train on profitable episodes
results = self._train_on_episodes(
profitable_episodes,
training_mode='profitable_replay'
)
logger.info(f"Trained on {len(profitable_episodes)} profitable episodes for {symbol}")
return results
except Exception as e:
logger.error(f"Error training on profitable episodes: {e}")
return {'status': 'error', 'error': str(e)}
def _train_on_episodes(self,
episodes: List[TrainingEpisode],
training_mode: str = 'batch') -> Dict[str, Any]:
"""Train on a batch of episodes with comprehensive data storage"""
try:
# Start new training session
session = CNNTrainingSession(
session_id=f"{training_mode}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
start_timestamp=datetime.now(),
training_mode=training_mode,
symbol=episodes[0].input_package.symbol if episodes else 'unknown'
)
self.current_session = session
# Create dataset and dataloader
dataset = CNNTrainingDataset(episodes)
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=2
)
# Training loop
self.model.train()
total_loss = 0.0
num_batches = 0
for batch_idx, batch in enumerate(dataloader):
# Move to device
features = batch['features'].to(self.device)
pivot_class = batch['pivot_class'].to(self.device)
pivot_price = batch['pivot_price'].to(self.device)
confidence_target = batch['confidence_target'].to(self.device)
# Forward pass
self.optimizer.zero_grad()
outputs = self.model(features)
# Calculate losses
classification_loss = F.cross_entropy(outputs['pivot_logits'], pivot_class)
regression_loss = F.mse_loss(outputs['pivot_price'].squeeze(), pivot_price)
confidence_loss = F.binary_cross_entropy(
outputs['confidence'].squeeze(),
confidence_target
)
# Combined loss
total_batch_loss = classification_loss + 0.5 * regression_loss + 0.3 * confidence_loss
# Backward pass
total_batch_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Store gradients before optimizer step
gradients = {}
gradient_norms = {}
for name, param in self.model.named_parameters():
if param.grad is not None:
gradients[name] = param.grad.clone().detach()
gradient_norms[name] = param.grad.norm().item()
# Optimizer step
self.optimizer.step()
# Create training step record
step = CNNTrainingStep(
step_id=f"{session.session_id}_step_{batch_idx}",
timestamp=datetime.now(),
episode_id=f"batch_{batch_idx}",
input_features=features.detach().cpu(),
target_labels=pivot_class.detach().cpu(),
model_outputs={k: v.detach().cpu() for k, v in outputs.items()},
predictions=self._extract_predictions(outputs),
confidence_scores=outputs['confidence'].detach().cpu(),
total_loss=total_batch_loss.item(),
pivot_prediction_loss=classification_loss.item(),
confidence_loss=confidence_loss.item(),
regularization_loss=0.0,
gradients=gradients,
gradient_norms=gradient_norms,
learning_rate=self.optimizer.param_groups[0]['lr'],
batch_size=features.size(0)
)
# Calculate training value for this step
step.training_value = self._calculate_step_training_value(step, batch)
# Add to session
session.training_steps.append(step)
total_loss += total_batch_loss.item()
num_batches += 1
# Log progress
if batch_idx % 10 == 0:
logger.debug(f"Batch {batch_idx}: Loss = {total_batch_loss.item():.4f}")
# Finalize session
session.end_timestamp = datetime.now()
session.total_steps = num_batches
session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0
session.best_loss = min(step.total_loss for step in session.training_steps)
# Calculate session value
session.session_value = self._calculate_session_value(session)
# Update scheduler
self.scheduler.step(session.average_loss)
# Save session
self._save_training_session(session)
# Update statistics
self.training_stats['total_sessions'] += 1
self.training_stats['total_steps'] += session.total_steps
if training_mode == 'profitable_replay':
self.training_stats['replay_sessions'] += 1
logger.info(f"Training session completed: {session.session_id}")
logger.info(f"Average loss: {session.average_loss:.4f}")
logger.info(f"Session value: {session.session_value:.3f}")
return {
'status': 'success',
'session_id': session.session_id,
'average_loss': session.average_loss,
'total_steps': session.total_steps,
'session_value': session.session_value
}
except Exception as e:
logger.error(f"Error in training session: {e}")
return {'status': 'error', 'error': str(e)}
finally:
self.current_session = None
def _extract_predictions(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
"""Extract human-readable predictions from model outputs"""
try:
pivot_probs = F.softmax(outputs['pivot_logits'], dim=1)
predicted_class = torch.argmax(pivot_probs, dim=1)
return {
'pivot_class': predicted_class.cpu().numpy().tolist(),
'pivot_probabilities': pivot_probs.cpu().numpy().tolist(),
'pivot_price': outputs['pivot_price'].cpu().numpy().tolist(),
'confidence': outputs['confidence'].cpu().numpy().tolist()
}
except Exception as e:
logger.warning(f"Error extracting predictions: {e}")
return {}
def _calculate_step_training_value(self,
step: CNNTrainingStep,
batch: Dict[str, Any]) -> float:
"""Calculate the training value of a step for replay prioritization"""
try:
value = 0.0
# Base value from loss (lower loss = higher value)
if step.total_loss > 0:
value += 1.0 / (1.0 + step.total_loss)
# Bonus for high profitability episodes in batch
avg_profitability = torch.mean(batch['profitability']).item()
value += avg_profitability * 0.3
# Bonus for gradient magnitude (indicates learning)
avg_grad_norm = np.mean(list(step.gradient_norms.values()))
value += min(avg_grad_norm / 10.0, 0.2) # Cap at 0.2
return min(value, 1.0)
except Exception as e:
logger.warning(f"Error calculating step training value: {e}")
return 0.0
def _calculate_session_value(self, session: CNNTrainingSession) -> float:
"""Calculate overall session value for replay prioritization"""
try:
if not session.training_steps:
return 0.0
# Average step values
avg_step_value = np.mean([step.training_value for step in session.training_steps])
# Bonus for convergence
convergence_bonus = 0.0
if len(session.training_steps) > 10:
early_loss = np.mean([s.total_loss for s in session.training_steps[:5]])
late_loss = np.mean([s.total_loss for s in session.training_steps[-5:]])
if early_loss > late_loss:
convergence_bonus = min((early_loss - late_loss) / early_loss, 0.3)
# Bonus for profitable replay sessions
mode_bonus = 0.2 if session.training_mode == 'profitable_replay' else 0.0
return min(avg_step_value + convergence_bonus + mode_bonus, 1.0)
except Exception as e:
logger.warning(f"Error calculating session value: {e}")
return 0.0
def _save_training_session(self, session: CNNTrainingSession):
"""Save training session to disk"""
try:
session_dir = self.storage_dir / session.symbol / 'sessions'
session_dir.mkdir(parents=True, exist_ok=True)
# Save full session data
session_file = session_dir / f"{session.session_id}.pkl"
with open(session_file, 'wb') as f:
pickle.dump(session, f)
# Save session metadata
metadata = {
'session_id': session.session_id,
'start_timestamp': session.start_timestamp.isoformat(),
'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None,
'training_mode': session.training_mode,
'symbol': session.symbol,
'total_steps': session.total_steps,
'average_loss': session.average_loss,
'best_loss': session.best_loss,
'session_value': session.session_value
}
metadata_file = session_dir / f"{session.session_id}_metadata.json"
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2)
logger.debug(f"Saved training session: {session.session_id}")
except Exception as e:
logger.error(f"Error saving training session: {e}")
def _finalize_training_session(self):
"""Finalize current training session"""
if self.current_session:
self.current_session.end_timestamp = datetime.now()
self._save_training_session(self.current_session)
self.training_sessions.append(self.current_session)
self.current_session = None
def get_training_statistics(self) -> Dict[str, Any]:
"""Get comprehensive training statistics"""
stats = self.training_stats.copy()
# Add recent session information
if self.training_sessions:
recent_sessions = sorted(
self.training_sessions,
key=lambda x: x.start_timestamp,
reverse=True
)[:10]
stats['recent_sessions'] = [
{
'session_id': s.session_id,
'timestamp': s.start_timestamp.isoformat(),
'mode': s.training_mode,
'average_loss': s.average_loss,
'session_value': s.session_value
}
for s in recent_sessions
]
# Calculate profitability rate
if stats['total_predictions'] > 0:
stats['profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions']
else:
stats['profitability_rate'] = 0.0
return stats
def replay_high_value_sessions(self,
symbol: str,
min_session_value: float = 0.7,
max_sessions: int = 10) -> Dict[str, Any]:
"""Replay high-value training sessions"""
try:
# Find high-value sessions
high_value_sessions = [
s for s in self.training_sessions
if (s.symbol == symbol and
s.session_value >= min_session_value)
]
# Sort by value and limit
high_value_sessions.sort(key=lambda x: x.session_value, reverse=True)
high_value_sessions = high_value_sessions[:max_sessions]
if not high_value_sessions:
return {'status': 'no_high_value_sessions', 'sessions_found': 0}
# Replay sessions
total_replayed = 0
for session in high_value_sessions:
# Extract episodes from session steps
episode_ids = list(set(step.episode_id for step in session.training_steps))
# Get corresponding episodes
episodes = []
for episode_id in episode_ids:
# Find episode in data collector
for ep in self.data_collector.training_episodes.get(symbol, []):
if ep.episode_id == episode_id:
episodes.append(ep)
break
if episodes:
self._train_on_episodes(episodes, training_mode='high_value_replay')
total_replayed += 1
logger.info(f"Replayed {total_replayed} high-value sessions for {symbol}")
return {
'status': 'success',
'sessions_replayed': total_replayed,
'sessions_found': len(high_value_sessions)
}
except Exception as e:
logger.error(f"Error replaying high-value sessions: {e}")
return {'status': 'error', 'error': str(e)}
# Global instance
cnn_trainer = None
def get_cnn_trainer(model: CNNPivotPredictor = None) -> CNNTrainer:
"""Get global CNN trainer instance"""
global cnn_trainer
if cnn_trainer is None:
if model is None:
model = CNNPivotPredictor()
cnn_trainer = CNNTrainer(model)
return cnn_trainer

View File

@ -1,864 +0,0 @@
# """
# Enhanced CNN Adapter for Standardized Input Format
# This module provides an adapter for the EnhancedCNN model to work with the standardized
# BaseDataInput format, enabling seamless integration with the multi-modal trading system.
# """
# import torch
# import numpy as np
# import logging
# import os
# import random
# from datetime import datetime, timedelta
# from typing import Dict, List, Optional, Tuple, Any, Union
# from threading import Lock
# from .data_models import BaseDataInput, ModelOutput, create_model_output
# from NN.models.enhanced_cnn import EnhancedCNN
# from utils.inference_logger import log_model_inference
# logger = logging.getLogger(__name__)
# class EnhancedCNNAdapter:
# """
# Adapter for EnhancedCNN model to work with standardized BaseDataInput format
# This adapter:
# 1. Converts BaseDataInput to the format expected by EnhancedCNN
# 2. Processes model outputs to create standardized ModelOutput
# 3. Manages model training with collected data
# 4. Handles checkpoint management
# """
# def __init__(self, model_path: str = None, checkpoint_dir: str = "models/enhanced_cnn"):
# """
# Initialize the EnhancedCNN adapter
# Args:
# model_path: Path to load model from, if None a new model is created
# checkpoint_dir: Directory to save checkpoints to
# """
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# self.model = None
# self.model_path = model_path
# self.checkpoint_dir = checkpoint_dir
# self.training_lock = Lock()
# self.training_data = []
# self.max_training_samples = 10000
# self.batch_size = 32
# self.learning_rate = 0.0001
# self.model_name = "enhanced_cnn"
# # Enhanced metrics tracking
# self.last_inference_time = None
# self.last_inference_duration = 0.0
# self.last_prediction_output = None
# self.last_training_time = None
# self.last_training_duration = 0.0
# self.last_training_loss = 0.0
# self.inference_count = 0
# self.training_count = 0
# # Create checkpoint directory if it doesn't exist
# os.makedirs(checkpoint_dir, exist_ok=True)
# # Initialize the model
# self._initialize_model()
# # Load checkpoint if available
# if model_path and os.path.exists(model_path):
# self._load_checkpoint(model_path)
# else:
# self._load_best_checkpoint()
# # Final device check and move
# self._ensure_model_on_device()
# logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
# def _create_realistic_synthetic_features(self, symbol: str) -> torch.Tensor:
# """Create realistic synthetic features instead of random data"""
# try:
# # Create realistic market-like features
# features = torch.zeros(7850, dtype=torch.float32, device=self.device)
# # OHLCV features (6000 features: 300 frames x 4 timeframes x 5 features)
# ohlcv_start = 0
# for timeframe_idx in range(4): # 1s, 1m, 1h, 1d
# base_price = 3500.0 + timeframe_idx * 10 # Slight variation per timeframe
# for frame_idx in range(300):
# # Create realistic price movement
# price_change = torch.sin(torch.tensor(frame_idx * 0.1)) * 0.01 # Cyclical movement
# current_price = base_price * (1 + price_change)
# # Realistic OHLCV values
# open_price = current_price
# high_price = current_price * torch.uniform(1.0, 1.005)
# low_price = current_price * torch.uniform(0.995, 1.0)
# close_price = current_price * torch.uniform(0.998, 1.002)
# volume = torch.uniform(500.0, 2000.0)
# # Set features
# feature_idx = ohlcv_start + frame_idx * 5 + timeframe_idx * 1500
# features[feature_idx:feature_idx+5] = torch.tensor([open_price, high_price, low_price, close_price, volume])
# # BTC OHLCV features (1500 features: 300 frames x 5 features)
# btc_start = 6000
# btc_base_price = 50000.0
# for frame_idx in range(300):
# price_change = torch.sin(torch.tensor(frame_idx * 0.05)) * 0.02
# current_price = btc_base_price * (1 + price_change)
# open_price = current_price
# high_price = current_price * torch.uniform(1.0, 1.01)
# low_price = current_price * torch.uniform(0.99, 1.0)
# close_price = current_price * torch.uniform(0.995, 1.005)
# volume = torch.uniform(100.0, 500.0)
# feature_idx = btc_start + frame_idx * 5
# features[feature_idx:feature_idx+5] = torch.tensor([open_price, high_price, low_price, close_price, volume])
# # COB features (200 features) - realistic order book data
# cob_start = 7500
# for i in range(200):
# features[cob_start + i] = torch.uniform(0.0, 1000.0) # Realistic COB values
# # Technical indicators (100 features)
# indicator_start = 7700
# for i in range(100):
# features[indicator_start + i] = torch.uniform(-1.0, 1.0) # Normalized indicators
# # Last predictions (50 features)
# prediction_start = 7800
# for i in range(50):
# features[prediction_start + i] = torch.uniform(0.0, 1.0) # Probability values
# return features
# except Exception as e:
# logger.error(f"Error creating realistic synthetic features: {e}")
# # Fallback to small random variation
# base_features = torch.ones(7850, dtype=torch.float32, device=self.device) * 0.5
# noise = torch.randn(7850, dtype=torch.float32, device=self.device) * 0.1
# return base_features + noise
# def _create_realistic_features(self, symbol: str) -> torch.Tensor:
# """Create features from real market data if available"""
# try:
# # This would need to be implemented to use actual market data
# # For now, fall back to synthetic features
# return self._create_realistic_synthetic_features(symbol)
# except Exception as e:
# logger.error(f"Error creating realistic features: {e}")
# return self._create_realistic_synthetic_features(symbol)
# def _initialize_model(self):
# """Initialize the EnhancedCNN model"""
# try:
# # Calculate input shape based on BaseDataInput structure
# # OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features
# # BTC OHLCV: 300 frames x 5 features = 1500 features
# # COB: ±20 buckets x 4 metrics = 160 features
# # MA: 4 timeframes x 10 buckets = 40 features
# # Technical indicators: 100 features
# # Last predictions: 50 features
# # Total: 7850 features
# input_shape = 7850
# n_actions = 3 # BUY, SELL, HOLD
# # Create model
# self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
# # Ensure model is moved to the correct device
# self.model.to(self.device)
# logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions} on device {self.device}")
# except Exception as e:
# logger.error(f"Error initializing EnhancedCNN model: {e}")
# raise
# def _load_checkpoint(self, checkpoint_path: str) -> bool:
# """Load model from checkpoint path"""
# try:
# if self.model and os.path.exists(checkpoint_path):
# success = self.model.load(checkpoint_path)
# if success:
# # Ensure model is moved to the correct device after loading
# self.model.to(self.device)
# logger.info(f"Loaded model from {checkpoint_path} and moved to {self.device}")
# return True
# else:
# logger.warning(f"Failed to load model from {checkpoint_path}")
# return False
# else:
# logger.warning(f"Checkpoint path does not exist: {checkpoint_path}")
# return False
# except Exception as e:
# logger.error(f"Error loading checkpoint: {e}")
# return False
# def _load_best_checkpoint(self) -> bool:
# """Load the best available checkpoint"""
# try:
# return self.load_best_checkpoint()
# except Exception as e:
# logger.error(f"Error loading best checkpoint: {e}")
# return False
# def load_best_checkpoint(self) -> bool:
# """Load the best checkpoint based on accuracy"""
# try:
# # Import checkpoint manager
# from utils.checkpoint_manager import CheckpointManager
# # Create checkpoint manager
# checkpoint_manager = CheckpointManager(
# checkpoint_dir=self.checkpoint_dir,
# max_checkpoints=10,
# metric_name="accuracy"
# )
# # Load best checkpoint
# best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name)
# if not best_checkpoint_path:
# logger.info(f"No checkpoints found for {self.model_name} - starting in COLD START mode")
# return False
# # Load model
# success = self.model.load(best_checkpoint_path)
# if success:
# # Ensure model is moved to the correct device after loading
# self.model.to(self.device)
# logger.info(f"Loaded best checkpoint from {best_checkpoint_path} and moved to {self.device}")
# # Log metrics
# metrics = best_checkpoint_metadata.get('metrics', {})
# logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}")
# return True
# else:
# logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}")
# return False
# except Exception as e:
# logger.error(f"Error loading best checkpoint: {e}")
# return False
# def _ensure_model_on_device(self):
# """Ensure model and all its components are on the correct device"""
# try:
# if self.model:
# self.model.to(self.device)
# # Also ensure the model's internal device is set correctly
# if hasattr(self.model, 'device'):
# self.model.device = self.device
# logger.debug(f"Model ensured on device {self.device}")
# except Exception as e:
# logger.error(f"Error ensuring model on device: {e}")
# def _create_default_output(self, symbol: str) -> ModelOutput:
# """Create default output when prediction fails"""
# return create_model_output(
# model_type='cnn',
# model_name=self.model_name,
# symbol=symbol,
# action='HOLD',
# confidence=0.0,
# metadata={'error': 'Prediction failed, using default output'}
# )
# def _process_hidden_states(self, hidden_states: Dict[str, Any]) -> Dict[str, Any]:
# """Process hidden states for cross-model feeding"""
# processed_states = {}
# for key, value in hidden_states.items():
# if isinstance(value, torch.Tensor):
# # Convert tensor to numpy array
# processed_states[key] = value.cpu().numpy().tolist()
# else:
# processed_states[key] = value
# return processed_states
# def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor:
# """
# Convert BaseDataInput to feature vector for EnhancedCNN
# Args:
# base_data: Standardized input data
# Returns:
# torch.Tensor: Feature vector for EnhancedCNN
# """
# try:
# # Use the get_feature_vector method from BaseDataInput
# features = base_data.get_feature_vector()
# # Validate feature quality before using
# self._validate_feature_quality(features)
# # Convert to torch tensor
# features_tensor = torch.tensor(features, dtype=torch.float32, device=self.device)
# return features_tensor
# except Exception as e:
# logger.error(f"Error converting BaseDataInput to features: {e}")
# # Return empty tensor with correct shape
# return torch.zeros(7850, dtype=torch.float32, device=self.device)
# def _validate_feature_quality(self, features: np.ndarray):
# """Validate that features are realistic and not synthetic/placeholder data"""
# try:
# if len(features) != 7850:
# logger.warning(f"Feature vector has wrong size: {len(features)} != 7850")
# return
# # Check for all-zero or all-identical features (indicates placeholder data)
# if np.all(features == 0):
# logger.warning("Feature vector contains all zeros - likely placeholder data")
# return
# # Check for repetitive patterns in OHLCV data (first 6000 features)
# ohlcv_features = features[:6000]
# if len(ohlcv_features) >= 20:
# # Check if first 20 values are identical (indicates padding with same bar)
# if np.allclose(ohlcv_features[:20], ohlcv_features[0], atol=1e-6):
# logger.warning("OHLCV features show repetitive pattern - possible synthetic data")
# # Check for unrealistic values
# if np.any(features > 1e6) or np.any(features < -1e6):
# logger.warning("Feature vector contains unrealistic values")
# # Check for NaN or infinite values
# if np.any(np.isnan(features)) or np.any(np.isinf(features)):
# logger.warning("Feature vector contains NaN or infinite values")
# except Exception as e:
# logger.error(f"Error validating feature quality: {e}")
# def predict(self, base_data: BaseDataInput) -> ModelOutput:
# """
# Make a prediction using the EnhancedCNN model
# Args:
# base_data: Standardized input data
# Returns:
# ModelOutput: Standardized model output
# """
# try:
# # Track inference timing
# start_time = datetime.now()
# inference_start = start_time.timestamp()
# # Convert BaseDataInput to features
# features = self._convert_base_data_to_features(base_data)
# # Ensure features has batch dimension
# if features.dim() == 1:
# features = features.unsqueeze(0)
# # Ensure model is on correct device before prediction
# self._ensure_model_on_device()
# # Set model to evaluation mode
# self.model.eval()
# # Make prediction
# with torch.no_grad():
# q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.model(features)
# # Get action and confidence
# action_probs = torch.softmax(q_values, dim=1)
# action_idx = torch.argmax(action_probs, dim=1).item()
# raw_confidence = float(action_probs[0, action_idx].item())
# # Validate confidence - prevent 100% confidence which indicates overfitting
# if raw_confidence >= 0.99:
# logger.warning(f"CNN produced suspiciously high confidence: {raw_confidence:.4f} - possible overfitting")
# # Cap confidence at 0.95 to prevent unrealistic predictions
# confidence = min(raw_confidence, 0.95)
# logger.info(f"Capped confidence from {raw_confidence:.4f} to {confidence:.4f}")
# else:
# confidence = raw_confidence
# # Map action index to action string
# actions = ['BUY', 'SELL', 'HOLD']
# action = actions[action_idx]
# # Extract pivot price prediction (simplified - take first value from price_pred)
# pivot_price = None
# if price_pred is not None and len(price_pred.squeeze()) > 0:
# # Get current price from base_data for context
# current_price = 0.0
# if base_data.ohlcv_1s and len(base_data.ohlcv_1s) > 0:
# current_price = base_data.ohlcv_1s[-1].close
# # Calculate pivot price as current price + predicted change
# price_change_pct = float(price_pred.squeeze()[0].item()) # First prediction value
# pivot_price = current_price * (1 + price_change_pct * 0.01) # Convert percentage to price
# # Create predictions dictionary
# predictions = {
# 'action': action,
# 'buy_probability': float(action_probs[0, 0].item()),
# 'sell_probability': float(action_probs[0, 1].item()),
# 'hold_probability': float(action_probs[0, 2].item()),
# 'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(),
# 'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist(),
# 'pivot_price': pivot_price
# }
# # Create hidden states dictionary
# hidden_states = {
# 'features': features_refined.squeeze(0).cpu().numpy().tolist()
# }
# # Calculate inference duration
# end_time = datetime.now()
# inference_duration = (end_time.timestamp() - inference_start) * 1000 # Convert to milliseconds
# # Update metrics
# self.last_inference_time = start_time
# self.last_inference_duration = inference_duration
# self.inference_count += 1
# # Store last prediction output for dashboard
# self.last_prediction_output = {
# 'action': action,
# 'confidence': confidence,
# 'pivot_price': pivot_price,
# 'timestamp': start_time,
# 'symbol': base_data.symbol
# }
# # Create metadata dictionary
# metadata = {
# 'model_version': '1.0',
# 'timestamp': start_time.isoformat(),
# 'input_shape': features.shape,
# 'inference_duration_ms': inference_duration,
# 'inference_count': self.inference_count
# }
# # Create ModelOutput
# model_output = ModelOutput(
# model_type='cnn',
# model_name=self.model_name,
# symbol=base_data.symbol,
# timestamp=start_time,
# confidence=confidence,
# predictions=predictions,
# hidden_states=hidden_states,
# metadata=metadata
# )
# # Log inference with full input data for training feedback
# log_model_inference(
# model_name=self.model_name,
# symbol=base_data.symbol,
# action=action,
# confidence=confidence,
# probabilities={
# 'BUY': predictions['buy_probability'],
# 'SELL': predictions['sell_probability'],
# 'HOLD': predictions['hold_probability']
# },
# input_features=features.cpu().numpy(), # Store full feature vector
# processing_time_ms=inference_duration,
# checkpoint_id=None, # Could be enhanced to track checkpoint
# metadata={
# 'base_data_input': {
# 'symbol': base_data.symbol,
# 'timestamp': base_data.timestamp.isoformat(),
# 'ohlcv_1s_count': len(base_data.ohlcv_1s),
# 'ohlcv_1m_count': len(base_data.ohlcv_1m),
# 'ohlcv_1h_count': len(base_data.ohlcv_1h),
# 'ohlcv_1d_count': len(base_data.ohlcv_1d),
# 'btc_ohlcv_1s_count': len(base_data.btc_ohlcv_1s),
# 'has_cob_data': base_data.cob_data is not None,
# 'technical_indicators_count': len(base_data.technical_indicators),
# 'pivot_points_count': len(base_data.pivot_points),
# 'last_predictions_count': len(base_data.last_predictions)
# },
# 'model_predictions': {
# 'pivot_price': pivot_price,
# 'extrema_prediction': predictions['extrema'],
# 'price_prediction': predictions['price_prediction']
# }
# }
# )
# return model_output
# except Exception as e:
# logger.error(f"Error making prediction with EnhancedCNN: {e}")
# # Return default ModelOutput
# return create_model_output(
# model_type='cnn',
# model_name=self.model_name,
# symbol=base_data.symbol,
# action='HOLD',
# confidence=0.0
# )
# def add_training_sample(self, symbol_or_base_data, actual_action: str, reward: float):
# """
# Add a training sample to the training data
# Args:
# symbol_or_base_data: Either a symbol string or BaseDataInput object
# actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
# reward: Reward received for the action
# """
# try:
# # Handle both symbol string and BaseDataInput object
# if isinstance(symbol_or_base_data, str):
# # For cold start mode - create a simple training sample with current features
# # This is a simplified approach for rapid training
# symbol = symbol_or_base_data
# # Create a realistic feature vector instead of random data
# # Use actual market data if available, otherwise create realistic synthetic data
# try:
# # Try to get real market data first
# if hasattr(self, 'data_provider') and self.data_provider:
# # This would need to be implemented in the adapter
# features = self._create_realistic_features(symbol)
# else:
# # Create realistic synthetic features (not random)
# features = self._create_realistic_synthetic_features(symbol)
# except Exception as e:
# logger.warning(f"Could not create realistic features for {symbol}: {e}")
# # Fallback to small random variation instead of pure random
# base_features = torch.ones(7850, dtype=torch.float32, device=self.device) * 0.5
# noise = torch.randn(7850, dtype=torch.float32, device=self.device) * 0.1
# features = base_features + noise
# logger.debug(f"Added realistic training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
# else:
# # Full BaseDataInput object
# base_data = symbol_or_base_data
# features = self._convert_base_data_to_features(base_data)
# symbol = base_data.symbol
# logger.debug(f"Added full training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
# # Convert action to index
# actions = ['BUY', 'SELL', 'HOLD']
# action_idx = actions.index(actual_action)
# # Add to training data
# with self.training_lock:
# self.training_data.append((features, action_idx, reward))
# # Limit training data size
# if len(self.training_data) > self.max_training_samples:
# # Sort by reward (highest first) and keep top samples
# self.training_data.sort(key=lambda x: x[2], reverse=True)
# self.training_data = self.training_data[:self.max_training_samples]
# except Exception as e:
# logger.error(f"Error adding training sample: {e}")
# def train(self, epochs: int = 1) -> Dict[str, float]:
# """
# Train the model with collected data and inference history
# Args:
# epochs: Number of epochs to train for
# Returns:
# Dict[str, float]: Training metrics
# """
# try:
# # Track training timing
# training_start_time = datetime.now()
# training_start = training_start_time.timestamp()
# with self.training_lock:
# # Get additional training data from inference history
# self._load_training_data_from_inference_history()
# # Check if we have enough data
# if len(self.training_data) < self.batch_size:
# logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
# return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
# # Ensure model is on correct device before training
# self._ensure_model_on_device()
# # Set model to training mode
# self.model.train()
# # Create optimizer
# optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
# # Training metrics
# total_loss = 0.0
# correct_predictions = 0
# total_predictions = 0
# # Train for specified number of epochs
# for epoch in range(epochs):
# # Shuffle training data
# np.random.shuffle(self.training_data)
# # Process in batches
# for i in range(0, len(self.training_data), self.batch_size):
# batch = self.training_data[i:i+self.batch_size]
# # Skip if batch is too small
# if len(batch) < 2:
# continue
# # Prepare batch - ensure all tensors are on the correct device
# features = torch.stack([sample[0].to(self.device) for sample in batch])
# actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
# rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)
# # Zero gradients
# optimizer.zero_grad()
# # Forward pass
# q_values, _, _, _, _ = self.model(features)
# # Calculate loss (CrossEntropyLoss with reward weighting)
# # First, apply softmax to get probabilities
# probs = torch.softmax(q_values, dim=1)
# # Get probability of chosen action
# chosen_probs = probs[torch.arange(len(actions)), actions]
# # Calculate negative log likelihood loss
# nll_loss = -torch.log(chosen_probs + 1e-10)
# # Weight by reward (higher reward = higher weight)
# # Normalize rewards to [0, 1] range
# min_reward = rewards.min()
# max_reward = rewards.max()
# if max_reward > min_reward:
# normalized_rewards = (rewards - min_reward) / (max_reward - min_reward)
# else:
# normalized_rewards = torch.ones_like(rewards)
# # Apply reward weighting (higher reward = higher weight)
# weighted_loss = nll_loss * (normalized_rewards + 0.1) # Add small constant to avoid zero weights
# # Mean loss
# loss = weighted_loss.mean()
# # Backward pass
# loss.backward()
# # Update weights
# optimizer.step()
# # Update metrics
# total_loss += loss.item()
# # Calculate accuracy
# predicted_actions = torch.argmax(q_values, dim=1)
# correct_predictions += (predicted_actions == actions).sum().item()
# total_predictions += len(actions)
# # Validate training - detect overfitting
# if total_predictions > 0:
# current_accuracy = correct_predictions / total_predictions
# if current_accuracy >= 0.99:
# logger.warning(f"CNN training shows suspiciously high accuracy: {current_accuracy:.4f} - possible overfitting")
# # Add regularization to prevent overfitting
# l2_reg = 0.01 * sum(p.pow(2.0).sum() for p in self.model.parameters())
# loss = loss + l2_reg
# logger.info("Added L2 regularization to prevent overfitting")
# # Calculate final metrics
# avg_loss = total_loss / (len(self.training_data) / self.batch_size)
# accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
# # Calculate training duration
# training_end_time = datetime.now()
# training_duration = (training_end_time.timestamp() - training_start) * 1000 # Convert to milliseconds
# # Update training metrics
# self.last_training_time = training_start_time
# self.last_training_duration = training_duration
# self.last_training_loss = avg_loss
# self.training_count += 1
# # Save checkpoint
# self._save_checkpoint(avg_loss, accuracy)
# logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}, duration={training_duration:.1f}ms")
# return {
# 'loss': avg_loss,
# 'accuracy': accuracy,
# 'samples': len(self.training_data),
# 'duration_ms': training_duration,
# 'training_count': self.training_count
# }
# except Exception as e:
# logger.error(f"Error training model: {e}")
# return {'loss': 0.0, 'accuracy': 0.0, 'samples': 0, 'error': str(e)}
# def _save_checkpoint(self, loss: float, accuracy: float):
# """
# Save model checkpoint
# Args:
# loss: Training loss
# accuracy: Training accuracy
# """
# try:
# # Import checkpoint manager
# from utils.checkpoint_manager import CheckpointManager
# # Create checkpoint manager
# checkpoint_manager = CheckpointManager(
# checkpoint_dir=self.checkpoint_dir,
# max_checkpoints=10,
# metric_name="accuracy"
# )
# # Create temporary model file
# temp_path = os.path.join(self.checkpoint_dir, f"{self.model_name}_temp")
# self.model.save(temp_path)
# # Create metrics
# metrics = {
# 'loss': loss,
# 'accuracy': accuracy,
# 'samples': len(self.training_data)
# }
# # Create metadata
# metadata = {
# 'timestamp': datetime.now().isoformat(),
# 'model_name': self.model_name,
# 'input_shape': self.model.input_shape,
# 'n_actions': self.model.n_actions
# }
# # Save checkpoint
# checkpoint_path = checkpoint_manager.save_checkpoint(
# model_name=self.model_name,
# model_path=f"{temp_path}.pt",
# metrics=metrics,
# metadata=metadata
# )
# # Delete temporary model file
# if os.path.exists(f"{temp_path}.pt"):
# os.remove(f"{temp_path}.pt")
# logger.info(f"Model checkpoint saved to {checkpoint_path}")
# except Exception as e:
# logger.error(f"Error saving checkpoint: {e}")
# def _load_training_data_from_inference_history(self):
# """Load training data from inference history for continuous learning"""
# try:
# from utils.database_manager import get_database_manager
# db_manager = get_database_manager()
# # Get recent inference records with input features
# inference_records = db_manager.get_inference_records_for_training(
# model_name=self.model_name,
# hours_back=24, # Last 24 hours
# limit=1000
# )
# if not inference_records:
# logger.debug("No inference records found for training")
# return
# # Convert inference records to training samples
# # For now, use a simple approach: treat high-confidence predictions as ground truth
# for record in inference_records:
# if record.input_features is not None and record.confidence > 0.7:
# # Convert action to index
# actions = ['BUY', 'SELL', 'HOLD']
# if record.action in actions:
# action_idx = actions.index(record.action)
# # Use confidence as a proxy for reward (high confidence = good prediction)
# reward = record.confidence * 2 - 1 # Scale to [-1, 1]
# # Convert features to tensor
# features_tensor = torch.tensor(record.input_features, dtype=torch.float32, device=self.device)
# # Add to training data if not already present (avoid duplicates)
# sample_exists = any(
# torch.equal(features_tensor, existing[0])
# for existing in self.training_data
# )
# if not sample_exists:
# self.training_data.append((features_tensor, action_idx, reward))
# logger.info(f"Loaded {len(inference_records)} inference records for training, total training samples: {len(self.training_data)}")
# except Exception as e:
# logger.error(f"Error loading training data from inference history: {e}")
# def evaluate_predictions_against_outcomes(self, hours_back: int = 1) -> Dict[str, float]:
# """
# Evaluate past predictions against actual market outcomes
# Args:
# hours_back: How many hours back to evaluate
# Returns:
# Dict with evaluation metrics
# """
# try:
# from utils.database_manager import get_database_manager
# db_manager = get_database_manager()
# # Get inference records from the specified time period
# inference_records = db_manager.get_inference_records_for_training(
# model_name=self.model_name,
# hours_back=hours_back,
# limit=100
# )
# if not inference_records:
# return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}
# # For now, use a simple evaluation based on confidence
# # In a real implementation, this would compare against actual price movements
# correct_predictions = 0
# total_predictions = len(inference_records)
# # Simple heuristic: high confidence predictions are more likely to be correct
# for record in inference_records:
# if record.confidence > 0.8: # High confidence threshold
# correct_predictions += 1
# elif record.confidence > 0.6: # Medium confidence
# correct_predictions += 0.5
# accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
# logger.info(f"Prediction evaluation: {correct_predictions:.1f}/{total_predictions} = {accuracy:.3f} accuracy")
# return {
# 'accuracy': accuracy,
# 'total_predictions': total_predictions,
# 'correct_predictions': correct_predictions
# }
# except Exception as e:
# logger.error(f"Error evaluating predictions: {e}")
# return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}

View File

@ -1,464 +0,0 @@
"""
Enhanced Trading Orchestrator
Central coordination hub for the multi-modal trading system that manages:
- Data subscription and management
- Model inference coordination
- Cross-model data feeding
- Training pipeline orchestration
- Decision making using Mixture of Experts
"""
import asyncio
import logging
import numpy as np
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from core.data_provider import DataProvider
from core.trading_action import TradingAction
from utils.tensorboard_logger import TensorBoardLogger
logger = logging.getLogger(__name__)
@dataclass
class ModelOutput:
"""Extensible model output format supporting all model types"""
model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator'
model_name: str # Specific model identifier
symbol: str
timestamp: datetime
confidence: float
predictions: Dict[str, Any] # Model-specific predictions
hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding
metadata: Dict[str, Any] = field(default_factory=dict) # Additional info
@dataclass
class BaseDataInput:
"""Unified base data input for all models"""
symbol: str
timestamp: datetime
ohlcv_data: Dict[str, Any] = field(default_factory=dict) # Multi-timeframe OHLCV
cob_data: Optional[Dict[str, Any]] = None # COB buckets for 1s timeframe
technical_indicators: Dict[str, float] = field(default_factory=dict)
pivot_points: List[Any] = field(default_factory=list)
last_predictions: Dict[str, ModelOutput] = field(default_factory=dict) # From all models
market_microstructure: Dict[str, Any] = field(default_factory=dict) # Order flow, etc.
@dataclass
class COBData:
"""Cumulative Order Book data for price buckets"""
symbol: str
timestamp: datetime
current_price: float
bucket_size: float # $1 for ETH, $10 for BTC
price_buckets: Dict[float, Dict[str, float]] = field(default_factory=dict) # price -> {bid_volume, ask_volume, etc.}
bid_ask_imbalance: Dict[float, float] = field(default_factory=dict) # price -> imbalance ratio
volume_weighted_prices: Dict[float, float] = field(default_factory=dict) # price -> VWAP within bucket
order_flow_metrics: Dict[str, float] = field(default_factory=dict) # Various order flow indicators
class EnhancedTradingOrchestrator:
"""
Enhanced Trading Orchestrator implementing the design specification
Coordinates data flow, model inference, and decision making for the multi-modal trading system.
"""
def __init__(self, data_provider: DataProvider, symbols: List[str], enhanced_rl_training: bool = False, model_registry: Dict = None):
"""Initialize the enhanced orchestrator"""
self.data_provider = data_provider
self.symbols = symbols
self.enhanced_rl_training = enhanced_rl_training
self.model_registry = model_registry or {}
# Data management
self.data_buffers = {symbol: {} for symbol in symbols}
self.last_update_times = {symbol: {} for symbol in symbols}
# Model output storage
self.model_outputs = {symbol: {} for symbol in symbols}
self.model_output_history = {symbol: {} for symbol in symbols}
# Training pipeline
self.training_data = {symbol: [] for symbol in symbols}
self.tensorboard_logger = TensorBoardLogger("runs", f"orchestrator_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
# COB integration
self.cob_data = {symbol: None for symbol in symbols}
# Performance tracking
self.performance_metrics = {
'inference_count': 0,
'successful_states': 0,
'total_episodes': 0
}
logger.info("Enhanced Trading Orchestrator initialized")
async def start_cob_integration(self):
"""Start COB data integration for real-time market microstructure"""
try:
# Subscribe to COB data updates
self.data_provider.subscribe_to_cob_data(self._on_cob_data_update)
logger.info("COB integration started")
except Exception as e:
logger.error(f"Error starting COB integration: {e}")
async def start_realtime_processing(self):
"""Start real-time data processing"""
try:
# Subscribe to tick data for real-time processing
for symbol in self.symbols:
self.data_provider.subscribe_to_ticks(
callback=self._on_tick_data,
symbols=[symbol],
subscriber_name=f"orchestrator_{symbol}"
)
logger.info("Real-time processing started")
except Exception as e:
logger.error(f"Error starting real-time processing: {e}")
def _on_cob_data_update(self, symbol: str, cob_data: dict):
"""Handle COB data updates"""
try:
# Process and store COB data
self.cob_data[symbol] = self._process_cob_data(symbol, cob_data)
logger.debug(f"COB data updated for {symbol}")
except Exception as e:
logger.error(f"Error processing COB data for {symbol}: {e}")
def _process_cob_data(self, symbol: str, cob_data: dict) -> COBData:
"""Process raw COB data into structured format"""
try:
# Determine bucket size based on symbol
bucket_size = 1.0 if 'ETH' in symbol else 10.0
# Extract current price
stats = cob_data.get('stats', {})
current_price = stats.get('mid_price', 0)
# Create COB data structure
cob = COBData(
symbol=symbol,
timestamp=datetime.now(),
current_price=current_price,
bucket_size=bucket_size
)
# Process order book data into price buckets
bids = cob_data.get('bids', [])
asks = cob_data.get('asks', [])
# Create price buckets around current price
bucket_count = 20 # ±20 buckets
for i in range(-bucket_count, bucket_count + 1):
bucket_price = current_price + (i * bucket_size)
cob.price_buckets[bucket_price] = {
'bid_volume': 0.0,
'ask_volume': 0.0
}
# Aggregate bid volumes into buckets
for price, volume in bids:
bucket_price = round(price / bucket_size) * bucket_size
if bucket_price in cob.price_buckets:
cob.price_buckets[bucket_price]['bid_volume'] += volume
# Aggregate ask volumes into buckets
for price, volume in asks:
bucket_price = round(price / bucket_size) * bucket_size
if bucket_price in cob.price_buckets:
cob.price_buckets[bucket_price]['ask_volume'] += volume
# Calculate bid/ask imbalances
for price, volumes in cob.price_buckets.items():
bid_vol = volumes['bid_volume']
ask_vol = volumes['ask_volume']
total_vol = bid_vol + ask_vol
if total_vol > 0:
cob.bid_ask_imbalance[price] = (bid_vol - ask_vol) / total_vol
else:
cob.bid_ask_imbalance[price] = 0.0
# Calculate volume-weighted prices
for price, volumes in cob.price_buckets.items():
bid_vol = volumes['bid_volume']
ask_vol = volumes['ask_volume']
total_vol = bid_vol + ask_vol
if total_vol > 0:
cob.volume_weighted_prices[price] = (
(price * bid_vol) + (price * ask_vol)
) / total_vol
else:
cob.volume_weighted_prices[price] = price
# Calculate order flow metrics
cob.order_flow_metrics = {
'total_bid_volume': sum(v['bid_volume'] for v in cob.price_buckets.values()),
'total_ask_volume': sum(v['ask_volume'] for v in cob.price_buckets.values()),
'bid_ask_ratio': 0.0 if cob.order_flow_metrics['total_ask_volume'] == 0 else
cob.order_flow_metrics['total_bid_volume'] / cob.order_flow_metrics['total_ask_volume']
}
return cob
except Exception as e:
logger.error(f"Error processing COB data for {symbol}: {e}")
return COBData(symbol=symbol, timestamp=datetime.now(), current_price=0, bucket_size=bucket_size)
def _on_tick_data(self, tick):
"""Handle incoming tick data"""
try:
# Update data buffers
symbol = tick.symbol
if symbol not in self.data_buffers:
self.data_buffers[symbol] = {}
# Store tick data
if 'ticks' not in self.data_buffers[symbol]:
self.data_buffers[symbol]['ticks'] = []
self.data_buffers[symbol]['ticks'].append(tick)
# Keep only last 1000 ticks
if len(self.data_buffers[symbol]['ticks']) > 1000:
self.data_buffers[symbol]['ticks'] = self.data_buffers[symbol]['ticks'][-1000:]
# Update last update time
self.last_update_times[symbol]['tick'] = datetime.now()
logger.debug(f"Tick data updated for {symbol}")
except Exception as e:
logger.error(f"Error processing tick data: {e}")
def build_comprehensive_rl_state(self, symbol: str) -> Optional[np.ndarray]:
"""
Build comprehensive RL state with 13,400 features as specified
Returns:
np.ndarray: State vector with 13,400 features
"""
try:
# Initialize state vector
state_size = 13400
state = np.zeros(state_size, dtype=np.float32)
# Get latest data
ohlcv_data = self.data_provider.get_latest_candles(symbol, '1s', limit=100)
cob_data = self.cob_data.get(symbol)
# Feature index tracking
idx = 0
# 1. OHLCV features (4000 features)
if ohlcv_data is not None and not ohlcv_data.empty:
# Use last 100 1s candles (40 features each: O,H,L,C,V + 36 indicators)
for i in range(min(100, len(ohlcv_data))):
if idx + 40 <= state_size:
row = ohlcv_data.iloc[-(i+1)]
state[idx] = row.get('open', 0) / 100000 # Normalized
state[idx+1] = row.get('high', 0) / 100000
state[idx+2] = row.get('low', 0) / 100000
state[idx+3] = row.get('close', 0) / 100000
state[idx+4] = row.get('volume', 0) / 1000000
# Add technical indicators if available
indicator_idx = 5
for col in ['sma_10', 'sma_20', 'ema_12', 'ema_26', 'rsi_14',
'macd', 'bb_upper', 'bb_lower', 'atr', 'adx']:
if col in row and idx + indicator_idx < state_size:
state[idx + indicator_idx] = row[col] / 100000
indicator_idx += 1
idx += 40
# 2. COB features (8000 features)
if cob_data and idx + 8000 <= state_size:
# Use 200 price buckets (40 features each)
bucket_prices = sorted(cob_data.price_buckets.keys())
for i, price in enumerate(bucket_prices[:200]):
if idx + 40 <= state_size:
bucket = cob_data.price_buckets[price]
state[idx] = bucket.get('bid_volume', 0) / 1000000 # Normalized
state[idx+1] = bucket.get('ask_volume', 0) / 1000000
state[idx+2] = cob_data.bid_ask_imbalance.get(price, 0)
state[idx+3] = cob_data.volume_weighted_prices.get(price, price) / 100000
# Additional COB metrics
state[idx+4] = cob_data.order_flow_metrics.get('total_bid_volume', 0) / 10000000
state[idx+5] = cob_data.order_flow_metrics.get('total_ask_volume', 0) / 10000000
state[idx+6] = cob_data.order_flow_metrics.get('bid_ask_ratio', 0)
idx += 40
# 3. Technical indicator features (1000 features)
# Already included in OHLCV section above
# 4. Market microstructure features (400 features)
if cob_data and idx + 400 <= state_size:
# Add order flow metrics
metrics = list(cob_data.order_flow_metrics.values())
for i, metric in enumerate(metrics[:400]):
if idx + i < state_size:
state[idx + i] = metric
# Log state building success
self.performance_metrics['successful_states'] += 1
logger.debug(f"Comprehensive RL state built for {symbol}: {len(state)} features")
# Log to TensorBoard
self.tensorboard_logger.log_state_metrics(
symbol=symbol,
state_info={
'size': len(state),
'quality': 1.0,
'feature_counts': {
'total': len(state),
'non_zero': np.count_nonzero(state)
}
},
step=self.performance_metrics['successful_states']
)
return state
except Exception as e:
logger.error(f"Error building comprehensive RL state for {symbol}: {e}")
return None
def calculate_enhanced_pivot_reward(self, trade_decision: Dict, market_data: Dict, trade_outcome: Dict) -> float:
"""
Calculate enhanced pivot-based reward
Args:
trade_decision: Trading decision with action and confidence
market_data: Market context data
trade_outcome: Actual trade results
Returns:
float: Enhanced reward value
"""
try:
# Base reward from PnL
pnl_reward = trade_outcome.get('net_pnl', 0) / 100 # Normalize
# Confidence weighting
confidence = trade_decision.get('confidence', 0.5)
confidence_reward = confidence * 0.2
# Volatility adjustment
volatility = market_data.get('volatility', 0.01)
volatility_reward = (1.0 - volatility * 10) * 0.1 # Prefer low volatility
# Order flow alignment
order_flow = market_data.get('order_flow_strength', 0)
order_flow_reward = order_flow * 0.2
# Pivot alignment bonus (if near pivot in favorable direction)
pivot_bonus = 0.0
if market_data.get('near_pivot', False):
action = trade_decision.get('action', '').upper()
pivot_type = market_data.get('pivot_type', '').upper()
# Bonus for buying near support or selling near resistance
if (action == 'BUY' and pivot_type == 'LOW') or \
(action == 'SELL' and pivot_type == 'HIGH'):
pivot_bonus = 0.5
# Calculate final reward
enhanced_reward = pnl_reward + confidence_reward + volatility_reward + order_flow_reward + pivot_bonus
# Log to TensorBoard
self.tensorboard_logger.log_scalars('Rewards/Components', {
'pnl_component': pnl_reward,
'confidence': confidence_reward,
'volatility': volatility_reward,
'order_flow': order_flow_reward,
'pivot_bonus': pivot_bonus
}, self.performance_metrics['total_episodes'])
self.tensorboard_logger.log_scalar('Rewards/Enhanced', enhanced_reward, self.performance_metrics['total_episodes'])
logger.debug(f"Enhanced reward calculated: {enhanced_reward}")
return enhanced_reward
except Exception as e:
logger.error(f"Error calculating enhanced pivot reward: {e}")
return 0.0
async def make_coordinated_decisions(self) -> Dict[str, TradingAction]:
"""
Make coordinated trading decisions using all available models
Returns:
Dict[str, TradingAction]: Trading actions for each symbol
"""
try:
decisions = {}
# For each symbol, coordinate model inference
for symbol in self.symbols:
# Build comprehensive state for RL model
rl_state = self.build_comprehensive_rl_state(symbol)
if rl_state is not None:
# Store state for training
self.performance_metrics['total_episodes'] += 1
# Create mock RL decision (in a real implementation, this would call the RL model)
action = 'BUY' if np.mean(rl_state[:100]) > 0.5 else 'SELL'
confidence = min(1.0, max(0.0, np.std(rl_state) * 10))
# Create trading action
decisions[symbol] = TradingAction(
symbol=symbol,
timestamp=datetime.now(),
action=action,
confidence=confidence,
source='rl_orchestrator'
)
logger.info(f"Coordinated decision for {symbol}: {action} (confidence: {confidence:.3f})")
else:
logger.warning(f"Failed to build state for {symbol}, skipping decision")
self.performance_metrics['inference_count'] += 1
return decisions
except Exception as e:
logger.error(f"Error making coordinated decisions: {e}")
return {}
def _get_symbol_correlation(self, symbol1: str, symbol2: str) -> float:
"""
Calculate correlation between two symbols
Args:
symbol1: First symbol
symbol2: Second symbol
Returns:
float: Correlation coefficient (-1 to 1)
"""
try:
# Get recent price data for both symbols
data1 = self.data_provider.get_latest_candles(symbol1, '1m', limit=50)
data2 = self.data_provider.get_latest_candles(symbol2, '1m', limit=50)
if data1 is None or data2 is None or data1.empty or data2.empty:
return 0.0
# Align data by timestamp
merged = data1[['close']].join(data2[['close']], lsuffix='_1', rsuffix='_2', how='inner')
if len(merged) < 10:
return 0.0
# Calculate correlation
correlation = merged['close_1'].corr(merged['close_2'])
return correlation if not np.isnan(correlation) else 0.0
except Exception as e:
logger.error(f"Error calculating symbol correlation: {e}")
return 0.0
```

View File

@ -1,775 +0,0 @@
"""
Enhanced Training Integration Module
This module provides comprehensive integration between the training data collection system,
CNN training pipeline, RL training pipeline, and your existing infrastructure.
Key Features:
- Real-time integration with existing DataProvider
- Coordinated training across CNN and RL models
- Automatic outcome validation and profitability tracking
- Integration with existing COB RL model
- Performance monitoring and optimization
- Seamless connection to existing orchestrator and trading executor
"""
import asyncio
import logging
import numpy as np
import pandas as pd
import torch
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any, Callable
from dataclasses import dataclass
import threading
import time
from pathlib import Path
# Import existing components
from .data_provider import DataProvider
from .orchestrator import Orchestrator
from .trading_executor import TradingExecutor
# Import our training system components
from .training_data_collector import (
TrainingDataCollector,
get_training_data_collector
)
from .cnn_training_pipeline import (
CNNPivotPredictor,
CNNTrainer,
get_cnn_trainer
)
from .rl_training_pipeline import (
RLTradingAgent,
RLTrainer,
get_rl_trainer
)
from .training_integration import TrainingIntegration
# Import existing RL model
try:
from NN.models.cob_rl_model import COBRLModelInterface
except ImportError:
logger.warning("Could not import COBRLModelInterface - using fallback")
COBRLModelInterface = None
logger = logging.getLogger(__name__)
@dataclass
class EnhancedTrainingConfig:
"""Enhanced configuration for comprehensive training integration"""
# Data collection
collection_interval: float = 1.0
min_data_completeness: float = 0.8
# Training triggers
min_episodes_for_cnn_training: int = 100
min_experiences_for_rl_training: int = 200
training_frequency_minutes: int = 30
# Profitability thresholds
min_profitability_for_replay: float = 0.1
high_profitability_threshold: float = 0.5
# Model integration
use_existing_cob_rl_model: bool = True
enable_cross_model_learning: bool = True
# Performance optimization
max_concurrent_training_sessions: int = 2
enable_background_validation: bool = True
class EnhancedTrainingIntegration:
"""Enhanced training integration with existing infrastructure"""
def __init__(self,
data_provider: DataProvider,
orchestrator: Orchestrator = None,
trading_executor: TradingExecutor = None,
config: EnhancedTrainingConfig = None):
self.data_provider = data_provider
self.orchestrator = orchestrator
self.trading_executor = trading_executor
self.config = config or EnhancedTrainingConfig()
# Initialize training components
self.data_collector = get_training_data_collector()
# Initialize CNN components
self.cnn_model = CNNPivotPredictor()
self.cnn_trainer = get_cnn_trainer(self.cnn_model)
# Initialize RL components
if self.config.use_existing_cob_rl_model and COBRLModelInterface:
self.existing_rl_model = COBRLModelInterface()
logger.info("Using existing COB RL model")
else:
self.existing_rl_model = None
self.rl_agent = RLTradingAgent()
self.rl_trainer = get_rl_trainer(self.rl_agent)
# Integration state
self.is_running = False
self.training_threads = {}
self.validation_thread = None
# Performance tracking
self.integration_stats = {
'total_data_packages': 0,
'cnn_training_sessions': 0,
'rl_training_sessions': 0,
'profitable_predictions': 0,
'total_predictions': 0,
'cross_model_improvements': 0,
'last_update': datetime.now()
}
# Model prediction tracking
self.recent_predictions = {}
self.prediction_outcomes = {}
# Cross-model learning
self.model_performance_history = {
'cnn': [],
'rl': [],
'orchestrator': []
}
logger.info("Enhanced Training Integration initialized")
logger.info(f"CNN model parameters: {sum(p.numel() for p in self.cnn_model.parameters()):,}")
logger.info(f"RL agent parameters: {sum(p.numel() for p in self.rl_agent.parameters()):,}")
logger.info(f"Using existing COB RL model: {self.existing_rl_model is not None}")
def start_enhanced_integration(self):
"""Start the enhanced training integration system"""
if self.is_running:
logger.warning("Enhanced training integration already running")
return
self.is_running = True
# Start data collection
self.data_collector.start_collection()
# Start CNN training
if self.config.min_episodes_for_cnn_training > 0:
for symbol in self.data_provider.symbols:
self.cnn_trainer.start_real_time_training(symbol)
# Start coordinated training thread
self.training_threads['coordinator'] = threading.Thread(
target=self._training_coordinator_worker,
daemon=True
)
self.training_threads['coordinator'].start()
# Start data collection and validation
self.training_threads['data_collector'] = threading.Thread(
target=self._enhanced_data_collection_worker,
daemon=True
)
self.training_threads['data_collector'].start()
# Start outcome validation if enabled
if self.config.enable_background_validation:
self.validation_thread = threading.Thread(
target=self._outcome_validation_worker,
daemon=True
)
self.validation_thread.start()
logger.info("Enhanced training integration started")
def stop_enhanced_integration(self):
"""Stop the enhanced training integration system"""
self.is_running = False
# Stop data collection
self.data_collector.stop_collection()
# Stop CNN training
self.cnn_trainer.stop_training()
# Wait for threads to finish
for thread_name, thread in self.training_threads.items():
thread.join(timeout=10)
logger.info(f"Stopped {thread_name} thread")
if self.validation_thread:
self.validation_thread.join(timeout=5)
logger.info("Enhanced training integration stopped")
def _enhanced_data_collection_worker(self):
"""Enhanced data collection with real-time model integration"""
logger.info("Enhanced data collection worker started")
while self.is_running:
try:
for symbol in self.data_provider.symbols:
self._collect_enhanced_training_data(symbol)
time.sleep(self.config.collection_interval)
except Exception as e:
logger.error(f"Error in enhanced data collection: {e}")
time.sleep(5)
logger.info("Enhanced data collection worker stopped")
def _collect_enhanced_training_data(self, symbol: str):
"""Collect enhanced training data with model predictions"""
try:
# Get comprehensive market data
market_data = self._get_comprehensive_market_data(symbol)
if not market_data or not self._validate_market_data(market_data):
return
# Get current model predictions
model_predictions = self._get_all_model_predictions(symbol, market_data)
# Create enhanced features
cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
rl_state = self._create_enhanced_rl_state(symbol, market_data, model_predictions)
# Collect training data with predictions
episode_id = self.data_collector.collect_training_data(
symbol=symbol,
ohlcv_data=market_data['ohlcv'],
tick_data=market_data['ticks'],
cob_data=market_data['cob'],
technical_indicators=market_data['indicators'],
pivot_points=market_data['pivots'],
cnn_features=cnn_features,
rl_state=rl_state,
orchestrator_context=market_data['context'],
model_predictions=model_predictions
)
if episode_id:
# Store predictions for outcome validation
self.recent_predictions[episode_id] = {
'timestamp': datetime.now(),
'symbol': symbol,
'predictions': model_predictions,
'market_data': market_data
}
# Add RL experience if we have action
if 'rl_action' in model_predictions:
self._add_rl_experience(symbol, market_data, model_predictions, episode_id)
self.integration_stats['total_data_packages'] += 1
except Exception as e:
logger.error(f"Error collecting enhanced training data for {symbol}: {e}")
def _get_comprehensive_market_data(self, symbol: str) -> Dict[str, Any]:
"""Get comprehensive market data from all sources"""
try:
market_data = {}
# OHLCV data
ohlcv_data = {}
for timeframe in ['1s', '1m', '5m', '15m', '1h', '1d']:
df = self.data_provider.get_historical_data(symbol, timeframe, limit=300, refresh=True)
if df is not None and not df.empty:
ohlcv_data[timeframe] = df
market_data['ohlcv'] = ohlcv_data
# Tick data
market_data['ticks'] = self._get_recent_tick_data(symbol)
# COB data
market_data['cob'] = self._get_cob_data(symbol)
# Technical indicators
market_data['indicators'] = self._get_technical_indicators(symbol)
# Pivot points
market_data['pivots'] = self._get_pivot_points(symbol)
# Market context
market_data['context'] = self._get_market_context(symbol)
return market_data
except Exception as e:
logger.error(f"Error getting comprehensive market data: {e}")
return {}
def _get_all_model_predictions(self, symbol: str, market_data: Dict[str, Any]) -> Dict[str, Any]:
"""Get predictions from all available models"""
predictions = {}
try:
# CNN predictions
if self.cnn_model and market_data.get('ohlcv'):
cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
if cnn_features is not None:
cnn_input = torch.from_numpy(cnn_features).float().unsqueeze(0)
# Reshape for CNN (add channel dimension)
cnn_input = cnn_input.view(1, 10, -1) # Assuming 10 channels
with torch.no_grad():
cnn_outputs = self.cnn_model(cnn_input)
predictions['cnn'] = {
'pivot_logits': cnn_outputs['pivot_logits'].cpu().numpy(),
'pivot_price': cnn_outputs['pivot_price'].cpu().numpy(),
'confidence': cnn_outputs['confidence'].cpu().numpy(),
'timestamp': datetime.now()
}
# RL predictions
if self.rl_agent and market_data.get('cob'):
rl_state = self._create_enhanced_rl_state(symbol, market_data, predictions)
if rl_state is not None:
action, confidence = self.rl_agent.select_action(rl_state, epsilon=0.1)
predictions['rl'] = {
'action': action,
'confidence': confidence,
'timestamp': datetime.now()
}
predictions['rl_action'] = action
# Existing COB RL model predictions
if self.existing_rl_model and market_data.get('cob'):
cob_features = market_data['cob'].get('cob_features', [])
if cob_features and len(cob_features) >= 2000:
cob_array = np.array(cob_features[:2000], dtype=np.float32)
cob_prediction = self.existing_rl_model.predict(cob_array)
predictions['cob_rl'] = {
'predicted_direction': cob_prediction.get('predicted_direction', 1),
'confidence': cob_prediction.get('confidence', 0.5),
'value': cob_prediction.get('value', 0.0),
'timestamp': datetime.now()
}
# Orchestrator predictions (if available)
if self.orchestrator:
try:
# This would integrate with your orchestrator's prediction method
orchestrator_prediction = self._get_orchestrator_prediction(symbol, market_data, predictions)
if orchestrator_prediction:
predictions['orchestrator'] = orchestrator_prediction
except Exception as e:
logger.debug(f"Could not get orchestrator prediction: {e}")
return predictions
except Exception as e:
logger.error(f"Error getting model predictions: {e}")
return {}
def _add_rl_experience(self, symbol: str, market_data: Dict[str, Any],
predictions: Dict[str, Any], episode_id: str):
"""Add RL experience to the training buffer"""
try:
# Create RL state
state = self._create_enhanced_rl_state(symbol, market_data, predictions)
if state is None:
return
# Get action from predictions
action = predictions.get('rl_action', 1) # Default to HOLD
# Calculate immediate reward (placeholder - would be updated with actual outcome)
reward = 0.0
# Create next state (same as current for now - would be updated)
next_state = state.copy()
# Market context
market_context = {
'symbol': symbol,
'episode_id': episode_id,
'timestamp': datetime.now(),
'market_session': market_data['context'].get('market_session', 'unknown'),
'volatility_regime': market_data['context'].get('volatility_regime', 'unknown')
}
# Add experience
experience_id = self.rl_trainer.add_experience(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=False,
market_context=market_context,
cnn_predictions=predictions.get('cnn'),
confidence_score=predictions.get('rl', {}).get('confidence', 0.0)
)
if experience_id:
logger.debug(f"Added RL experience: {experience_id}")
except Exception as e:
logger.error(f"Error adding RL experience: {e}")
def _training_coordinator_worker(self):
"""Coordinate training across all models"""
logger.info("Training coordinator worker started")
while self.is_running:
try:
# Check if we should trigger training
for symbol in self.data_provider.symbols:
self._check_and_trigger_training(symbol)
# Wait before next check
time.sleep(self.config.training_frequency_minutes * 60)
except Exception as e:
logger.error(f"Error in training coordinator: {e}")
time.sleep(60)
logger.info("Training coordinator worker stopped")
def _check_and_trigger_training(self, symbol: str):
"""Check conditions and trigger training if needed"""
try:
# Get training episodes and experiences
episodes = self.data_collector.get_high_priority_episodes(symbol, limit=1000)
# Check CNN training conditions
if len(episodes) >= self.config.min_episodes_for_cnn_training:
profitable_episodes = [ep for ep in episodes if ep.actual_outcome.is_profitable]
if len(profitable_episodes) >= 20: # Minimum profitable episodes
logger.info(f"Triggering CNN training for {symbol} with {len(profitable_episodes)} profitable episodes")
results = self.cnn_trainer.train_on_profitable_episodes(
symbol=symbol,
min_profitability=self.config.min_profitability_for_replay,
max_episodes=len(profitable_episodes)
)
if results.get('status') == 'success':
self.integration_stats['cnn_training_sessions'] += 1
logger.info(f"CNN training completed for {symbol}")
# Check RL training conditions
buffer_stats = self.rl_trainer.experience_buffer.get_buffer_statistics()
total_experiences = buffer_stats.get('total_experiences', 0)
if total_experiences >= self.config.min_experiences_for_rl_training:
profitable_experiences = buffer_stats.get('profitable_experiences', 0)
if profitable_experiences >= 50: # Minimum profitable experiences
logger.info(f"Triggering RL training with {profitable_experiences} profitable experiences")
results = self.rl_trainer.train_on_profitable_experiences(
min_profitability=self.config.min_profitability_for_replay,
max_experiences=min(profitable_experiences, 500),
batch_size=32
)
if results.get('status') == 'success':
self.integration_stats['rl_training_sessions'] += 1
logger.info("RL training completed")
except Exception as e:
logger.error(f"Error checking training conditions for {symbol}: {e}")
def _outcome_validation_worker(self):
"""Background worker for validating prediction outcomes"""
logger.info("Outcome validation worker started")
while self.is_running:
try:
self._validate_recent_predictions()
time.sleep(300) # Check every 5 minutes
except Exception as e:
logger.error(f"Error in outcome validation: {e}")
time.sleep(60)
logger.info("Outcome validation worker stopped")
def _validate_recent_predictions(self):
"""Validate recent predictions against actual outcomes"""
try:
current_time = datetime.now()
validation_delay = timedelta(hours=1) # Wait 1 hour to validate
validated_predictions = []
for episode_id, prediction_data in self.recent_predictions.items():
prediction_time = prediction_data['timestamp']
if current_time - prediction_time >= validation_delay:
# Validate this prediction
outcome = self._calculate_prediction_outcome(prediction_data)
if outcome:
self.prediction_outcomes[episode_id] = outcome
# Update RL experience if exists
if 'rl_action' in prediction_data['predictions']:
self._update_rl_experience_outcome(episode_id, outcome)
# Update statistics
if outcome['is_profitable']:
self.integration_stats['profitable_predictions'] += 1
self.integration_stats['total_predictions'] += 1
validated_predictions.append(episode_id)
# Remove validated predictions
for episode_id in validated_predictions:
del self.recent_predictions[episode_id]
if validated_predictions:
logger.info(f"Validated {len(validated_predictions)} predictions")
except Exception as e:
logger.error(f"Error validating predictions: {e}")
def _calculate_prediction_outcome(self, prediction_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Calculate actual outcome for a prediction"""
try:
symbol = prediction_data['symbol']
prediction_time = prediction_data['timestamp']
# Get price data after prediction
current_df = self.data_provider.get_historical_data(symbol, '1m', limit=100, refresh=True)
if current_df is None or current_df.empty:
return None
# Find price at prediction time and current price
prediction_price = prediction_data['market_data']['ohlcv'].get('1m', pd.DataFrame())
if prediction_price.empty:
return None
base_price = float(prediction_price['close'].iloc[-1])
current_price = float(current_df['close'].iloc[-1])
# Calculate outcome
price_change = (current_price - base_price) / base_price
is_profitable = abs(price_change) > 0.005 # 0.5% threshold
return {
'episode_id': prediction_data.get('episode_id'),
'base_price': base_price,
'current_price': current_price,
'price_change': price_change,
'is_profitable': is_profitable,
'profitability_score': abs(price_change) * 10, # Scale to 0-1 range
'validation_time': datetime.now()
}
except Exception as e:
logger.error(f"Error calculating prediction outcome: {e}")
return None
def _update_rl_experience_outcome(self, episode_id: str, outcome: Dict[str, Any]):
"""Update RL experience with actual outcome"""
try:
# Find the experience ID associated with this episode
# This is a simplified approach - in practice you'd maintain better mapping
actual_profit = outcome['price_change']
# Determine optimal action based on outcome
if outcome['price_change'] > 0.01:
optimal_action = 2 # BUY
elif outcome['price_change'] < -0.01:
optimal_action = 0 # SELL
else:
optimal_action = 1 # HOLD
# Update experience (this would need proper experience ID mapping)
# For now, we'll update the most recent experience
# In practice, you'd maintain a mapping between episodes and experiences
except Exception as e:
logger.error(f"Error updating RL experience outcome: {e}")
def get_integration_statistics(self) -> Dict[str, Any]:
"""Get comprehensive integration statistics"""
stats = self.integration_stats.copy()
# Add component statistics
stats['data_collector'] = self.data_collector.get_collection_statistics()
stats['cnn_trainer'] = self.cnn_trainer.get_training_statistics()
stats['rl_trainer'] = self.rl_trainer.get_training_statistics()
# Add performance metrics
stats['is_running'] = self.is_running
stats['active_symbols'] = len(self.data_provider.symbols)
stats['recent_predictions_count'] = len(self.recent_predictions)
stats['validated_outcomes_count'] = len(self.prediction_outcomes)
# Calculate profitability rate
if stats['total_predictions'] > 0:
stats['overall_profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions']
else:
stats['overall_profitability_rate'] = 0.0
return stats
def trigger_manual_training(self, training_type: str = 'all', symbol: str = None) -> Dict[str, Any]:
"""Manually trigger training"""
results = {}
try:
if training_type in ['all', 'cnn']:
symbols = [symbol] if symbol else self.data_provider.symbols
for sym in symbols:
cnn_results = self.cnn_trainer.train_on_profitable_episodes(
symbol=sym,
min_profitability=0.1,
max_episodes=200
)
results[f'cnn_{sym}'] = cnn_results
if training_type in ['all', 'rl']:
rl_results = self.rl_trainer.train_on_profitable_experiences(
min_profitability=0.1,
max_experiences=500,
batch_size=32
)
results['rl'] = rl_results
return {'status': 'success', 'results': results}
except Exception as e:
logger.error(f"Error in manual training trigger: {e}")
return {'status': 'error', 'error': str(e)}
# Helper methods (simplified implementations)
def _get_recent_tick_data(self, symbol: str) -> List[Dict[str, Any]]:
"""Get recent tick data"""
# Implementation would get tick data from data provider
return []
def _get_cob_data(self, symbol: str) -> Dict[str, Any]:
"""Get COB data"""
# Implementation would get COB data from data provider
return {}
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
"""Get technical indicators"""
# Implementation would get indicators from data provider
return {}
def _get_pivot_points(self, symbol: str) -> List[Dict[str, Any]]:
"""Get pivot points"""
# Implementation would get pivot points from data provider
return []
def _get_market_context(self, symbol: str) -> Dict[str, Any]:
"""Get market context"""
return {
'symbol': symbol,
'timestamp': datetime.now(),
'market_session': 'unknown',
'volatility_regime': 'unknown'
}
def _validate_market_data(self, market_data: Dict[str, Any]) -> bool:
"""Validate market data completeness"""
required_fields = ['ohlcv', 'indicators']
return all(field in market_data for field in required_fields)
def _create_enhanced_cnn_features(self, symbol: str, market_data: Dict[str, Any]) -> Optional[np.ndarray]:
"""Create enhanced CNN features"""
try:
# Simplified feature creation
features = []
# Add OHLCV features
for timeframe in ['1m', '5m', '15m', '1h']:
if timeframe in market_data.get('ohlcv', {}):
df = market_data['ohlcv'][timeframe]
if not df.empty:
ohlcv_values = df[['open', 'high', 'low', 'close', 'volume']].values
if len(ohlcv_values) > 0:
recent_values = ohlcv_values[-60:].flatten()
features.extend(recent_values)
# Pad to target size
target_size = 3000 # 10 channels * 300 sequence length
if len(features) < target_size:
features.extend([0.0] * (target_size - len(features)))
else:
features = features[:target_size]
return np.array(features, dtype=np.float32)
except Exception as e:
logger.warning(f"Error creating CNN features: {e}")
return None
def _create_enhanced_rl_state(self, symbol: str, market_data: Dict[str, Any],
predictions: Dict[str, Any] = None) -> Optional[np.ndarray]:
"""Create enhanced RL state"""
try:
state_features = []
# Add market features
if '1m' in market_data.get('ohlcv', {}):
df = market_data['ohlcv']['1m']
if not df.empty:
latest = df.iloc[-1]
state_features.extend([
latest['open'], latest['high'],
latest['low'], latest['close'], latest['volume']
])
# Add technical indicators
indicators = market_data.get('indicators', {})
for value in indicators.values():
state_features.append(value)
# Add model predictions as features
if predictions:
if 'cnn' in predictions:
cnn_pred = predictions['cnn']
state_features.extend(cnn_pred.get('pivot_logits', [0, 0, 0]))
state_features.append(cnn_pred.get('confidence', [0.0])[0])
if 'cob_rl' in predictions:
cob_pred = predictions['cob_rl']
state_features.append(cob_pred.get('predicted_direction', 1))
state_features.append(cob_pred.get('confidence', 0.5))
# Pad to target size
target_size = 2000
if len(state_features) < target_size:
state_features.extend([0.0] * (target_size - len(state_features)))
else:
state_features = state_features[:target_size]
return np.array(state_features, dtype=np.float32)
except Exception as e:
logger.warning(f"Error creating RL state: {e}")
return None
def _get_orchestrator_prediction(self, symbol: str, market_data: Dict[str, Any],
predictions: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Get orchestrator prediction"""
# This would integrate with your orchestrator
return None
# Global instance
enhanced_training_integration = None
def get_enhanced_training_integration(data_provider: DataProvider = None,
orchestrator: Orchestrator = None,
trading_executor: TradingExecutor = None) -> EnhancedTrainingIntegration:
"""Get global enhanced training integration instance"""
global enhanced_training_integration
if enhanced_training_integration is None:
if data_provider is None:
raise ValueError("DataProvider required for first initialization")
enhanced_training_integration = EnhancedTrainingIntegration(
data_provider, orchestrator, trading_executor
)
return enhanced_training_integration

View File

@ -1,8 +0,0 @@
# MEXC Web Client Module
#
# This module provides web-based trading capabilities for MEXC futures trading
# which is not supported by their official API.
from .mexc_futures_client import MEXCFuturesWebClient
__all__ = ['MEXCFuturesWebClient']

View File

@ -1,555 +0,0 @@
#!/usr/bin/env python3
"""
MEXC Auto Browser with Request Interception
This script automatically spawns a ChromeDriver instance and captures
all MEXC futures trading requests in real-time, including full request
and response data needed for reverse engineering.
"""
import logging
import time
import json
import sys
import os
from typing import Dict, List, Optional, Any
from datetime import datetime
import threading
import queue
# Selenium imports
try:
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.common.exceptions import TimeoutException, WebDriverException
from webdriver_manager.chrome import ChromeDriverManager
except ImportError:
print("Please install selenium and webdriver-manager:")
print("pip install selenium webdriver-manager")
sys.exit(1)
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class MEXCRequestInterceptor:
"""
Automatically spawns ChromeDriver and intercepts all MEXC API requests
"""
def __init__(self, headless: bool = False, save_to_file: bool = True):
"""
Initialize the request interceptor
Args:
headless: Run browser in headless mode
save_to_file: Save captured requests to JSON file
"""
self.driver = None
self.headless = headless
self.save_to_file = save_to_file
self.captured_requests = []
self.captured_responses = []
self.session_cookies = {}
self.monitoring = False
self.request_queue = queue.Queue()
# File paths for saving data
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.requests_file = f"mexc_requests_{self.timestamp}.json"
self.cookies_file = f"mexc_cookies_{self.timestamp}.json"
def setup_browser(self):
"""Setup Chrome browser with necessary options"""
chrome_options = webdriver.ChromeOptions()
# Enable headless mode if needed
if self.headless:
chrome_options.add_argument('--headless')
chrome_options.add_argument('--disable-gpu')
chrome_options.add_argument('--window-size=1920,1080')
chrome_options.add_argument('--disable-extensions')
# Set up Chrome options with a user data directory to persist session
user_data_base_dir = os.path.join(os.getcwd(), 'chrome_user_data')
os.makedirs(user_data_base_dir, exist_ok=True)
# Check for existing session directories
session_dirs = [d for d in os.listdir(user_data_base_dir) if d.startswith('session_')]
session_dirs.sort(reverse=True) # Sort descending to get the most recent first
user_data_dir = None
if session_dirs:
use_existing = input(f"Found {len(session_dirs)} existing sessions. Use an existing session? (y/n): ").lower().strip() == 'y'
if use_existing:
print("Available sessions:")
for i, session in enumerate(session_dirs[:5], 1): # Show up to 5 most recent
print(f"{i}. {session}")
choice = input("Enter session number (default 1) or any other key for most recent: ")
if choice.isdigit() and 1 <= int(choice) <= len(session_dirs):
selected_session = session_dirs[int(choice) - 1]
else:
selected_session = session_dirs[0]
user_data_dir = os.path.join(user_data_base_dir, selected_session)
print(f"Using session: {selected_session}")
if user_data_dir is None:
user_data_dir = os.path.join(user_data_base_dir, f'session_{self.timestamp}')
os.makedirs(user_data_dir, exist_ok=True)
print(f"Creating new session: session_{self.timestamp}")
chrome_options.add_argument(f'--user-data-dir={user_data_dir}')
# Enable logging to capture JS console output and network activity
chrome_options.set_capability('goog:loggingPrefs', {
'browser': 'ALL',
'performance': 'ALL'
})
try:
self.driver = webdriver.Chrome(options=chrome_options)
except Exception as e:
print(f"Failed to start browser with session: {e}")
print("Falling back to a new session...")
user_data_dir = os.path.join(user_data_base_dir, f'session_{self.timestamp}_fallback')
os.makedirs(user_data_dir, exist_ok=True)
print(f"Creating fallback session: session_{self.timestamp}_fallback")
chrome_options = webdriver.ChromeOptions()
if self.headless:
chrome_options.add_argument('--headless')
chrome_options.add_argument('--disable-gpu')
chrome_options.add_argument('--window-size=1920,1080')
chrome_options.add_argument('--disable-extensions')
chrome_options.add_argument(f'--user-data-dir={user_data_dir}')
chrome_options.set_capability('goog:loggingPrefs', {
'browser': 'ALL',
'performance': 'ALL'
})
self.driver = webdriver.Chrome(options=chrome_options)
return self.driver
def start_monitoring(self):
"""Start the browser and begin monitoring"""
logger.info("Starting MEXC Request Interceptor...")
try:
# Setup ChromeDriver
self.driver = self.setup_browser()
# Navigate to MEXC futures
mexc_url = "https://www.mexc.com/en-GB/futures/ETH_USDT?type=linear_swap"
logger.info(f"Navigating to: {mexc_url}")
self.driver.get(mexc_url)
# Wait for page load
WebDriverWait(self.driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
logger.info("✅ MEXC page loaded successfully!")
logger.info("📝 Please log in manually in the browser window")
logger.info("🔍 Request monitoring is now active...")
# Start monitoring in background thread
self.monitoring = True
monitor_thread = threading.Thread(target=self._monitor_requests, daemon=True)
monitor_thread.start()
# Wait for manual login
self._wait_for_login()
return True
except Exception as e:
logger.error(f"Failed to start monitoring: {e}")
return False
def _wait_for_login(self):
"""Wait for user to log in and show interactive menu"""
logger.info("\n" + "="*60)
logger.info("MEXC REQUEST INTERCEPTOR - INTERACTIVE MODE")
logger.info("="*60)
while True:
print("\nOptions:")
print("1. Check login status")
print("2. Extract current cookies")
print("3. Show captured requests summary")
print("4. Save captured data to files")
print("5. Perform test trade (manual)")
print("6. Monitor for 60 seconds")
print("0. Stop and exit")
choice = input("\nEnter choice (0-6): ").strip()
if choice == "1":
self._check_login_status()
elif choice == "2":
self._extract_cookies()
elif choice == "3":
self._show_requests_summary()
elif choice == "4":
self._save_all_data()
elif choice == "5":
self._guide_test_trade()
elif choice == "6":
self._monitor_for_duration(60)
elif choice == "0":
break
else:
print("Invalid choice. Please try again.")
self.stop_monitoring()
def _check_login_status(self):
"""Check if user is logged into MEXC"""
try:
cookies = self.driver.get_cookies()
auth_cookies = ['uc_token', 'u_id', 'x-mxc-fingerprint']
found_auth = []
for cookie in cookies:
if cookie['name'] in auth_cookies and cookie['value']:
found_auth.append(cookie['name'])
if len(found_auth) >= 2:
print("✅ LOGIN DETECTED - You appear to be logged in!")
print(f" Found auth cookies: {', '.join(found_auth)}")
return True
else:
print("❌ NOT LOGGED IN - Please log in to MEXC in the browser")
print(" Missing required authentication cookies")
return False
except Exception as e:
print(f"❌ Error checking login: {e}")
return False
def _extract_cookies(self):
"""Extract and display current session cookies"""
try:
cookies = self.driver.get_cookies()
cookie_dict = {}
for cookie in cookies:
cookie_dict[cookie['name']] = cookie['value']
self.session_cookies = cookie_dict
print(f"\n📊 Extracted {len(cookie_dict)} cookies:")
# Show important cookies
important = ['uc_token', 'u_id', 'x-mxc-fingerprint', 'mexc_fingerprint_visitorId']
for name in important:
if name in cookie_dict:
value = cookie_dict[name]
display_value = value[:20] + "..." if len(value) > 20 else value
print(f"{name}: {display_value}")
else:
print(f"{name}: Missing")
# Save cookies to file
if self.save_to_file:
with open(self.cookies_file, 'w') as f:
json.dump(cookie_dict, f, indent=2)
print(f"\n💾 Cookies saved to: {self.cookies_file}")
except Exception as e:
print(f"❌ Error extracting cookies: {e}")
def _monitor_requests(self):
"""Background thread to monitor network requests"""
last_log_count = 0
while self.monitoring:
try:
# Get performance logs
logs = self.driver.get_log('performance')
for log in logs:
try:
message = json.loads(log['message'])
method = message.get('message', {}).get('method', '')
# Capture network requests
if method == 'Network.requestWillBeSent':
self._process_request(message['message']['params'])
elif method == 'Network.responseReceived':
self._process_response(message['message']['params'])
except (json.JSONDecodeError, KeyError) as e:
continue
# Show progress every 10 new requests
if len(self.captured_requests) >= last_log_count + 10:
last_log_count = len(self.captured_requests)
logger.info(f"📈 Captured {len(self.captured_requests)} requests, {len(self.captured_responses)} responses")
except Exception as e:
if self.monitoring: # Only log if we're still supposed to be monitoring
logger.debug(f"Monitor error: {e}")
time.sleep(0.5) # Check every 500ms
def _process_request(self, request_data):
"""Process a captured network request"""
try:
url = request_data.get('request', {}).get('url', '')
# Filter for MEXC API requests
if self._is_mexc_request(url):
request_info = {
'type': 'request',
'timestamp': datetime.now().isoformat(),
'url': url,
'method': request_data.get('request', {}).get('method', ''),
'headers': request_data.get('request', {}).get('headers', {}),
'postData': request_data.get('request', {}).get('postData', ''),
'requestId': request_data.get('requestId', '')
}
self.captured_requests.append(request_info)
# Show important requests immediately
if ('futures.mexc.com' in url or 'captcha' in url):
print(f"\n🚀 CAPTURED REQUEST: {request_info['method']} {url}")
if request_info['postData']:
print(f" 📄 POST Data: {request_info['postData'][:100]}...")
# Enhanced captcha detection and detailed logging
if 'captcha' in url.lower() or 'robot' in url.lower():
logger.info(f"CAPTCHA REQUEST DETECTED: {request_data.get('request', {}).get('method', 'UNKNOWN')} {url}")
logger.info(f" Headers: {request_data.get('request', {}).get('headers', {})}")
if request_data.get('request', {}).get('postData', ''):
logger.info(f" Data: {request_data.get('request', {}).get('postData', '')}")
# Attempt to capture related JavaScript or DOM elements (if possible)
if self.driver is not None:
try:
js_snippet = self.driver.execute_script("return document.querySelector('script[src*=\"captcha\"]') ? document.querySelector('script[src*=\"captcha\"]').outerHTML : 'No captcha script found';")
logger.info(f" Related JS Snippet: {js_snippet}")
except Exception as e:
logger.warning(f" Could not capture JS snippet: {e}")
try:
dom_element = self.driver.execute_script("return document.querySelector('div[id*=\"captcha\"]') ? document.querySelector('div[id*=\"captcha\"]').outerHTML : 'No captcha element found';")
logger.info(f" Related DOM Element: {dom_element}")
except Exception as e:
logger.warning(f" Could not capture DOM element: {e}")
else:
logger.warning(" Driver not initialized, cannot capture JS or DOM elements")
except Exception as e:
logger.debug(f"Error processing request: {e}")
def _process_response(self, response_data):
"""Process a captured network response"""
try:
url = response_data.get('response', {}).get('url', '')
# Filter for MEXC API responses
if self._is_mexc_request(url):
response_info = {
'type': 'response',
'timestamp': datetime.now().isoformat(),
'url': url,
'status': response_data.get('response', {}).get('status', 0),
'headers': response_data.get('response', {}).get('headers', {}),
'requestId': response_data.get('requestId', '')
}
self.captured_responses.append(response_info)
# Show important responses immediately
if ('futures.mexc.com' in url or 'captcha' in url):
status = response_info['status']
status_emoji = "" if status == 200 else ""
print(f" {status_emoji} RESPONSE: {status} for {url}")
except Exception as e:
logger.debug(f"Error processing response: {e}")
def _is_mexc_request(self, url: str) -> bool:
"""Check if URL is a relevant MEXC API request"""
mexc_indicators = [
'futures.mexc.com',
'ucgateway/captcha_api',
'api/v1/private',
'api/v3/order',
'mexc.com/api'
]
return any(indicator in url for indicator in mexc_indicators)
def _show_requests_summary(self):
"""Show summary of captured requests"""
print(f"\n📊 CAPTURE SUMMARY:")
print(f" Total Requests: {len(self.captured_requests)}")
print(f" Total Responses: {len(self.captured_responses)}")
# Group by URL pattern
url_counts = {}
for req in self.captured_requests:
base_url = req['url'].split('?')[0] # Remove query params
url_counts[base_url] = url_counts.get(base_url, 0) + 1
print("\n🔗 Top URLs:")
for url, count in sorted(url_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
print(f" {count}x {url}")
# Show recent futures API calls
futures_requests = [r for r in self.captured_requests if 'futures.mexc.com' in r['url']]
if futures_requests:
print(f"\n🚀 Futures API Calls: {len(futures_requests)}")
for req in futures_requests[-3:]: # Show last 3
print(f" {req['method']} {req['url']}")
def _save_all_data(self):
"""Save all captured data to files"""
if not self.save_to_file:
print("File saving is disabled")
return
try:
# Save requests
with open(self.requests_file, 'w') as f:
json.dump({
'requests': self.captured_requests,
'responses': self.captured_responses,
'summary': {
'total_requests': len(self.captured_requests),
'total_responses': len(self.captured_responses),
'capture_session': self.timestamp
}
}, f, indent=2)
# Save cookies if we have them
if self.session_cookies:
with open(self.cookies_file, 'w') as f:
json.dump(self.session_cookies, f, indent=2)
print(f"\n💾 Data saved to:")
print(f" 📋 Requests: {self.requests_file}")
if self.session_cookies:
print(f" 🍪 Cookies: {self.cookies_file}")
# Extract and save CAPTCHA tokens from captured requests
captcha_tokens = self.extract_captcha_tokens()
if captcha_tokens:
captcha_file = f"mexc_captcha_tokens_{self.timestamp}.json"
with open(captcha_file, 'w') as f:
json.dump(captcha_tokens, f, indent=2)
logger.info(f"Saved CAPTCHA tokens to {captcha_file}")
else:
logger.warning("No CAPTCHA tokens found in captured requests")
except Exception as e:
print(f"❌ Error saving data: {e}")
def _guide_test_trade(self):
"""Guide user through performing a test trade"""
print("\n🧪 TEST TRADE GUIDE:")
print("1. Make sure you're logged into MEXC")
print("2. Go to the trading interface")
print("3. Try to place a SMALL test trade (it may fail, but we'll capture the requests)")
print("4. Watch the console for captured API calls")
print("\n⚠️ IMPORTANT: Use very small amounts for testing!")
input("\nPress Enter when you're ready to start monitoring...")
self._monitor_for_duration(120) # Monitor for 2 minutes
def _monitor_for_duration(self, seconds: int):
"""Monitor requests for a specific duration"""
print(f"\n🔍 Monitoring requests for {seconds} seconds...")
print("Perform your trading actions now!")
start_time = time.time()
initial_count = len(self.captured_requests)
while time.time() - start_time < seconds:
current_count = len(self.captured_requests)
new_requests = current_count - initial_count
remaining = seconds - int(time.time() - start_time)
print(f"\r⏱️ Time remaining: {remaining}s | New requests: {new_requests}", end="", flush=True)
time.sleep(1)
final_count = len(self.captured_requests)
new_total = final_count - initial_count
print(f"\n✅ Monitoring complete! Captured {new_total} new requests")
def stop_monitoring(self):
"""Stop monitoring and close browser"""
logger.info("Stopping request monitoring...")
self.monitoring = False
if self.driver:
self.driver.quit()
logger.info("Browser closed")
# Final save
if self.save_to_file and (self.captured_requests or self.captured_responses):
self._save_all_data()
logger.info("Final data save complete")
def extract_captcha_tokens(self):
"""Extract CAPTCHA tokens from captured requests"""
captcha_tokens = []
for request in self.captured_requests:
if 'captcha-token' in request.get('headers', {}):
token = request['headers']['captcha-token']
captcha_tokens.append({
'token': token,
'url': request.get('url', ''),
'timestamp': request.get('timestamp', '')
})
elif 'captcha' in request.get('url', '').lower():
response = request.get('response', {})
if response and 'captcha-token' in response.get('headers', {}):
token = response['headers']['captcha-token']
captcha_tokens.append({
'token': token,
'url': request.get('url', ''),
'timestamp': request.get('timestamp', '')
})
return captcha_tokens
def main():
"""Main function to run the interceptor"""
print("🚀 MEXC Request Interceptor with ChromeDriver")
print("=" * 50)
print("This will automatically:")
print("✅ Download/setup ChromeDriver")
print("✅ Open MEXC futures page")
print("✅ Capture all API requests/responses")
print("✅ Extract session cookies")
print("✅ Save data to JSON files")
print("\nPress Ctrl+C to stop at any time")
# Ask for preferences
headless = input("\nRun in headless mode? (y/n): ").lower().strip() == 'y'
interceptor = MEXCRequestInterceptor(headless=headless, save_to_file=True)
try:
success = interceptor.start_monitoring()
if not success:
print("❌ Failed to start monitoring")
return
except KeyboardInterrupt:
print("\n\n⏹️ Stopping interceptor...")
except Exception as e:
print(f"\n❌ Error: {e}")
finally:
interceptor.stop_monitoring()
print("\n👋 Goodbye!")
if __name__ == "__main__":
main()

View File

@ -1,358 +0,0 @@
"""
MEXC Browser Automation for Cookie Extraction and Request Monitoring
This module uses Selenium to automate browser interactions and extract
session cookies and request data for MEXC futures trading.
"""
import logging
import time
import json
from typing import Dict, List, Optional, Any
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.common.exceptions import TimeoutException, WebDriverException
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager
logger = logging.getLogger(__name__)
class MEXCBrowserAutomation:
"""
Browser automation for MEXC futures trading session management
"""
def __init__(self, headless: bool = False, proxy: Optional[str] = None):
"""
Initialize browser automation
Args:
headless: Run browser in headless mode
proxy: HTTP proxy to use (format: host:port)
"""
self.driver = None
self.headless = headless
self.proxy = proxy
self.logged_in = False
def setup_chrome_driver(self) -> webdriver.Chrome:
"""Setup Chrome driver with appropriate options"""
chrome_options = Options()
if self.headless:
chrome_options.add_argument("--headless")
# Basic Chrome options for automation
chrome_options.add_argument("--no-sandbox")
chrome_options.add_argument("--disable-dev-shm-usage")
chrome_options.add_argument("--disable-blink-features=AutomationControlled")
chrome_options.add_experimental_option("excludeSwitches", ["enable-automation"])
chrome_options.add_experimental_option('useAutomationExtension', False)
# Set user agent to avoid detection
chrome_options.add_argument("--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36")
# Proxy setup if provided
if self.proxy:
chrome_options.add_argument(f"--proxy-server=http://{self.proxy}")
# Enable network logging
chrome_options.add_argument("--enable-logging")
chrome_options.add_argument("--log-level=0")
chrome_options.set_capability("goog:loggingPrefs", {"performance": "ALL"})
# Automatically download and setup ChromeDriver
service = Service(ChromeDriverManager().install())
try:
driver = webdriver.Chrome(service=service, options=chrome_options)
# Execute script to avoid detection
driver.execute_script("Object.defineProperty(navigator, 'webdriver', {get: () => undefined})")
return driver
except WebDriverException as e:
logger.error(f"Failed to setup Chrome driver: {e}")
raise
def start_browser(self):
"""Start the browser session"""
if self.driver is None:
logger.info("Starting Chrome browser for MEXC automation")
self.driver = self.setup_chrome_driver()
logger.info("Browser started successfully")
def stop_browser(self):
"""Stop the browser session"""
if self.driver:
logger.info("Stopping browser")
self.driver.quit()
self.driver = None
def navigate_to_mexc_futures(self, symbol: str = "ETH_USDT"):
"""
Navigate to MEXC futures trading page
Args:
symbol: Trading symbol to navigate to
"""
if not self.driver:
self.start_browser()
url = f"https://www.mexc.com/en-GB/futures/{symbol}?type=linear_swap"
logger.info(f"Navigating to MEXC futures: {url}")
self.driver.get(url)
# Wait for page to load
try:
WebDriverWait(self.driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
logger.info("MEXC futures page loaded")
except TimeoutException:
logger.error("Timeout waiting for MEXC page to load")
def wait_for_login(self, timeout: int = 300) -> bool:
"""
Wait for user to manually log in to MEXC
Args:
timeout: Maximum time to wait for login (seconds)
Returns:
bool: True if login detected, False if timeout
"""
logger.info("Please log in to MEXC manually in the browser window")
logger.info("Waiting for login completion...")
start_time = time.time()
while time.time() - start_time < timeout:
# Check if we can find elements that indicate logged in state
try:
# Look for user-specific elements that appear after login
cookies = self.driver.get_cookies()
# Check for authentication cookies
auth_cookies = ['uc_token', 'u_id']
logged_in_indicators = 0
for cookie in cookies:
if cookie['name'] in auth_cookies and cookie['value']:
logged_in_indicators += 1
if logged_in_indicators >= 2:
logger.info("Login detected!")
self.logged_in = True
return True
except Exception as e:
logger.debug(f"Error checking login status: {e}")
time.sleep(2) # Check every 2 seconds
logger.error(f"Login timeout after {timeout} seconds")
return False
def extract_session_cookies(self) -> Dict[str, str]:
"""
Extract all cookies from current browser session
Returns:
Dictionary of cookie name-value pairs
"""
if not self.driver:
logger.error("Browser not started")
return {}
cookies = {}
try:
browser_cookies = self.driver.get_cookies()
for cookie in browser_cookies:
cookies[cookie['name']] = cookie['value']
logger.info(f"Extracted {len(cookies)} cookies from browser session")
# Log important cookies (without values for security)
important_cookies = ['uc_token', 'u_id', 'x-mxc-fingerprint', 'mexc_fingerprint_visitorId']
for cookie_name in important_cookies:
if cookie_name in cookies:
logger.info(f"Found important cookie: {cookie_name}")
else:
logger.warning(f"Missing important cookie: {cookie_name}")
return cookies
except Exception as e:
logger.error(f"Failed to extract cookies: {e}")
return {}
def monitor_network_requests(self, duration: int = 60) -> List[Dict[str, Any]]:
"""
Monitor network requests for the specified duration
Args:
duration: How long to monitor requests (seconds)
Returns:
List of captured network requests
"""
if not self.driver:
logger.error("Browser not started")
return []
logger.info(f"Starting network monitoring for {duration} seconds")
logger.info("Please perform trading actions in the browser (open/close positions)")
start_time = time.time()
captured_requests = []
while time.time() - start_time < duration:
try:
# Get performance logs (network requests)
logs = self.driver.get_log('performance')
for log in logs:
message = json.loads(log['message'])
# Filter for relevant MEXC API requests
if (message.get('message', {}).get('method') == 'Network.responseReceived'):
response = message['message']['params']['response']
url = response.get('url', '')
# Look for futures API calls
if ('futures.mexc.com' in url or
'ucgateway/captcha_api' in url or
'api/v1/private' in url):
request_data = {
'url': url,
'method': response.get('mimeType', ''),
'status': response.get('status'),
'headers': response.get('headers', {}),
'timestamp': log['timestamp']
}
captured_requests.append(request_data)
logger.info(f"Captured request: {url}")
except Exception as e:
logger.debug(f"Error in network monitoring: {e}")
time.sleep(1)
logger.info(f"Network monitoring complete. Captured {len(captured_requests)} requests")
return captured_requests
def perform_test_trade(self, symbol: str = "ETH_USDT", volume: float = 1.0, leverage: int = 200):
"""
Attempt to perform a test trade to capture the complete request flow
Args:
symbol: Trading symbol
volume: Position size
leverage: Leverage multiplier
"""
if not self.logged_in:
logger.error("Not logged in - cannot perform test trade")
return
logger.info(f"Attempting test trade: {symbol}, Volume: {volume}, Leverage: {leverage}x")
logger.info("This will attempt to click trading interface elements")
try:
# This would need to be implemented based on MEXC's specific UI elements
# For now, just wait and let user perform manual actions
logger.info("Please manually place a small test trade while monitoring is active")
time.sleep(30)
except Exception as e:
logger.error(f"Error during test trade: {e}")
def full_session_capture(self, symbol: str = "ETH_USDT") -> Dict[str, Any]:
"""
Complete session capture workflow
Args:
symbol: Trading symbol to use
Returns:
Dictionary containing cookies and captured requests
"""
logger.info("Starting full MEXC session capture")
try:
# Start browser and navigate to MEXC
self.navigate_to_mexc_futures(symbol)
# Wait for manual login
if not self.wait_for_login():
return {'success': False, 'error': 'Login timeout'}
# Extract session cookies
cookies = self.extract_session_cookies()
if not cookies:
return {'success': False, 'error': 'Failed to extract cookies'}
# Monitor network requests while user performs actions
logger.info("Starting network monitoring - please perform trading actions now")
requests = self.monitor_network_requests(duration=120) # 2 minutes
return {
'success': True,
'cookies': cookies,
'network_requests': requests,
'timestamp': int(time.time())
}
except Exception as e:
logger.error(f"Error in session capture: {e}")
return {'success': False, 'error': str(e)}
finally:
self.stop_browser()
def main():
"""Main function for standalone execution"""
logging.basicConfig(level=logging.INFO)
print("MEXC Browser Automation - Session Capture")
print("This will open a browser window for you to log into MEXC")
print("Make sure you have Chrome browser installed")
automation = MEXCBrowserAutomation(headless=False)
try:
result = automation.full_session_capture()
if result['success']:
print(f"\nSession capture successful!")
print(f"Extracted {len(result['cookies'])} cookies")
print(f"Captured {len(result['network_requests'])} network requests")
# Save results to file
output_file = f"mexc_session_capture_{int(time.time())}.json"
with open(output_file, 'w') as f:
json.dump(result, f, indent=2)
print(f"Results saved to: {output_file}")
else:
print(f"Session capture failed: {result['error']}")
except KeyboardInterrupt:
print("\nSession capture interrupted by user")
except Exception as e:
print(f"Error: {e}")
finally:
automation.stop_browser()
if __name__ == "__main__":
main()

View File

@ -1,525 +0,0 @@
"""
MEXC Futures Web Client
This module implements a web-based client for MEXC futures trading
since their official API doesn't support futures (leverage) trading.
It mimics browser behavior by replicating the exact HTTP requests
that the web interface makes.
"""
import logging
import requests
import time
import json
import hmac
import hashlib
import base64
from typing import Dict, List, Optional, Any
from datetime import datetime
import uuid
from urllib.parse import urlencode
import glob
import os
logger = logging.getLogger(__name__)
class MEXCSessionManager:
def __init__(self):
self.captcha_token = None
def get_captcha_token(self) -> str:
return self.captcha_token if self.captcha_token else ""
def save_captcha_token(self, token: str):
self.captcha_token = token
logger.info("MEXC: Captcha token saved in session manager")
class MEXCFuturesWebClient:
"""
MEXC Futures Web Client that mimics browser behavior for futures trading.
Since MEXC's official API doesn't support futures, this client replicates
the exact HTTP requests made by their web interface.
"""
def __init__(self, api_key: str, api_secret: str, user_id: str, base_url: str = 'https://www.mexc.com', headless: bool = True):
"""
Initialize the MEXC Futures Web Client
Args:
api_key: API key for authentication
api_secret: API secret for authentication
user_id: User ID for authentication
base_url: Base URL for the MEXC website
headless: Whether to run the browser in headless mode
"""
self.api_key = api_key
self.api_secret = api_secret
self.user_id = user_id
self.base_url = base_url
self.is_authenticated = False
self.headless = headless
self.session = requests.Session()
self.session_manager = MEXCSessionManager() # Adding session_manager attribute
self.captcha_url = f'{base_url}/ucgateway/captcha_api'
self.futures_api_url = "https://futures.mexc.com/api/v1"
# Setup default headers that mimic a real browser
self.setup_browser_headers()
def setup_browser_headers(self):
"""Setup default headers that mimic Chrome browser"""
self.session.headers.update({
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36',
'Accept': '*/*',
'Accept-Language': 'en-GB,en-US;q=0.9,en;q=0.8',
'Accept-Encoding': 'gzip, deflate, br',
'sec-ch-ua': '"Chromium";v="136", "Google Chrome";v="136", "Not.A/Brand";v="99"',
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"Windows"',
'sec-fetch-dest': 'empty',
'sec-fetch-mode': 'cors',
'sec-fetch-site': 'same-origin',
'Cache-Control': 'no-cache',
'Pragma': 'no-cache',
'Referer': f'{self.base_url}/en-GB/futures/ETH_USDT?type=linear_swap',
'Language': 'English',
'X-Language': 'en-GB',
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
'trochilus-uid': str(self.user_id) if self.user_id is not None else ''
})
def load_session_cookies(self, cookies: Dict[str, str]):
"""
Load session cookies from browser
Args:
cookies: Dictionary of cookie name-value pairs
"""
for name, value in cookies.items():
self.session.cookies.set(name, value)
# Extract important session info from cookies
self.auth_token = cookies.get('uc_token')
self.user_id = cookies.get('u_id')
self.fingerprint = cookies.get('x-mxc-fingerprint')
self.visitor_id = cookies.get('mexc_fingerprint_visitorId')
if self.auth_token and self.user_id:
self.is_authenticated = True
logger.info("MEXC: Loaded authenticated session")
else:
logger.warning("MEXC: Session cookies incomplete - authentication may fail")
def extract_cookies_from_browser(self, cookie_string: str) -> Dict[str, str]:
"""
Extract cookies from a browser cookie string
Args:
cookie_string: Raw cookie string from browser (copy from Network tab)
Returns:
Dictionary of parsed cookies
"""
cookies = {}
cookie_pairs = cookie_string.split(';')
for pair in cookie_pairs:
if '=' in pair:
name, value = pair.strip().split('=', 1)
cookies[name] = value
return cookies
def verify_captcha(self, symbol: str, side: str, leverage: str) -> bool:
"""
Verify captcha for robot trading protection
Args:
symbol: Trading symbol (e.g., 'ETH_USDT')
side: 'openlong', 'closelong', 'openshort', 'closeshort'
leverage: Leverage string (e.g., '200X')
Returns:
bool: True if captcha verification successful
"""
if not self.is_authenticated:
logger.error("MEXC: Cannot verify captcha - not authenticated")
return False
# Build captcha endpoint URL
endpoint = f"robot.future.{side}.{symbol}.{leverage}"
url = f"{self.captcha_url}/{endpoint}"
# Attempt to get captcha token from session manager
captcha_token = self.session_manager.get_captcha_token()
if not captcha_token:
logger.warning("MEXC: No captcha token available, attempting to fetch from browser")
captcha_token = self._extract_captcha_token_from_browser()
if captcha_token:
self.session_manager.save_captcha_token(captcha_token)
else:
logger.error("MEXC: Failed to extract captcha token from browser")
return False
headers = {
'Content-Type': 'application/json',
'Language': 'en-GB',
'Referer': f'{self.base_url}/en-GB/futures/{symbol}?type=linear_swap',
'trochilus-uid': self.user_id if self.user_id else '',
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
'captcha-token': captcha_token
}
logger.info(f"MEXC: Verifying captcha for {endpoint}")
try:
response = self.session.get(url, headers=headers, timeout=10)
if response.status_code == 200:
data = response.json()
if data.get('success'):
logger.info(f"MEXC: Captcha verified successfully for {endpoint}")
return True
else:
logger.error(f"MEXC: Captcha verification failed for {endpoint}: {data}")
return False
else:
logger.error(f"MEXC: Captcha verification request failed with status {response.status_code}: {response.text}")
return False
except Exception as e:
logger.error(f"MEXC: Captcha verification error for {endpoint}: {str(e)}")
return False
def _extract_captcha_token_from_browser(self) -> str:
"""
Extract captcha token from browser session using stored cookies or requests.
This method looks for the most recent mexc_captcha_tokens JSON file to retrieve a token.
"""
try:
# Look for the most recent mexc_captcha_tokens file
captcha_files = glob.glob("mexc_captcha_tokens_*.json")
if not captcha_files:
logger.error("MEXC: No CAPTCHA token files found")
return ""
# Sort files by timestamp (most recent first)
latest_file = max(captcha_files, key=os.path.getctime)
logger.info(f"MEXC: Using CAPTCHA token file {latest_file}")
with open(latest_file, 'r') as f:
captcha_data = json.load(f)
if captcha_data and isinstance(captcha_data, list) and len(captcha_data) > 0:
# Return the most recent token
return captcha_data[0].get('token', '')
else:
logger.error("MEXC: No valid CAPTCHA tokens found in file")
return ""
except Exception as e:
logger.error(f"MEXC: Error extracting captcha token from browser data: {str(e)}")
return ""
def generate_signature(self, method: str, path: str, params: Dict[str, Any],
timestamp: int, nonce: int) -> str:
"""
Generate signature for MEXC futures API requests
This is reverse-engineered from the browser requests
"""
# This is a placeholder - the actual signature generation would need
# to be reverse-engineered from the browser's JavaScript
# For now, return empty string and rely on cookie authentication
return ""
def open_long_position(self, symbol: str, volume: float, leverage: int = 200,
price: Optional[float] = None) -> Dict[str, Any]:
"""
Open a long futures position
Args:
symbol: Trading symbol (e.g., 'ETH_USDT')
volume: Position size (contracts)
leverage: Leverage multiplier (default 200)
price: Limit price (None for market order)
Returns:
dict: Order response with order ID
"""
if not self.is_authenticated:
logger.error("MEXC: Cannot open position - not authenticated")
return {'success': False, 'error': 'Not authenticated'}
# First verify captcha
if not self.verify_captcha(symbol, 'openlong', f'{leverage}X'):
logger.error("MEXC: Captcha verification failed for opening long position")
return {'success': False, 'error': 'Captcha verification failed'}
# Prepare order parameters based on the request dump
timestamp = int(time.time() * 1000)
nonce = timestamp
order_data = {
'symbol': symbol,
'side': 1, # 1 = long, 2 = short
'openType': 2, # Open position
'type': '5', # Market order (might be '1' for limit)
'vol': volume,
'leverage': leverage,
'marketCeiling': False,
'priceProtect': '0',
'ts': timestamp,
'mhash': self._generate_mhash(), # This needs to be implemented
'mtoken': self.visitor_id
}
# Add price for limit orders
if price is not None:
order_data['price'] = price
order_data['type'] = '1' # Limit order
# Add encrypted parameters (these would need proper implementation)
order_data['p0'] = self._encrypt_p0(order_data) # Placeholder
order_data['k0'] = self._encrypt_k0(order_data) # Placeholder
order_data['chash'] = self._generate_chash(order_data) # Placeholder
# Setup headers for the order request
headers = {
'Authorization': self.auth_token,
'Content-Type': 'application/json',
'Language': 'English',
'x-language': 'en-GB',
'x-mxc-nonce': str(nonce),
'x-mxc-sign': self.generate_signature('POST', '/private/order/create', order_data, timestamp, nonce),
'trochilus-uid': self.user_id,
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
'Referer': 'https://www.mexc.com/'
}
# Make the order request
url = f"{self.futures_api_url}/private/order/create"
try:
# First make OPTIONS request (preflight)
options_response = self.session.options(url, headers=headers, timeout=10)
if options_response.status_code == 200:
# Now make the actual POST request
response = self.session.post(url, json=order_data, headers=headers, timeout=15)
if response.status_code == 200:
data = response.json()
if data.get('success') and data.get('code') == 0:
order_id = data.get('data', {}).get('orderId')
logger.info(f"MEXC: Long position opened successfully - Order ID: {order_id}")
return {
'success': True,
'order_id': order_id,
'timestamp': data.get('data', {}).get('ts'),
'symbol': symbol,
'side': 'long',
'volume': volume,
'leverage': leverage
}
else:
logger.error(f"MEXC: Order failed: {data}")
return {'success': False, 'error': data.get('msg', 'Unknown error')}
else:
logger.error(f"MEXC: Order request failed with status {response.status_code}")
return {'success': False, 'error': f'HTTP {response.status_code}'}
else:
logger.error(f"MEXC: OPTIONS preflight failed with status {options_response.status_code}")
return {'success': False, 'error': f'Preflight failed: HTTP {options_response.status_code}'}
except Exception as e:
logger.error(f"MEXC: Order execution error: {e}")
return {'success': False, 'error': str(e)}
def close_long_position(self, symbol: str, volume: float, leverage: int = 200,
price: Optional[float] = None) -> Dict[str, Any]:
"""
Close a long futures position
Args:
symbol: Trading symbol (e.g., 'ETH_USDT')
volume: Position size to close (contracts)
leverage: Leverage multiplier
price: Limit price (None for market order)
Returns:
dict: Order response
"""
if not self.is_authenticated:
logger.error("MEXC: Cannot close position - not authenticated")
return {'success': False, 'error': 'Not authenticated'}
# First verify captcha
if not self.verify_captcha(symbol, 'closelong', f'{leverage}X'):
logger.error("MEXC: Captcha verification failed for closing long position")
return {'success': False, 'error': 'Captcha verification failed'}
# Similar to open_long_position but with closeType instead of openType
timestamp = int(time.time() * 1000)
nonce = timestamp
order_data = {
'symbol': symbol,
'side': 2, # Close side is opposite
'closeType': 1, # Close position
'type': '5', # Market order
'vol': volume,
'leverage': leverage,
'marketCeiling': False,
'priceProtect': '0',
'ts': timestamp,
'mhash': self._generate_mhash(),
'mtoken': self.visitor_id
}
if price is not None:
order_data['price'] = price
order_data['type'] = '1'
order_data['p0'] = self._encrypt_p0(order_data)
order_data['k0'] = self._encrypt_k0(order_data)
order_data['chash'] = self._generate_chash(order_data)
return self._execute_order(order_data, 'close_long')
def open_short_position(self, symbol: str, volume: float, leverage: int = 200,
price: Optional[float] = None) -> Dict[str, Any]:
"""Open a short futures position"""
if not self.verify_captcha(symbol, 'openshort', f'{leverage}X'):
return {'success': False, 'error': 'Captcha verification failed'}
order_data = {
'symbol': symbol,
'side': 2, # 2 = short
'openType': 2,
'type': '5',
'vol': volume,
'leverage': leverage,
'marketCeiling': False,
'priceProtect': '0',
'ts': int(time.time() * 1000),
'mhash': self._generate_mhash(),
'mtoken': self.visitor_id
}
if price is not None:
order_data['price'] = price
order_data['type'] = '1'
order_data['p0'] = self._encrypt_p0(order_data)
order_data['k0'] = self._encrypt_k0(order_data)
order_data['chash'] = self._generate_chash(order_data)
return self._execute_order(order_data, 'open_short')
def close_short_position(self, symbol: str, volume: float, leverage: int = 200,
price: Optional[float] = None) -> Dict[str, Any]:
"""Close a short futures position"""
if not self.verify_captcha(symbol, 'closeshort', f'{leverage}X'):
return {'success': False, 'error': 'Captcha verification failed'}
order_data = {
'symbol': symbol,
'side': 1, # Close side is opposite
'closeType': 1,
'type': '5',
'vol': volume,
'leverage': leverage,
'marketCeiling': False,
'priceProtect': '0',
'ts': int(time.time() * 1000),
'mhash': self._generate_mhash(),
'mtoken': self.visitor_id
}
if price is not None:
order_data['price'] = price
order_data['type'] = '1'
order_data['p0'] = self._encrypt_p0(order_data)
order_data['k0'] = self._encrypt_k0(order_data)
order_data['chash'] = self._generate_chash(order_data)
return self._execute_order(order_data, 'close_short')
def _execute_order(self, order_data: Dict[str, Any], action: str) -> Dict[str, Any]:
"""Common order execution logic"""
timestamp = order_data['ts']
nonce = timestamp
headers = {
'Authorization': self.auth_token,
'Content-Type': 'application/json',
'Language': 'English',
'x-language': 'en-GB',
'x-mxc-nonce': str(nonce),
'x-mxc-sign': self.generate_signature('POST', '/private/order/create', order_data, timestamp, nonce),
'trochilus-uid': self.user_id,
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
'Referer': 'https://www.mexc.com/'
}
url = f"{self.futures_api_url}/private/order/create"
try:
response = self.session.post(url, json=order_data, headers=headers, timeout=15)
if response.status_code == 200:
data = response.json()
if data.get('success') and data.get('code') == 0:
order_id = data.get('data', {}).get('orderId')
logger.info(f"MEXC: {action} executed successfully - Order ID: {order_id}")
return {
'success': True,
'order_id': order_id,
'timestamp': data.get('data', {}).get('ts'),
'action': action
}
else:
logger.error(f"MEXC: {action} failed: {data}")
return {'success': False, 'error': data.get('msg', 'Unknown error')}
else:
logger.error(f"MEXC: {action} request failed with status {response.status_code}")
return {'success': False, 'error': f'HTTP {response.status_code}'}
except Exception as e:
logger.error(f"MEXC: {action} execution error: {e}")
return {'success': False, 'error': str(e)}
# Placeholder methods for encryption/hashing - these need proper implementation
def _generate_mhash(self) -> str:
"""Generate mhash parameter (needs reverse engineering)"""
return "a0015441fd4c3b6ba427b894b76cb7dd" # Placeholder from request dump
def _encrypt_p0(self, order_data: Dict[str, Any]) -> str:
"""Encrypt p0 parameter (needs reverse engineering)"""
return "placeholder_p0_encryption" # This needs proper implementation
def _encrypt_k0(self, order_data: Dict[str, Any]) -> str:
"""Encrypt k0 parameter (needs reverse engineering)"""
return "placeholder_k0_encryption" # This needs proper implementation
def _generate_chash(self, order_data: Dict[str, Any]) -> str:
"""Generate chash parameter (needs reverse engineering)"""
return "d6c64d28e362f314071b3f9d78ff7494d9cd7177ae0465e772d1840e9f7905d8" # Placeholder
def get_account_info(self) -> Dict[str, Any]:
"""Get account information including positions and balances"""
if not self.is_authenticated:
return {'success': False, 'error': 'Not authenticated'}
# This would need to be implemented by reverse engineering the account info endpoints
logger.info("MEXC: Account info endpoint not yet implemented")
return {'success': False, 'error': 'Not implemented'}
def get_open_positions(self) -> List[Dict[str, Any]]:
"""Get list of open futures positions"""
if not self.is_authenticated:
return []
# This would need to be implemented by reverse engineering the positions endpoint
logger.info("MEXC: Open positions endpoint not yet implemented")
return []

View File

@ -1,259 +0,0 @@
"""
MEXC Session Manager
Helper utilities for managing MEXC web sessions and extracting cookies from browser.
"""
import logging
import json
import re
from typing import Dict, Optional, Any
from pathlib import Path
logger = logging.getLogger(__name__)
class MEXCSessionManager:
"""
Helper class for managing MEXC web sessions and extracting browser cookies
"""
def __init__(self):
self.session_file = Path("mexc_session.json")
def extract_cookies_from_network_tab(self, cookie_header: str) -> Dict[str, str]:
"""
Extract cookies from browser Network tab cookie header
Args:
cookie_header: Raw cookie string from browser (copy from Request Headers)
Returns:
Dictionary of parsed cookies
"""
cookies = {}
# Remove 'Cookie: ' prefix if present
if cookie_header.startswith('Cookie: '):
cookie_header = cookie_header[8:]
elif cookie_header.startswith('cookie: '):
cookie_header = cookie_header[8:]
# Split by semicolon and parse each cookie
cookie_pairs = cookie_header.split(';')
for pair in cookie_pairs:
pair = pair.strip()
if '=' in pair:
name, value = pair.split('=', 1)
cookies[name.strip()] = value.strip()
logger.info(f"Extracted {len(cookies)} cookies from browser")
return cookies
def validate_session_cookies(self, cookies: Dict[str, str]) -> bool:
"""
Validate that essential cookies are present for authentication
Args:
cookies: Dictionary of cookie name-value pairs
Returns:
bool: True if cookies appear valid for authentication
"""
required_cookies = [
'uc_token', # User authentication token
'u_id', # User ID
'x-mxc-fingerprint', # Browser fingerprint
'mexc_fingerprint_visitorId' # Visitor ID
]
missing_cookies = []
for cookie_name in required_cookies:
if cookie_name not in cookies or not cookies[cookie_name]:
missing_cookies.append(cookie_name)
if missing_cookies:
logger.warning(f"Missing required cookies: {missing_cookies}")
return False
logger.info("All required cookies are present")
return True
def save_session(self, cookies: Dict[str, str], metadata: Optional[Dict[str, Any]] = None):
"""
Save session cookies to file for reuse
Args:
cookies: Dictionary of cookies to save
metadata: Optional metadata about the session
"""
session_data = {
'cookies': cookies,
'metadata': metadata or {},
'timestamp': int(time.time())
}
try:
with open(self.session_file, 'w') as f:
json.dump(session_data, f, indent=2)
logger.info(f"Session saved to {self.session_file}")
except Exception as e:
logger.error(f"Failed to save session: {e}")
def load_session(self) -> Optional[Dict[str, str]]:
"""
Load session cookies from file
Returns:
Dictionary of cookies if successful, None otherwise
"""
if not self.session_file.exists():
logger.info("No saved session found")
return None
try:
with open(self.session_file, 'r') as f:
session_data = json.load(f)
cookies = session_data.get('cookies', {})
timestamp = session_data.get('timestamp', 0)
# Check if session is too old (24 hours)
import time
if time.time() - timestamp > 24 * 3600:
logger.warning("Saved session is too old (>24h), may be expired")
if self.validate_session_cookies(cookies):
logger.info("Loaded valid session from file")
return cookies
else:
logger.warning("Loaded session has invalid cookies")
return None
except Exception as e:
logger.error(f"Failed to load session: {e}")
return None
def extract_from_curl_command(self, curl_command: str) -> Dict[str, str]:
"""
Extract cookies from a curl command copied from browser
Args:
curl_command: Complete curl command from browser "Copy as cURL"
Returns:
Dictionary of extracted cookies
"""
cookies = {}
# Find cookie header in curl command
cookie_match = re.search(r'-H [\'"]cookie: ([^\'"]+)[\'"]', curl_command, re.IGNORECASE)
if not cookie_match:
cookie_match = re.search(r'--header [\'"]cookie: ([^\'"]+)[\'"]', curl_command, re.IGNORECASE)
if cookie_match:
cookie_header = cookie_match.group(1)
cookies = self.extract_cookies_from_network_tab(cookie_header)
logger.info(f"Extracted {len(cookies)} cookies from curl command")
else:
logger.warning("No cookie header found in curl command")
return cookies
def print_cookie_extraction_guide(self):
"""Print instructions for extracting cookies from browser"""
print("\n" + "="*80)
print("MEXC COOKIE EXTRACTION GUIDE")
print("="*80)
print("""
To extract cookies from your browser for MEXC futures trading:
METHOD 1: Browser Network Tab
1. Open MEXC futures page and log in: https://www.mexc.com/en-GB/futures/ETH_USDT
2. Open browser Developer Tools (F12)
3. Go to Network tab
4. Try to place a small futures trade (it will fail, but we need the request)
5. Find the request to 'futures.mexc.com' in the Network tab
6. Right-click on the request -> Copy -> Copy request headers
7. Find the 'Cookie:' line and copy everything after 'Cookie: '
METHOD 2: Copy as cURL
1. Follow steps 1-5 above
2. Right-click on the futures API request -> Copy -> Copy as cURL
3. Paste the entire cURL command
METHOD 3: Manual Cookie Extraction
1. While logged into MEXC, press F12 -> Application/Storage tab
2. On the left, expand 'Cookies' -> click on 'https://www.mexc.com'
3. Copy the values for these important cookies:
- uc_token
- u_id
- x-mxc-fingerprint
- mexc_fingerprint_visitorId
IMPORTANT NOTES:
- Cookies expire after some time (usually 24 hours)
- You must be logged into MEXC futures (not just spot trading)
- Keep your cookies secure - they provide access to your account
- Test with small amounts first
Example usage:
session_manager = MEXCSessionManager()
# Method 1: From cookie header
cookie_header = "uc_token=ABC123; u_id=DEF456; ..."
cookies = session_manager.extract_cookies_from_network_tab(cookie_header)
# Method 2: From cURL command
curl_cmd = "curl 'https://futures.mexc.com/...' -H 'cookie: uc_token=ABC123...'"
cookies = session_manager.extract_from_curl_command(curl_cmd)
# Save session for reuse
session_manager.save_session(cookies)
""")
print("="*80)
if __name__ == "__main__":
# When run directly, show the extraction guide
import time
manager = MEXCSessionManager()
manager.print_cookie_extraction_guide()
print("\nWould you like to:")
print("1. Load saved session")
print("2. Extract cookies from clipboard")
print("3. Exit")
choice = input("\nEnter choice (1-3): ").strip()
if choice == "1":
cookies = manager.load_session()
if cookies:
print(f"\nLoaded {len(cookies)} cookies from saved session")
if manager.validate_session_cookies(cookies):
print("Session appears valid for trading")
else:
print("Warning: Session may be incomplete or expired")
else:
print("No valid saved session found")
elif choice == "2":
print("\nPaste your cookie header or cURL command:")
user_input = input().strip()
if user_input.startswith('curl'):
cookies = manager.extract_from_curl_command(user_input)
else:
cookies = manager.extract_cookies_from_network_tab(user_input)
if cookies and manager.validate_session_cookies(cookies):
print(f"\nSuccessfully extracted {len(cookies)} valid cookies")
save = input("Save session for reuse? (y/n): ").strip().lower()
if save == 'y':
manager.save_session(cookies)
else:
print("Failed to extract valid cookies")
else:
print("Goodbye!")

View File

@ -1,346 +0,0 @@
#!/usr/bin/env python3
"""
Test MEXC Futures Web Client
This script demonstrates how to use the MEXC Futures Web Client
for futures trading that isn't supported by their official API.
IMPORTANT: This requires extracting cookies from your browser session.
"""
import logging
import sys
import os
import time
import json
import uuid
# Add the project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from mexc_futures_client import MEXCFuturesWebClient
from session_manager import MEXCSessionManager
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Constants
SYMBOL = "ETH_USDT"
LEVERAGE = 300
CREDENTIALS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'mexc_credentials.json')
# Read credentials from mexc_credentials.json in JSON format
def load_credentials():
credentials_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'mexc_credentials.json')
cookies = {}
captcha_token_open = ''
captcha_token_close = ''
try:
with open(credentials_file, 'r') as f:
data = json.load(f)
cookies = data.get('credentials', {}).get('cookies', {})
captcha_token_open = data.get('credentials', {}).get('captcha_token_open', '')
captcha_token_close = data.get('credentials', {}).get('captcha_token_close', '')
logger.info(f"Loaded credentials from {credentials_file}")
except Exception as e:
logger.error(f"Error loading credentials: {e}")
return cookies, captcha_token_open, captcha_token_close
def test_basic_connection():
"""Test basic connection and authentication"""
logger.info("Testing MEXC Futures Web Client")
# Initialize session manager
session_manager = MEXCSessionManager()
# Try to load saved session first
cookies = session_manager.load_session()
if not cookies:
# Explicitly load the cookies from the file we have
cookies_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'mexc_cookies_20250703_003625.json')
if os.path.exists(cookies_file):
try:
with open(cookies_file, 'r') as f:
cookies = json.load(f)
logger.info(f"Loaded cookies from {cookies_file}")
except Exception as e:
logger.error(f"Failed to load cookies from {cookies_file}: {e}")
cookies = None
else:
logger.error(f"Cookies file not found at {cookies_file}")
cookies = None
if not cookies:
print("\nNo saved session found. You need to extract cookies from your browser.")
session_manager.print_cookie_extraction_guide()
print("\nPaste your cookie header or cURL command (or press Enter to exit):")
user_input = input().strip()
if not user_input:
print("No input provided. Exiting.")
return False
# Extract cookies from user input
if user_input.startswith('curl'):
cookies = session_manager.extract_from_curl_command(user_input)
else:
cookies = session_manager.extract_cookies_from_network_tab(user_input)
if not cookies:
logger.error("Failed to extract cookies from input")
return False
# Validate and save session
if session_manager.validate_session_cookies(cookies):
session_manager.save_session(cookies)
logger.info("Session saved for future use")
else:
logger.warning("Extracted cookies may be incomplete")
# Initialize the web client
client = MEXCFuturesWebClient(api_key='', api_secret='', user_id='', base_url='https://www.mexc.com', headless=True)
# Load cookies into the client's session
for name, value in cookies.items():
client.session.cookies.set(name, value)
# Update headers to include additional parameters from captured requests
client.session.headers.update({
'trochilus-trace-id': f"{uuid.uuid4()}-{int(time.time() * 1000) % 10000:04d}",
'trochilus-uid': cookies.get('u_id', ''),
'Referer': 'https://www.mexc.com/en-GB/futures/ETH_USDT?type=linear_swap',
'Language': 'English',
'X-Language': 'en-GB'
})
if not client.is_authenticated:
logger.error("Failed to authenticate with extracted cookies")
return False
logger.info("Successfully authenticated with MEXC")
logger.info(f"User ID: {client.user_id}")
logger.info(f"Auth Token: {client.auth_token[:20]}..." if client.auth_token else "No auth token")
return True
def test_captcha_verification(client: MEXCFuturesWebClient):
"""Test captcha verification system"""
logger.info("Testing captcha verification...")
# Test captcha for ETH_USDT long position with 200x leverage
success = client.verify_captcha('ETH_USDT', 'openlong', '200X')
if success:
logger.info("Captcha verification successful")
else:
logger.warning("Captcha verification failed - this may be normal if no position is being opened")
return success
def test_position_opening(client: MEXCFuturesWebClient, dry_run: bool = True):
"""Test opening a position (dry run by default)"""
if dry_run:
logger.info("DRY RUN: Testing position opening (no actual trade)")
else:
logger.warning("LIVE TRADING: Opening actual position!")
symbol = 'ETH_USDT'
volume = 1 # Small test position
leverage = 200
logger.info(f"Attempting to open long position: {symbol}, Volume: {volume}, Leverage: {leverage}x")
if not dry_run:
result = client.open_long_position(symbol, volume, leverage)
if result['success']:
logger.info(f"Position opened successfully!")
logger.info(f"Order ID: {result['order_id']}")
logger.info(f"Timestamp: {result['timestamp']}")
return True
else:
logger.error(f"Failed to open position: {result['error']}")
return False
else:
logger.info("DRY RUN: Would attempt to open position here")
# Test just the captcha verification part
return client.verify_captcha(symbol, 'openlong', f'{leverage}X')
def test_position_opening_live(client):
symbol = "ETH_USDT"
volume = 1 # Small volume for testing
leverage = 200
logger.info(f"LIVE TRADING: Opening actual position!")
logger.info(f"Attempting to open long position: {symbol}, Volume: {volume}, Leverage: {leverage}x")
result = client.open_long_position(symbol, volume, leverage)
if result.get('success'):
logger.info(f"Successfully opened position: {result}")
else:
logger.error(f"Failed to open position: {result.get('error', 'Unknown error')}")
def interactive_menu(client: MEXCFuturesWebClient):
"""Interactive menu for testing different functions"""
while True:
print("\n" + "="*50)
print("MEXC Futures Web Client Test Menu")
print("="*50)
print("1. Test captcha verification")
print("2. Test position opening (DRY RUN)")
print("3. Test position opening (LIVE - BE CAREFUL!)")
print("4. Test position closing (DRY RUN)")
print("5. Show session info")
print("6. Refresh session")
print("0. Exit")
choice = input("\nEnter choice (0-6): ").strip()
if choice == "1":
test_captcha_verification(client)
elif choice == "2":
test_position_opening(client, dry_run=True)
elif choice == "3":
test_position_opening_live(client)
elif choice == "4":
logger.info("DRY RUN: Position closing test")
success = client.verify_captcha('ETH_USDT', 'closelong', '200X')
if success:
logger.info("DRY RUN: Would close position here")
else:
logger.warning("Captcha verification failed for position closing")
elif choice == "5":
print(f"\nSession Information:")
print(f"Authenticated: {client.is_authenticated}")
print(f"User ID: {client.user_id}")
print(f"Auth Token: {client.auth_token[:20]}..." if client.auth_token else "None")
print(f"Fingerprint: {client.fingerprint}")
print(f"Visitor ID: {client.visitor_id}")
elif choice == "6":
session_manager = MEXCSessionManager()
session_manager.print_cookie_extraction_guide()
elif choice == "0":
print("Goodbye!")
break
else:
print("Invalid choice. Please try again.")
def main():
"""Main test function"""
print("MEXC Futures Web Client Test")
print("WARNING: This is experimental software for futures trading")
print("Use at your own risk and test with small amounts first!")
# Load cookies and tokens
cookies, captcha_token_open, captcha_token_close = load_credentials()
if not cookies:
logger.error("Failed to load cookies from credentials file")
sys.exit(1)
# Initialize client with loaded cookies and tokens
client = MEXCFuturesWebClient(api_key='', api_secret='', user_id='')
# Load cookies into the client's session
for name, value in cookies.items():
client.session.cookies.set(name, value)
# Set captcha tokens
client.captcha_token_open = captcha_token_open
client.captcha_token_close = captcha_token_close
# Try to load credentials from the new JSON file
try:
with open(CREDENTIALS_FILE, 'r') as f:
credentials_data = json.load(f)
cookies = credentials_data['credentials']['cookies']
captcha_token_open = credentials_data['credentials']['captcha_token_open']
captcha_token_close = credentials_data['credentials']['captcha_token_close']
client.load_session_cookies(cookies)
client.session_manager.save_captcha_token(captcha_token_open) # Assuming this is for opening
except FileNotFoundError:
logger.error(f"Credentials file not found at {CREDENTIALS_FILE}")
return False
except json.JSONDecodeError as e:
logger.error(f"Error loading credentials: {e}")
return False
except KeyError as e:
logger.error(f"Missing key in credentials file: {e}")
return False
if not client.is_authenticated:
logger.error("Client not authenticated. Please ensure valid cookies and tokens are in mexc_credentials.json")
return False
# Test connection and authentication
logger.info("Successfully authenticated with MEXC")
# Set leverage
leverage_response = client.update_leverage(symbol=SYMBOL, leverage=LEVERAGE)
if leverage_response and leverage_response.get('code') == 200:
logger.info(f"Leverage set to {LEVERAGE}x for {SYMBOL}")
else:
logger.error(f"Failed to set leverage: {leverage_response}")
sys.exit(1)
# Get current price
ticker = client.get_ticker_data(symbol=SYMBOL)
if ticker and ticker.get('code') == 200:
current_price = float(ticker['data']['last'])
logger.info(f"Current {SYMBOL} price: {current_price}")
else:
logger.error(f"Failed to get ticker data: {ticker}")
sys.exit(1)
# Calculate order size for a small test trade (e.g., $10 worth)
trade_usdt = 10.0
order_qty = round((trade_usdt / current_price) * LEVERAGE, 3)
logger.info(f"Calculated order quantity: {order_qty} {SYMBOL} for ~${trade_usdt} at {LEVERAGE}x")
# Test 1: Open LONG position
logger.info(f"Opening LONG position for {SYMBOL} at {current_price} with qty {order_qty}")
open_long_order = client.create_order(
symbol=SYMBOL,
side=1, # 1 for BUY
position_side=1, # 1 for LONG
order_type=1, # 1 for LIMIT
price=current_price,
vol=order_qty
)
if open_long_order and open_long_order.get('code') == 200:
logger.info(f"✅ Successfully opened LONG position: {open_long_order['data']}")
else:
logger.error(f"❌ Failed to open LONG position: {open_long_order}")
sys.exit(1)
# Test 2: Close LONG position
logger.info(f"Closing LONG position for {SYMBOL}")
close_long_order = client.create_order(
symbol=SYMBOL,
side=2, # 2 for SELL
position_side=1, # 1 for LONG
order_type=1, # 1 for LIMIT
price=current_price,
vol=order_qty,
reduce_only=True
)
if close_long_order and close_long_order.get('code') == 200:
logger.info(f"✅ Successfully closed LONG position: {close_long_order['data']}")
else:
logger.error(f"❌ Failed to close LONG position: {close_long_order}")
sys.exit(1)
logger.info("All tests completed successfully!")
if __name__ == "__main__":
main()

View File

@ -1,595 +0,0 @@
"""
Negative Case Trainer - Intensive Training on Losing Trades
This module focuses on learning from losses to prevent future mistakes.
Stores negative cases in testcases/negative folder for reuse and retraining.
Supports simultaneous inference and training.
"""
import os
import json
import logging
import pickle
import threading
import time
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict
from collections import deque
import numpy as np
import pandas as pd
# Import checkpoint management
import torch
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from utils.training_integration import get_training_integration
logger = logging.getLogger(__name__)
@dataclass
class NegativeCase:
"""Represents a losing trade case for intensive training"""
case_id: str
timestamp: datetime
symbol: str
action: str # 'BUY' or 'SELL'
entry_price: float
exit_price: float
loss_amount: float
loss_percentage: float
confidence_used: float
market_state_before: Dict[str, Any]
market_state_after: Dict[str, Any]
tick_data: List[Dict[str, Any]] # 15 minutes of tick data around the trade
technical_indicators: Dict[str, float]
what_should_have_been_done: str # 'HOLD', 'OPPOSITE', 'WAIT'
lesson_learned: str
training_priority: int # 1-5, 5 being highest priority
retraining_count: int = 0
last_retrained: Optional[datetime] = None
@dataclass
class TrainingSession:
"""Represents an intensive training session on negative cases"""
session_id: str
start_time: datetime
cases_trained: List[str] # case_ids
epochs_completed: int
loss_improvement: float
accuracy_improvement: float
inference_paused: bool = False
training_active: bool = True
class NegativeCaseTrainer:
"""
Intensive trainer focused on learning from losing trades with checkpoint management
Features:
- Stores all losing trades as negative cases
- Intensive retraining on losses
- Simultaneous inference and training
- Persistent storage in testcases/negative
- Priority-based training (bigger losses = higher priority)
- Checkpoint management for training progress
"""
def __init__(self, storage_dir: str = "testcases/negative",
model_name: str = "negative_case_trainer", enable_checkpoints: bool = True):
self.storage_dir = storage_dir
self.stored_cases: List[NegativeCase] = []
self.training_queue = deque(maxlen=1000)
self.training_lock = threading.Lock()
self.inference_lock = threading.Lock()
# Checkpoint management
self.model_name = model_name
self.enable_checkpoints = enable_checkpoints
self.training_integration = get_training_integration() if enable_checkpoints else None
self.training_session_count = 0
self.best_loss_reduction = 0.0
self.checkpoint_frequency = 25 # Save checkpoint every 25 training sessions
# Training configuration
self.max_concurrent_training = 3 # Max parallel training sessions
self.intensive_training_epochs = 50 # Epochs per negative case
self.priority_multiplier = 2.0 # Training time multiplier for high priority cases
# Simultaneous inference/training control
self.inference_active = True
self.training_active = False
self.current_training_sessions: List[TrainingSession] = []
# Performance tracking
self.total_cases_processed = 0
self.total_training_time = 0.0
self.accuracy_improvements = []
# Initialize storage
self._initialize_storage()
self._load_existing_cases()
# Load best checkpoint if available
if self.enable_checkpoints:
self.load_best_checkpoint()
# Start background training thread
self.training_thread = threading.Thread(target=self._background_training_loop, daemon=True)
self.training_thread.start()
logger.info(f"NegativeCaseTrainer initialized with {len(self.stored_cases)} existing cases")
logger.info(f"Storage directory: {self.storage_dir}")
logger.info(f"Checkpoint management: {enable_checkpoints}, Model name: {model_name}")
logger.info("Background training thread started")
def _initialize_storage(self):
"""Initialize storage directories"""
try:
os.makedirs(self.storage_dir, exist_ok=True)
os.makedirs(f"{self.storage_dir}/cases", exist_ok=True)
os.makedirs(f"{self.storage_dir}/sessions", exist_ok=True)
os.makedirs(f"{self.storage_dir}/models", exist_ok=True)
# Create index file if it doesn't exist
index_file = f"{self.storage_dir}/case_index.json"
if not os.path.exists(index_file):
with open(index_file, 'w') as f:
json.dump({"cases": [], "last_updated": datetime.now().isoformat()}, f)
logger.info(f"Storage initialized at {self.storage_dir}")
except Exception as e:
logger.error(f"Error initializing storage: {e}")
def _load_existing_cases(self):
"""Load existing negative cases from storage"""
try:
index_file = f"{self.storage_dir}/case_index.json"
if os.path.exists(index_file):
with open(index_file, 'r') as f:
index_data = json.load(f)
for case_info in index_data.get("cases", []):
case_file = f"{self.storage_dir}/cases/{case_info['case_id']}.pkl"
if os.path.exists(case_file):
try:
with open(case_file, 'rb') as f:
case = pickle.load(f)
self.stored_cases.append(case)
except Exception as e:
logger.warning(f"Error loading case {case_info['case_id']}: {e}")
logger.info(f"Loaded {len(self.stored_cases)} existing negative cases")
except Exception as e:
logger.error(f"Error loading existing cases: {e}")
def add_losing_trade(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str:
"""
Add a losing trade as a negative case for intensive training
Args:
trade_info: Trade information including P&L
market_data: Market state and tick data around the trade
Returns:
case_id: Unique identifier for the negative case
"""
try:
# Generate unique case ID
case_id = f"loss_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{trade_info['symbol'].replace('/', '')}"
# Calculate loss metrics
loss_amount = abs(trade_info.get('pnl', 0))
loss_percentage = (loss_amount / trade_info.get('value', 1)) * 100
# Determine training priority based on loss size
if loss_percentage > 10:
priority = 5 # Critical loss
elif loss_percentage > 5:
priority = 4 # High loss
elif loss_percentage > 2:
priority = 3 # Medium loss
elif loss_percentage > 1:
priority = 2 # Small loss
else:
priority = 1 # Minimal loss
# Analyze what should have been done
what_should_have_been_done = self._analyze_optimal_action(trade_info, market_data)
lesson_learned = self._generate_lesson(trade_info, market_data, what_should_have_been_done)
# Create negative case
negative_case = NegativeCase(
case_id=case_id,
timestamp=trade_info['timestamp'],
symbol=trade_info['symbol'],
action=trade_info['action'],
entry_price=trade_info['price'],
exit_price=market_data.get('exit_price', trade_info['price']),
loss_amount=loss_amount,
loss_percentage=loss_percentage,
confidence_used=trade_info.get('confidence', 0.5),
market_state_before=market_data.get('state_before', {}),
market_state_after=market_data.get('state_after', {}),
tick_data=market_data.get('tick_data', []),
technical_indicators=market_data.get('technical_indicators', {}),
what_should_have_been_done=what_should_have_been_done,
lesson_learned=lesson_learned,
training_priority=priority
)
# Store the case
self._store_case(negative_case)
# Add to training queue with priority
with self.training_lock:
self.training_queue.append(negative_case)
self.stored_cases.append(negative_case)
logger.error(f"NEGATIVE CASE ADDED: {case_id} | Loss: ${loss_amount:.2f} ({loss_percentage:.1f}%) | Priority: {priority}")
logger.error(f"Lesson: {lesson_learned}")
return case_id
except Exception as e:
logger.error(f"Error adding losing trade: {e}")
return ""
def _analyze_optimal_action(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str:
"""Analyze what the optimal action should have been"""
try:
# Simple analysis based on price movement
entry_price = trade_info['price']
exit_price = market_data.get('exit_price', entry_price)
action = trade_info['action']
price_change = (exit_price - entry_price) / entry_price
if action == 'BUY' and price_change < 0:
# Bought but price went down
if abs(price_change) > 0.005: # >0.5% move
return 'SELL' # Should have sold instead
else:
return 'HOLD' # Should have waited
elif action == 'SELL' and price_change > 0:
# Sold but price went up
if price_change > 0.005: # >0.5% move
return 'BUY' # Should have bought instead
else:
return 'HOLD' # Should have waited
else:
return 'HOLD' # Should have done nothing
except Exception as e:
logger.error(f"Error analyzing optimal action: {e}")
return 'HOLD'
def _generate_lesson(self, trade_info: Dict[str, Any], market_data: Dict[str, Any], optimal_action: str) -> str:
"""Generate a lesson learned from the losing trade"""
try:
action = trade_info['action']
symbol = trade_info['symbol']
loss_pct = (abs(trade_info.get('pnl', 0)) / trade_info.get('value', 1)) * 100
confidence = trade_info.get('confidence', 0.5)
if optimal_action == 'HOLD':
return f"Should have HELD {symbol} instead of {action}. Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss."
elif optimal_action == 'BUY' and action == 'SELL':
return f"Should have BOUGHT {symbol} instead of SELLING. Market moved opposite to prediction."
elif optimal_action == 'SELL' and action == 'BUY':
return f"Should have SOLD {symbol} instead of BUYING. Market moved opposite to prediction."
else:
return f"Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss on {action} {symbol}."
except Exception as e:
logger.error(f"Error generating lesson: {e}")
return "Learn from this loss to improve future decisions."
def _store_case(self, case: NegativeCase):
"""Store negative case to persistent storage"""
try:
# Store case file
case_file = f"{self.storage_dir}/cases/{case.case_id}.pkl"
with open(case_file, 'wb') as f:
pickle.dump(case, f)
# Update index
index_file = f"{self.storage_dir}/case_index.json"
with open(index_file, 'r') as f:
index_data = json.load(f)
# Add case to index
case_info = {
'case_id': case.case_id,
'timestamp': case.timestamp.isoformat(),
'symbol': case.symbol,
'loss_amount': case.loss_amount,
'loss_percentage': case.loss_percentage,
'training_priority': case.training_priority,
'retraining_count': case.retraining_count
}
index_data['cases'].append(case_info)
index_data['last_updated'] = datetime.now().isoformat()
with open(index_file, 'w') as f:
json.dump(index_data, f, indent=2)
logger.info(f"Stored negative case: {case.case_id}")
except Exception as e:
logger.error(f"Error storing case: {e}")
def _background_training_loop(self):
"""Background loop for intensive training on negative cases"""
logger.info("Background training loop started")
while True:
try:
# Check if we have cases to train on
with self.training_lock:
if not self.training_queue:
time.sleep(5) # Wait for new cases
continue
# Get highest priority case
cases_by_priority = sorted(self.training_queue, key=lambda x: x.training_priority, reverse=True)
case_to_train = cases_by_priority[0]
self.training_queue.remove(case_to_train)
# Start intensive training session
self._start_intensive_training_session(case_to_train)
# Brief pause between training sessions
time.sleep(2)
except Exception as e:
logger.error(f"Error in background training loop: {e}")
time.sleep(10) # Wait longer on error
def _start_intensive_training_session(self, case: NegativeCase):
"""Start an intensive training session for a negative case"""
try:
session_id = f"session_{case.case_id}_{int(time.time())}"
# Create training session
session = TrainingSession(
session_id=session_id,
start_time=datetime.now(),
cases_trained=[case.case_id],
epochs_completed=0,
loss_improvement=0.0,
accuracy_improvement=0.0
)
self.current_training_sessions.append(session)
self.training_active = True
logger.warning(f"INTENSIVE TRAINING STARTED: {session_id}")
logger.warning(f"Training on loss case: {case.case_id} (Priority: {case.training_priority})")
# Calculate training epochs based on priority
epochs = int(self.intensive_training_epochs * case.training_priority * self.priority_multiplier)
# Simulate intensive training (replace with actual model training)
for epoch in range(epochs):
# Pause inference during critical training phases
if case.training_priority >= 4 and epoch % 10 == 0:
with self.inference_lock:
session.inference_paused = True
time.sleep(0.1) # Brief pause for critical training
session.inference_paused = False
# Simulate training step
session.epochs_completed = epoch + 1
# Log progress for high priority cases
if case.training_priority >= 4 and epoch % 10 == 0:
logger.warning(f"Intensive training progress: {epoch}/{epochs} epochs ({case.case_id})")
time.sleep(0.05) # Simulate training time
# Update case retraining info
case.retraining_count += 1
case.last_retrained = datetime.now()
# Calculate improvements (simulated)
session.loss_improvement = np.random.uniform(0.1, 0.5) # 10-50% improvement
session.accuracy_improvement = np.random.uniform(0.05, 0.2) # 5-20% improvement
# Store training session results
self._store_training_session(session)
# Update statistics
self.total_cases_processed += 1
self.total_training_time += (datetime.now() - session.start_time).total_seconds()
self.accuracy_improvements.append(session.accuracy_improvement)
# Remove from active sessions
self.current_training_sessions.remove(session)
if not self.current_training_sessions:
self.training_active = False
logger.warning(f"INTENSIVE TRAINING COMPLETED: {session_id}")
logger.warning(f"Epochs: {session.epochs_completed} | Loss improvement: {session.loss_improvement:.1%} | Accuracy improvement: {session.accuracy_improvement:.1%}")
except Exception as e:
logger.error(f"Error in intensive training session: {e}")
def _store_training_session(self, session: TrainingSession):
"""Store training session results"""
try:
session_file = f"{self.storage_dir}/sessions/{session.session_id}.json"
session_data = {
'session_id': session.session_id,
'start_time': session.start_time.isoformat(),
'end_time': datetime.now().isoformat(),
'cases_trained': session.cases_trained,
'epochs_completed': session.epochs_completed,
'loss_improvement': session.loss_improvement,
'accuracy_improvement': session.accuracy_improvement
}
with open(session_file, 'w') as f:
json.dump(session_data, f, indent=2)
except Exception as e:
logger.error(f"Error storing training session: {e}")
def can_inference_proceed(self) -> bool:
"""Check if inference can proceed (not blocked by critical training)"""
with self.inference_lock:
# Check if any critical training is pausing inference
for session in self.current_training_sessions:
if session.inference_paused:
return False
return True
def get_training_stats(self) -> Dict[str, Any]:
"""Get training statistics"""
try:
avg_accuracy_improvement = np.mean(self.accuracy_improvements) if self.accuracy_improvements else 0.0
return {
'total_negative_cases': len(self.stored_cases),
'cases_in_queue': len(self.training_queue),
'total_cases_processed': self.total_cases_processed,
'total_training_time': self.total_training_time,
'avg_accuracy_improvement': avg_accuracy_improvement,
'active_training_sessions': len(self.current_training_sessions),
'training_active': self.training_active,
'high_priority_cases': len([c for c in self.stored_cases if c.training_priority >= 4]),
'storage_directory': self.storage_dir
}
except Exception as e:
logger.error(f"Error getting training stats: {e}")
return {}
def get_recent_lessons(self, count: int = 5) -> List[str]:
"""Get recent lessons learned from negative cases"""
try:
recent_cases = sorted(self.stored_cases, key=lambda x: x.timestamp, reverse=True)[:count]
return [case.lesson_learned for case in recent_cases]
except Exception as e:
logger.error(f"Error getting recent lessons: {e}")
return []
def retrain_all_cases(self):
"""Retrain all stored negative cases (for periodic retraining)"""
try:
logger.warning("RETRAINING ALL NEGATIVE CASES - This may take a while...")
with self.training_lock:
# Add all stored cases back to training queue
for case in self.stored_cases:
if case not in self.training_queue:
self.training_queue.append(case)
logger.warning(f"Added {len(self.stored_cases)} cases to retraining queue")
except Exception as e:
logger.error(f"Error retraining all cases: {e}")
def load_best_checkpoint(self):
"""Load the best checkpoint for this negative case trainer"""
try:
if not self.enable_checkpoints:
return
result = load_best_checkpoint(self.model_name)
if result:
file_path, metadata = result
checkpoint = torch.load(file_path, map_location='cpu')
# Load training state
if 'training_session_count' in checkpoint:
self.training_session_count = checkpoint['training_session_count']
if 'best_loss_reduction' in checkpoint:
self.best_loss_reduction = checkpoint['best_loss_reduction']
if 'total_cases_processed' in checkpoint:
self.total_cases_processed = checkpoint['total_cases_processed']
if 'total_training_time' in checkpoint:
self.total_training_time = checkpoint['total_training_time']
if 'accuracy_improvements' in checkpoint:
self.accuracy_improvements = checkpoint['accuracy_improvements']
logger.info(f"Loaded NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
logger.info(f"Session: {self.training_session_count}, Best loss reduction: {self.best_loss_reduction:.4f}")
except Exception as e:
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
def save_checkpoint(self, loss_improvement: float = 0.0, force_save: bool = False):
"""Save checkpoint if performance improved or forced"""
try:
if not self.enable_checkpoints:
return False
self.training_session_count += 1
# Update best loss reduction
improved = False
if loss_improvement > self.best_loss_reduction:
self.best_loss_reduction = loss_improvement
improved = True
# Save checkpoint if improved, forced, or at regular intervals
should_save = (
force_save or
improved or
self.training_session_count % self.checkpoint_frequency == 0
)
if should_save:
# Prepare checkpoint data
checkpoint_data = {
'training_session_count': self.training_session_count,
'best_loss_reduction': self.best_loss_reduction,
'total_cases_processed': self.total_cases_processed,
'total_training_time': self.total_training_time,
'accuracy_improvements': self.accuracy_improvements,
'storage_dir': self.storage_dir,
'max_concurrent_training': self.max_concurrent_training,
'intensive_training_epochs': self.intensive_training_epochs
}
# Create performance metrics for checkpoint manager
avg_accuracy_improvement = (
sum(self.accuracy_improvements) / len(self.accuracy_improvements)
if self.accuracy_improvements else 0.0
)
performance_metrics = {
'loss_reduction': self.best_loss_reduction,
'avg_accuracy_improvement': avg_accuracy_improvement,
'total_cases_processed': self.total_cases_processed,
'training_efficiency': (
self.total_cases_processed / self.total_training_time
if self.total_training_time > 0 else 0.0
)
}
# Save using checkpoint manager
metadata = save_checkpoint(
model=checkpoint_data, # We're saving data dict instead of model
model_name=self.model_name,
model_type="negative_case_trainer",
performance_metrics=performance_metrics,
training_metadata={
'session': self.training_session_count,
'cases_processed': self.total_cases_processed,
'training_time_hours': self.total_training_time / 3600
},
force_save=force_save
)
if metadata:
logger.info(f"Saved NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
return True
return False
except Exception as e:
logger.error(f"Error saving NegativeCaseTrainer checkpoint: {e}")
return False

View File

@ -1,277 +0,0 @@
#!/usr/bin/env python3
"""
Neural Network Decision Fusion System
Central NN that merges all model outputs + market data for final trading decisions
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from datetime import datetime
import logging
logger = logging.getLogger(__name__)
@dataclass
class ModelPrediction:
"""Standardized prediction from any model"""
model_name: str
prediction_type: str # 'price', 'direction', 'action'
value: float # -1 to 1 for direction, actual price for price predictions
confidence: float # 0 to 1
timestamp: datetime
metadata: Optional[Dict[str, Any]] = None
@dataclass
class MarketContext:
"""Current market context for decision fusion"""
symbol: str
current_price: float
price_change_1m: float
price_change_5m: float
volume_ratio: float
volatility: float
timestamp: datetime
@dataclass
class FusionDecision:
"""Final trading decision from fusion NN"""
action: str # 'BUY', 'SELL', 'HOLD'
confidence: float # 0 to 1
expected_return: float # Expected return percentage
risk_score: float # 0 to 1, higher = riskier
position_size: float # Recommended position size
reasoning: str # Human-readable explanation
model_contributions: Dict[str, float] # How much each model contributed
timestamp: datetime
class DecisionFusionNetwork(nn.Module):
"""Small NN that fuses model predictions with market context"""
def __init__(self, input_dim: int = 32, hidden_dim: int = 64):
super().__init__()
self.fusion_layers = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 16)
)
# Output heads
self.action_head = nn.Linear(16, 3) # BUY, SELL, HOLD
self.confidence_head = nn.Linear(16, 1)
self.return_head = nn.Linear(16, 1)
def forward(self, features: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Forward pass through fusion network"""
fusion_output = self.fusion_layers(features)
action_logits = self.action_head(fusion_output)
action_probs = F.softmax(action_logits, dim=1)
confidence = torch.sigmoid(self.confidence_head(fusion_output))
expected_return = torch.tanh(self.return_head(fusion_output))
return {
'action_probs': action_probs,
'confidence': confidence.squeeze(),
'expected_return': expected_return.squeeze()
}
class NeuralDecisionFusion:
"""Main NN-based decision fusion system"""
def __init__(self, training_mode: bool = True):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.network = DecisionFusionNetwork().to(self.device)
self.training_mode = training_mode
self.registered_models = {}
self.last_predictions = {}
logger.info(f"Neural Decision Fusion initialized on {self.device}")
def register_model(self, model_name: str, model_type: str, prediction_format: str):
"""Register a model that will provide predictions"""
self.registered_models[model_name] = {
'type': model_type,
'format': prediction_format,
'prediction_count': 0
}
logger.info(f"Registered NN model: {model_name} ({model_type})")
def add_prediction(self, prediction: ModelPrediction):
"""Add a prediction from a registered model"""
self.last_predictions[prediction.model_name] = prediction
if prediction.model_name in self.registered_models:
self.registered_models[prediction.model_name]['prediction_count'] += 1
logger.debug(f"🔮 {prediction.model_name}: {prediction.value:.3f} "
f"(confidence: {prediction.confidence:.3f})")
def make_decision(self, symbol: str, market_context: MarketContext,
min_confidence: float = 0.25) -> Optional[FusionDecision]:
"""Make NN-driven trading decision"""
try:
if len(self.last_predictions) < 1:
logger.debug("No NN predictions available")
return None
# Prepare features
features = self._prepare_features(market_context)
if features is None:
return None
# Run NN inference
with torch.no_grad():
self.network.eval()
features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(self.device)
outputs = self.network(features_tensor)
action_probs = outputs['action_probs'][0].cpu().numpy()
confidence = outputs['confidence'].cpu().item()
expected_return = outputs['expected_return'].cpu().item()
# Determine action
action_idx = np.argmax(action_probs)
actions = ['BUY', 'SELL', 'HOLD']
action = actions[action_idx]
# Check confidence threshold
if confidence < min_confidence:
action = 'HOLD'
logger.debug(f"Low NN confidence ({confidence:.3f}), defaulting to HOLD")
# Calculate position size
position_size = self._calculate_position_size(confidence, expected_return)
# Generate reasoning
reasoning = self._generate_reasoning(action, confidence, expected_return, action_probs)
# Calculate risk score and model contributions
risk_score = min(1.0, abs(expected_return) * 5 + (1 - confidence) * 0.5)
model_contributions = self._calculate_model_contributions()
decision = FusionDecision(
action=action,
confidence=confidence,
expected_return=expected_return,
risk_score=risk_score,
position_size=position_size,
reasoning=reasoning,
model_contributions=model_contributions,
timestamp=datetime.now()
)
logger.info(f"🧠 NN DECISION: {action} (conf: {confidence:.3f}, "
f"return: {expected_return:.3f}, size: {position_size:.4f})")
return decision
except Exception as e:
logger.error(f"Error in NN decision making: {e}")
return None
def _prepare_features(self, context: MarketContext) -> Optional[np.ndarray]:
"""Prepare feature vector for NN"""
try:
features = np.zeros(32)
# Model predictions (slots 0-15)
idx = 0
for model_name, prediction in self.last_predictions.items():
if idx < 14: # Leave room for other features
features[idx] = prediction.value
features[idx + 1] = prediction.confidence
idx += 2
# Market context (slots 16-31)
features[16] = np.tanh(context.price_change_1m * 100) # 1m change
features[17] = np.tanh(context.price_change_5m * 100) # 5m change
features[18] = np.tanh(context.volume_ratio - 1) # Volume ratio
features[19] = np.tanh(context.volatility * 100) # Volatility
features[20] = context.current_price / 10000.0 # Normalized price
# Time features
now = context.timestamp
features[21] = now.hour / 24.0
features[22] = now.weekday() / 7.0
# Model agreement features
if len(self.last_predictions) >= 2:
values = [p.value for p in self.last_predictions.values()]
features[23] = np.mean(values) # Average prediction
features[24] = np.std(values) # Prediction variance
features[25] = len(self.last_predictions) # Model count
return features
except Exception as e:
logger.error(f"Error preparing NN features: {e}")
return None
def _calculate_position_size(self, confidence: float, expected_return: float) -> float:
"""Calculate position size based on NN outputs"""
base_size = 0.01 # 0.01 ETH base
# Scale by confidence
confidence_multiplier = max(0.1, min(2.0, confidence * 1.5))
# Scale by expected return
return_multiplier = 1.0 + abs(expected_return) * 0.5
final_size = base_size * confidence_multiplier * return_multiplier
return max(0.001, min(0.05, final_size))
def _generate_reasoning(self, action: str, confidence: float,
expected_return: float, action_probs: np.ndarray) -> str:
"""Generate human-readable reasoning"""
reasons = []
if action == 'BUY':
reasons.append(f"NN suggests BUY ({action_probs[0]:.1%})")
elif action == 'SELL':
reasons.append(f"NN suggests SELL ({action_probs[1]:.1%})")
else:
reasons.append(f"NN suggests HOLD")
if confidence > 0.7:
reasons.append("High confidence")
elif confidence > 0.5:
reasons.append("Moderate confidence")
else:
reasons.append("Low confidence")
if abs(expected_return) > 0.01:
direction = "positive" if expected_return > 0 else "negative"
reasons.append(f"Expected {direction} return: {expected_return:.2%}")
reasons.append(f"Based on {len(self.last_predictions)} NN models")
return " | ".join(reasons)
def _calculate_model_contributions(self) -> Dict[str, float]:
"""Calculate how much each model contributed to the decision"""
contributions = {}
total_confidence = sum(p.confidence for p in self.last_predictions.values()) if self.last_predictions else 1.0
if total_confidence > 0:
for model_name, prediction in self.last_predictions.items():
contributions[model_name] = prediction.confidence / total_confidence
return contributions
def get_status(self) -> Dict[str, Any]:
"""Get NN fusion system status"""
return {
'device': str(self.device),
'training_mode': self.training_mode,
'registered_models': len(self.registered_models),
'recent_predictions': len(self.last_predictions),
'model_parameters': sum(p.numel() for p in self.network.parameters())
}

Binary file not shown.

View File

@ -1,649 +0,0 @@
"""
Real-Time Tick Processing Neural Network Module
This module acts as a Neural Network DPS (Data Processing System) alternative,
processing raw tick data with ultra-low latency and feeding processed features
to trading models in real-time.
Features:
- Real-time tick ingestion with volume processing
- Neural network feature extraction from tick streams
- Ultra-low latency processing (sub-millisecond)
- Volume-weighted price analysis
- Microstructure pattern detection
- Real-time feature streaming to models
"""
import asyncio
import logging
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any, Deque
from collections import deque
from threading import Thread, Lock
import websockets
import json
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class TickData:
"""Raw tick data structure"""
timestamp: datetime
price: float
volume: float
side: str # 'buy' or 'sell'
trade_id: Optional[str] = None
@dataclass
class ProcessedTickFeatures:
"""Processed tick features for model consumption"""
timestamp: datetime
price_features: np.ndarray # Price-based features
volume_features: np.ndarray # Volume-based features
microstructure_features: np.ndarray # Market microstructure features
neural_features: np.ndarray # Neural network extracted features
confidence: float # Feature quality confidence
class TickProcessingNN(nn.Module):
"""
Neural Network for real-time tick processing
Extracts high-level features from raw tick data
"""
def __init__(self, input_size: int = 9, hidden_size: int = 128, output_size: int = 64):
super(TickProcessingNN, self).__init__()
# Tick sequence processing layers
self.tick_encoder = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1)
)
# LSTM for temporal patterns
self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, num_layers=2)
# Attention mechanism for important tick selection
self.attention = nn.MultiheadAttention(hidden_size, num_heads=8, batch_first=True)
# Feature extraction heads
self.price_head = nn.Linear(hidden_size, 16) # Price pattern features
self.volume_head = nn.Linear(hidden_size, 16) # Volume pattern features
self.microstructure_head = nn.Linear(hidden_size, 16) # Microstructure features
# Final feature fusion
self.feature_fusion = nn.Sequential(
nn.Linear(48, output_size), # 16+16+16 = 48
nn.ReLU(),
nn.Linear(output_size, output_size)
)
# Confidence estimation
self.confidence_head = nn.Sequential(
nn.Linear(output_size, 32),
nn.ReLU(),
nn.Linear(32, 1),
nn.Sigmoid()
)
def forward(self, tick_sequence: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Process tick sequence and extract features
Args:
tick_sequence: [batch, sequence_length, features]
Returns:
features: [batch, output_size] - extracted features
confidence: [batch, 1] - feature confidence
"""
batch_size, seq_len, _ = tick_sequence.shape
# Encode each tick
encoded = self.tick_encoder(tick_sequence) # [batch, seq_len, hidden_size]
# LSTM processing for temporal patterns
lstm_out, _ = self.lstm(encoded) # [batch, seq_len, hidden_size]
# Attention to focus on important ticks
attended, _ = self.attention(lstm_out, lstm_out, lstm_out) # [batch, seq_len, hidden_size]
# Use the last attended output
final_features = attended[:, -1, :] # [batch, hidden_size]
# Extract specialized features
price_features = self.price_head(final_features)
volume_features = self.volume_head(final_features)
microstructure_features = self.microstructure_head(final_features)
# Fuse all features
combined_features = torch.cat([price_features, volume_features, microstructure_features], dim=1)
final_features = self.feature_fusion(combined_features)
# Estimate confidence
confidence = self.confidence_head(final_features)
return final_features, confidence
class RealTimeTickProcessor:
"""
Real-time tick processing system with neural network feature extraction
Acts as a DPS alternative for ultra-low latency tick processing
"""
def __init__(self, symbols: List[str] = None, tick_buffer_size: int = 1000):
"""Initialize the real-time tick processor"""
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
self.tick_buffer_size = tick_buffer_size
# Tick storage buffers
self.tick_buffers: Dict[str, Deque[TickData]] = {}
self.processed_features: Dict[str, Deque[ProcessedTickFeatures]] = {}
# Initialize buffers for each symbol
for symbol in self.symbols:
self.tick_buffers[symbol] = deque(maxlen=tick_buffer_size)
self.processed_features[symbol] = deque(maxlen=100) # Keep last 100 processed features
# Neural network for feature extraction
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.tick_nn = TickProcessingNN(input_size=9).to(self.device)
self.tick_nn.eval() # Start in evaluation mode
# Processing parameters
self.processing_window = 50 # Number of ticks to process at once
self.min_ticks_for_processing = 10 # Minimum ticks before processing
# Real-time streaming
self.streaming = False
self.websocket_tasks = {}
self.processing_threads = {}
# Performance tracking
self.processing_times = deque(maxlen=1000)
self.tick_counts = {symbol: 0 for symbol in self.symbols}
# Thread safety
self.data_lock = Lock()
# Feature subscribers (models that want real-time features)
self.feature_subscribers = []
logger.info(f"RealTimeTickProcessor initialized for symbols: {self.symbols}")
logger.info(f"Neural network device: {self.device}")
logger.info(f"Tick buffer size: {tick_buffer_size}")
def add_feature_subscriber(self, callback):
"""Add a callback function to receive processed features"""
self.feature_subscribers.append(callback)
logger.info(f"Added feature subscriber: {callback.__name__}")
def remove_feature_subscriber(self, callback):
"""Remove a feature subscriber"""
if callback in self.feature_subscribers:
self.feature_subscribers.remove(callback)
logger.info(f"Removed feature subscriber: {callback.__name__}")
async def start_processing(self):
"""Start real-time tick processing"""
logger.info("Starting real-time tick processing...")
self.streaming = True
# Start WebSocket streams for each symbol
for symbol in self.symbols:
task = asyncio.create_task(self._websocket_stream(symbol))
self.websocket_tasks[symbol] = task
# Start processing thread for each symbol
thread = Thread(target=self._processing_loop, args=(symbol,), daemon=True)
thread.start()
self.processing_threads[symbol] = thread
logger.info("Real-time tick processing started")
async def stop_processing(self):
"""Stop real-time tick processing"""
logger.info("Stopping real-time tick processing...")
self.streaming = False
# Cancel WebSocket tasks
for symbol, task in self.websocket_tasks.items():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self.websocket_tasks.clear()
logger.info("Real-time tick processing stopped")
async def _websocket_stream(self, symbol: str):
"""WebSocket stream for real-time tick data"""
binance_symbol = symbol.replace('/', '').lower()
url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@trade"
while self.streaming:
try:
async with websockets.connect(url) as websocket:
logger.info(f"Tick WebSocket connected for {symbol}")
async for message in websocket:
if not self.streaming:
break
try:
data = json.loads(message)
await self._process_raw_tick(symbol, data)
except Exception as e:
logger.warning(f"Error processing tick for {symbol}: {e}")
except Exception as e:
logger.error(f"WebSocket error for {symbol}: {e}")
if self.streaming:
logger.info(f"Reconnecting tick WebSocket for {symbol} in 2 seconds...")
await asyncio.sleep(2)
async def _process_raw_tick(self, symbol: str, raw_data: Dict):
"""Process raw tick data from WebSocket"""
try:
# Extract tick information
tick = TickData(
timestamp=datetime.fromtimestamp(int(raw_data['T']) / 1000),
price=float(raw_data['p']),
volume=float(raw_data['q']),
side='buy' if raw_data['m'] == False else 'sell', # m=true means buyer is market maker (sell)
trade_id=raw_data.get('t')
)
# Add to buffer
with self.data_lock:
self.tick_buffers[symbol].append(tick)
self.tick_counts[symbol] += 1
except Exception as e:
logger.error(f"Error processing raw tick for {symbol}: {e}")
def _processing_loop(self, symbol: str):
"""Main processing loop for a symbol"""
logger.info(f"Starting processing loop for {symbol}")
while self.streaming:
try:
# Check if we have enough ticks to process
with self.data_lock:
tick_count = len(self.tick_buffers[symbol])
if tick_count >= self.min_ticks_for_processing:
start_time = time.time()
# Process ticks
features = self._extract_neural_features(symbol)
if features is not None:
# Store processed features
with self.data_lock:
self.processed_features[symbol].append(features)
# Notify subscribers
self._notify_feature_subscribers(symbol, features)
# Track processing time
processing_time = (time.time() - start_time) * 1000 # Convert to ms
self.processing_times.append(processing_time)
if len(self.processing_times) % 100 == 0:
avg_time = np.mean(list(self.processing_times))
logger.debug(f"RTP: Average processing time: {avg_time:.2f}ms")
# Small sleep to prevent CPU overload
time.sleep(0.001) # 1ms sleep for ultra-low latency
except Exception as e:
logger.error(f"Error in processing loop for {symbol}: {e}")
time.sleep(0.01) # Longer sleep on error
def _extract_neural_features(self, symbol: str) -> Optional[ProcessedTickFeatures]:
"""Extract neural network features from recent ticks"""
try:
with self.data_lock:
# Get recent ticks
recent_ticks = list(self.tick_buffers[symbol])[-self.processing_window:]
if len(recent_ticks) < self.min_ticks_for_processing:
return None
# Convert ticks to neural network input
tick_features = self._ticks_to_features(recent_ticks)
# Process with neural network
with torch.no_grad():
tick_tensor = torch.FloatTensor(tick_features).unsqueeze(0).to(self.device)
neural_features, confidence = self.tick_nn(tick_tensor)
neural_features = neural_features.cpu().numpy().flatten()
confidence = confidence.cpu().numpy().item()
# Extract traditional features
price_features = self._extract_price_features(recent_ticks)
volume_features = self._extract_volume_features(recent_ticks)
microstructure_features = self._extract_microstructure_features(recent_ticks)
# Create processed features object
processed = ProcessedTickFeatures(
timestamp=recent_ticks[-1].timestamp,
price_features=price_features,
volume_features=volume_features,
microstructure_features=microstructure_features,
neural_features=neural_features,
confidence=confidence
)
return processed
except Exception as e:
logger.error(f"Error extracting neural features for {symbol}: {e}")
return None
def _ticks_to_features(self, ticks: List[TickData]) -> np.ndarray:
"""Convert tick data to neural network input features"""
features = []
for i, tick in enumerate(ticks):
tick_features = [
tick.price,
tick.volume,
1.0 if tick.side == 'buy' else 0.0, # Buy/sell indicator
tick.timestamp.timestamp(), # Timestamp
]
# Add relative features if we have previous ticks
if i > 0:
prev_tick = ticks[i-1]
price_change = (tick.price - prev_tick.price) / prev_tick.price
volume_ratio = tick.volume / (prev_tick.volume + 1e-8)
time_delta = (tick.timestamp - prev_tick.timestamp).total_seconds()
tick_features.extend([
price_change,
volume_ratio,
time_delta
])
else:
tick_features.extend([0.0, 1.0, 0.0]) # Default values for first tick
# Add moving averages if we have enough data
if i >= 5:
recent_prices = [t.price for t in ticks[max(0, i-4):i+1]]
recent_volumes = [t.volume for t in ticks[max(0, i-4):i+1]]
price_ma = np.mean(recent_prices)
volume_ma = np.mean(recent_volumes)
tick_features.extend([
(tick.price - price_ma) / price_ma, # Price deviation from MA
(tick.volume - volume_ma) / (volume_ma + 1e-8) # Volume deviation from MA
])
else:
tick_features.extend([0.0, 0.0])
features.append(tick_features)
# Pad or truncate to fixed size
target_length = self.processing_window
if len(features) < target_length:
# Pad with zeros
padding = [[0.0] * len(features[0])] * (target_length - len(features))
features = padding + features
elif len(features) > target_length:
# Take the most recent ticks
features = features[-target_length:]
return np.array(features, dtype=np.float32)
def _extract_price_features(self, ticks: List[TickData]) -> np.ndarray:
"""Extract price-based features"""
prices = np.array([tick.price for tick in ticks])
features = [
prices[-1], # Current price
np.mean(prices), # Average price
np.std(prices), # Price volatility
np.max(prices), # High
np.min(prices), # Low
(prices[-1] - prices[0]) / prices[0] if prices[0] != 0 else 0, # Total return
]
# Price momentum features
if len(prices) >= 10:
short_ma = np.mean(prices[-5:])
long_ma = np.mean(prices[-10:])
momentum = (short_ma - long_ma) / long_ma if long_ma != 0 else 0
features.append(momentum)
else:
features.append(0.0)
return np.array(features, dtype=np.float32)
def _extract_volume_features(self, ticks: List[TickData]) -> np.ndarray:
"""Extract volume-based features"""
volumes = np.array([tick.volume for tick in ticks])
buy_volumes = np.array([tick.volume for tick in ticks if tick.side == 'buy'])
sell_volumes = np.array([tick.volume for tick in ticks if tick.side == 'sell'])
features = [
np.sum(volumes), # Total volume
np.mean(volumes), # Average volume
np.std(volumes), # Volume volatility
np.sum(buy_volumes) if len(buy_volumes) > 0 else 0, # Buy volume
np.sum(sell_volumes) if len(sell_volumes) > 0 else 0, # Sell volume
]
# Volume imbalance
total_buy = np.sum(buy_volumes) if len(buy_volumes) > 0 else 0
total_sell = np.sum(sell_volumes) if len(sell_volumes) > 0 else 0
total_volume = total_buy + total_sell
if total_volume > 0:
buy_ratio = total_buy / total_volume
volume_imbalance = buy_ratio - 0.5 # -0.5 to 0.5 range
else:
volume_imbalance = 0.0
features.append(volume_imbalance)
# VWAP (Volume Weighted Average Price)
if np.sum(volumes) > 0:
prices = np.array([tick.price for tick in ticks])
vwap = np.sum(prices * volumes) / np.sum(volumes)
current_price = ticks[-1].price
vwap_deviation = (current_price - vwap) / vwap if vwap != 0 else 0
else:
vwap_deviation = 0.0
features.append(vwap_deviation)
return np.array(features, dtype=np.float32)
def _extract_microstructure_features(self, ticks: List[TickData]) -> np.ndarray:
"""Extract market microstructure features"""
features = []
# Trade frequency
if len(ticks) >= 2:
time_deltas = [(ticks[i].timestamp - ticks[i-1].timestamp).total_seconds()
for i in range(1, len(ticks))]
avg_time_delta = np.mean(time_deltas)
trade_frequency = 1.0 / avg_time_delta if avg_time_delta > 0 else 0
else:
trade_frequency = 0.0
features.append(trade_frequency)
# Price impact features
prices = [tick.price for tick in ticks]
volumes = [tick.volume for tick in ticks]
if len(prices) >= 3:
# Calculate price changes and corresponding volumes
price_changes = [(prices[i] - prices[i-1]) / prices[i-1]
for i in range(1, len(prices)) if prices[i-1] != 0]
corresponding_volumes = volumes[1:len(price_changes)+1]
if len(price_changes) > 0 and len(corresponding_volumes) > 0:
# Simple price impact measure
price_impact = np.corrcoef(np.abs(price_changes), corresponding_volumes)[0, 1]
if np.isnan(price_impact):
price_impact = 0.0
else:
price_impact = 0.0
else:
price_impact = 0.0
features.append(price_impact)
# Bid-ask spread proxy (using price volatility)
if len(prices) >= 5:
recent_prices = prices[-5:]
spread_proxy = (np.max(recent_prices) - np.min(recent_prices)) / np.mean(recent_prices)
else:
spread_proxy = 0.0
features.append(spread_proxy)
# Order flow imbalance (already calculated in volume features, but different perspective)
buy_count = sum(1 for tick in ticks if tick.side == 'buy')
sell_count = len(ticks) - buy_count
total_trades = len(ticks)
if total_trades > 0:
order_flow_imbalance = (buy_count - sell_count) / total_trades
else:
order_flow_imbalance = 0.0
features.append(order_flow_imbalance)
return np.array(features, dtype=np.float32)
def _notify_feature_subscribers(self, symbol: str, features: ProcessedTickFeatures):
"""Notify all feature subscribers of new processed features"""
for callback in self.feature_subscribers:
try:
callback(symbol, features)
except Exception as e:
logger.error(f"Error notifying feature subscriber {callback.__name__}: {e}")
def get_latest_features(self, symbol: str) -> Optional[ProcessedTickFeatures]:
"""Get the latest processed features for a symbol"""
with self.data_lock:
if symbol in self.processed_features and self.processed_features[symbol]:
return self.processed_features[symbol][-1]
return None
def get_processing_stats(self) -> Dict[str, Any]:
"""Get processing performance statistics"""
stats = {
'symbols': self.symbols,
'streaming': self.streaming,
'tick_counts': dict(self.tick_counts),
'buffer_sizes': {symbol: len(self.tick_buffers[symbol]) for symbol in self.symbols},
'feature_counts': {symbol: len(self.processed_features[symbol]) for symbol in self.symbols},
'subscribers': len(self.feature_subscribers)
}
if self.processing_times:
stats['processing_performance'] = {
'avg_time_ms': np.mean(list(self.processing_times)),
'min_time_ms': np.min(list(self.processing_times)),
'max_time_ms': np.max(list(self.processing_times)),
'std_time_ms': np.std(list(self.processing_times))
}
return stats
def train_neural_network(self, training_data: List[Tuple[np.ndarray, np.ndarray]], epochs: int = 100):
"""Train the tick processing neural network"""
logger.info("Training tick processing neural network...")
self.tick_nn.train()
optimizer = torch.optim.Adam(self.tick_nn.parameters(), lr=0.001)
criterion = nn.MSELoss()
for epoch in range(epochs):
total_loss = 0.0
for batch_features, batch_targets in training_data:
optimizer.zero_grad()
# Convert to tensors
features_tensor = torch.FloatTensor(batch_features).to(self.device)
targets_tensor = torch.FloatTensor(batch_targets).to(self.device)
# Forward pass
outputs, confidence = self.tick_nn(features_tensor)
# Calculate loss
loss = criterion(outputs, targets_tensor)
# Backward pass
loss.backward()
optimizer.step()
total_loss += loss.item()
if epoch % 10 == 0:
avg_loss = total_loss / len(training_data)
logger.info(f"Epoch {epoch}/{epochs}, Average Loss: {avg_loss:.6f}")
self.tick_nn.eval()
logger.info("Neural network training completed")
# Integration with existing orchestrator
def integrate_with_orchestrator(orchestrator, tick_processor: RealTimeTickProcessor):
"""Integrate tick processor with enhanced orchestrator"""
def feature_callback(symbol: str, features: ProcessedTickFeatures):
"""Callback to feed processed features to orchestrator"""
try:
# Convert processed features to format expected by orchestrator
feature_dict = {
'symbol': symbol,
'timestamp': features.timestamp,
'neural_features': features.neural_features,
'price_features': features.price_features,
'volume_features': features.volume_features,
'microstructure_features': features.microstructure_features,
'confidence': features.confidence
}
# Feed to orchestrator's real-time feature processing
if hasattr(orchestrator, 'process_realtime_features'):
orchestrator.process_realtime_features(feature_dict)
except Exception as e:
logger.error(f"Error integrating features with orchestrator: {e}")
# Add the callback to tick processor
tick_processor.add_feature_subscriber(feature_callback)
logger.info("Tick processor integrated with orchestrator")
# Factory function for easy creation
def create_realtime_tick_processor(symbols: List[str] = None) -> RealTimeTickProcessor:
"""Create and configure a real-time tick processor"""
if symbols is None:
symbols = ['ETH/USDT', 'BTC/USDT']
processor = RealTimeTickProcessor(symbols=symbols)
logger.info(f"Created RealTimeTickProcessor for symbols: {symbols}")
return processor

View File

@ -1,453 +0,0 @@
"""
Retrospective Training System
This module implements a retrospective training system that:
1. Triggers training when trades close with known P&L outcomes
2. Uses captured model inputs from trade entry to train models
3. Optimizes for profit by learning from profitable vs unprofitable patterns
4. Supports simultaneous inference and training without weight reloading
5. Implements reinforcement learning with immediate reward feedback
"""
import logging
import threading
import time
import queue
from datetime import datetime
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass
import numpy as np
from collections import deque
logger = logging.getLogger(__name__)
@dataclass
class TrainingCase:
"""Represents a completed trade case for retrospective training"""
case_id: str
symbol: str
action: str # 'BUY' or 'SELL'
entry_price: float
exit_price: float
entry_time: datetime
exit_time: datetime
pnl: float
fees: float
confidence: float
model_inputs: Dict[str, Any]
market_state: Dict[str, Any]
outcome_label: int # 1 for profit, 0 for loss, 2 for breakeven
reward_signal: float # Scaled reward for RL training
leverage: float = 1.0
class RetrospectiveTrainer:
"""Retrospective training system for real-time model optimization"""
def __init__(self, orchestrator=None, config: Optional[Dict[str, Any]] = None):
"""Initialize the retrospective trainer"""
self.orchestrator = orchestrator
self.config = config or {}
# Training configuration
self.batch_size = self.config.get('batch_size', 32)
self.min_cases_for_training = self.config.get('min_cases_for_training', 5)
self.profit_threshold = self.config.get('profit_threshold', 0.0)
self.training_frequency = self.config.get('training_frequency_seconds', 120) # 2 minutes
self.max_training_cases = self.config.get('max_training_cases', 1000)
# Training state
self.training_queue = queue.Queue()
self.completed_cases = deque(maxlen=self.max_training_cases)
self.training_stats = {
'total_cases': 0,
'profitable_cases': 0,
'loss_cases': 0,
'breakeven_cases': 0,
'avg_profit': 0.0,
'last_training_time': datetime.now(),
'training_sessions': 0,
'model_updates': 0
}
# Threading
self.training_thread = None
self.is_training_active = False
self.training_lock = threading.Lock()
logger.info("RetrospectiveTrainer initialized")
logger.info(f"Configuration: batch_size={self.batch_size}, "
f"min_cases={self.min_cases_for_training}, "
f"training_freq={self.training_frequency}s")
def add_completed_trade(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> bool:
"""Add a completed trade for retrospective training"""
try:
# Create training case from trade record
case = self._create_training_case(trade_record, model_inputs)
if case is None:
return False
# Add to completed cases
self.completed_cases.append(case)
self.training_queue.put(case)
# Update statistics
self.training_stats['total_cases'] += 1
if case.outcome_label == 1: # Profit
self.training_stats['profitable_cases'] += 1
elif case.outcome_label == 0: # Loss
self.training_stats['loss_cases'] += 1
else: # Breakeven
self.training_stats['breakeven_cases'] += 1
# Calculate running average profit
total_pnl = sum(c.pnl for c in self.completed_cases)
self.training_stats['avg_profit'] = total_pnl / len(self.completed_cases)
logger.info(f"RETROSPECTIVE: Added training case {case.case_id} "
f"(P&L: ${case.pnl:.3f}, Label: {case.outcome_label})")
# Trigger training if we have enough cases
self._maybe_trigger_training()
return True
except Exception as e:
logger.error(f"Error adding completed trade for retrospective training: {e}")
return False
def _create_training_case(self, trade_record: Dict[str, Any], model_inputs: Dict[str, Any]) -> Optional[TrainingCase]:
"""Create a training case from trade record and model inputs"""
try:
# Extract trade information
symbol = trade_record.get('symbol', 'UNKNOWN')
side = trade_record.get('side', 'UNKNOWN')
pnl = trade_record.get('pnl', 0.0)
fees = trade_record.get('fees', 0.0)
confidence = trade_record.get('confidence', 0.0)
# Calculate net P&L after fees
net_pnl = pnl - fees
# Determine outcome label and reward signal
if net_pnl > self.profit_threshold:
outcome_label = 1 # Profitable
# Scale reward by profit magnitude and confidence
reward_signal = min(10.0, net_pnl * confidence * 10) # Amplify for training
elif net_pnl < -self.profit_threshold:
outcome_label = 0 # Loss
# Negative reward scaled by loss magnitude
reward_signal = max(-10.0, net_pnl * confidence * 10) # Negative reward
else:
outcome_label = 2 # Breakeven
reward_signal = 0.0
# Create case ID
timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S')
case_id = f"retro_{timestamp_str}_{symbol.replace('/', '')}_{side}_pnl_{abs(net_pnl):.3f}".replace('.', 'p')
# Create training case
case = TrainingCase(
case_id=case_id,
symbol=symbol,
action=side,
entry_price=trade_record.get('entry_price', 0.0),
exit_price=trade_record.get('exit_price', 0.0),
entry_time=trade_record.get('entry_time', datetime.now()),
exit_time=trade_record.get('exit_time', datetime.now()),
pnl=net_pnl,
fees=fees,
confidence=confidence,
model_inputs=model_inputs,
market_state=model_inputs.get('market_state', {}),
outcome_label=outcome_label,
reward_signal=reward_signal,
leverage=trade_record.get('leverage', 1.0)
)
return case
except Exception as e:
logger.error(f"Error creating training case: {e}")
return None
def _maybe_trigger_training(self):
"""Check if we should trigger a training session"""
try:
# Check if we have enough cases
if len(self.completed_cases) < self.min_cases_for_training:
return
# Check if enough time has passed since last training
time_since_last = (datetime.now() - self.training_stats['last_training_time']).total_seconds()
if time_since_last < self.training_frequency:
return
# Check if training thread is not already running
if self.is_training_active:
logger.debug("Training already in progress, skipping trigger")
return
# Start training in background thread
self._start_training_session()
except Exception as e:
logger.error(f"Error checking training trigger: {e}")
def _start_training_session(self):
"""Start a training session in background thread"""
try:
if self.training_thread and self.training_thread.is_alive():
logger.debug("Training thread already running")
return
self.training_thread = threading.Thread(
target=self._run_training_session,
daemon=True,
name="RetrospectiveTrainer"
)
self.training_thread.start()
logger.info("RETROSPECTIVE: Started training session")
except Exception as e:
logger.error(f"Error starting training session: {e}")
def _run_training_session(self):
"""Run a complete training session"""
try:
with self.training_lock:
self.is_training_active = True
start_time = time.time()
logger.info(f"RETROSPECTIVE: Training with {len(self.completed_cases)} cases")
# Train models if orchestrator available
training_results = {}
if self.orchestrator:
training_results = self._train_models()
# Update statistics
self.training_stats['last_training_time'] = datetime.now()
self.training_stats['training_sessions'] += 1
self.training_stats['model_updates'] += len(training_results)
elapsed_time = time.time() - start_time
logger.info(f"RETROSPECTIVE: Training completed in {elapsed_time:.2f}s - {training_results}")
except Exception as e:
logger.error(f"Error in retrospective training session: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
self.is_training_active = False
def _train_models(self) -> Dict[str, Any]:
"""Train available models using retrospective data"""
results = {}
try:
# Prepare training data
profitable_cases = [c for c in self.completed_cases if c.outcome_label == 1]
loss_cases = [c for c in self.completed_cases if c.outcome_label == 0]
if len(profitable_cases) == 0 and len(loss_cases) == 0:
return {'error': 'No labeled cases for training'}
logger.info(f"RETROSPECTIVE: Training data - Profitable: {len(profitable_cases)}, Loss: {len(loss_cases)}")
# Train DQN agent if available
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
try:
dqn_result = self._train_dqn_retrospective()
results['dqn'] = dqn_result
logger.info(f"RETROSPECTIVE: DQN training result: {dqn_result}")
except Exception as e:
logger.warning(f"DQN retrospective training failed: {e}")
results['dqn'] = {'error': str(e)}
# Train other models
if self.orchestrator and hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
try:
# Update extrema trainer with retrospective feedback
extrema_feedback = self._create_extrema_feedback()
if extrema_feedback:
results['extrema'] = {'feedback_cases': len(extrema_feedback)}
logger.info(f"RETROSPECTIVE: Extrema feedback provided for {len(extrema_feedback)} cases")
except Exception as e:
logger.warning(f"Extrema retrospective training failed: {e}")
return results
except Exception as e:
logger.error(f"Error training models retrospectively: {e}")
return {'error': str(e)}
def _train_dqn_retrospective(self) -> Dict[str, Any]:
"""Train DQN agent using retrospective experience replay"""
try:
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
return {'error': 'DQN agent not available'}
dqn_agent = self.orchestrator.rl_agent
experiences_added = 0
# Add retrospective experiences to DQN replay buffer
for case in self.completed_cases:
try:
# Extract state from model inputs
state = self._extract_state_vector(case.model_inputs)
if state is None:
continue
# Action mapping: BUY=0, SELL=1
action = 0 if case.action == 'BUY' else 1
# Use reward signal as immediate reward
reward = case.reward_signal
# For retrospective training, next_state is None (terminal)
next_state = np.zeros_like(state) # Terminal state
done = True
# Add experience to DQN replay buffer
if hasattr(dqn_agent, 'add_experience'):
dqn_agent.add_experience(state, action, reward, next_state, done)
experiences_added += 1
except Exception as e:
logger.debug(f"Error adding DQN experience: {e}")
continue
# Train DQN if we have enough experiences
if experiences_added > 0 and hasattr(dqn_agent, 'train'):
try:
# Perform multiple training steps on retrospective data
training_steps = min(10, experiences_added // 4) # Conservative training
for _ in range(training_steps):
loss = dqn_agent.train()
if loss is None:
break
return {
'experiences_added': experiences_added,
'training_steps': training_steps,
'method': 'retrospective_experience_replay'
}
except Exception as e:
logger.warning(f"DQN training step failed: {e}")
return {'experiences_added': experiences_added, 'training_error': str(e)}
return {'experiences_added': experiences_added, 'training_steps': 0}
except Exception as e:
logger.error(f"Error in DQN retrospective training: {e}")
return {'error': str(e)}
def _extract_state_vector(self, model_inputs: Dict[str, Any]) -> Optional[np.ndarray]:
"""Extract state vector for DQN training from model inputs"""
try:
# Try to get pre-built RL state
if 'dqn_state' in model_inputs:
state = model_inputs['dqn_state']
if isinstance(state, dict) and 'state_vector' in state:
return np.array(state['state_vector'])
# Build state from market features
market_state = model_inputs.get('market_state', {})
features = []
# Price features
for key in ['current_price', 'price_sma_5', 'price_sma_20', 'price_std_20', 'price_rsi']:
features.append(market_state.get(key, 0.0))
# Volume features
for key in ['volume_current', 'volume_sma_20', 'volume_ratio']:
features.append(market_state.get(key, 0.0))
# Technical indicators
indicators = model_inputs.get('technical_indicators', {})
for key in ['sma_10', 'sma_20', 'bb_upper', 'bb_lower', 'bb_position', 'macd', 'volatility']:
features.append(indicators.get(key, 0.0))
if len(features) < 5: # Minimum required features
return None
return np.array(features, dtype=np.float32)
except Exception as e:
logger.debug(f"Error extracting state vector: {e}")
return None
def _create_extrema_feedback(self) -> List[Dict[str, Any]]:
"""Create feedback data for extrema trainer"""
feedback = []
try:
for case in self.completed_cases:
if case.outcome_label in [0, 1]: # Only profit/loss cases
feedback_item = {
'symbol': case.symbol,
'action': case.action,
'entry_price': case.entry_price,
'exit_price': case.exit_price,
'was_profitable': case.outcome_label == 1,
'reward_signal': case.reward_signal,
'market_state': case.market_state
}
feedback.append(feedback_item)
return feedback
except Exception as e:
logger.error(f"Error creating extrema feedback: {e}")
return []
def get_training_stats(self) -> Dict[str, Any]:
"""Get current training statistics"""
stats = self.training_stats.copy()
stats['total_cases_in_memory'] = len(self.completed_cases)
stats['training_queue_size'] = self.training_queue.qsize()
stats['is_training_active'] = self.is_training_active
# Calculate profit metrics
if len(self.completed_cases) > 0:
profitable_count = sum(1 for c in self.completed_cases if c.pnl > 0)
stats['profit_rate'] = profitable_count / len(self.completed_cases)
stats['total_pnl'] = sum(c.pnl for c in self.completed_cases)
stats['avg_reward'] = sum(c.reward_signal for c in self.completed_cases) / len(self.completed_cases)
return stats
def force_training_session(self) -> bool:
"""Force a training session regardless of timing constraints"""
try:
if self.is_training_active:
logger.warning("Training already in progress")
return False
if len(self.completed_cases) < 1:
logger.warning("No completed cases available for training")
return False
logger.info("RETROSPECTIVE: Forcing training session")
self._start_training_session()
return True
except Exception as e:
logger.error(f"Error forcing training session: {e}")
return False
def stop(self):
"""Stop the retrospective trainer"""
try:
self.is_training_active = False
if self.training_thread and self.training_thread.is_alive():
self.training_thread.join(timeout=10)
logger.info("RetrospectiveTrainer stopped")
except Exception as e:
logger.error(f"Error stopping RetrospectiveTrainer: {e}")
def create_retrospective_trainer(orchestrator=None, config: Optional[Dict[str, Any]] = None) -> RetrospectiveTrainer:
"""Factory function to create a RetrospectiveTrainer instance"""
return RetrospectiveTrainer(orchestrator=orchestrator, config=config)

View File

@ -1,529 +0,0 @@
"""
RL Training Pipeline with Comprehensive Experience Storage and Replay
This module implements a robust RL training pipeline that:
1. Stores all training experiences with profitability metrics
2. Implements profit-weighted experience replay
3. Tracks gradient information for each training step
4. Enables retraining on most profitable trading sequences
5. Maintains comprehensive trading episode analysis
"""
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, field
import json
import pickle
from collections import deque
import threading
import random
from .training_data_collector import get_training_data_collector
logger = logging.getLogger(__name__)
@dataclass
class RLExperience:
"""Single RL experience with complete state-action-reward information"""
experience_id: str
timestamp: datetime
episode_id: str
# Core RL components
state: np.ndarray
action: int # 0=SELL, 1=HOLD, 2=BUY
reward: float
next_state: np.ndarray
done: bool
# Extended state information
market_context: Dict[str, Any]
cnn_predictions: Optional[Dict[str, Any]] = None
confidence_score: float = 0.0
# Actual trading outcome
actual_profit: Optional[float] = None
actual_holding_time: Optional[timedelta] = None
optimal_action: Optional[int] = None
# Experience value for replay
experience_value: float = 0.0
profitability_score: float = 0.0
learning_priority: float = 0.0
# Training metadata
times_trained: int = 0
last_trained: Optional[datetime] = None
class ProfitWeightedExperienceBuffer:
"""Experience buffer with profit-weighted sampling for replay"""
def __init__(self, max_size: int = 100000):
self.max_size = max_size
self.experiences: Dict[str, RLExperience] = {}
self.experience_order: deque = deque(maxlen=max_size)
self.profitable_experiences: List[str] = []
self.total_experiences = 0
self.total_profitable = 0
def add_experience(self, experience: RLExperience):
"""Add experience to buffer"""
try:
self.experiences[experience.experience_id] = experience
self.experience_order.append(experience.experience_id)
if experience.actual_profit is not None and experience.actual_profit > 0:
self.profitable_experiences.append(experience.experience_id)
self.total_profitable += 1
# Remove oldest if buffer is full
if len(self.experiences) > self.max_size:
oldest_id = self.experience_order[0]
if oldest_id in self.experiences:
del self.experiences[oldest_id]
if oldest_id in self.profitable_experiences:
self.profitable_experiences.remove(oldest_id)
self.total_experiences += 1
except Exception as e:
logger.error(f"Error adding experience to buffer: {e}")
def sample_batch(self, batch_size: int, prioritize_profitable: bool = True) -> List[RLExperience]:
"""Sample batch with profit-weighted prioritization"""
try:
if len(self.experiences) < batch_size:
return list(self.experiences.values())
if prioritize_profitable and len(self.profitable_experiences) > batch_size // 2:
# Sample mix of profitable and all experiences
profitable_sample_size = min(batch_size // 2, len(self.profitable_experiences))
remaining_sample_size = batch_size - profitable_sample_size
profitable_ids = random.sample(self.profitable_experiences, profitable_sample_size)
all_ids = list(self.experiences.keys())
remaining_ids = random.sample(all_ids, remaining_sample_size)
sampled_ids = profitable_ids + remaining_ids
else:
# Random sampling from all experiences
all_ids = list(self.experiences.keys())
sampled_ids = random.sample(all_ids, batch_size)
sampled_experiences = [self.experiences[exp_id] for exp_id in sampled_ids]
# Update training counts
for experience in sampled_experiences:
experience.times_trained += 1
experience.last_trained = datetime.now()
return sampled_experiences
except Exception as e:
logger.error(f"Error sampling batch: {e}")
return list(self.experiences.values())[:batch_size]
def get_most_profitable_experiences(self, limit: int = 100) -> List[RLExperience]:
"""Get most profitable experiences for targeted training"""
try:
profitable_experiences = [
self.experiences[exp_id] for exp_id in self.profitable_experiences
if exp_id in self.experiences
]
profitable_experiences.sort(
key=lambda x: x.actual_profit if x.actual_profit else 0,
reverse=True
)
return profitable_experiences[:limit]
except Exception as e:
logger.error(f"Error getting profitable experiences: {e}")
return []
class RLTradingAgent(nn.Module):
"""RL Trading Agent with comprehensive state processing"""
def __init__(self, state_dim: int = 2000, action_dim: int = 3, hidden_dim: int = 512):
super(RLTradingAgent, self).__init__()
self.state_dim = state_dim
self.action_dim = action_dim
self.hidden_dim = hidden_dim
# State processing network
self.state_processor = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.ReLU()
)
# Q-value network
self.q_network = nn.Sequential(
nn.Linear(hidden_dim // 2, hidden_dim // 4),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 4, action_dim)
)
# Policy network
self.policy_network = nn.Sequential(
nn.Linear(hidden_dim // 2, hidden_dim // 4),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 4, action_dim),
nn.Softmax(dim=-1)
)
# Value network
self.value_network = nn.Sequential(
nn.Linear(hidden_dim // 2, hidden_dim // 4),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 4, 1)
)
def forward(self, state):
"""Forward pass through the agent"""
processed_state = self.state_processor(state)
q_values = self.q_network(processed_state)
policy_probs = self.policy_network(processed_state)
state_value = self.value_network(processed_state)
return {
'q_values': q_values,
'policy_probs': policy_probs,
'state_value': state_value,
'processed_state': processed_state
}
def select_action(self, state, epsilon: float = 0.1) -> Tuple[int, float]:
"""Select action using epsilon-greedy policy"""
self.eval()
with torch.no_grad():
if isinstance(state, np.ndarray):
state = torch.from_numpy(state).float().unsqueeze(0)
outputs = self.forward(state)
if random.random() < epsilon:
action = random.randint(0, self.action_dim - 1)
confidence = 0.33
else:
q_values = outputs['q_values']
action = torch.argmax(q_values, dim=1).item()
q_softmax = F.softmax(q_values, dim=1)
confidence = torch.max(q_softmax).item()
return action, confidence
@dataclass
class RLTrainingStep:
"""Single RL training step with backpropagation data"""
step_id: str
timestamp: datetime
batch_experiences: List[str]
# Training data
total_loss: float
q_loss: float
policy_loss: float
# Gradients
gradients: Dict[str, torch.Tensor]
gradient_norms: Dict[str, float]
# Metadata
learning_rate: float = 0.001
batch_size: int = 32
# Performance
batch_profitability: float = 0.0
correct_actions: int = 0
total_actions: int = 0
step_value: float = 0.0
@dataclass
class RLTrainingSession:
"""Complete RL training session"""
session_id: str
start_timestamp: datetime
end_timestamp: Optional[datetime] = None
training_mode: str = 'experience_replay'
symbol: str = ''
training_steps: List[RLTrainingStep] = field(default_factory=list)
total_steps: int = 0
average_loss: float = 0.0
best_loss: float = float('inf')
profitable_actions: int = 0
total_actions: int = 0
profitability_rate: float = 0.0
session_value: float = 0.0
class RLTrainer:
"""RL trainer with comprehensive experience storage and replay"""
def __init__(self, agent: RLTradingAgent, device: str = 'cuda', storage_dir: str = "rl_training_storage"):
self.agent = agent.to(device)
self.device = device
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(parents=True, exist_ok=True)
self.optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001)
self.experience_buffer = ProfitWeightedExperienceBuffer()
self.data_collector = get_training_data_collector()
self.training_sessions: List[RLTrainingSession] = []
self.current_session: Optional[RLTrainingSession] = None
self.gamma = 0.99
self.training_stats = {
'total_sessions': 0,
'total_steps': 0,
'total_experiences': 0,
'profitable_actions': 0,
'total_actions': 0,
'average_reward': 0.0
}
logger.info(f"RL Trainer initialized with {sum(p.numel() for p in agent.parameters()):,} parameters")
def add_experience(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool, market_context: Dict[str, Any],
cnn_predictions: Dict[str, Any] = None, confidence_score: float = 0.0) -> str:
"""Add experience to the buffer"""
try:
experience_id = f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
experience = RLExperience(
experience_id=experience_id,
timestamp=datetime.now(),
episode_id=market_context.get('episode_id', 'unknown'),
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
market_context=market_context,
cnn_predictions=cnn_predictions,
confidence_score=confidence_score
)
self.experience_buffer.add_experience(experience)
self.training_stats['total_experiences'] += 1
return experience_id
except Exception as e:
logger.error(f"Error adding experience: {e}")
return None
def train_on_experiences(self, batch_size: int = 32, num_batches: int = 10) -> Dict[str, Any]:
"""Train on experiences with comprehensive data storage"""
try:
session = RLTrainingSession(
session_id=f"rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
start_timestamp=datetime.now(),
training_mode='experience_replay'
)
self.current_session = session
self.agent.train()
total_loss = 0.0
for batch_idx in range(num_batches):
experiences = self.experience_buffer.sample_batch(batch_size, True)
if len(experiences) < batch_size:
continue
# Prepare batch tensors
states = torch.FloatTensor([exp.state for exp in experiences]).to(self.device)
actions = torch.LongTensor([exp.action for exp in experiences]).to(self.device)
rewards = torch.FloatTensor([exp.reward for exp in experiences]).to(self.device)
next_states = torch.FloatTensor([exp.next_state for exp in experiences]).to(self.device)
dones = torch.BoolTensor([exp.done for exp in experiences]).to(self.device)
# Forward pass
self.optimizer.zero_grad()
current_outputs = self.agent(states)
current_q_values = current_outputs['q_values']
# Calculate target Q-values
with torch.no_grad():
next_outputs = self.agent(next_states)
next_q_values = next_outputs['q_values']
max_next_q_values = torch.max(next_q_values, dim=1)[0]
target_q_values = rewards + (self.gamma * max_next_q_values * ~dones)
# Calculate loss
current_q_values_for_actions = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
q_loss = F.mse_loss(current_q_values_for_actions, target_q_values)
# Backward pass
q_loss.backward()
# Store gradients
gradients = {}
gradient_norms = {}
for name, param in self.agent.named_parameters():
if param.grad is not None:
gradients[name] = param.grad.clone().detach()
gradient_norms[name] = param.grad.norm().item()
torch.nn.utils.clip_grad_norm_(self.agent.parameters(), max_norm=1.0)
self.optimizer.step()
# Create training step record
step = RLTrainingStep(
step_id=f"{session.session_id}_step_{batch_idx}",
timestamp=datetime.now(),
batch_experiences=[exp.experience_id for exp in experiences],
total_loss=q_loss.item(),
q_loss=q_loss.item(),
policy_loss=0.0,
gradients=gradients,
gradient_norms=gradient_norms,
batch_size=len(experiences)
)
session.training_steps.append(step)
total_loss += q_loss.item()
# Finalize session
session.end_timestamp = datetime.now()
session.total_steps = num_batches
session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0
self._save_training_session(session)
self.training_stats['total_sessions'] += 1
self.training_stats['total_steps'] += session.total_steps
logger.info(f"RL training session completed: {session.session_id}")
logger.info(f"Average loss: {session.average_loss:.4f}")
return {
'status': 'success',
'session_id': session.session_id,
'average_loss': session.average_loss,
'total_steps': session.total_steps
}
except Exception as e:
logger.error(f"Error in RL training session: {e}")
return {'status': 'error', 'error': str(e)}
finally:
self.current_session = None
def train_on_profitable_experiences(self, min_profitability: float = 0.1,
max_experiences: int = 1000, batch_size: int = 32) -> Dict[str, Any]:
"""Train specifically on most profitable experiences"""
try:
profitable_experiences = self.experience_buffer.get_most_profitable_experiences(max_experiences)
filtered_experiences = [
exp for exp in profitable_experiences
if exp.actual_profit is not None and exp.actual_profit >= min_profitability
]
if len(filtered_experiences) < batch_size:
return {'status': 'insufficient_data', 'experiences_found': len(filtered_experiences)}
logger.info(f"Training on {len(filtered_experiences)} profitable experiences")
num_batches = len(filtered_experiences) // batch_size
# Temporarily replace buffer sampling
original_sample_method = self.experience_buffer.sample_batch
def profitable_sample_batch(batch_size, prioritize_profitable=True):
return random.sample(filtered_experiences, min(batch_size, len(filtered_experiences)))
self.experience_buffer.sample_batch = profitable_sample_batch
try:
results = self.train_on_experiences(batch_size=batch_size, num_batches=num_batches)
results['training_mode'] = 'profitable_replay'
results['experiences_used'] = len(filtered_experiences)
return results
finally:
self.experience_buffer.sample_batch = original_sample_method
except Exception as e:
logger.error(f"Error training on profitable experiences: {e}")
return {'status': 'error', 'error': str(e)}
def _save_training_session(self, session: RLTrainingSession):
"""Save training session to disk"""
try:
session_dir = self.storage_dir / 'sessions'
session_dir.mkdir(parents=True, exist_ok=True)
session_file = session_dir / f"{session.session_id}.pkl"
with open(session_file, 'wb') as f:
pickle.dump(session, f)
metadata = {
'session_id': session.session_id,
'start_timestamp': session.start_timestamp.isoformat(),
'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None,
'training_mode': session.training_mode,
'total_steps': session.total_steps,
'average_loss': session.average_loss
}
metadata_file = session_dir / f"{session.session_id}_metadata.json"
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2)
except Exception as e:
logger.error(f"Error saving training session: {e}")
def get_training_statistics(self) -> Dict[str, Any]:
"""Get comprehensive training statistics"""
stats = self.training_stats.copy()
if self.training_sessions:
recent_sessions = sorted(self.training_sessions, key=lambda x: x.start_timestamp, reverse=True)[:10]
stats['recent_sessions'] = [
{
'session_id': s.session_id,
'timestamp': s.start_timestamp.isoformat(),
'mode': s.training_mode,
'average_loss': s.average_loss
}
for s in recent_sessions
]
return stats
# Global instance
rl_trainer = None
def get_rl_trainer(agent: RLTradingAgent = None) -> RLTrainer:
"""Get global RL trainer instance"""
global rl_trainer
if rl_trainer is None:
if agent is None:
agent = RLTradingAgent()
rl_trainer = RLTrainer(agent)
return rl_trainer

View File

@ -1,460 +0,0 @@
"""
Robust COB (Consolidated Order Book) Provider
This module provides a robust COB data provider that handles:
- HTTP 418 errors from Binance (rate limiting)
- Thread safety issues
- API rate limiting and backoff
- Fallback data sources
- Error recovery strategies
Features:
- Automatic rate limiting and backoff
- Multiple exchange support with fallbacks
- Thread-safe operations
- Comprehensive error handling
- Data validation and integrity checking
"""
import asyncio
import logging
import time
import threading
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any, Callable
from dataclasses import dataclass, field
from collections import deque
import json
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
import requests
from .api_rate_limiter import get_rate_limiter, RateLimitConfig
logger = logging.getLogger(__name__)
@dataclass
class COBData:
"""Consolidated Order Book data structure"""
symbol: str
timestamp: datetime
bids: List[Tuple[float, float]] # [(price, quantity), ...]
asks: List[Tuple[float, float]] # [(price, quantity), ...]
# Derived metrics
spread: float = 0.0
mid_price: float = 0.0
total_bid_volume: float = 0.0
total_ask_volume: float = 0.0
# Data quality
data_source: str = 'unknown'
quality_score: float = 1.0
def __post_init__(self):
"""Calculate derived metrics"""
if self.bids and self.asks:
self.spread = self.asks[0][0] - self.bids[0][0]
self.mid_price = (self.asks[0][0] + self.bids[0][0]) / 2
self.total_bid_volume = sum(qty for _, qty in self.bids)
self.total_ask_volume = sum(qty for _, qty in self.asks)
# Calculate quality score based on data completeness
self.quality_score = min(
len(self.bids) / 20, # Expect at least 20 bid levels
len(self.asks) / 20, # Expect at least 20 ask levels
1.0
)
class RobustCOBProvider:
"""Robust COB provider with error handling and rate limiting"""
def __init__(self, symbols: List[str] = None):
self.symbols = symbols or ['ETHUSDT', 'BTCUSDT']
# Rate limiter
self.rate_limiter = get_rate_limiter()
# Thread safety
self.lock = threading.RLock()
# Data cache
self.cob_cache: Dict[str, COBData] = {}
self.cache_timestamps: Dict[str, datetime] = {}
self.cache_ttl = timedelta(seconds=5) # 5 second cache TTL
# Error tracking
self.error_counts: Dict[str, int] = {}
self.last_successful_fetch: Dict[str, datetime] = {}
# Background fetching
self.is_running = False
self.fetch_threads: Dict[str, threading.Thread] = {}
self.executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="COB-Fetcher")
# Fallback data
self.fallback_data: Dict[str, COBData] = {}
# Performance tracking
self.fetch_stats = {
'total_requests': 0,
'successful_requests': 0,
'failed_requests': 0,
'rate_limited_requests': 0,
'cache_hits': 0,
'fallback_uses': 0
}
logger.info(f"Robust COB Provider initialized for symbols: {self.symbols}")
def start_background_fetching(self):
"""Start background COB data fetching"""
if self.is_running:
logger.warning("Background fetching already running")
return
self.is_running = True
# Start fetching thread for each symbol
for symbol in self.symbols:
thread = threading.Thread(
target=self._background_fetch_worker,
args=(symbol,),
name=f"COB-{symbol}",
daemon=True
)
self.fetch_threads[symbol] = thread
thread.start()
logger.info(f"Started background COB fetching for {len(self.symbols)} symbols")
def stop_background_fetching(self):
"""Stop background COB data fetching"""
self.is_running = False
# Wait for threads to finish
for symbol, thread in self.fetch_threads.items():
thread.join(timeout=5)
logger.debug(f"Stopped COB fetching for {symbol}")
# Shutdown executor
self.executor.shutdown(wait=True, timeout=10)
logger.info("Stopped background COB fetching")
def _background_fetch_worker(self, symbol: str):
"""Background worker for fetching COB data"""
logger.info(f"Started COB fetching worker for {symbol}")
while self.is_running:
try:
# Fetch COB data
cob_data = self._fetch_cob_data_safe(symbol)
if cob_data:
with self.lock:
self.cob_cache[symbol] = cob_data
self.cache_timestamps[symbol] = datetime.now()
self.last_successful_fetch[symbol] = datetime.now()
self.error_counts[symbol] = 0 # Reset error count on success
logger.debug(f"Updated COB cache for {symbol}")
else:
with self.lock:
self.error_counts[symbol] = self.error_counts.get(symbol, 0) + 1
logger.debug(f"Failed to fetch COB for {symbol}, error count: {self.error_counts.get(symbol, 0)}")
# Wait before next fetch (adaptive based on errors)
error_count = self.error_counts.get(symbol, 0)
base_interval = 2.0 # Base 2 second interval
backoff_interval = min(base_interval * (2 ** min(error_count, 5)), 60.0) # Max 60s
time.sleep(backoff_interval)
except Exception as e:
logger.error(f"Error in COB fetching worker for {symbol}: {e}")
time.sleep(10) # Wait 10s on unexpected errors
logger.info(f"Stopped COB fetching worker for {symbol}")
def _fetch_cob_data_safe(self, symbol: str) -> Optional[COBData]:
"""Safely fetch COB data with error handling"""
try:
self.fetch_stats['total_requests'] += 1
# Try Binance first
cob_data = self._fetch_binance_cob(symbol)
if cob_data:
self.fetch_stats['successful_requests'] += 1
return cob_data
# Try MEXC as fallback
cob_data = self._fetch_mexc_cob(symbol)
if cob_data:
self.fetch_stats['successful_requests'] += 1
cob_data.data_source = 'mexc_fallback'
return cob_data
# Use cached fallback data if available
if symbol in self.fallback_data:
self.fetch_stats['fallback_uses'] += 1
fallback = self.fallback_data[symbol]
fallback.timestamp = datetime.now()
fallback.data_source = 'fallback_cache'
fallback.quality_score *= 0.5 # Reduce quality score for old data
return fallback
self.fetch_stats['failed_requests'] += 1
return None
except Exception as e:
logger.error(f"Error fetching COB data for {symbol}: {e}")
self.fetch_stats['failed_requests'] += 1
return None
def _fetch_binance_cob(self, symbol: str) -> Optional[COBData]:
"""Fetch COB data from Binance with rate limiting"""
try:
url = f"https://api.binance.com/api/v3/depth"
params = {
'symbol': symbol,
'limit': 100 # Get 100 levels
}
# Use rate limiter
response = self.rate_limiter.make_request(
'binance_api',
url,
method='GET',
params=params
)
if not response:
self.fetch_stats['rate_limited_requests'] += 1
return None
if response.status_code != 200:
logger.warning(f"Binance COB API returned {response.status_code} for {symbol}")
return None
data = response.json()
# Parse order book data
bids = [(float(price), float(qty)) for price, qty in data.get('bids', [])]
asks = [(float(price), float(qty)) for price, qty in data.get('asks', [])]
if not bids or not asks:
logger.warning(f"Empty order book data from Binance for {symbol}")
return None
cob_data = COBData(
symbol=symbol,
timestamp=datetime.now(),
bids=bids,
asks=asks,
data_source='binance'
)
# Store as fallback for future use
self.fallback_data[symbol] = cob_data
return cob_data
except Exception as e:
logger.error(f"Error fetching Binance COB for {symbol}: {e}")
return None
def _fetch_mexc_cob(self, symbol: str) -> Optional[COBData]:
"""Fetch COB data from MEXC as fallback"""
try:
url = f"https://api.mexc.com/api/v3/depth"
params = {
'symbol': symbol,
'limit': 100
}
response = self.rate_limiter.make_request(
'mexc_api',
url,
method='GET',
params=params
)
if not response or response.status_code != 200:
return None
data = response.json()
# Parse order book data
bids = [(float(price), float(qty)) for price, qty in data.get('bids', [])]
asks = [(float(price), float(qty)) for price, qty in data.get('asks', [])]
if not bids or not asks:
return None
return COBData(
symbol=symbol,
timestamp=datetime.now(),
bids=bids,
asks=asks,
data_source='mexc'
)
except Exception as e:
logger.debug(f"Error fetching MEXC COB for {symbol}: {e}")
return None
def get_cob_data(self, symbol: str) -> Optional[COBData]:
"""Get COB data for a symbol (from cache or fresh fetch)"""
with self.lock:
# Check cache first
if symbol in self.cob_cache:
cached_data = self.cob_cache[symbol]
cache_time = self.cache_timestamps.get(symbol, datetime.min)
# Return cached data if still fresh
if datetime.now() - cache_time < self.cache_ttl:
self.fetch_stats['cache_hits'] += 1
return cached_data
# If background fetching is running, return cached data even if stale
if self.is_running and symbol in self.cob_cache:
return self.cob_cache[symbol]
# Fetch fresh data if not running background fetching
if not self.is_running:
return self._fetch_cob_data_safe(symbol)
return None
def get_cob_features(self, symbol: str, feature_count: int = 120) -> Optional[np.ndarray]:
"""
Get COB features for ML models
Args:
symbol: Trading symbol
feature_count: Number of features to return
Returns:
Numpy array of COB features or None if no data
"""
cob_data = self.get_cob_data(symbol)
if not cob_data:
return None
try:
features = []
# Basic market metrics
features.extend([
cob_data.mid_price,
cob_data.spread,
cob_data.total_bid_volume,
cob_data.total_ask_volume,
cob_data.quality_score
])
# Bid levels (price and volume)
max_levels = min(len(cob_data.bids), 20)
for i in range(max_levels):
price, volume = cob_data.bids[i]
features.extend([price, volume])
# Pad bid levels if needed
for i in range(max_levels, 20):
features.extend([0.0, 0.0])
# Ask levels (price and volume)
max_levels = min(len(cob_data.asks), 20)
for i in range(max_levels):
price, volume = cob_data.asks[i]
features.extend([price, volume])
# Pad ask levels if needed
for i in range(max_levels, 20):
features.extend([0.0, 0.0])
# Calculate additional features
if len(cob_data.bids) > 0 and len(cob_data.asks) > 0:
# Volume imbalance
bid_volume_5 = sum(vol for _, vol in cob_data.bids[:5])
ask_volume_5 = sum(vol for _, vol in cob_data.asks[:5])
volume_imbalance = (bid_volume_5 - ask_volume_5) / (bid_volume_5 + ask_volume_5) if (bid_volume_5 + ask_volume_5) > 0 else 0
features.append(volume_imbalance)
# Price levels
bid_price_levels = [price for price, _ in cob_data.bids[:10]]
ask_price_levels = [price for price, _ in cob_data.asks[:10]]
features.extend(bid_price_levels + ask_price_levels)
# Pad or truncate to desired feature count
if len(features) < feature_count:
features.extend([0.0] * (feature_count - len(features)))
else:
features = features[:feature_count]
return np.array(features, dtype=np.float32)
except Exception as e:
logger.error(f"Error creating COB features for {symbol}: {e}")
return None
def get_provider_status(self) -> Dict[str, Any]:
"""Get provider status and statistics"""
with self.lock:
status = {
'is_running': self.is_running,
'symbols': self.symbols,
'cache_status': {},
'error_counts': self.error_counts.copy(),
'last_successful_fetch': {
symbol: timestamp.isoformat()
for symbol, timestamp in self.last_successful_fetch.items()
},
'fetch_stats': self.fetch_stats.copy(),
'rate_limiter_status': self.rate_limiter.get_all_endpoint_status()
}
# Cache status for each symbol
for symbol in self.symbols:
cache_time = self.cache_timestamps.get(symbol)
status['cache_status'][symbol] = {
'has_data': symbol in self.cob_cache,
'cache_time': cache_time.isoformat() if cache_time else None,
'cache_age_seconds': (datetime.now() - cache_time).total_seconds() if cache_time else None,
'data_quality': self.cob_cache[symbol].quality_score if symbol in self.cob_cache else 0.0
}
return status
def reset_errors(self):
"""Reset error counts and rate limiter"""
with self.lock:
self.error_counts.clear()
self.rate_limiter.reset_all_endpoints()
logger.info("Reset all error counts and rate limiter")
def force_refresh(self, symbol: str = None):
"""Force refresh COB data for symbol(s)"""
symbols_to_refresh = [symbol] if symbol else self.symbols
for sym in symbols_to_refresh:
# Clear cache to force refresh
with self.lock:
if sym in self.cob_cache:
del self.cob_cache[sym]
if sym in self.cache_timestamps:
del self.cache_timestamps[sym]
logger.info(f"Forced refresh for {sym}")
# Global COB provider instance
_global_cob_provider = None
def get_cob_provider(symbols: List[str] = None) -> RobustCOBProvider:
"""Get global COB provider instance"""
global _global_cob_provider
if _global_cob_provider is None:
_global_cob_provider = RobustCOBProvider(symbols)
return _global_cob_provider

View File

@ -1,350 +0,0 @@
#!/usr/bin/env python3
"""
Shared COB Service - Eliminates Redundant COB Implementations
This service provides a singleton COB integration that can be shared across:
- Dashboard components
- RL trading systems
- Enhanced orchestrators
- Training pipelines
Instead of each component creating its own COBIntegration instance,
they all share this single service, eliminating redundant connections.
"""
import asyncio
import logging
import weakref
from typing import Dict, List, Optional, Any, Callable, Set
from datetime import datetime
from threading import Lock
from dataclasses import dataclass
from .cob_integration import COBIntegration
from .multi_exchange_cob_provider import COBSnapshot
from .data_provider import DataProvider
logger = logging.getLogger(__name__)
@dataclass
class COBSubscription:
"""Represents a subscription to COB updates"""
subscriber_id: str
callback: Callable
symbol_filter: Optional[List[str]] = None
callback_type: str = "general" # general, cnn, dqn, dashboard
class SharedCOBService:
"""
Shared COB Service - Singleton pattern for unified COB data access
This service eliminates redundant COB integrations by providing a single
shared instance that all components can subscribe to.
"""
_instance: Optional['SharedCOBService'] = None
_lock = Lock()
def __new__(cls, *args, **kwargs):
"""Singleton pattern implementation"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(SharedCOBService, cls).__new__(cls)
return cls._instance
def __init__(self, symbols: Optional[List[str]] = None, data_provider: Optional[DataProvider] = None):
"""Initialize shared COB service (only called once due to singleton)"""
if hasattr(self, '_initialized'):
return
self.symbols = symbols or ['BTC/USDT', 'ETH/USDT']
self.data_provider = data_provider
# Single COB integration instance
self.cob_integration: Optional[COBIntegration] = None
self.is_running = False
# Subscriber management
self.subscribers: Dict[str, COBSubscription] = {}
self.subscriber_counter = 0
self.subscription_lock = Lock()
# Cached data for immediate access
self.latest_snapshots: Dict[str, COBSnapshot] = {}
self.latest_cnn_features: Dict[str, Any] = {}
self.latest_dqn_states: Dict[str, Any] = {}
# Performance tracking
self.total_subscribers = 0
self.update_count = 0
self.start_time = None
self._initialized = True
logger.info(f"SharedCOBService initialized for symbols: {self.symbols}")
async def start(self) -> None:
"""Start the shared COB service"""
if self.is_running:
logger.warning("SharedCOBService already running")
return
logger.info("Starting SharedCOBService...")
try:
# Initialize COB integration if not already done
if self.cob_integration is None:
self.cob_integration = COBIntegration(
data_provider=self.data_provider,
symbols=self.symbols
)
# Register internal callbacks
self.cob_integration.add_cnn_callback(self._on_cob_cnn_update)
self.cob_integration.add_dqn_callback(self._on_cob_dqn_update)
self.cob_integration.add_dashboard_callback(self._on_cob_dashboard_update)
# Start COB integration
await self.cob_integration.start()
self.is_running = True
self.start_time = datetime.now()
logger.info("SharedCOBService started successfully")
logger.info(f"Active subscribers: {len(self.subscribers)}")
except Exception as e:
logger.error(f"Error starting SharedCOBService: {e}")
raise
async def stop(self) -> None:
"""Stop the shared COB service"""
if not self.is_running:
return
logger.info("Stopping SharedCOBService...")
try:
if self.cob_integration:
await self.cob_integration.stop()
self.is_running = False
# Notify all subscribers of shutdown
for subscription in self.subscribers.values():
try:
if hasattr(subscription.callback, '__call__'):
subscription.callback("SHUTDOWN", None)
except Exception as e:
logger.warning(f"Error notifying subscriber {subscription.subscriber_id}: {e}")
logger.info("SharedCOBService stopped")
except Exception as e:
logger.error(f"Error stopping SharedCOBService: {e}")
def subscribe(self,
callback: Callable,
callback_type: str = "general",
symbol_filter: Optional[List[str]] = None,
subscriber_name: str = None) -> str:
"""
Subscribe to COB updates
Args:
callback: Function to call on updates
callback_type: Type of callback ('general', 'cnn', 'dqn', 'dashboard')
symbol_filter: Only receive updates for these symbols (None = all)
subscriber_name: Optional name for the subscriber
Returns:
Subscription ID for unsubscribing
"""
with self.subscription_lock:
self.subscriber_counter += 1
subscriber_id = f"{callback_type}_{self.subscriber_counter}"
if subscriber_name:
subscriber_id = f"{subscriber_name}_{subscriber_id}"
subscription = COBSubscription(
subscriber_id=subscriber_id,
callback=callback,
symbol_filter=symbol_filter,
callback_type=callback_type
)
self.subscribers[subscriber_id] = subscription
self.total_subscribers += 1
logger.info(f"New subscriber: {subscriber_id} ({callback_type})")
logger.info(f"Total active subscribers: {len(self.subscribers)}")
return subscriber_id
def unsubscribe(self, subscriber_id: str) -> bool:
"""
Unsubscribe from COB updates
Args:
subscriber_id: ID returned from subscribe()
Returns:
True if successfully unsubscribed
"""
with self.subscription_lock:
if subscriber_id in self.subscribers:
del self.subscribers[subscriber_id]
logger.info(f"Unsubscribed: {subscriber_id}")
logger.info(f"Remaining subscribers: {len(self.subscribers)}")
return True
else:
logger.warning(f"Subscriber not found: {subscriber_id}")
return False
# Internal callback handlers
async def _on_cob_cnn_update(self, symbol: str, data: Dict):
"""Handle CNN feature updates from COB integration"""
try:
self.latest_cnn_features[symbol] = data
await self._notify_subscribers("cnn", symbol, data)
except Exception as e:
logger.error(f"Error in CNN update handler: {e}")
async def _on_cob_dqn_update(self, symbol: str, data: Dict):
"""Handle DQN state updates from COB integration"""
try:
self.latest_dqn_states[symbol] = data
await self._notify_subscribers("dqn", symbol, data)
except Exception as e:
logger.error(f"Error in DQN update handler: {e}")
async def _on_cob_dashboard_update(self, symbol: str, data: Dict):
"""Handle dashboard updates from COB integration"""
try:
# Store snapshot if it's a COBSnapshot
if hasattr(data, 'volume_weighted_mid'): # Duck typing for COBSnapshot
self.latest_snapshots[symbol] = data
await self._notify_subscribers("dashboard", symbol, data)
await self._notify_subscribers("general", symbol, data)
self.update_count += 1
except Exception as e:
logger.error(f"Error in dashboard update handler: {e}")
async def _notify_subscribers(self, callback_type: str, symbol: str, data: Any):
"""Notify all relevant subscribers of an update"""
try:
relevant_subscribers = [
sub for sub in self.subscribers.values()
if (sub.callback_type == callback_type or sub.callback_type == "general") and
(sub.symbol_filter is None or symbol in sub.symbol_filter)
]
for subscription in relevant_subscribers:
try:
if asyncio.iscoroutinefunction(subscription.callback):
asyncio.create_task(subscription.callback(symbol, data))
else:
subscription.callback(symbol, data)
except Exception as e:
logger.warning(f"Error notifying subscriber {subscription.subscriber_id}: {e}")
except Exception as e:
logger.error(f"Error notifying subscribers: {e}")
# Public data access methods
def get_cob_snapshot(self, symbol: str) -> Optional[COBSnapshot]:
"""Get latest COB snapshot for a symbol"""
if self.cob_integration:
return self.cob_integration.get_cob_snapshot(symbol)
return self.latest_snapshots.get(symbol)
def get_cnn_features(self, symbol: str) -> Optional[Any]:
"""Get latest CNN features for a symbol"""
return self.latest_cnn_features.get(symbol)
def get_dqn_state(self, symbol: str) -> Optional[Any]:
"""Get latest DQN state for a symbol"""
return self.latest_dqn_states.get(symbol)
def get_market_depth_analysis(self, symbol: str) -> Optional[Dict]:
"""Get detailed market depth analysis"""
if self.cob_integration:
return self.cob_integration.get_market_depth_analysis(symbol)
return None
def get_exchange_breakdown(self, symbol: str) -> Optional[Dict]:
"""Get liquidity breakdown by exchange"""
if self.cob_integration:
return self.cob_integration.get_exchange_breakdown(symbol)
return None
def get_price_buckets(self, symbol: str) -> Optional[Dict]:
"""Get fine-grain price buckets"""
if self.cob_integration:
return self.cob_integration.get_price_buckets(symbol)
return None
def get_session_volume_profile(self, symbol: str) -> Optional[Dict]:
"""Get session volume profile"""
if self.cob_integration and hasattr(self.cob_integration.cob_provider, 'get_session_volume_profile'):
return self.cob_integration.cob_provider.get_session_volume_profile(symbol)
return None
def get_realtime_stats_for_nn(self, symbol: str) -> Dict:
"""Get real-time statistics formatted for NN models"""
if self.cob_integration:
return self.cob_integration.get_realtime_stats_for_nn(symbol)
return {}
def get_service_statistics(self) -> Dict[str, Any]:
"""Get service statistics"""
uptime = None
if self.start_time:
uptime = (datetime.now() - self.start_time).total_seconds()
base_stats = {
'service_name': 'SharedCOBService',
'is_running': self.is_running,
'symbols': self.symbols,
'total_subscribers': len(self.subscribers),
'lifetime_subscribers': self.total_subscribers,
'update_count': self.update_count,
'uptime_seconds': uptime,
'subscribers_by_type': {}
}
# Count subscribers by type
for subscription in self.subscribers.values():
callback_type = subscription.callback_type
if callback_type not in base_stats['subscribers_by_type']:
base_stats['subscribers_by_type'][callback_type] = 0
base_stats['subscribers_by_type'][callback_type] += 1
# Get COB integration stats if available
if self.cob_integration:
cob_stats = self.cob_integration.get_statistics()
base_stats.update(cob_stats)
return base_stats
# Global service instance access functions
def get_shared_cob_service(symbols: List[str] = None, data_provider: DataProvider = None) -> SharedCOBService:
"""Get the shared COB service instance"""
return SharedCOBService(symbols=symbols, data_provider=data_provider)
async def start_shared_cob_service(symbols: List[str] = None, data_provider: DataProvider = None) -> SharedCOBService:
"""Start the shared COB service"""
service = get_shared_cob_service(symbols=symbols, data_provider=data_provider)
await service.start()
return service
async def stop_shared_cob_service():
"""Stop the shared COB service"""
service = get_shared_cob_service()
await service.stop()

View File

@ -1,425 +0,0 @@
"""
Shared Data Manager for UI Stability Fix
Manages data sharing between processes through files with proper locking
and atomic operations to prevent corruption and conflicts.
"""
import json
import os
import time
import tempfile
import platform
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import Dict, Any, Optional, Union
from pathlib import Path
import logging
# Windows-compatible file locking
if platform.system() == "Windows":
import msvcrt
else:
import fcntl
logger = logging.getLogger(__name__)
@dataclass
class ProcessStatus:
"""Model for process status information"""
name: str
pid: int
status: str # 'running', 'stopped', 'error'
start_time: datetime
last_heartbeat: datetime
memory_usage: float
cpu_usage: float
error_message: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary with datetime serialization"""
data = asdict(self)
data['start_time'] = self.start_time.isoformat()
data['last_heartbeat'] = self.last_heartbeat.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ProcessStatus':
"""Create from dictionary with datetime deserialization"""
data['start_time'] = datetime.fromisoformat(data['start_time'])
data['last_heartbeat'] = datetime.fromisoformat(data['last_heartbeat'])
return cls(**data)
@dataclass
class TrainingStatus:
"""Model for training status information"""
is_running: bool
current_epoch: int
total_epochs: int
loss: float
accuracy: float
last_update: datetime
model_path: str
error_message: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary with datetime serialization"""
data = asdict(self)
data['last_update'] = self.last_update.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'TrainingStatus':
"""Create from dictionary with datetime deserialization"""
data['last_update'] = datetime.fromisoformat(data['last_update'])
return cls(**data)
@dataclass
class DashboardState:
"""Model for dashboard state information"""
is_connected: bool
last_data_update: datetime
active_connections: int
error_count: int
performance_metrics: Dict[str, float]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary with datetime serialization"""
data = asdict(self)
data['last_data_update'] = self.last_data_update.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'DashboardState':
"""Create from dictionary with datetime deserialization"""
data['last_data_update'] = datetime.fromisoformat(data['last_data_update'])
return cls(**data)
class SharedDataManager:
"""
Manages data sharing between processes through files with proper locking
and atomic operations to prevent corruption and conflicts.
"""
def __init__(self, data_dir: str = "shared_data"):
"""
Initialize the shared data manager
Args:
data_dir: Directory to store shared data files
"""
self.data_dir = Path(data_dir)
self.data_dir.mkdir(exist_ok=True)
# Define file paths for different data types
self.training_status_file = self.data_dir / "training_status.json"
self.dashboard_state_file = self.data_dir / "dashboard_state.json"
self.process_status_file = self.data_dir / "process_status.json"
self.market_data_file = self.data_dir / "market_data.json"
self.model_metrics_file = self.data_dir / "model_metrics.json"
logger.info(f"SharedDataManager initialized with data directory: {self.data_dir}")
def _lock_file(self, file_handle, exclusive=True):
"""Cross-platform file locking"""
if platform.system() == "Windows":
# Windows file locking
try:
if exclusive:
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
else:
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
except IOError:
pass # File locking may not be available in all scenarios
else:
# Unix file locking
lock_type = fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH
fcntl.flock(file_handle.fileno(), lock_type)
def _unlock_file(self, file_handle):
"""Cross-platform file unlocking"""
if platform.system() == "Windows":
try:
msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1)
except IOError:
pass
else:
fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN)
def _write_json_atomic(self, file_path: Path, data: Dict[str, Any]) -> None:
"""
Write JSON data atomically with file locking
Args:
file_path: Path to the file to write
data: Data to write as JSON
"""
temp_path = None
try:
# Create temporary file in the same directory
temp_fd, temp_path = tempfile.mkstemp(
dir=file_path.parent,
prefix=f".{file_path.name}.",
suffix=".tmp"
)
with os.fdopen(temp_fd, 'w') as temp_file:
# Lock the temporary file
self._lock_file(temp_file, exclusive=True)
# Write data with proper formatting
json.dump(data, temp_file, indent=2, default=str)
temp_file.flush()
os.fsync(temp_file.fileno())
# Unlock before closing
self._unlock_file(temp_file)
# Atomically replace the original file
os.replace(temp_path, file_path)
logger.debug(f"Successfully wrote data to {file_path}")
except Exception as e:
# Clean up temporary file if it exists
if temp_path:
try:
os.unlink(temp_path)
except:
pass
logger.error(f"Failed to write data to {file_path}: {e}")
raise
def _read_json_safe(self, file_path: Path) -> Dict[str, Any]:
"""
Read JSON data safely with file locking
Args:
file_path: Path to the file to read
Returns:
Dictionary containing the JSON data
"""
if not file_path.exists():
logger.debug(f"File {file_path} does not exist, returning empty dict")
return {}
try:
with open(file_path, 'r') as file:
# Lock the file for reading
self._lock_file(file, exclusive=False)
data = json.load(file)
self._unlock_file(file)
logger.debug(f"Successfully read data from {file_path}")
return data
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in {file_path}: {e}")
return {}
except Exception as e:
logger.error(f"Failed to read data from {file_path}: {e}")
return {}
def write_training_status(self, status: TrainingStatus) -> None:
"""
Write training status to shared file
Args:
status: TrainingStatus object to write
"""
try:
data = status.to_dict()
self._write_json_atomic(self.training_status_file, data)
logger.debug("Training status written successfully")
except Exception as e:
logger.error(f"Failed to write training status: {e}")
raise
def read_training_status(self) -> Optional[TrainingStatus]:
"""
Read training status from shared file
Returns:
TrainingStatus object or None if not available
"""
try:
data = self._read_json_safe(self.training_status_file)
if not data:
return None
return TrainingStatus.from_dict(data)
except Exception as e:
logger.error(f"Failed to read training status: {e}")
return None
def write_dashboard_state(self, state: DashboardState) -> None:
"""
Write dashboard state to shared file
Args:
state: DashboardState object to write
"""
try:
data = state.to_dict()
self._write_json_atomic(self.dashboard_state_file, data)
logger.debug("Dashboard state written successfully")
except Exception as e:
logger.error(f"Failed to write dashboard state: {e}")
raise
def read_dashboard_state(self) -> Optional[DashboardState]:
"""
Read dashboard state from shared file
Returns:
DashboardState object or None if not available
"""
try:
data = self._read_json_safe(self.dashboard_state_file)
if not data:
return None
return DashboardState.from_dict(data)
except Exception as e:
logger.error(f"Failed to read dashboard state: {e}")
return None
def write_process_status(self, status: ProcessStatus) -> None:
"""
Write process status to shared file
Args:
status: ProcessStatus object to write
"""
try:
data = status.to_dict()
self._write_json_atomic(self.process_status_file, data)
logger.debug("Process status written successfully")
except Exception as e:
logger.error(f"Failed to write process status: {e}")
raise
def read_process_status(self) -> Optional[ProcessStatus]:
"""
Read process status from shared file
Returns:
ProcessStatus object or None if not available
"""
try:
data = self._read_json_safe(self.process_status_file)
if not data:
return None
return ProcessStatus.from_dict(data)
except Exception as e:
logger.error(f"Failed to read process status: {e}")
return None
def write_market_data(self, data: Dict[str, Any]) -> None:
"""
Write market data to shared file
Args:
data: Market data dictionary to write
"""
try:
# Add timestamp to market data
data['timestamp'] = datetime.now().isoformat()
self._write_json_atomic(self.market_data_file, data)
logger.debug("Market data written successfully")
except Exception as e:
logger.error(f"Failed to write market data: {e}")
raise
def read_market_data(self) -> Dict[str, Any]:
"""
Read market data from shared file
Returns:
Dictionary containing market data
"""
try:
return self._read_json_safe(self.market_data_file)
except Exception as e:
logger.error(f"Failed to read market data: {e}")
return {}
def write_model_metrics(self, metrics: Dict[str, Any]) -> None:
"""
Write model metrics to shared file
Args:
metrics: Model metrics dictionary to write
"""
try:
# Add timestamp to metrics
metrics['timestamp'] = datetime.now().isoformat()
self._write_json_atomic(self.model_metrics_file, metrics)
logger.debug("Model metrics written successfully")
except Exception as e:
logger.error(f"Failed to write model metrics: {e}")
raise
def read_model_metrics(self) -> Dict[str, Any]:
"""
Read model metrics from shared file
Returns:
Dictionary containing model metrics
"""
try:
return self._read_json_safe(self.model_metrics_file)
except Exception as e:
logger.error(f"Failed to read model metrics: {e}")
return {}
def cleanup(self) -> None:
"""
Clean up shared data files
"""
try:
for file_path in [
self.training_status_file,
self.dashboard_state_file,
self.process_status_file,
self.market_data_file,
self.model_metrics_file
]:
if file_path.exists():
file_path.unlink()
logger.debug(f"Removed {file_path}")
# Remove directory if empty
if self.data_dir.exists() and not any(self.data_dir.iterdir()):
self.data_dir.rmdir()
logger.debug(f"Removed empty directory {self.data_dir}")
except Exception as e:
logger.error(f"Failed to cleanup shared data: {e}")
def get_data_age(self, data_type: str) -> Optional[float]:
"""
Get the age of data in seconds
Args:
data_type: Type of data ('training', 'dashboard', 'process', 'market', 'metrics')
Returns:
Age in seconds or None if file doesn't exist
"""
file_map = {
'training': self.training_status_file,
'dashboard': self.dashboard_state_file,
'process': self.process_status_file,
'market': self.market_data_file,
'metrics': self.model_metrics_file
}
file_path = file_map.get(data_type)
if not file_path or not file_path.exists():
return None
try:
mtime = file_path.stat().st_mtime
return time.time() - mtime
except Exception as e:
logger.error(f"Failed to get data age for {data_type}: {e}")
return None

View File

@ -1,59 +0,0 @@
"""
Trading Action Module
Defines the TradingAction class used throughout the trading system.
"""
from dataclasses import dataclass
from datetime import datetime
from typing import Dict, Any, List
@dataclass
class TradingAction:
"""Represents a trading action with full context"""
symbol: str
action: str # 'BUY', 'SELL', 'HOLD'
quantity: float
confidence: float
price: float
timestamp: datetime
reasoning: Dict[str, Any]
def __post_init__(self):
"""Validate the trading action after initialization"""
if self.action not in ['BUY', 'SELL', 'HOLD']:
raise ValueError(f"Invalid action: {self.action}. Must be 'BUY', 'SELL', or 'HOLD'")
if self.confidence < 0.0 or self.confidence > 1.0:
raise ValueError(f"Invalid confidence: {self.confidence}. Must be between 0.0 and 1.0")
if self.quantity < 0:
raise ValueError(f"Invalid quantity: {self.quantity}. Must be non-negative")
if self.price <= 0:
raise ValueError(f"Invalid price: {self.price}. Must be positive")
def to_dict(self) -> Dict[str, Any]:
"""Convert trading action to dictionary"""
return {
'symbol': self.symbol,
'action': self.action,
'quantity': self.quantity,
'confidence': self.confidence,
'price': self.price,
'timestamp': self.timestamp.isoformat(),
'reasoning': self.reasoning
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'TradingAction':
"""Create trading action from dictionary"""
return cls(
symbol=data['symbol'],
action=data['action'],
quantity=data['quantity'],
confidence=data['confidence'],
price=data['price'],
timestamp=datetime.fromisoformat(data['timestamp']),
reasoning=data['reasoning']
)

View File

@ -1,401 +0,0 @@
"""
Trading Executor Fix - Addresses issues with entry/exit prices and P&L calculations
This module provides fixes for:
1. Identical entry prices issue
2. Price caching problems
3. Position tracking reset logic
4. Trade cooldown implementation
5. P&L calculation verification
Apply these fixes to the TradingExecutor class to improve trade execution reliability.
"""
import logging
import time
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union
logger = logging.getLogger(__name__)
class TradingExecutorFix:
"""
Fixes for the TradingExecutor class to address entry/exit price issues
and improve P&L calculation accuracy.
"""
def __init__(self, trading_executor):
"""
Initialize the fix with a reference to the trading executor
Args:
trading_executor: The TradingExecutor instance to fix
"""
self.trading_executor = trading_executor
# Add cooldown tracking
self.last_trade_time = {} # {symbol: timestamp}
self.min_trade_cooldown = 30 # 30 seconds minimum between trades
# Add price history for validation
self.recent_entry_prices = {} # {symbol: [recent_prices]}
self.max_price_history = 10 # Keep last 10 entry prices
# Add position reset tracking
self.position_reset_flags = {} # {symbol: bool}
# Add price update tracking
self.last_price_update = {} # {symbol: timestamp}
self.price_update_threshold = 5 # 5 seconds max since last price update
# Add P&L verification
self.trade_history = {} # {symbol: [trade_records]}
logger.info("TradingExecutorFix initialized - addressing entry/exit price issues")
def apply_fixes(self):
"""Apply all fixes to the trading executor"""
self._patch_execute_action()
self._patch_close_position()
self._patch_calculate_pnl()
self._patch_update_prices()
logger.info("All trading executor fixes applied successfully")
def _patch_execute_action(self):
"""Patch the execute_action method to add price validation and cooldown"""
original_execute_action = self.trading_executor.execute_action
def execute_action_with_fixes(decision):
"""Enhanced execute_action with price validation and cooldown"""
try:
symbol = decision.symbol
action = decision.action
current_time = datetime.now()
# 1. Check cooldown period
if symbol in self.last_trade_time:
time_since_last_trade = (current_time - self.last_trade_time[symbol]).total_seconds()
if time_since_last_trade < self.min_trade_cooldown:
logger.warning(f"Trade rejected: Cooldown period ({time_since_last_trade:.1f}s < {self.min_trade_cooldown}s) for {symbol}")
return False
# 2. Validate price freshness
if symbol in self.last_price_update:
time_since_update = (current_time - self.last_price_update[symbol]).total_seconds()
if time_since_update > self.price_update_threshold:
logger.warning(f"Trade rejected: Price data stale ({time_since_update:.1f}s > {self.price_update_threshold}s) for {symbol}")
# Force price refresh
self._refresh_price(symbol)
return False
# 3. Validate entry price against recent history
current_price = self._get_current_price(symbol)
if symbol in self.recent_entry_prices and len(self.recent_entry_prices[symbol]) > 0:
# Check if price is identical to any recent entry
if current_price in self.recent_entry_prices[symbol]:
logger.warning(f"Trade rejected: Duplicate entry price ${current_price} for {symbol}")
return False
# 4. Ensure position is properly reset before new entry
if not self._ensure_position_reset(symbol):
logger.warning(f"Trade rejected: Position not properly reset for {symbol}")
return False
# Execute the original action
result = original_execute_action(decision)
# If successful, update tracking
if result:
# Update cooldown timestamp
self.last_trade_time[symbol] = current_time
# Update price history
if symbol not in self.recent_entry_prices:
self.recent_entry_prices[symbol] = []
self.recent_entry_prices[symbol].append(current_price)
# Keep only the most recent prices
if len(self.recent_entry_prices[symbol]) > self.max_price_history:
self.recent_entry_prices[symbol] = self.recent_entry_prices[symbol][-self.max_price_history:]
# Mark position as active
self.position_reset_flags[symbol] = False
logger.info(f"Trade executed: {action} {symbol} at ${current_price} with validation")
return result
except Exception as e:
logger.error(f"Error in execute_action_with_fixes: {e}")
return original_execute_action(decision)
# Replace the original method
self.trading_executor.execute_action = execute_action_with_fixes
logger.info("Patched execute_action with price validation and cooldown")
def _patch_close_position(self):
"""Patch the close_position method to ensure proper position reset"""
original_close_position = self.trading_executor.close_position
def close_position_with_fixes(symbol, **kwargs):
"""Enhanced close_position with proper reset logic"""
try:
# Get current price for P&L verification
exit_price = self._get_current_price(symbol)
# Call original close position
result = original_close_position(symbol, **kwargs)
if result:
# Mark position as reset
self.position_reset_flags[symbol] = True
# Record trade for verification
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
position = self.trading_executor.positions[symbol]
# Create trade record
trade_record = {
'symbol': symbol,
'entry_time': getattr(position, 'entry_time', datetime.now()),
'exit_time': datetime.now(),
'entry_price': getattr(position, 'entry_price', 0),
'exit_price': exit_price,
'size': getattr(position, 'size', 0),
'side': getattr(position, 'side', 'UNKNOWN'),
'pnl': self._calculate_verified_pnl(position, exit_price),
'fees': getattr(position, 'fees', 0),
'hold_time_seconds': (datetime.now() - getattr(position, 'entry_time', datetime.now())).total_seconds()
}
# Store trade record
if symbol not in self.trade_history:
self.trade_history[symbol] = []
self.trade_history[symbol].append(trade_record)
logger.info(f"Position closed: {symbol} at ${exit_price} with verified P&L: ${trade_record['pnl']:.2f}")
return result
except Exception as e:
logger.error(f"Error in close_position_with_fixes: {e}")
return original_close_position(symbol, **kwargs)
# Replace the original method
self.trading_executor.close_position = close_position_with_fixes
logger.info("Patched close_position with proper reset logic")
def _patch_calculate_pnl(self):
"""Patch the calculate_pnl method to ensure accurate P&L calculation"""
original_calculate_pnl = getattr(self.trading_executor, 'calculate_pnl', None)
def calculate_pnl_with_fixes(position, current_price=None):
"""Enhanced calculate_pnl with verification"""
try:
# If no original method, implement our own
if original_calculate_pnl is None:
return self._calculate_verified_pnl(position, current_price)
# Call original method
original_pnl = original_calculate_pnl(position, current_price)
# Calculate our verified P&L
verified_pnl = self._calculate_verified_pnl(position, current_price)
# If there's a significant difference, log it
if abs(original_pnl - verified_pnl) > 0.01:
logger.warning(f"P&L calculation discrepancy: original=${original_pnl:.2f}, verified=${verified_pnl:.2f}")
# Use the verified P&L
return verified_pnl
return original_pnl
except Exception as e:
logger.error(f"Error in calculate_pnl_with_fixes: {e}")
if original_calculate_pnl:
return original_calculate_pnl(position, current_price)
return 0.0
# Replace the original method if it exists
if original_calculate_pnl:
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
logger.info("Patched calculate_pnl with verification")
else:
# Add the method if it doesn't exist
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
logger.info("Added calculate_pnl method with verification")
def _patch_update_prices(self):
"""Patch the update_prices method to track price updates"""
original_update_prices = getattr(self.trading_executor, 'update_prices', None)
def update_prices_with_tracking(prices):
"""Enhanced update_prices with timestamp tracking"""
try:
# Call original method if it exists
if original_update_prices:
result = original_update_prices(prices)
else:
# If no original method, update prices directly
if hasattr(self.trading_executor, 'current_prices'):
self.trading_executor.current_prices.update(prices)
result = True
# Track update timestamps
current_time = datetime.now()
for symbol in prices:
self.last_price_update[symbol] = current_time
return result
except Exception as e:
logger.error(f"Error in update_prices_with_tracking: {e}")
if original_update_prices:
return original_update_prices(prices)
return False
# Replace the original method if it exists
if original_update_prices:
self.trading_executor.update_prices = update_prices_with_tracking
logger.info("Patched update_prices with timestamp tracking")
else:
# Add the method if it doesn't exist
self.trading_executor.update_prices = update_prices_with_tracking
logger.info("Added update_prices method with timestamp tracking")
def _calculate_verified_pnl(self, position, current_price=None):
"""Calculate verified P&L for a position"""
try:
# Get position details
entry_price = getattr(position, 'entry_price', 0)
size = getattr(position, 'size', 0)
side = getattr(position, 'side', 'UNKNOWN')
leverage = getattr(position, 'leverage', 1.0)
fees = getattr(position, 'fees', 0.0)
# If current_price is not provided, try to get it
if current_price is None:
symbol = getattr(position, 'symbol', None)
if symbol:
current_price = self._get_current_price(symbol)
else:
return 0.0
# Calculate P&L based on position side
if side == 'LONG':
pnl = (current_price - entry_price) * size * leverage
elif side == 'SHORT':
pnl = (entry_price - current_price) * size * leverage
else:
pnl = 0.0
# Subtract fees for net P&L
net_pnl = pnl - fees
return net_pnl
except Exception as e:
logger.error(f"Error calculating verified P&L: {e}")
return 0.0
def _get_current_price(self, symbol):
"""Get current price for a symbol with fallbacks"""
try:
# Try to get from trading executor
if hasattr(self.trading_executor, 'current_prices') and symbol in self.trading_executor.current_prices:
return self.trading_executor.current_prices[symbol]
# Try to get from data provider
if hasattr(self.trading_executor, 'data_provider'):
data_provider = self.trading_executor.data_provider
if hasattr(data_provider, 'get_current_price'):
price = data_provider.get_current_price(symbol)
if price and price > 0:
return price
# Try to get from COB data
if hasattr(self.trading_executor, 'latest_cob_data') and symbol in self.trading_executor.latest_cob_data:
cob_data = self.trading_executor.latest_cob_data[symbol]
if hasattr(cob_data, 'stats') and 'mid_price' in cob_data.stats:
return cob_data.stats['mid_price']
# Default fallback
return 0.0
except Exception as e:
logger.error(f"Error getting current price for {symbol}: {e}")
return 0.0
def _refresh_price(self, symbol):
"""Force a price refresh for a symbol"""
try:
# Try to refresh from data provider
if hasattr(self.trading_executor, 'data_provider'):
data_provider = self.trading_executor.data_provider
if hasattr(data_provider, 'fetch_current_price'):
price = data_provider.fetch_current_price(symbol)
if price and price > 0:
# Update trading executor price
if hasattr(self.trading_executor, 'current_prices'):
self.trading_executor.current_prices[symbol] = price
# Update timestamp
self.last_price_update[symbol] = datetime.now()
logger.info(f"Refreshed price for {symbol}: ${price:.2f}")
return True
logger.warning(f"Failed to refresh price for {symbol}")
return False
except Exception as e:
logger.error(f"Error refreshing price for {symbol}: {e}")
return False
def _ensure_position_reset(self, symbol):
"""Ensure position is properly reset before new entry"""
try:
# Check if we have an active position
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
# Position exists, check if it's valid
position = self.trading_executor.positions[symbol]
if position and getattr(position, 'active', False):
logger.warning(f"Position already active for {symbol}, cannot enter new position")
return False
# Check reset flag
if symbol in self.position_reset_flags and not self.position_reset_flags[symbol]:
# Force position cleanup
if hasattr(self.trading_executor, 'positions'):
self.trading_executor.positions.pop(symbol, None)
logger.info(f"Forced position reset for {symbol}")
self.position_reset_flags[symbol] = True
return True
except Exception as e:
logger.error(f"Error ensuring position reset for {symbol}: {e}")
return False
def get_trade_history(self, symbol=None):
"""Get verified trade history"""
if symbol:
return self.trade_history.get(symbol, [])
return self.trade_history
def get_price_update_status(self):
"""Get price update status for all symbols"""
status = {}
current_time = datetime.now()
for symbol, timestamp in self.last_price_update.items():
time_since_update = (current_time - timestamp).total_seconds()
status[symbol] = {
'last_update': timestamp,
'seconds_ago': time_since_update,
'is_fresh': time_since_update <= self.price_update_threshold
}
return status

View File

@ -1,795 +0,0 @@
"""
Comprehensive Training Data Collection System
This module implements a robust training data collection system that:
1. Captures all model inputs with validation and completeness checks
2. Stores training data packages with future outcome validation
3. Detects rapid price changes for high-value training examples
4. Enables replay and retraining on most profitable setups
5. Maintains data integrity and traceability
Key Features:
- Real-time data package creation with all model inputs
- Future outcome validation (profitable vs unprofitable predictions)
- Rapid price change detection for premium training examples
- Comprehensive data validation and completeness verification
- Backpropagation data storage for gradient replay
- Training episode profitability tracking and ranking
"""
import asyncio
import json
import logging
import numpy as np
import pandas as pd
import pickle
import torch
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any, Callable
from dataclasses import dataclass, field, asdict
from collections import deque
import hashlib
import threading
from concurrent.futures import ThreadPoolExecutor
logger = logging.getLogger(__name__)
@dataclass
class ModelInputPackage:
"""Complete package of all model inputs at a specific timestamp"""
timestamp: datetime
symbol: str
# Market data inputs
ohlcv_data: Dict[str, pd.DataFrame] # {timeframe: DataFrame}
tick_data: List[Dict[str, Any]] # Raw tick data
cob_data: Dict[str, Any] # Consolidated Order Book data
technical_indicators: Dict[str, float] # All technical indicators
pivot_points: List[Dict[str, Any]] # Detected pivot points
# Model-specific inputs
cnn_features: np.ndarray # CNN input features
rl_state: np.ndarray # RL state representation
orchestrator_context: Dict[str, Any] # Orchestrator context
# Cross-model inputs (outputs from other models)
cnn_predictions: Optional[Dict[str, Any]] = None
rl_predictions: Optional[Dict[str, Any]] = None
orchestrator_decision: Optional[Dict[str, Any]] = None
# Data validation
data_hash: str = ""
completeness_score: float = 0.0
validation_flags: Dict[str, bool] = field(default_factory=dict)
def __post_init__(self):
"""Calculate data hash and completeness after initialization"""
self.data_hash = self._calculate_hash()
self.completeness_score = self._calculate_completeness()
self.validation_flags = self._validate_data()
def _calculate_hash(self) -> str:
"""Calculate hash for data integrity verification"""
try:
# Create a string representation of all data
data_str = f"{self.timestamp}_{self.symbol}"
data_str += f"_{len(self.ohlcv_data)}_{len(self.tick_data)}"
data_str += f"_{self.cnn_features.shape if self.cnn_features is not None else 'None'}"
data_str += f"_{self.rl_state.shape if self.rl_state is not None else 'None'}"
return hashlib.md5(data_str.encode()).hexdigest()
except Exception as e:
logger.warning(f"Error calculating data hash: {e}")
return "invalid_hash"
def _calculate_completeness(self) -> float:
"""Calculate completeness score (0.0 to 1.0)"""
try:
total_fields = 10 # Total expected data fields
complete_fields = 0
# Check each required field
if self.ohlcv_data and len(self.ohlcv_data) > 0:
complete_fields += 1
if self.tick_data and len(self.tick_data) > 0:
complete_fields += 1
if self.cob_data and len(self.cob_data) > 0:
complete_fields += 1
if self.technical_indicators and len(self.technical_indicators) > 0:
complete_fields += 1
if self.pivot_points and len(self.pivot_points) > 0:
complete_fields += 1
if self.cnn_features is not None and self.cnn_features.size > 0:
complete_fields += 1
if self.rl_state is not None and self.rl_state.size > 0:
complete_fields += 1
if self.orchestrator_context and len(self.orchestrator_context) > 0:
complete_fields += 1
if self.cnn_predictions is not None:
complete_fields += 1
if self.rl_predictions is not None:
complete_fields += 1
return complete_fields / total_fields
except Exception as e:
logger.warning(f"Error calculating completeness: {e}")
return 0.0
def _validate_data(self) -> Dict[str, bool]:
"""Validate data integrity and consistency"""
flags = {}
try:
# Validate timestamp
flags['valid_timestamp'] = isinstance(self.timestamp, datetime)
# Validate OHLCV data
flags['valid_ohlcv'] = (
self.ohlcv_data is not None and
len(self.ohlcv_data) > 0 and
all(isinstance(df, pd.DataFrame) for df in self.ohlcv_data.values())
)
# Validate feature arrays
flags['valid_cnn_features'] = (
self.cnn_features is not None and
isinstance(self.cnn_features, np.ndarray) and
self.cnn_features.size > 0
)
flags['valid_rl_state'] = (
self.rl_state is not None and
isinstance(self.rl_state, np.ndarray) and
self.rl_state.size > 0
)
# Validate data consistency
flags['data_consistent'] = self.completeness_score > 0.7
except Exception as e:
logger.warning(f"Error validating data: {e}")
flags['validation_error'] = True
return flags
@dataclass
class TrainingOutcome:
"""Future outcome validation for training data"""
input_package_hash: str
timestamp: datetime
symbol: str
# Price movement outcomes
price_change_1m: float
price_change_5m: float
price_change_15m: float
price_change_1h: float
# Profitability metrics
max_profit_potential: float
max_loss_potential: float
optimal_entry_price: float
optimal_exit_price: float
optimal_holding_time: timedelta
# Classification labels
is_profitable: bool
profitability_score: float # 0.0 to 1.0
risk_reward_ratio: float
# Rapid price change detection
is_rapid_change: bool
change_velocity: float # Price change per minute
volatility_spike: bool
# Validation
outcome_validated: bool = False
validation_timestamp: datetime = field(default_factory=datetime.now)
@dataclass
class TrainingEpisode:
"""Complete training episode with inputs, predictions, and outcomes"""
episode_id: str
input_package: ModelInputPackage
model_predictions: Dict[str, Any] # Predictions from all models
actual_outcome: TrainingOutcome
# Training metadata
episode_type: str # 'normal', 'rapid_change', 'high_profit'
profitability_rank: float # Ranking among all episodes
training_priority: float # Priority for replay training
# Backpropagation data storage
gradient_data: Optional[Dict[str, torch.Tensor]] = None
loss_components: Optional[Dict[str, float]] = None
model_states: Optional[Dict[str, Any]] = None
# Episode statistics
created_timestamp: datetime = field(default_factory=datetime.now)
last_trained_timestamp: Optional[datetime] = None
training_count: int = 0
def calculate_training_priority(self) -> float:
"""Calculate training priority based on profitability and characteristics"""
try:
priority = 0.0
# Base priority from profitability
if self.actual_outcome.is_profitable:
priority += self.actual_outcome.profitability_score * 0.4
# Bonus for rapid changes (high learning value)
if self.actual_outcome.is_rapid_change:
priority += 0.3
# Bonus for high risk-reward ratio
if self.actual_outcome.risk_reward_ratio > 2.0:
priority += 0.2
# Bonus for data completeness
priority += self.input_package.completeness_score * 0.1
# Penalty for frequent training (avoid overfitting)
if self.training_count > 5:
priority *= 0.8
return min(priority, 1.0)
except Exception as e:
logger.warning(f"Error calculating training priority: {e}")
return 0.0
class RapidChangeDetector:
"""Detects rapid price changes for high-value training examples"""
def __init__(self,
velocity_threshold: float = 0.5, # % per minute
volatility_multiplier: float = 3.0,
lookback_minutes: int = 5):
self.velocity_threshold = velocity_threshold
self.volatility_multiplier = volatility_multiplier
self.lookback_minutes = lookback_minutes
# Price history for change detection
self.price_history: Dict[str, deque] = {}
self.volatility_baseline: Dict[str, float] = {}
def add_price_point(self, symbol: str, timestamp: datetime, price: float):
"""Add new price point for change detection"""
if symbol not in self.price_history:
self.price_history[symbol] = deque(maxlen=self.lookback_minutes * 60) # 1 second resolution
self.volatility_baseline[symbol] = 0.0
self.price_history[symbol].append((timestamp, price))
self._update_volatility_baseline(symbol)
def detect_rapid_change(self, symbol: str) -> Tuple[bool, float, bool]:
"""
Detect rapid price changes
Returns:
(is_rapid_change, change_velocity, volatility_spike)
"""
if symbol not in self.price_history or len(self.price_history[symbol]) < 60:
return False, 0.0, False
try:
prices = list(self.price_history[symbol])
# Calculate recent velocity (last minute)
recent_prices = prices[-60:] # Last 60 seconds
if len(recent_prices) < 2:
return False, 0.0, False
start_price = recent_prices[0][1]
end_price = recent_prices[-1][1]
time_diff = (recent_prices[-1][0] - recent_prices[0][0]).total_seconds() / 60.0 # minutes
if time_diff <= 0:
return False, 0.0, False
# Calculate velocity (% change per minute)
velocity = abs((end_price - start_price) / start_price * 100) / time_diff
# Check for rapid change
is_rapid = velocity > self.velocity_threshold
# Check for volatility spike
current_volatility = self._calculate_current_volatility(symbol)
baseline_volatility = self.volatility_baseline.get(symbol, 0.0)
volatility_spike = (
baseline_volatility > 0 and
current_volatility > baseline_volatility * self.volatility_multiplier
)
return is_rapid, velocity, volatility_spike
except Exception as e:
logger.warning(f"Error detecting rapid change for {symbol}: {e}")
return False, 0.0, False
def _update_volatility_baseline(self, symbol: str):
"""Update volatility baseline for the symbol"""
try:
if len(self.price_history[symbol]) < 120: # Need at least 2 minutes of data
return
# Calculate rolling volatility over longer period
prices = [p[1] for p in list(self.price_history[symbol])[-300:]] # Last 5 minutes
if len(prices) < 2:
return
# Calculate standard deviation of price changes
price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))]
volatility = np.std(price_changes) * 100 # Convert to percentage
# Update baseline with exponential moving average
alpha = 0.1
if self.volatility_baseline[symbol] == 0:
self.volatility_baseline[symbol] = volatility
else:
self.volatility_baseline[symbol] = (
alpha * volatility + (1 - alpha) * self.volatility_baseline[symbol]
)
except Exception as e:
logger.warning(f"Error updating volatility baseline for {symbol}: {e}")
def _calculate_current_volatility(self, symbol: str) -> float:
"""Calculate current volatility for the symbol"""
try:
if len(self.price_history[symbol]) < 60:
return 0.0
# Use last minute of data
recent_prices = [p[1] for p in list(self.price_history[symbol])[-60:]]
if len(recent_prices) < 2:
return 0.0
price_changes = [abs(recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1]
for i in range(1, len(recent_prices))]
return np.std(price_changes) * 100
except Exception as e:
logger.warning(f"Error calculating current volatility for {symbol}: {e}")
return 0.0
class TrainingDataCollector:
"""Main training data collection system"""
def __init__(self,
storage_dir: str = "training_data",
max_episodes_per_symbol: int = 10000,
outcome_validation_delay: timedelta = timedelta(hours=1)):
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(parents=True, exist_ok=True)
self.max_episodes_per_symbol = max_episodes_per_symbol
self.outcome_validation_delay = outcome_validation_delay
# Data storage
self.training_episodes: Dict[str, List[TrainingEpisode]] = {} # {symbol: episodes}
self.pending_outcomes: Dict[str, List[ModelInputPackage]] = {} # Awaiting outcome validation
# Rapid change detection
self.rapid_change_detector = RapidChangeDetector()
# Data validation and statistics
self.collection_stats = {
'total_episodes': 0,
'profitable_episodes': 0,
'rapid_change_episodes': 0,
'validation_errors': 0,
'data_completeness_avg': 0.0
}
# Background processing
self.is_collecting = False
self.collection_thread = None
self.outcome_validation_thread = None
# Thread safety
self.data_lock = threading.Lock()
logger.info(f"Training Data Collector initialized")
logger.info(f"Storage directory: {self.storage_dir}")
logger.info(f"Max episodes per symbol: {self.max_episodes_per_symbol}")
def start_collection(self):
"""Start the training data collection system"""
if self.is_collecting:
logger.warning("Training data collection already running")
return
self.is_collecting = True
# Start outcome validation thread
self.outcome_validation_thread = threading.Thread(
target=self._outcome_validation_worker,
daemon=True
)
self.outcome_validation_thread.start()
logger.info("Training data collection started")
def stop_collection(self):
"""Stop the training data collection system"""
self.is_collecting = False
if self.outcome_validation_thread:
self.outcome_validation_thread.join(timeout=5)
logger.info("Training data collection stopped")
def collect_training_data(self,
symbol: str,
ohlcv_data: Dict[str, pd.DataFrame],
tick_data: List[Dict[str, Any]],
cob_data: Dict[str, Any],
technical_indicators: Dict[str, float],
pivot_points: List[Dict[str, Any]],
cnn_features: np.ndarray,
rl_state: np.ndarray,
orchestrator_context: Dict[str, Any],
model_predictions: Dict[str, Any] = None) -> str:
"""
Collect comprehensive training data package
Returns:
episode_id for tracking
"""
try:
# Create input package
input_package = ModelInputPackage(
timestamp=datetime.now(),
symbol=symbol,
ohlcv_data=ohlcv_data,
tick_data=tick_data,
cob_data=cob_data,
technical_indicators=technical_indicators,
pivot_points=pivot_points,
cnn_features=cnn_features,
rl_state=rl_state,
orchestrator_context=orchestrator_context
)
# Validate data completeness
if input_package.completeness_score < 0.5:
logger.warning(f"Low data completeness for {symbol}: {input_package.completeness_score:.2f}")
self.collection_stats['validation_errors'] += 1
return None
# Check for rapid price changes
current_price = self._extract_current_price(ohlcv_data)
if current_price:
self.rapid_change_detector.add_price_point(symbol, input_package.timestamp, current_price)
# Add to pending outcomes for future validation
with self.data_lock:
if symbol not in self.pending_outcomes:
self.pending_outcomes[symbol] = []
self.pending_outcomes[symbol].append(input_package)
# Limit pending outcomes to prevent memory issues
if len(self.pending_outcomes[symbol]) > 1000:
self.pending_outcomes[symbol] = self.pending_outcomes[symbol][-500:]
# Generate episode ID
episode_id = f"{symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}"
# Update statistics
self.collection_stats['total_episodes'] += 1
self.collection_stats['data_completeness_avg'] = (
(self.collection_stats['data_completeness_avg'] * (self.collection_stats['total_episodes'] - 1) +
input_package.completeness_score) / self.collection_stats['total_episodes']
)
logger.debug(f"Collected training data for {symbol}: {episode_id}")
logger.debug(f"Data completeness: {input_package.completeness_score:.2f}")
return episode_id
except Exception as e:
logger.error(f"Error collecting training data for {symbol}: {e}")
self.collection_stats['validation_errors'] += 1
return None
def _extract_current_price(self, ohlcv_data: Dict[str, pd.DataFrame]) -> Optional[float]:
"""Extract current price from OHLCV data"""
try:
# Try to get price from shortest timeframe first
for timeframe in ['1s', '1m', '5m', '15m', '1h']:
if timeframe in ohlcv_data and not ohlcv_data[timeframe].empty:
return float(ohlcv_data[timeframe]['close'].iloc[-1])
return None
except Exception as e:
logger.warning(f"Error extracting current price: {e}")
return None
def _outcome_validation_worker(self):
"""Background worker for validating training outcomes"""
logger.info("Outcome validation worker started")
while self.is_collecting:
try:
self._validate_pending_outcomes()
threading.Event().wait(60) # Check every minute
except Exception as e:
logger.error(f"Error in outcome validation worker: {e}")
threading.Event().wait(30) # Wait before retrying
logger.info("Outcome validation worker stopped")
def _validate_pending_outcomes(self):
"""Validate outcomes for pending training data"""
current_time = datetime.now()
with self.data_lock:
for symbol in list(self.pending_outcomes.keys()):
if symbol not in self.pending_outcomes:
continue
validated_packages = []
remaining_packages = []
for package in self.pending_outcomes[symbol]:
# Check if enough time has passed for outcome validation
if current_time - package.timestamp >= self.outcome_validation_delay:
outcome = self._calculate_training_outcome(package)
if outcome:
self._create_training_episode(package, outcome)
validated_packages.append(package)
else:
remaining_packages.append(package)
else:
remaining_packages.append(package)
# Update pending outcomes
self.pending_outcomes[symbol] = remaining_packages
if validated_packages:
logger.info(f"Validated {len(validated_packages)} outcomes for {symbol}")
def _calculate_training_outcome(self, input_package: ModelInputPackage) -> Optional[TrainingOutcome]:
"""Calculate training outcome based on future price movements"""
try:
# This would typically fetch recent price data to calculate outcomes
# For now, we'll create a placeholder implementation
# Extract base price from input package
base_price = self._extract_current_price(input_package.ohlcv_data)
if not base_price:
return None
# Simulate outcome calculation (in real implementation, fetch actual future prices)
# This is where you would integrate with your data provider to get actual outcomes
# Check for rapid change
is_rapid, velocity, volatility_spike = self.rapid_change_detector.detect_rapid_change(
input_package.symbol
)
# Create outcome (placeholder values - replace with actual calculation)
outcome = TrainingOutcome(
input_package_hash=input_package.data_hash,
timestamp=input_package.timestamp,
symbol=input_package.symbol,
price_change_1m=0.0, # Calculate from actual future data
price_change_5m=0.0,
price_change_15m=0.0,
price_change_1h=0.0,
max_profit_potential=0.0,
max_loss_potential=0.0,
optimal_entry_price=base_price,
optimal_exit_price=base_price,
optimal_holding_time=timedelta(minutes=5),
is_profitable=False, # Determine from actual outcomes
profitability_score=0.0,
risk_reward_ratio=1.0,
is_rapid_change=is_rapid,
change_velocity=velocity,
volatility_spike=volatility_spike,
outcome_validated=True
)
return outcome
except Exception as e:
logger.error(f"Error calculating training outcome: {e}")
return None
def _create_training_episode(self, input_package: ModelInputPackage, outcome: TrainingOutcome):
"""Create complete training episode"""
try:
episode_id = f"{input_package.symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}"
# Determine episode type
episode_type = 'normal'
if outcome.is_rapid_change:
episode_type = 'rapid_change'
self.collection_stats['rapid_change_episodes'] += 1
elif outcome.profitability_score > 0.8:
episode_type = 'high_profit'
if outcome.is_profitable:
self.collection_stats['profitable_episodes'] += 1
# Create training episode
episode = TrainingEpisode(
episode_id=episode_id,
input_package=input_package,
model_predictions={}, # Will be filled when models make predictions
actual_outcome=outcome,
episode_type=episode_type,
profitability_rank=0.0, # Will be calculated later
training_priority=0.0
)
# Calculate training priority
episode.training_priority = episode.calculate_training_priority()
# Store episode
symbol = input_package.symbol
if symbol not in self.training_episodes:
self.training_episodes[symbol] = []
self.training_episodes[symbol].append(episode)
# Limit episodes per symbol
if len(self.training_episodes[symbol]) > self.max_episodes_per_symbol:
# Keep highest priority episodes
self.training_episodes[symbol].sort(key=lambda x: x.training_priority, reverse=True)
self.training_episodes[symbol] = self.training_episodes[symbol][:self.max_episodes_per_symbol]
# Save episode to disk
self._save_episode_to_disk(episode)
logger.debug(f"Created training episode: {episode_id}")
logger.debug(f"Episode type: {episode_type}, Priority: {episode.training_priority:.3f}")
except Exception as e:
logger.error(f"Error creating training episode: {e}")
def _save_episode_to_disk(self, episode: TrainingEpisode):
"""Save training episode to disk for persistence"""
try:
symbol_dir = self.storage_dir / episode.input_package.symbol
symbol_dir.mkdir(parents=True, exist_ok=True)
# Save episode data
episode_file = symbol_dir / f"{episode.episode_id}.pkl"
with open(episode_file, 'wb') as f:
pickle.dump(episode, f)
# Save episode metadata for quick access
metadata = {
'episode_id': episode.episode_id,
'timestamp': episode.input_package.timestamp.isoformat(),
'episode_type': episode.episode_type,
'training_priority': episode.training_priority,
'profitability_score': episode.actual_outcome.profitability_score,
'is_profitable': episode.actual_outcome.is_profitable,
'is_rapid_change': episode.actual_outcome.is_rapid_change,
'data_completeness': episode.input_package.completeness_score
}
metadata_file = symbol_dir / f"{episode.episode_id}_metadata.json"
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2)
except Exception as e:
logger.error(f"Error saving episode to disk: {e}")
def get_high_priority_episodes(self,
symbol: str,
limit: int = 100,
min_priority: float = 0.5) -> List[TrainingEpisode]:
"""Get high-priority training episodes for replay training"""
try:
if symbol not in self.training_episodes:
return []
# Filter and sort by priority
high_priority = [
ep for ep in self.training_episodes[symbol]
if ep.training_priority >= min_priority
]
high_priority.sort(key=lambda x: x.training_priority, reverse=True)
return high_priority[:limit]
except Exception as e:
logger.error(f"Error getting high priority episodes for {symbol}: {e}")
return []
def get_collection_statistics(self) -> Dict[str, Any]:
"""Get comprehensive collection statistics"""
stats = self.collection_stats.copy()
# Add per-symbol statistics
stats['episodes_per_symbol'] = {
symbol: len(episodes)
for symbol, episodes in self.training_episodes.items()
}
# Add pending outcomes count
stats['pending_outcomes'] = {
symbol: len(packages)
for symbol, packages in self.pending_outcomes.items()
}
# Calculate profitability rate
if stats['total_episodes'] > 0:
stats['profitability_rate'] = stats['profitable_episodes'] / stats['total_episodes']
stats['rapid_change_rate'] = stats['rapid_change_episodes'] / stats['total_episodes']
else:
stats['profitability_rate'] = 0.0
stats['rapid_change_rate'] = 0.0
return stats
def validate_data_integrity(self) -> Dict[str, Any]:
"""Comprehensive data integrity validation"""
validation_results = {
'total_episodes_checked': 0,
'hash_mismatches': 0,
'completeness_issues': 0,
'validation_flag_failures': 0,
'corrupted_episodes': [],
'integrity_score': 1.0
}
try:
for symbol, episodes in self.training_episodes.items():
for episode in episodes:
validation_results['total_episodes_checked'] += 1
# Check data hash
expected_hash = episode.input_package._calculate_hash()
if expected_hash != episode.input_package.data_hash:
validation_results['hash_mismatches'] += 1
validation_results['corrupted_episodes'].append(episode.episode_id)
# Check completeness
if episode.input_package.completeness_score < 0.7:
validation_results['completeness_issues'] += 1
# Check validation flags
if not episode.input_package.validation_flags.get('data_consistent', False):
validation_results['validation_flag_failures'] += 1
# Calculate integrity score
total_issues = (
validation_results['hash_mismatches'] +
validation_results['completeness_issues'] +
validation_results['validation_flag_failures']
)
if validation_results['total_episodes_checked'] > 0:
validation_results['integrity_score'] = 1.0 - (
total_issues / validation_results['total_episodes_checked']
)
logger.info(f"Data integrity validation completed")
logger.info(f"Integrity score: {validation_results['integrity_score']:.3f}")
except Exception as e:
logger.error(f"Error during data integrity validation: {e}")
validation_results['validation_error'] = str(e)
return validation_results
# Global instance for easy access
training_data_collector = None
def get_training_data_collector() -> TrainingDataCollector:
"""Get global training data collector instance"""
global training_data_collector
if training_data_collector is None:
training_data_collector = TrainingDataCollector()
return training_data_collector