BIG CLEANUP
This commit is contained in:
@@ -1,276 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Compare COB data quality between DataProvider and COBIntegration
|
||||
|
||||
This test compares:
|
||||
1. DataProvider COB collection (used in our test)
|
||||
2. COBIntegration direct access (used in cob_realtime_dashboard.py)
|
||||
|
||||
To understand why cob_realtime_dashboard.py gets more stable data.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from core.data_provider import DataProvider, MarketTick
|
||||
from core.config import get_config
|
||||
|
||||
# Try to import COBIntegration like cob_realtime_dashboard does
|
||||
try:
|
||||
from core.cob_integration import COBIntegration
|
||||
COB_INTEGRATION_AVAILABLE = True
|
||||
except ImportError:
|
||||
COB_INTEGRATION_AVAILABLE = False
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class COBComparisonTester:
|
||||
def __init__(self, symbol='ETH/USDT', duration_seconds=15):
|
||||
self.symbol = symbol
|
||||
self.duration = timedelta(seconds=duration_seconds)
|
||||
|
||||
# Data storage for both methods
|
||||
self.dp_ticks = deque() # DataProvider ticks
|
||||
self.cob_data = deque() # COBIntegration data
|
||||
|
||||
# Initialize DataProvider (method 1)
|
||||
logger.info("Initializing DataProvider...")
|
||||
self.data_provider = DataProvider()
|
||||
self.dp_cob_received = 0
|
||||
|
||||
# Initialize COBIntegration (method 2)
|
||||
self.cob_integration = None
|
||||
self.cob_received = 0
|
||||
if COB_INTEGRATION_AVAILABLE:
|
||||
logger.info("Initializing COBIntegration...")
|
||||
self.cob_integration = COBIntegration(symbols=[self.symbol])
|
||||
else:
|
||||
logger.warning("COBIntegration not available - will only test DataProvider")
|
||||
|
||||
self.start_time = None
|
||||
self.subscriber_id = None
|
||||
|
||||
def _dp_cob_callback(self, symbol: str, cob_data: dict):
|
||||
"""Callback for DataProvider COB data"""
|
||||
self.dp_cob_received += 1
|
||||
|
||||
if 'stats' in cob_data and 'mid_price' in cob_data['stats']:
|
||||
mid_price = cob_data['stats']['mid_price']
|
||||
if mid_price > 0:
|
||||
synthetic_tick = MarketTick(
|
||||
symbol=symbol,
|
||||
timestamp=cob_data.get('timestamp', datetime.now()),
|
||||
price=mid_price,
|
||||
volume=cob_data.get('stats', {}).get('total_volume', 0),
|
||||
quantity=0,
|
||||
side='dp_cob',
|
||||
trade_id=f"dp_{self.dp_cob_received}",
|
||||
is_buyer_maker=False,
|
||||
raw_data=cob_data
|
||||
)
|
||||
self.dp_ticks.append(synthetic_tick)
|
||||
|
||||
if self.dp_cob_received % 20 == 0:
|
||||
logger.info(f"[DataProvider] Update #{self.dp_cob_received}: {symbol} @ ${mid_price:.2f}")
|
||||
|
||||
def _cob_integration_callback(self, symbol: str, data: dict):
|
||||
"""Callback for COBIntegration data"""
|
||||
self.cob_received += 1
|
||||
|
||||
# Store COBIntegration data directly
|
||||
cob_record = {
|
||||
'symbol': symbol,
|
||||
'timestamp': datetime.now(),
|
||||
'data': data,
|
||||
'source': 'cob_integration'
|
||||
}
|
||||
self.cob_data.append(cob_record)
|
||||
|
||||
if self.cob_received % 20 == 0:
|
||||
stats = data.get('stats', {})
|
||||
mid_price = stats.get('mid_price', 0)
|
||||
logger.info(f"[COBIntegration] Update #{self.cob_received}: {symbol} @ ${mid_price:.2f}")
|
||||
|
||||
async def run_comparison_test(self):
|
||||
"""Run the comparison test"""
|
||||
logger.info(f"Starting COB comparison test for {self.symbol} for {self.duration.total_seconds()} seconds...")
|
||||
|
||||
# Start DataProvider COB collection
|
||||
try:
|
||||
logger.info("Starting DataProvider COB collection...")
|
||||
self.data_provider.start_cob_collection()
|
||||
self.data_provider.subscribe_to_cob(self._dp_cob_callback)
|
||||
await self.data_provider.start_real_time_streaming()
|
||||
logger.info("DataProvider streaming started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start DataProvider: {e}")
|
||||
|
||||
# Start COBIntegration if available
|
||||
if self.cob_integration:
|
||||
try:
|
||||
logger.info("Starting COBIntegration...")
|
||||
self.cob_integration.add_dashboard_callback(self._cob_integration_callback)
|
||||
await self.cob_integration.start()
|
||||
logger.info("COBIntegration started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start COBIntegration: {e}")
|
||||
|
||||
# Collect data for specified duration
|
||||
self.start_time = datetime.now()
|
||||
while datetime.now() - self.start_time < self.duration:
|
||||
await asyncio.sleep(1)
|
||||
logger.info(f"DataProvider: {len(self.dp_ticks)} ticks | COBIntegration: {len(self.cob_data)} updates")
|
||||
|
||||
# Stop data collection
|
||||
try:
|
||||
await self.data_provider.stop_real_time_streaming()
|
||||
if self.cob_integration:
|
||||
await self.cob_integration.stop()
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping data collection: {e}")
|
||||
|
||||
logger.info(f"Comparison complete:")
|
||||
logger.info(f" DataProvider: {len(self.dp_ticks)} ticks received")
|
||||
logger.info(f" COBIntegration: {len(self.cob_data)} updates received")
|
||||
|
||||
# Analyze and plot the differences
|
||||
self.analyze_differences()
|
||||
self.create_comparison_plots()
|
||||
|
||||
def analyze_differences(self):
|
||||
"""Analyze the differences between the two data sources"""
|
||||
logger.info("Analyzing data quality differences...")
|
||||
|
||||
# Analyze DataProvider data
|
||||
dp_order_book_count = 0
|
||||
dp_mid_prices = []
|
||||
|
||||
for tick in self.dp_ticks:
|
||||
if hasattr(tick, 'raw_data') and tick.raw_data:
|
||||
if 'bids' in tick.raw_data and 'asks' in tick.raw_data:
|
||||
dp_order_book_count += 1
|
||||
if 'stats' in tick.raw_data and 'mid_price' in tick.raw_data['stats']:
|
||||
dp_mid_prices.append(tick.raw_data['stats']['mid_price'])
|
||||
|
||||
# Analyze COBIntegration data
|
||||
cob_order_book_count = 0
|
||||
cob_mid_prices = []
|
||||
|
||||
for record in self.cob_data:
|
||||
data = record['data']
|
||||
if 'bids' in data and 'asks' in data:
|
||||
cob_order_book_count += 1
|
||||
if 'stats' in data and 'mid_price' in data['stats']:
|
||||
cob_mid_prices.append(data['stats']['mid_price'])
|
||||
|
||||
logger.info("Data Quality Analysis:")
|
||||
logger.info(f" DataProvider:")
|
||||
logger.info(f" Total updates: {len(self.dp_ticks)}")
|
||||
logger.info(f" With order book data: {dp_order_book_count}")
|
||||
logger.info(f" Mid prices collected: {len(dp_mid_prices)}")
|
||||
if dp_mid_prices:
|
||||
logger.info(f" Price range: ${min(dp_mid_prices):.2f} - ${max(dp_mid_prices):.2f}")
|
||||
|
||||
logger.info(f" COBIntegration:")
|
||||
logger.info(f" Total updates: {len(self.cob_data)}")
|
||||
logger.info(f" With order book data: {cob_order_book_count}")
|
||||
logger.info(f" Mid prices collected: {len(cob_mid_prices)}")
|
||||
if cob_mid_prices:
|
||||
logger.info(f" Price range: ${min(cob_mid_prices):.2f} - ${max(cob_mid_prices):.2f}")
|
||||
|
||||
def create_comparison_plots(self):
|
||||
"""Create comparison plots showing the difference"""
|
||||
logger.info("Creating comparison plots...")
|
||||
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12))
|
||||
|
||||
# Plot 1: Price comparison
|
||||
dp_times = []
|
||||
dp_prices = []
|
||||
for tick in self.dp_ticks:
|
||||
if tick.price > 0:
|
||||
dp_times.append(tick.timestamp)
|
||||
dp_prices.append(tick.price)
|
||||
|
||||
cob_times = []
|
||||
cob_prices = []
|
||||
for record in self.cob_data:
|
||||
data = record['data']
|
||||
if 'stats' in data and 'mid_price' in data['stats']:
|
||||
cob_times.append(record['timestamp'])
|
||||
cob_prices.append(data['stats']['mid_price'])
|
||||
|
||||
if dp_times:
|
||||
ax1.plot(pd.to_datetime(dp_times), dp_prices, 'b-', alpha=0.7, label='DataProvider COB', linewidth=1)
|
||||
if cob_times:
|
||||
ax1.plot(pd.to_datetime(cob_times), cob_prices, 'r-', alpha=0.7, label='COBIntegration', linewidth=1)
|
||||
|
||||
ax1.set_title('Price Comparison: DataProvider vs COBIntegration')
|
||||
ax1.set_ylabel('Price (USDT)')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 2: Data quality comparison (order book depth)
|
||||
dp_bid_counts = []
|
||||
dp_ask_counts = []
|
||||
dp_ob_times = []
|
||||
|
||||
for tick in self.dp_ticks:
|
||||
if hasattr(tick, 'raw_data') and tick.raw_data:
|
||||
if 'bids' in tick.raw_data and 'asks' in tick.raw_data:
|
||||
dp_bid_counts.append(len(tick.raw_data['bids']))
|
||||
dp_ask_counts.append(len(tick.raw_data['asks']))
|
||||
dp_ob_times.append(tick.timestamp)
|
||||
|
||||
cob_bid_counts = []
|
||||
cob_ask_counts = []
|
||||
cob_ob_times = []
|
||||
|
||||
for record in self.cob_data:
|
||||
data = record['data']
|
||||
if 'bids' in data and 'asks' in data:
|
||||
cob_bid_counts.append(len(data['bids']))
|
||||
cob_ask_counts.append(len(data['asks']))
|
||||
cob_ob_times.append(record['timestamp'])
|
||||
|
||||
if dp_ob_times:
|
||||
ax2.plot(pd.to_datetime(dp_ob_times), dp_bid_counts, 'b--', alpha=0.7, label='DP Bid Levels')
|
||||
ax2.plot(pd.to_datetime(dp_ob_times), dp_ask_counts, 'b:', alpha=0.7, label='DP Ask Levels')
|
||||
if cob_ob_times:
|
||||
ax2.plot(pd.to_datetime(cob_ob_times), cob_bid_counts, 'r--', alpha=0.7, label='COB Bid Levels')
|
||||
ax2.plot(pd.to_datetime(cob_ob_times), cob_ask_counts, 'r:', alpha=0.7, label='COB Ask Levels')
|
||||
|
||||
ax2.set_title('Order Book Depth Comparison')
|
||||
ax2.set_ylabel('Number of Levels')
|
||||
ax2.set_xlabel('Time')
|
||||
ax2.legend()
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
plot_filename = f"cob_comparison_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png"
|
||||
plt.savefig(plot_filename, dpi=150)
|
||||
logger.info(f"Comparison plot saved to {plot_filename}")
|
||||
plt.show()
|
||||
|
||||
|
||||
async def main():
|
||||
tester = COBComparisonTester()
|
||||
await tester.run_comparison_test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Test interrupted by user.")
|
||||
@@ -1,502 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from matplotlib.colors import LogNorm
|
||||
|
||||
from core.data_provider import DataProvider, MarketTick
|
||||
from core.config import get_config
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class COBStabilityTester:
|
||||
def __init__(self, symbol='ETHUSDT', duration_seconds=10):
|
||||
self.symbol = symbol
|
||||
self.duration = timedelta(seconds=duration_seconds)
|
||||
self.ticks = deque()
|
||||
|
||||
# Set granularity (buckets) based on symbol
|
||||
if 'ETH' in symbol.upper():
|
||||
self.price_granularity = 1.0 # 1 USD for ETH
|
||||
elif 'BTC' in symbol.upper():
|
||||
self.price_granularity = 10.0 # 10 USD for BTC
|
||||
else:
|
||||
self.price_granularity = 1.0 # Default 1 USD
|
||||
|
||||
logger.info(f"Using price granularity: ${self.price_granularity} for {symbol}")
|
||||
|
||||
# Initialize DataProvider the same way as clean_dashboard
|
||||
logger.info("Initializing DataProvider like in clean_dashboard...")
|
||||
self.data_provider = DataProvider() # Use default constructor like clean_dashboard
|
||||
|
||||
# Initialize COB data collection like clean_dashboard does
|
||||
self.cob_data_received = 0
|
||||
self.latest_cob_data = {}
|
||||
|
||||
# Store all COB snapshots for heatmap generation
|
||||
self.cob_snapshots = deque()
|
||||
self.price_data = [] # For price line chart
|
||||
|
||||
self.start_time = None
|
||||
self.subscriber_id = None
|
||||
self.last_log_time = None
|
||||
|
||||
def _tick_callback(self, tick: MarketTick):
|
||||
"""Callback function to receive ticks from the DataProvider."""
|
||||
if self.start_time is None:
|
||||
self.start_time = datetime.now()
|
||||
logger.info(f"Started collecting ticks at {self.start_time}")
|
||||
|
||||
# Store all ticks
|
||||
self.ticks.append(tick)
|
||||
|
||||
def _cob_data_callback(self, symbol: str, cob_data: dict):
|
||||
"""Callback function to receive COB data from the DataProvider."""
|
||||
# Debug: Log first few callbacks to see what symbols we're getting
|
||||
if self.cob_data_received < 5:
|
||||
logger.info(f"DEBUG: Received COB data for symbol '{symbol}' (target: '{self.symbol}')")
|
||||
|
||||
# Filter to only our requested symbol - handle both formats (ETH/USDT and ETHUSDT)
|
||||
normalized_symbol = symbol.replace('/', '')
|
||||
normalized_target = self.symbol.replace('/', '')
|
||||
if normalized_symbol != normalized_target:
|
||||
if self.cob_data_received < 5:
|
||||
logger.info(f"DEBUG: Skipping symbol '{symbol}' (normalized: '{normalized_symbol}' vs target: '{normalized_target}')")
|
||||
return
|
||||
|
||||
self.cob_data_received += 1
|
||||
self.latest_cob_data[symbol] = cob_data
|
||||
|
||||
# Store the complete COB snapshot for heatmap generation
|
||||
if 'bids' in cob_data and 'asks' in cob_data:
|
||||
# Debug: Log structure of first few COB snapshots
|
||||
if len(self.cob_snapshots) < 3:
|
||||
logger.info(f"DEBUG: COB data structure - bids: {len(cob_data['bids'])} items, asks: {len(cob_data['asks'])} items")
|
||||
if cob_data['bids']:
|
||||
logger.info(f"DEBUG: First bid: {cob_data['bids'][0]}")
|
||||
if cob_data['asks']:
|
||||
logger.info(f"DEBUG: First ask: {cob_data['asks'][0]}")
|
||||
|
||||
# Use current time for timestamp consistency
|
||||
current_time = datetime.now()
|
||||
snapshot = {
|
||||
'timestamp': current_time,
|
||||
'bids': cob_data['bids'],
|
||||
'asks': cob_data['asks'],
|
||||
'stats': cob_data.get('stats', {})
|
||||
}
|
||||
self.cob_snapshots.append(snapshot)
|
||||
|
||||
# Log bucketed COB data every second
|
||||
now = datetime.now()
|
||||
if self.last_log_time is None or (now - self.last_log_time).total_seconds() >= 1.0:
|
||||
self.last_log_time = now
|
||||
self._log_bucketed_cob_data(cob_data)
|
||||
|
||||
# Convert COB data to tick-like format for analysis
|
||||
if 'stats' in cob_data and 'mid_price' in cob_data['stats']:
|
||||
mid_price = cob_data['stats']['mid_price']
|
||||
if mid_price > 0:
|
||||
# Filter out extreme price movements (±10% of recent average)
|
||||
if len(self.price_data) > 5:
|
||||
recent_prices = [p['price'] for p in self.price_data[-5:]]
|
||||
avg_recent_price = sum(recent_prices) / len(recent_prices)
|
||||
price_deviation = abs(mid_price - avg_recent_price) / avg_recent_price
|
||||
|
||||
if price_deviation > 0.10: # More than 10% deviation
|
||||
logger.warning(f"Filtering out extreme price: ${mid_price:.2f} (deviation: {price_deviation:.1%} from avg ${avg_recent_price:.2f})")
|
||||
return # Skip this data point
|
||||
|
||||
# Store price data for line chart with consistent timestamp
|
||||
current_time = datetime.now()
|
||||
self.price_data.append({
|
||||
'timestamp': current_time,
|
||||
'price': mid_price
|
||||
})
|
||||
|
||||
# Create a synthetic tick from COB data with consistent timestamp
|
||||
current_time = datetime.now()
|
||||
synthetic_tick = MarketTick(
|
||||
symbol=symbol,
|
||||
timestamp=current_time,
|
||||
price=mid_price,
|
||||
volume=cob_data.get('stats', {}).get('total_volume', 0),
|
||||
quantity=0, # Not available in COB data
|
||||
side='unknown', # COB data doesn't have side info
|
||||
trade_id=f"cob_{self.cob_data_received}",
|
||||
is_buyer_maker=False,
|
||||
raw_data=cob_data
|
||||
)
|
||||
self.ticks.append(synthetic_tick)
|
||||
|
||||
if self.cob_data_received % 10 == 0: # Log every 10th update
|
||||
logger.info(f"COB update #{self.cob_data_received}: {symbol} @ ${mid_price:.2f}")
|
||||
|
||||
def _log_bucketed_cob_data(self, cob_data: dict):
|
||||
"""Log bucketed COB data every second"""
|
||||
try:
|
||||
if 'bids' not in cob_data or 'asks' not in cob_data:
|
||||
logger.info("COB-1s: No order book data available")
|
||||
return
|
||||
|
||||
if 'stats' not in cob_data or 'mid_price' not in cob_data['stats']:
|
||||
logger.info("COB-1s: No mid price available")
|
||||
return
|
||||
|
||||
mid_price = cob_data['stats']['mid_price']
|
||||
if mid_price <= 0:
|
||||
return
|
||||
|
||||
# Bucket the order book data
|
||||
bid_buckets = {}
|
||||
ask_buckets = {}
|
||||
|
||||
# Process bids (top 10)
|
||||
for bid in cob_data['bids'][:10]:
|
||||
try:
|
||||
if isinstance(bid, dict):
|
||||
price = float(bid['price'])
|
||||
size = float(bid['size'])
|
||||
elif isinstance(bid, (list, tuple)) and len(bid) >= 2:
|
||||
price = float(bid[0])
|
||||
size = float(bid[1])
|
||||
else:
|
||||
continue
|
||||
|
||||
bucketed_price = round(price / self.price_granularity) * self.price_granularity
|
||||
bid_buckets[bucketed_price] = bid_buckets.get(bucketed_price, 0) + size
|
||||
except (ValueError, TypeError, IndexError):
|
||||
continue
|
||||
|
||||
# Process asks (top 10)
|
||||
for ask in cob_data['asks'][:10]:
|
||||
try:
|
||||
if isinstance(ask, dict):
|
||||
price = float(ask['price'])
|
||||
size = float(ask['size'])
|
||||
elif isinstance(ask, (list, tuple)) and len(ask) >= 2:
|
||||
price = float(ask[0])
|
||||
size = float(ask[1])
|
||||
else:
|
||||
continue
|
||||
|
||||
bucketed_price = round(price / self.price_granularity) * self.price_granularity
|
||||
ask_buckets[bucketed_price] = ask_buckets.get(bucketed_price, 0) + size
|
||||
except (ValueError, TypeError, IndexError):
|
||||
continue
|
||||
|
||||
# Format for log output
|
||||
bid_str = ", ".join([f"${p:.0f}:{s:.3f}" for p, s in sorted(bid_buckets.items(), reverse=True)])
|
||||
ask_str = ", ".join([f"${p:.0f}:{s:.3f}" for p, s in sorted(ask_buckets.items())])
|
||||
|
||||
logger.info(f"COB-1s @ ${mid_price:.2f} | BIDS: {bid_str} | ASKS: {ask_str}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging bucketed COB data: {e}")
|
||||
|
||||
async def run_test(self):
|
||||
"""Run the data collection and plotting test."""
|
||||
logger.info(f"Starting COB stability test for {self.symbol} for {self.duration.total_seconds()} seconds...")
|
||||
|
||||
# Initialize COB collection like clean_dashboard does
|
||||
try:
|
||||
logger.info("Starting COB collection in data provider...")
|
||||
self.data_provider.start_cob_collection()
|
||||
logger.info("Started COB collection in data provider")
|
||||
|
||||
# Subscribe to COB updates
|
||||
logger.info("Subscribing to COB data updates...")
|
||||
self.data_provider.subscribe_to_cob(self._cob_data_callback)
|
||||
logger.info("Subscribed to COB data updates from data provider")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start COB collection or subscribe: {e}")
|
||||
|
||||
# Subscribe to ticks as fallback
|
||||
try:
|
||||
self.subscriber_id = self.data_provider.subscribe_to_ticks(self._tick_callback, symbols=[self.symbol])
|
||||
logger.info("Subscribed to tick data as fallback")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to subscribe to ticks: {e}")
|
||||
|
||||
# Start the data provider's real-time streaming
|
||||
try:
|
||||
await self.data_provider.start_real_time_streaming()
|
||||
logger.info("Started real-time streaming")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start real-time streaming: {e}")
|
||||
|
||||
# Collect data for the specified duration
|
||||
self.start_time = datetime.now()
|
||||
while datetime.now() - self.start_time < self.duration:
|
||||
await asyncio.sleep(1)
|
||||
logger.info(f"Collected {len(self.ticks)} ticks so far...")
|
||||
|
||||
# Stop streaming and unsubscribe
|
||||
await self.data_provider.stop_real_time_streaming()
|
||||
self.data_provider.unsubscribe_from_ticks(self.subscriber_id)
|
||||
|
||||
logger.info(f"Finished collecting data. Total ticks: {len(self.ticks)}")
|
||||
|
||||
# Plot the results
|
||||
if self.price_data and self.cob_snapshots:
|
||||
self.create_price_heatmap_chart()
|
||||
elif self.ticks:
|
||||
self._create_simple_price_chart()
|
||||
else:
|
||||
logger.warning("No data was collected. Cannot generate plot.")
|
||||
|
||||
def create_price_heatmap_chart(self):
|
||||
"""Create a visualization with price chart and order book scatter plot."""
|
||||
if not self.price_data or not self.cob_snapshots:
|
||||
logger.warning("Insufficient data to plot.")
|
||||
return
|
||||
|
||||
logger.info(f"Creating price and order book chart...")
|
||||
logger.info(f"Data summary: {len(self.price_data)} price points, {len(self.cob_snapshots)} COB snapshots")
|
||||
|
||||
# Prepare price data
|
||||
price_df = pd.DataFrame(self.price_data)
|
||||
price_df['timestamp'] = pd.to_datetime(price_df['timestamp'])
|
||||
|
||||
logger.info(f"Price data time range: {price_df['timestamp'].min()} to {price_df['timestamp'].max()}")
|
||||
logger.info(f"Price range: ${price_df['price'].min():.2f} to ${price_df['price'].max():.2f}")
|
||||
|
||||
# Create figure with subplots
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 12), height_ratios=[3, 2])
|
||||
|
||||
# Top plot: Price chart with order book levels
|
||||
ax1.plot(price_df['timestamp'], price_df['price'], 'yellow', linewidth=2, label='Mid Price', zorder=10)
|
||||
|
||||
# Plot order book levels as scatter points
|
||||
bid_times, bid_prices, bid_sizes = [], [], []
|
||||
ask_times, ask_prices, ask_sizes = [], [], []
|
||||
|
||||
# Calculate average price for filtering
|
||||
avg_price = price_df['price'].mean() if not price_df.empty else 3500 # Fallback price
|
||||
price_lower = avg_price * 0.9 # -10%
|
||||
price_upper = avg_price * 1.1 # +10%
|
||||
|
||||
logger.info(f"Filtering order book data to price range: ${price_lower:.2f} - ${price_upper:.2f} (±10% of ${avg_price:.2f})")
|
||||
|
||||
for snapshot in list(self.cob_snapshots)[-50:]: # Use last 50 snapshots for clarity
|
||||
timestamp = pd.to_datetime(snapshot['timestamp'])
|
||||
|
||||
# Process bids (top 10)
|
||||
for order in snapshot.get('bids', [])[:10]:
|
||||
try:
|
||||
if isinstance(order, dict):
|
||||
price = float(order['price'])
|
||||
size = float(order['size'])
|
||||
elif isinstance(order, (list, tuple)) and len(order) >= 2:
|
||||
price = float(order[0])
|
||||
size = float(order[1])
|
||||
else:
|
||||
continue
|
||||
|
||||
# Filter out prices outside ±10% range
|
||||
if price < price_lower or price > price_upper:
|
||||
continue
|
||||
|
||||
bid_times.append(timestamp)
|
||||
bid_prices.append(price)
|
||||
bid_sizes.append(size)
|
||||
except (ValueError, TypeError, IndexError):
|
||||
continue
|
||||
|
||||
# Process asks (top 10)
|
||||
for order in snapshot.get('asks', [])[:10]:
|
||||
try:
|
||||
if isinstance(order, dict):
|
||||
price = float(order['price'])
|
||||
size = float(order['size'])
|
||||
elif isinstance(order, (list, tuple)) and len(order) >= 2:
|
||||
price = float(order[0])
|
||||
size = float(order[1])
|
||||
else:
|
||||
continue
|
||||
|
||||
# Filter out prices outside ±10% range
|
||||
if price < price_lower or price > price_upper:
|
||||
continue
|
||||
|
||||
ask_times.append(timestamp)
|
||||
ask_prices.append(price)
|
||||
ask_sizes.append(size)
|
||||
except (ValueError, TypeError, IndexError):
|
||||
continue
|
||||
|
||||
# Plot order book data as scatter with size indicating volume
|
||||
if bid_times:
|
||||
bid_sizes_normalized = np.array(bid_sizes) * 3 # Scale for visibility
|
||||
ax1.scatter(bid_times, bid_prices, s=bid_sizes_normalized, c='green', alpha=0.3, label='Bids')
|
||||
logger.info(f"Plotted {len(bid_times)} bid levels")
|
||||
|
||||
if ask_times:
|
||||
ask_sizes_normalized = np.array(ask_sizes) * 3 # Scale for visibility
|
||||
ax1.scatter(ask_times, ask_prices, s=ask_sizes_normalized, c='red', alpha=0.3, label='Asks')
|
||||
logger.info(f"Plotted {len(ask_times)} ask levels")
|
||||
|
||||
ax1.set_title(f'Real-time Price and Order Book - {self.symbol}\nGranularity: ${self.price_granularity} | Duration: {self.duration.total_seconds()}s')
|
||||
ax1.set_ylabel('Price (USDT)')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Set proper time range (X-axis) - use actual data collection period
|
||||
time_min = price_df['timestamp'].min()
|
||||
time_max = price_df['timestamp'].max()
|
||||
actual_duration = (time_max - time_min).total_seconds()
|
||||
logger.info(f"Actual data collection duration: {actual_duration:.1f} seconds")
|
||||
|
||||
ax1.set_xlim(time_min, time_max)
|
||||
|
||||
# Set tight price range (Y-axis) - use ±2% of price range for better visibility
|
||||
price_min = price_df['price'].min()
|
||||
price_max = price_df['price'].max()
|
||||
price_center = (price_min + price_max) / 2
|
||||
price_range = price_max - price_min
|
||||
|
||||
# If price range is very small, use a minimum range of $5
|
||||
if price_range < 5:
|
||||
price_range = 5
|
||||
|
||||
# Add 20% padding to the price range for better visualization
|
||||
y_padding = price_range * 0.2
|
||||
y_min = price_min - y_padding
|
||||
y_max = price_max + y_padding
|
||||
|
||||
ax1.set_ylim(y_min, y_max)
|
||||
logger.info(f"Chart Y-axis range: ${y_min:.2f} - ${y_max:.2f} (center: ${price_center:.2f}, range: ${price_range:.2f})")
|
||||
|
||||
# Bottom plot: Order book depth over time (aggregated)
|
||||
time_buckets = []
|
||||
bid_depths = []
|
||||
ask_depths = []
|
||||
|
||||
# Create time buckets (every few snapshots)
|
||||
snapshots_list = list(self.cob_snapshots)
|
||||
bucket_size = max(1, len(snapshots_list) // 20) # ~20 buckets
|
||||
for i in range(0, len(snapshots_list), bucket_size):
|
||||
bucket_snapshots = snapshots_list[i:i+bucket_size]
|
||||
if not bucket_snapshots:
|
||||
continue
|
||||
|
||||
# Use middle timestamp of bucket
|
||||
mid_snapshot = bucket_snapshots[len(bucket_snapshots)//2]
|
||||
time_buckets.append(pd.to_datetime(mid_snapshot['timestamp']))
|
||||
|
||||
# Calculate average depths
|
||||
total_bid_depth = 0
|
||||
total_ask_depth = 0
|
||||
snapshot_count = 0
|
||||
|
||||
for snapshot in bucket_snapshots:
|
||||
bid_depth = sum([float(order[1]) if isinstance(order, (list, tuple)) else float(order.get('size', 0))
|
||||
for order in snapshot.get('bids', [])[:10]])
|
||||
ask_depth = sum([float(order[1]) if isinstance(order, (list, tuple)) else float(order.get('size', 0))
|
||||
for order in snapshot.get('asks', [])[:10]])
|
||||
total_bid_depth += bid_depth
|
||||
total_ask_depth += ask_depth
|
||||
snapshot_count += 1
|
||||
|
||||
if snapshot_count > 0:
|
||||
bid_depths.append(total_bid_depth / snapshot_count)
|
||||
ask_depths.append(total_ask_depth / snapshot_count)
|
||||
else:
|
||||
bid_depths.append(0)
|
||||
ask_depths.append(0)
|
||||
|
||||
if time_buckets:
|
||||
ax2.plot(time_buckets, bid_depths, 'green', linewidth=2, label='Bid Depth', alpha=0.7)
|
||||
ax2.plot(time_buckets, ask_depths, 'red', linewidth=2, label='Ask Depth', alpha=0.7)
|
||||
ax2.fill_between(time_buckets, bid_depths, alpha=0.3, color='green')
|
||||
ax2.fill_between(time_buckets, ask_depths, alpha=0.3, color='red')
|
||||
|
||||
ax2.set_title('Order Book Depth Over Time')
|
||||
ax2.set_xlabel('Time')
|
||||
ax2.set_ylabel('Depth (Volume)')
|
||||
ax2.legend()
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
# Set same time range for bottom chart
|
||||
ax2.set_xlim(time_min, time_max)
|
||||
|
||||
# Format time axes
|
||||
fig.autofmt_xdate()
|
||||
plt.tight_layout()
|
||||
|
||||
plot_filename = f"price_heatmap_chart_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png"
|
||||
plt.savefig(plot_filename, dpi=150, bbox_inches='tight')
|
||||
logger.info(f"Price and order book chart saved to {plot_filename}")
|
||||
plt.show()
|
||||
|
||||
def _create_simple_price_chart(self):
|
||||
"""Create a simple price chart as fallback"""
|
||||
logger.info("Creating simple price chart as fallback...")
|
||||
|
||||
prices = []
|
||||
times = []
|
||||
|
||||
for tick in self.ticks:
|
||||
if tick.price > 0:
|
||||
prices.append(tick.price)
|
||||
times.append(tick.timestamp)
|
||||
|
||||
if not prices:
|
||||
logger.warning("No price data to plot")
|
||||
return
|
||||
|
||||
fig, ax = plt.subplots(figsize=(15, 8))
|
||||
ax.plot(pd.to_datetime(times), prices, 'cyan', linewidth=1)
|
||||
ax.set_title(f'Price Chart - {self.symbol}')
|
||||
ax.set_xlabel('Time')
|
||||
ax.set_ylabel('Price (USDT)')
|
||||
fig.autofmt_xdate()
|
||||
|
||||
plot_filename = f"cob_price_chart_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png"
|
||||
plt.savefig(plot_filename)
|
||||
logger.info(f"Price chart saved to {plot_filename}")
|
||||
plt.show()
|
||||
|
||||
|
||||
async def main(symbol='ETHUSDT', duration_seconds=10):
|
||||
"""Main function to run the COB test with configurable parameters.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (default: ETHUSDT)
|
||||
duration_seconds: Test duration in seconds (default: 10)
|
||||
"""
|
||||
logger.info(f"Starting COB test with symbol={symbol}, duration={duration_seconds}s")
|
||||
tester = COBStabilityTester(symbol=symbol, duration_seconds=duration_seconds)
|
||||
await tester.run_test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
# Parse command line arguments
|
||||
symbol = 'ETHUSDT' # Default
|
||||
duration = 10 # Default
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
symbol = sys.argv[1]
|
||||
if len(sys.argv) > 2:
|
||||
try:
|
||||
duration = int(sys.argv[2])
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid duration '{sys.argv[2]}', using default 10 seconds")
|
||||
|
||||
logger.info(f"Configuration: Symbol={symbol}, Duration={duration}s")
|
||||
logger.info(f"Granularity: {'1 USD for ETH' if 'ETH' in symbol.upper() else '10 USD for BTC' if 'BTC' in symbol.upper() else '1 USD default'}")
|
||||
|
||||
try:
|
||||
asyncio.run(main(symbol, duration))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Test interrupted by user.")
|
||||
@@ -1,310 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Training Script for AI Trading Models
|
||||
|
||||
This script tests the training functionality of our CNN and RL models
|
||||
and demonstrates the learning capabilities.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from safe_logging import setup_safe_logging
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import setup_logging
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from models import get_model_registry, CNNModelWrapper, RLAgentWrapper
|
||||
|
||||
# Setup logging
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_model_loading():
|
||||
"""Test that models load correctly"""
|
||||
logger.info("=== TESTING MODEL LOADING ===")
|
||||
|
||||
try:
|
||||
# Get model registry
|
||||
registry = get_model_registry()
|
||||
|
||||
# Check loaded models
|
||||
logger.info(f"Loaded models: {list(registry.models.keys())}")
|
||||
|
||||
# Test each model
|
||||
for name, model in registry.models.items():
|
||||
logger.info(f"Testing {name} model...")
|
||||
|
||||
# Test prediction
|
||||
import numpy as np
|
||||
test_features = np.random.random((20, 5)) # 20 timesteps, 5 features
|
||||
|
||||
try:
|
||||
predictions, confidence = model.predict(test_features)
|
||||
logger.info(f" ✅ {name} prediction: {predictions} (confidence: {confidence:.3f})")
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ {name} prediction failed: {e}")
|
||||
|
||||
# Memory stats
|
||||
stats = registry.get_memory_stats()
|
||||
logger.info(f"Memory usage: {stats['total_used_mb']:.1f}MB / {stats['total_limit_mb']:.1f}MB")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Model loading test failed: {e}")
|
||||
return False
|
||||
|
||||
async def test_orchestrator_integration():
|
||||
"""Test orchestrator integration with models"""
|
||||
logger.info("=== TESTING ORCHESTRATOR INTEGRATION ===")
|
||||
|
||||
try:
|
||||
# Initialize components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
|
||||
# Test coordinated decisions
|
||||
logger.info("Testing coordinated decision making...")
|
||||
decisions = await orchestrator.make_coordinated_decisions()
|
||||
|
||||
if decisions:
|
||||
for symbol, decision in decisions.items():
|
||||
if decision:
|
||||
logger.info(f" ✅ {symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
|
||||
else:
|
||||
logger.info(f" ⏸️ {symbol}: No decision (waiting)")
|
||||
else:
|
||||
logger.warning(" ❌ No decisions made")
|
||||
|
||||
# Test RL evaluation
|
||||
logger.info("Testing RL evaluation...")
|
||||
await orchestrator.evaluate_actions_with_rl()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Orchestrator integration test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_rl_learning():
|
||||
"""Test RL learning functionality"""
|
||||
logger.info("=== TESTING RL LEARNING ===")
|
||||
|
||||
try:
|
||||
registry = get_model_registry()
|
||||
rl_agent = registry.get_model('RL')
|
||||
|
||||
if not rl_agent:
|
||||
logger.error("RL agent not found")
|
||||
return False
|
||||
|
||||
# Simulate some experiences
|
||||
import numpy as np
|
||||
|
||||
logger.info("Simulating trading experiences...")
|
||||
for i in range(50):
|
||||
state = np.random.random(10)
|
||||
action = np.random.randint(0, 3)
|
||||
reward = np.random.uniform(-0.1, 0.1) # Random P&L
|
||||
next_state = np.random.random(10)
|
||||
done = False
|
||||
|
||||
# Store experience
|
||||
rl_agent.remember(state, action, reward, next_state, done)
|
||||
|
||||
logger.info(f"Stored {len(rl_agent.experience_buffer)} experiences")
|
||||
|
||||
# Test replay training
|
||||
logger.info("Testing replay training...")
|
||||
loss = rl_agent.replay()
|
||||
|
||||
if loss is not None:
|
||||
logger.info(f" ✅ Training loss: {loss:.4f}")
|
||||
else:
|
||||
logger.info(" ⏸️ Not enough experiences for training")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RL learning test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_cnn_training():
|
||||
"""Test CNN training functionality"""
|
||||
logger.info("=== TESTING CNN TRAINING ===")
|
||||
|
||||
try:
|
||||
registry = get_model_registry()
|
||||
cnn_model = registry.get_model('CNN')
|
||||
|
||||
if not cnn_model:
|
||||
logger.error("CNN model not found")
|
||||
return False
|
||||
|
||||
# Test training with mock perfect moves
|
||||
training_data = {
|
||||
'perfect_moves': [],
|
||||
'market_data': {},
|
||||
'symbols': ['ETH/USDT', 'BTC/USDT'],
|
||||
'timeframes': ['1m', '1h']
|
||||
}
|
||||
|
||||
# Mock some perfect moves
|
||||
for i in range(10):
|
||||
perfect_move = {
|
||||
'symbol': 'ETH/USDT',
|
||||
'timeframe': '1m',
|
||||
'timestamp': datetime.now() - timedelta(hours=i),
|
||||
'optimal_action': 'BUY' if i % 2 == 0 else 'SELL',
|
||||
'confidence_should_have_been': 0.8 + i * 0.01,
|
||||
'actual_outcome': 0.02 if i % 2 == 0 else -0.015
|
||||
}
|
||||
training_data['perfect_moves'].append(perfect_move)
|
||||
|
||||
logger.info(f"Testing training with {len(training_data['perfect_moves'])} perfect moves...")
|
||||
|
||||
# Test training
|
||||
result = cnn_model.train(training_data)
|
||||
|
||||
if result and result.get('status') == 'training_simulated':
|
||||
logger.info(f" ✅ Training completed: {result}")
|
||||
else:
|
||||
logger.warning(f" ⚠️ Training result: {result}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CNN training test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_prediction_tracking():
|
||||
"""Test prediction tracking and learning feedback"""
|
||||
logger.info("=== TESTING PREDICTION TRACKING ===")
|
||||
|
||||
try:
|
||||
# Initialize components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
|
||||
# Get some market data for testing
|
||||
test_data = data_provider.get_historical_data('ETH/USDT', '1m', limit=100)
|
||||
|
||||
if test_data is None or test_data.empty:
|
||||
logger.warning("No market data available for testing")
|
||||
return True
|
||||
|
||||
logger.info(f"Testing with {len(test_data)} candles of ETH/USDT 1m data")
|
||||
|
||||
# Simulate some predictions and outcomes
|
||||
correct_predictions = 0
|
||||
total_predictions = 0
|
||||
|
||||
for i in range(min(10, len(test_data) - 5)):
|
||||
# Get a slice of data
|
||||
current_data = test_data.iloc[i:i+20]
|
||||
future_data = test_data.iloc[i+20:i+25]
|
||||
|
||||
if len(current_data) < 20 or len(future_data) < 5:
|
||||
continue
|
||||
|
||||
# Make prediction
|
||||
current_price = current_data['close'].iloc[-1]
|
||||
future_price = future_data['close'].iloc[-1]
|
||||
actual_change = (future_price - current_price) / current_price
|
||||
|
||||
# Simulate model prediction
|
||||
predicted_action = 'BUY' if actual_change > 0.001 else 'SELL' if actual_change < -0.001 else 'HOLD'
|
||||
|
||||
# Check if prediction was correct
|
||||
if predicted_action == 'BUY' and actual_change > 0:
|
||||
correct_predictions += 1
|
||||
logger.info(f" ✅ Correct BUY prediction: {actual_change:.4f}")
|
||||
elif predicted_action == 'SELL' and actual_change < 0:
|
||||
correct_predictions += 1
|
||||
logger.info(f" ✅ Correct SELL prediction: {actual_change:.4f}")
|
||||
elif predicted_action == 'HOLD' and abs(actual_change) < 0.001:
|
||||
correct_predictions += 1
|
||||
logger.info(f" ✅ Correct HOLD prediction: {actual_change:.4f}")
|
||||
else:
|
||||
logger.info(f" ❌ Wrong {predicted_action} prediction: {actual_change:.4f}")
|
||||
|
||||
total_predictions += 1
|
||||
|
||||
if total_predictions > 0:
|
||||
accuracy = correct_predictions / total_predictions
|
||||
logger.info(f"Prediction accuracy: {accuracy:.1%} ({correct_predictions}/{total_predictions})")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Prediction tracking test failed: {e}")
|
||||
return False
|
||||
|
||||
async def main():
|
||||
"""Main test function"""
|
||||
logger.info("🧪 STARTING AI TRADING MODEL TESTS")
|
||||
logger.info("Testing model loading, training, and learning capabilities")
|
||||
|
||||
tests = [
|
||||
("Model Loading", test_model_loading),
|
||||
("Orchestrator Integration", test_orchestrator_integration),
|
||||
("RL Learning", test_rl_learning),
|
||||
("CNN Training", test_cnn_training),
|
||||
("Prediction Tracking", test_prediction_tracking)
|
||||
]
|
||||
|
||||
results = {}
|
||||
|
||||
for test_name, test_func in tests:
|
||||
logger.info(f"\n{'='*50}")
|
||||
logger.info(f"Running: {test_name}")
|
||||
logger.info(f"{'='*50}")
|
||||
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(test_func):
|
||||
result = await test_func()
|
||||
else:
|
||||
result = test_func()
|
||||
|
||||
results[test_name] = result
|
||||
|
||||
if result:
|
||||
logger.info(f"✅ {test_name}: PASSED")
|
||||
else:
|
||||
logger.error(f"❌ {test_name}: FAILED")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ {test_name}: ERROR - {e}")
|
||||
results[test_name] = False
|
||||
|
||||
# Summary
|
||||
logger.info(f"\n{'='*50}")
|
||||
logger.info("TEST SUMMARY")
|
||||
logger.info(f"{'='*50}")
|
||||
|
||||
passed = sum(1 for result in results.values() if result)
|
||||
total = len(results)
|
||||
|
||||
for test_name, result in results.items():
|
||||
status = "✅ PASSED" if result else "❌ FAILED"
|
||||
logger.info(f"{test_name}: {status}")
|
||||
|
||||
logger.info(f"\nOverall: {passed}/{total} tests passed ({passed/total:.1%})")
|
||||
|
||||
if passed == total:
|
||||
logger.info("🎉 All tests passed! The AI trading system is working correctly.")
|
||||
else:
|
||||
logger.warning(f"⚠️ {total-passed} tests failed. Please check the logs above.")
|
||||
|
||||
return 0 if passed == total else 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = asyncio.run(main())
|
||||
sys.exit(exit_code)
|
||||
@@ -1,204 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Training Integration with Dashboard
|
||||
|
||||
This script tests the enhanced dashboard's ability to:
|
||||
1. Stream training data to CNN and DQN models
|
||||
2. Display real-time training metrics and progress
|
||||
3. Show model learning curves and performance
|
||||
4. Integrate with the continuous training system
|
||||
"""
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_training_integration():
|
||||
"""Test the training integration functionality"""
|
||||
try:
|
||||
print("="*60)
|
||||
print("TESTING TRAINING INTEGRATION WITH DASHBOARD")
|
||||
print("="*60)
|
||||
|
||||
# Import dashboard
|
||||
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
|
||||
from core.data_provider import DataProvider
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
|
||||
# Create components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider)
|
||||
dashboard = TradingDashboard(data_provider, orchestrator)
|
||||
|
||||
print(f"✓ Dashboard created with training integration")
|
||||
print(f"✓ Continuous training active: {getattr(dashboard, 'training_active', False)}")
|
||||
|
||||
# Test 1: Simulate tick data for training
|
||||
print("\n📊 TEST 1: Simulating Tick Data")
|
||||
print("-" * 40)
|
||||
|
||||
# Add simulated tick data to cache
|
||||
base_price = 3500.0
|
||||
for i in range(1000):
|
||||
tick_data = {
|
||||
'timestamp': datetime.now() - timedelta(seconds=1000-i),
|
||||
'price': base_price + (i % 100) * 0.1,
|
||||
'volume': 100 + (i % 50),
|
||||
'side': 'buy' if i % 2 == 0 else 'sell'
|
||||
}
|
||||
dashboard.tick_cache.append(tick_data)
|
||||
|
||||
print(f"✓ Added {len(dashboard.tick_cache)} ticks to cache")
|
||||
|
||||
# Test 2: Prepare training data
|
||||
print("\n🔄 TEST 2: Preparing Training Data")
|
||||
print("-" * 40)
|
||||
|
||||
training_data = dashboard._prepare_training_data()
|
||||
if training_data:
|
||||
print(f"✓ Training data prepared successfully")
|
||||
print(f" - OHLCV bars: {len(training_data['ohlcv'])}")
|
||||
print(f" - Features: {training_data['features']}")
|
||||
print(f" - Symbol: {training_data['symbol']}")
|
||||
else:
|
||||
print("❌ Failed to prepare training data")
|
||||
|
||||
# Test 3: Format data for CNN
|
||||
print("\n🧠 TEST 3: CNN Data Formatting")
|
||||
print("-" * 40)
|
||||
|
||||
if training_data:
|
||||
cnn_data = dashboard._format_data_for_cnn(training_data)
|
||||
if cnn_data and 'sequences' in cnn_data:
|
||||
print(f"✓ CNN data formatted successfully")
|
||||
print(f" - Sequences shape: {cnn_data['sequences'].shape}")
|
||||
print(f" - Targets shape: {cnn_data['targets'].shape}")
|
||||
print(f" - Sequence length: {cnn_data['sequence_length']}")
|
||||
else:
|
||||
print("❌ Failed to format CNN data")
|
||||
|
||||
# Test 4: Format data for RL
|
||||
print("\n🤖 TEST 4: RL Data Formatting")
|
||||
print("-" * 40)
|
||||
|
||||
if training_data:
|
||||
rl_experiences = dashboard._format_data_for_rl(training_data)
|
||||
if rl_experiences:
|
||||
print(f"✓ RL experiences formatted successfully")
|
||||
print(f" - Number of experiences: {len(rl_experiences)}")
|
||||
print(f" - Experience format: (state, action, reward, next_state, done)")
|
||||
print(f" - Sample experience shapes: {[len(exp) for exp in rl_experiences[:3]]}")
|
||||
else:
|
||||
print("❌ Failed to format RL experiences")
|
||||
|
||||
# Test 5: Send training data to models
|
||||
print("\n📤 TEST 5: Sending Training Data to Models")
|
||||
print("-" * 40)
|
||||
|
||||
success = dashboard.send_training_data_to_models()
|
||||
print(f"✓ Training data sent: {success}")
|
||||
|
||||
if hasattr(dashboard, 'training_stats'):
|
||||
stats = dashboard.training_stats
|
||||
print(f" - Total training sessions: {stats.get('total_training_sessions', 0)}")
|
||||
print(f" - CNN training count: {stats.get('cnn_training_count', 0)}")
|
||||
print(f" - RL training count: {stats.get('rl_training_count', 0)}")
|
||||
print(f" - Training data points: {stats.get('training_data_points', 0)}")
|
||||
|
||||
# Test 6: Training metrics display
|
||||
print("\n📈 TEST 6: Training Metrics Display")
|
||||
print("-" * 40)
|
||||
|
||||
training_metrics = dashboard._create_training_metrics()
|
||||
print(f"✓ Training metrics created: {len(training_metrics)} components")
|
||||
|
||||
# Test 7: Model training status
|
||||
print("\n🔍 TEST 7: Model Training Status")
|
||||
print("-" * 40)
|
||||
|
||||
training_status = dashboard._get_model_training_status()
|
||||
print(f"✓ Training status retrieved")
|
||||
print(f" - CNN status: {training_status['cnn']['status']}")
|
||||
print(f" - CNN accuracy: {training_status['cnn']['accuracy']:.1%}")
|
||||
print(f" - RL status: {training_status['rl']['status']}")
|
||||
print(f" - RL win rate: {training_status['rl']['win_rate']:.1%}")
|
||||
|
||||
# Test 8: Training events log
|
||||
print("\n📝 TEST 8: Training Events Log")
|
||||
print("-" * 40)
|
||||
|
||||
training_events = dashboard._get_recent_training_events()
|
||||
print(f"✓ Training events retrieved: {len(training_events)} events")
|
||||
|
||||
# Test 9: Mini training chart
|
||||
print("\n📊 TEST 9: Mini Training Chart")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
training_chart = dashboard._create_mini_training_chart(training_status)
|
||||
print(f"✓ Mini training chart created")
|
||||
print(f" - Chart type: {type(training_chart)}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating training chart: {e}")
|
||||
|
||||
# Test 10: Continuous training loop
|
||||
print("\n🔄 TEST 10: Continuous Training Loop")
|
||||
print("-" * 40)
|
||||
|
||||
print(f"✓ Continuous training active: {getattr(dashboard, 'training_active', False)}")
|
||||
if hasattr(dashboard, 'training_thread'):
|
||||
print(f"✓ Training thread alive: {dashboard.training_thread.is_alive()}")
|
||||
|
||||
# Test 11: Integration with existing continuous training system
|
||||
print("\n🔗 TEST 11: Integration with Continuous Training System")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
# Check if we can get tick cache for external training
|
||||
tick_cache = dashboard.get_tick_cache_for_training()
|
||||
print(f"✓ Tick cache accessible: {len(tick_cache)} ticks")
|
||||
|
||||
# Check if we can get 1-second bars
|
||||
one_second_bars = dashboard.get_one_second_bars()
|
||||
print(f"✓ 1-second bars accessible: {len(one_second_bars)} bars")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error accessing training data: {e}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("TRAINING INTEGRATION TEST COMPLETED")
|
||||
print("="*60)
|
||||
|
||||
# Summary
|
||||
print("\n📋 SUMMARY:")
|
||||
print(f"✓ Dashboard with training integration: WORKING")
|
||||
print(f"✓ Training data preparation: WORKING")
|
||||
print(f"✓ CNN data formatting: WORKING")
|
||||
print(f"✓ RL data formatting: WORKING")
|
||||
print(f"✓ Training metrics display: WORKING")
|
||||
print(f"✓ Continuous training: ACTIVE")
|
||||
print(f"✓ Model status tracking: WORKING")
|
||||
print(f"✓ Training events logging: WORKING")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training integration test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_training_integration()
|
||||
if success:
|
||||
print("\n🎉 All training integration tests passed!")
|
||||
else:
|
||||
print("\n❌ Some training integration tests failed!")
|
||||
sys.exit(1)
|
||||
@@ -1,59 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to check training status functionality
|
||||
"""
|
||||
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
print("Testing training status functionality...")
|
||||
|
||||
try:
|
||||
from web.old_archived.scalping_dashboard import create_scalping_dashboard
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
|
||||
print("✅ Imports successful")
|
||||
|
||||
# Create components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
dashboard = create_scalping_dashboard(data_provider, orchestrator)
|
||||
|
||||
print("✅ Dashboard created successfully")
|
||||
|
||||
# Test training status
|
||||
training_status = dashboard._get_model_training_status()
|
||||
print("\n📊 Training Status:")
|
||||
print(f"CNN Status: {training_status['cnn']['status']}")
|
||||
print(f"CNN Accuracy: {training_status['cnn']['accuracy']:.1%}")
|
||||
print(f"CNN Loss: {training_status['cnn']['loss']:.4f}")
|
||||
print(f"CNN Epochs: {training_status['cnn']['epochs']}")
|
||||
|
||||
print(f"RL Status: {training_status['rl']['status']}")
|
||||
print(f"RL Win Rate: {training_status['rl']['win_rate']:.1%}")
|
||||
print(f"RL Episodes: {training_status['rl']['episodes']}")
|
||||
print(f"RL Memory: {training_status['rl']['memory_size']}")
|
||||
|
||||
# Test extrema stats
|
||||
if hasattr(orchestrator, 'get_extrema_stats'):
|
||||
extrema_stats = orchestrator.get_extrema_stats()
|
||||
print(f"\n🎯 Extrema Stats:")
|
||||
print(f"Total extrema detected: {extrema_stats.get('total_extrema_detected', 0)}")
|
||||
print(f"Training queue size: {extrema_stats.get('training_queue_size', 0)}")
|
||||
print("✅ Extrema stats available")
|
||||
else:
|
||||
print("❌ Extrema stats not available")
|
||||
|
||||
# Test tick cache
|
||||
print(f"\n📈 Training Data:")
|
||||
print(f"Tick cache size: {len(dashboard.tick_cache)}")
|
||||
print(f"1s bars cache size: {len(dashboard.one_second_bars)}")
|
||||
print(f"Streaming status: {dashboard.is_streaming}")
|
||||
|
||||
print("\n✅ All tests completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
@@ -1,262 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Universal Data Format Compliance
|
||||
|
||||
This script verifies that our enhanced trading system properly feeds
|
||||
the 5 required timeseries streams to all models:
|
||||
- ETH/USDT: ticks (1s), 1m, 1h, 1d
|
||||
- BTC/USDT: ticks (1s) as reference
|
||||
|
||||
This is our universal trading system input format.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.universal_data_adapter import UniversalDataAdapter, UniversalDataStream
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from training.enhanced_cnn_trainer import EnhancedCNNTrainer
|
||||
from training.enhanced_rl_trainer import EnhancedRLTrainer
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def test_universal_data_format():
|
||||
"""Test that all components properly use the universal 5-timeseries format"""
|
||||
logger.info("="*80)
|
||||
logger.info("🧪 TESTING UNIVERSAL DATA FORMAT COMPLIANCE")
|
||||
logger.info("="*80)
|
||||
|
||||
try:
|
||||
# Initialize components
|
||||
config = get_config()
|
||||
data_provider = DataProvider(config)
|
||||
|
||||
# Test 1: Universal Data Adapter
|
||||
logger.info("\n📊 TEST 1: Universal Data Adapter")
|
||||
logger.info("-" * 40)
|
||||
|
||||
adapter = UniversalDataAdapter(data_provider)
|
||||
universal_stream = adapter.get_universal_data_stream()
|
||||
|
||||
if universal_stream is None:
|
||||
logger.error("❌ Failed to get universal data stream")
|
||||
return False
|
||||
|
||||
# Validate format
|
||||
is_valid, issues = adapter.validate_universal_format(universal_stream)
|
||||
if not is_valid:
|
||||
logger.error(f"❌ Universal format validation failed: {issues}")
|
||||
return False
|
||||
|
||||
logger.info("✅ Universal Data Adapter: PASSED")
|
||||
logger.info(f" ETH ticks: {len(universal_stream.eth_ticks)} samples")
|
||||
logger.info(f" ETH 1m: {len(universal_stream.eth_1m)} candles")
|
||||
logger.info(f" ETH 1h: {len(universal_stream.eth_1h)} candles")
|
||||
logger.info(f" ETH 1d: {len(universal_stream.eth_1d)} candles")
|
||||
logger.info(f" BTC reference: {len(universal_stream.btc_ticks)} samples")
|
||||
logger.info(f" Data quality: {universal_stream.metadata['data_quality']['overall_score']:.2f}")
|
||||
|
||||
# Test 2: Enhanced Orchestrator
|
||||
logger.info("\n🎯 TEST 2: Enhanced Orchestrator")
|
||||
logger.info("-" * 40)
|
||||
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
|
||||
# Test that orchestrator uses universal adapter
|
||||
if not hasattr(orchestrator, 'universal_adapter'):
|
||||
logger.error("❌ Orchestrator missing universal_adapter")
|
||||
return False
|
||||
|
||||
# Test coordinated decisions
|
||||
decisions = await orchestrator.make_coordinated_decisions()
|
||||
|
||||
logger.info("✅ Enhanced Orchestrator: PASSED")
|
||||
logger.info(f" Generated {len(decisions)} decisions")
|
||||
logger.info(f" Universal adapter: {type(orchestrator.universal_adapter).__name__}")
|
||||
|
||||
for symbol, decision in decisions.items():
|
||||
if decision:
|
||||
logger.info(f" {symbol}: {decision.action} (confidence: {decision.confidence:.2f})")
|
||||
|
||||
# Test 3: CNN Model Data Format
|
||||
logger.info("\n🧠 TEST 3: CNN Model Data Format")
|
||||
logger.info("-" * 40)
|
||||
|
||||
# Format data for CNN
|
||||
cnn_data = adapter.format_for_model(universal_stream, 'cnn')
|
||||
|
||||
required_cnn_keys = ['eth_ticks', 'eth_1m', 'eth_1h', 'eth_1d', 'btc_ticks']
|
||||
missing_keys = [key for key in required_cnn_keys if key not in cnn_data]
|
||||
|
||||
if missing_keys:
|
||||
logger.error(f"❌ CNN data missing keys: {missing_keys}")
|
||||
return False
|
||||
|
||||
logger.info("✅ CNN Model Data Format: PASSED")
|
||||
for key, data in cnn_data.items():
|
||||
if isinstance(data, np.ndarray):
|
||||
logger.info(f" {key}: shape {data.shape}")
|
||||
else:
|
||||
logger.info(f" {key}: {type(data)}")
|
||||
|
||||
# Test 4: RL Model Data Format
|
||||
logger.info("\n🤖 TEST 4: RL Model Data Format")
|
||||
logger.info("-" * 40)
|
||||
|
||||
# Format data for RL
|
||||
rl_data = adapter.format_for_model(universal_stream, 'rl')
|
||||
|
||||
if 'state_vector' not in rl_data:
|
||||
logger.error("❌ RL data missing state_vector")
|
||||
return False
|
||||
|
||||
state_vector = rl_data['state_vector']
|
||||
if not isinstance(state_vector, np.ndarray):
|
||||
logger.error("❌ RL state_vector is not numpy array")
|
||||
return False
|
||||
|
||||
logger.info("✅ RL Model Data Format: PASSED")
|
||||
logger.info(f" State vector shape: {state_vector.shape}")
|
||||
logger.info(f" State vector size: {len(state_vector)} features")
|
||||
|
||||
# Test 5: CNN Trainer Integration
|
||||
logger.info("\n🎓 TEST 5: CNN Trainer Integration")
|
||||
logger.info("-" * 40)
|
||||
|
||||
try:
|
||||
cnn_trainer = EnhancedCNNTrainer(config, orchestrator)
|
||||
logger.info("✅ CNN Trainer Integration: PASSED")
|
||||
logger.info(f" Model timeframes: {cnn_trainer.model.timeframes}")
|
||||
logger.info(f" Model device: {cnn_trainer.model.device}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ CNN Trainer Integration failed: {e}")
|
||||
return False
|
||||
|
||||
# Test 6: RL Trainer Integration
|
||||
logger.info("\n🎮 TEST 6: RL Trainer Integration")
|
||||
logger.info("-" * 40)
|
||||
|
||||
try:
|
||||
rl_trainer = EnhancedRLTrainer(config, orchestrator)
|
||||
logger.info("✅ RL Trainer Integration: PASSED")
|
||||
logger.info(f" RL agents: {len(rl_trainer.agents)}")
|
||||
for symbol, agent in rl_trainer.agents.items():
|
||||
logger.info(f" {symbol} agent: {type(agent).__name__}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ RL Trainer Integration failed: {e}")
|
||||
return False
|
||||
|
||||
# Test 7: Data Flow Verification
|
||||
logger.info("\n🔄 TEST 7: Data Flow Verification")
|
||||
logger.info("-" * 40)
|
||||
|
||||
# Verify that models receive the correct data format
|
||||
test_predictions = await orchestrator._get_enhanced_predictions_universal(
|
||||
'ETH/USDT',
|
||||
list(orchestrator.market_states['ETH/USDT'])[-1] if orchestrator.market_states['ETH/USDT'] else None,
|
||||
universal_stream
|
||||
)
|
||||
|
||||
if test_predictions:
|
||||
logger.info("✅ Data Flow Verification: PASSED")
|
||||
for pred in test_predictions:
|
||||
logger.info(f" Model: {pred.model_name}")
|
||||
logger.info(f" Action: {pred.overall_action}")
|
||||
logger.info(f" Confidence: {pred.overall_confidence:.2f}")
|
||||
logger.info(f" Timeframes: {len(pred.timeframe_predictions)}")
|
||||
else:
|
||||
logger.warning("⚠️ No predictions generated (may be normal if no models loaded)")
|
||||
|
||||
# Test 8: Configuration Compliance
|
||||
logger.info("\n⚙️ TEST 8: Configuration Compliance")
|
||||
logger.info("-" * 40)
|
||||
|
||||
# Check that config matches universal format
|
||||
expected_symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
expected_timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
config_symbols = config.symbols
|
||||
config_timeframes = config.timeframes
|
||||
|
||||
symbols_match = all(symbol in config_symbols for symbol in expected_symbols)
|
||||
timeframes_match = all(tf in config_timeframes for tf in expected_timeframes)
|
||||
|
||||
if not symbols_match:
|
||||
logger.warning(f"⚠️ Config symbols may not match universal format")
|
||||
logger.warning(f" Expected: {expected_symbols}")
|
||||
logger.warning(f" Config: {config_symbols}")
|
||||
|
||||
if not timeframes_match:
|
||||
logger.warning(f"⚠️ Config timeframes may not match universal format")
|
||||
logger.warning(f" Expected: {expected_timeframes}")
|
||||
logger.warning(f" Config: {config_timeframes}")
|
||||
|
||||
if symbols_match and timeframes_match:
|
||||
logger.info("✅ Configuration Compliance: PASSED")
|
||||
else:
|
||||
logger.info("⚠️ Configuration Compliance: PARTIAL")
|
||||
|
||||
logger.info(f" Symbols: {config_symbols}")
|
||||
logger.info(f" Timeframes: {config_timeframes}")
|
||||
|
||||
# Final Summary
|
||||
logger.info("\n" + "="*80)
|
||||
logger.info("🎉 UNIVERSAL DATA FORMAT TEST SUMMARY")
|
||||
logger.info("="*80)
|
||||
logger.info("✅ All core tests PASSED!")
|
||||
logger.info("")
|
||||
logger.info("📋 VERIFIED COMPLIANCE:")
|
||||
logger.info(" ✓ Universal Data Adapter working")
|
||||
logger.info(" ✓ Enhanced Orchestrator using universal format")
|
||||
logger.info(" ✓ CNN models receive 5 timeseries streams")
|
||||
logger.info(" ✓ RL models receive combined state vector")
|
||||
logger.info(" ✓ Trainers properly integrated")
|
||||
logger.info(" ✓ Data flow verified")
|
||||
logger.info("")
|
||||
logger.info("🎯 UNIVERSAL FORMAT ACTIVE:")
|
||||
logger.info(" 1. ETH/USDT ticks (1s) ✓")
|
||||
logger.info(" 2. ETH/USDT 1m ✓")
|
||||
logger.info(" 3. ETH/USDT 1h ✓")
|
||||
logger.info(" 4. ETH/USDT 1d ✓")
|
||||
logger.info(" 5. BTC/USDT reference ticks ✓")
|
||||
logger.info("")
|
||||
logger.info("🚀 Your enhanced trading system is ready with universal data format!")
|
||||
logger.info("="*80)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Universal data format test failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def main():
|
||||
"""Main test function"""
|
||||
logger.info("🚀 Starting Universal Data Format Compliance Test...")
|
||||
|
||||
success = await test_universal_data_format()
|
||||
|
||||
if success:
|
||||
logger.info("\n🎉 All tests passed! Universal data format is properly implemented.")
|
||||
logger.info("Your enhanced trading system respects the 5-timeseries input format.")
|
||||
else:
|
||||
logger.error("\n💥 Tests failed! Please check the universal data format implementation.")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,177 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Universal Data Stream Integration with Dashboard
|
||||
|
||||
This script validates that:
|
||||
1. CleanTradingDashboard properly subscribes to UnifiedDataStream
|
||||
2. All 5 timeseries are properly received and processed
|
||||
3. Data flows correctly from provider -> adapter -> stream -> dashboard
|
||||
4. Consumer callback functions work as expected
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.trading_executor import TradingExecutor
|
||||
from web.clean_dashboard import CleanTradingDashboard
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def test_universal_stream_integration():
|
||||
"""Test Universal Data Stream integration with dashboard"""
|
||||
logger.info("="*80)
|
||||
logger.info("🧪 TESTING UNIVERSAL DATA STREAM INTEGRATION")
|
||||
logger.info("="*80)
|
||||
|
||||
try:
|
||||
# Initialize components
|
||||
logger.info("\n📦 STEP 1: Initialize Components")
|
||||
logger.info("-" * 40)
|
||||
|
||||
config = get_config()
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(
|
||||
data_provider=data_provider,
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
enhanced_rl_training=True
|
||||
)
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
logger.info("✅ Core components initialized")
|
||||
|
||||
# Initialize dashboard with Universal Data Stream
|
||||
logger.info("\n📊 STEP 2: Initialize Dashboard with Universal Stream")
|
||||
logger.info("-" * 40)
|
||||
|
||||
dashboard = CleanTradingDashboard(
|
||||
data_provider=data_provider,
|
||||
orchestrator=orchestrator,
|
||||
trading_executor=trading_executor
|
||||
)
|
||||
|
||||
# Check Universal Stream initialization
|
||||
if hasattr(dashboard, 'unified_stream') and dashboard.unified_stream:
|
||||
logger.info("✅ Universal Data Stream initialized successfully")
|
||||
logger.info(f"📋 Consumer ID: {dashboard.stream_consumer_id}")
|
||||
else:
|
||||
logger.error("❌ Universal Data Stream not initialized")
|
||||
return False
|
||||
|
||||
# Test consumer registration
|
||||
logger.info("\n🔗 STEP 3: Validate Consumer Registration")
|
||||
logger.info("-" * 40)
|
||||
|
||||
stream_stats = dashboard.unified_stream.get_stream_stats()
|
||||
logger.info(f"📊 Stream Stats: {stream_stats}")
|
||||
|
||||
if stream_stats['total_consumers'] > 0:
|
||||
logger.info(f"✅ {stream_stats['total_consumers']} consumers registered")
|
||||
else:
|
||||
logger.warning("⚠️ No consumers registered")
|
||||
|
||||
# Test data callback
|
||||
logger.info("\n📡 STEP 4: Test Data Callback")
|
||||
logger.info("-" * 40)
|
||||
|
||||
# Create test data packet
|
||||
test_data = {
|
||||
'timestamp': time.time(),
|
||||
'consumer_id': dashboard.stream_consumer_id,
|
||||
'consumer_name': 'CleanTradingDashboard',
|
||||
'ticks': [
|
||||
{'symbol': 'ETHUSDT', 'price': 3000.0, 'volume': 1.5, 'timestamp': time.time()},
|
||||
{'symbol': 'ETHUSDT', 'price': 3001.0, 'volume': 2.0, 'timestamp': time.time()},
|
||||
],
|
||||
'ohlcv': {'one_second_bars': [], 'multi_timeframe': {
|
||||
'ETH/USDT': {
|
||||
'1s': [{'timestamp': time.time(), 'open': 3000, 'high': 3002, 'low': 2999, 'close': 3001, 'volume': 10}],
|
||||
'1m': [{'timestamp': time.time(), 'open': 2990, 'high': 3010, 'low': 2985, 'close': 3001, 'volume': 100}],
|
||||
'1h': [{'timestamp': time.time(), 'open': 2900, 'high': 3050, 'low': 2880, 'close': 3001, 'volume': 1000}],
|
||||
'1d': [{'timestamp': time.time(), 'open': 2800, 'high': 3200, 'low': 2750, 'close': 3001, 'volume': 10000}]
|
||||
},
|
||||
'BTC/USDT': {
|
||||
'1s': [{'timestamp': time.time(), 'open': 65000, 'high': 65020, 'low': 64980, 'close': 65010, 'volume': 0.5}]
|
||||
}
|
||||
}},
|
||||
'training_data': {'market_state': 'test', 'features': []},
|
||||
'ui_data': {'formatted_data': 'test_ui_data'}
|
||||
}
|
||||
|
||||
# Test callback manually
|
||||
try:
|
||||
dashboard._handle_unified_stream_data(test_data)
|
||||
logger.info("✅ Data callback executed successfully")
|
||||
|
||||
# Check if data was processed
|
||||
if hasattr(dashboard, 'current_prices') and 'ETH/USDT' in dashboard.current_prices:
|
||||
logger.info(f"✅ Price updated: ETH/USDT = ${dashboard.current_prices['ETH/USDT']}")
|
||||
else:
|
||||
logger.warning("⚠️ Prices not updated in dashboard")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Data callback failed: {e}")
|
||||
return False
|
||||
|
||||
# Test Universal Data Adapter
|
||||
logger.info("\n🔄 STEP 5: Test Universal Data Adapter")
|
||||
logger.info("-" * 40)
|
||||
|
||||
if hasattr(orchestrator, 'universal_adapter'):
|
||||
universal_stream = orchestrator.universal_adapter.get_universal_data_stream()
|
||||
if universal_stream:
|
||||
logger.info("✅ Universal Data Adapter working")
|
||||
logger.info(f"📊 ETH ticks: {len(universal_stream.eth_ticks)} samples")
|
||||
logger.info(f"📊 ETH 1m: {len(universal_stream.eth_1m)} candles")
|
||||
logger.info(f"📊 ETH 1h: {len(universal_stream.eth_1h)} candles")
|
||||
logger.info(f"📊 ETH 1d: {len(universal_stream.eth_1d)} candles")
|
||||
logger.info(f"📊 BTC ticks: {len(universal_stream.btc_ticks)} samples")
|
||||
|
||||
# Validate format
|
||||
is_valid, issues = orchestrator.universal_adapter.validate_universal_format(universal_stream)
|
||||
if is_valid:
|
||||
logger.info("✅ Universal format validation passed")
|
||||
else:
|
||||
logger.warning(f"⚠️ Format issues: {issues}")
|
||||
else:
|
||||
logger.error("❌ Universal Data Adapter failed to get stream")
|
||||
return False
|
||||
else:
|
||||
logger.error("❌ Universal Data Adapter not found in orchestrator")
|
||||
return False
|
||||
|
||||
# Summary
|
||||
logger.info("\n🎯 SUMMARY")
|
||||
logger.info("-" * 40)
|
||||
logger.info("✅ Universal Data Stream properly integrated")
|
||||
logger.info("✅ Dashboard subscribes as consumer")
|
||||
logger.info("✅ All 5 timeseries format validated")
|
||||
logger.info("✅ Data callback processing works")
|
||||
logger.info("✅ Universal Data Adapter functional")
|
||||
|
||||
logger.info("\n🏆 INTEGRATION TEST PASSED")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Integration test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = asyncio.run(test_universal_stream_integration())
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user