BIG CLEANUP

This commit is contained in:
Dobromir Popov
2025-08-08 14:58:55 +03:00
parent e39e9ee95a
commit 2b0d2679c6
162 changed files with 455 additions and 42814 deletions

View File

@@ -1,276 +0,0 @@
#!/usr/bin/env python3
"""
Compare COB data quality between DataProvider and COBIntegration
This test compares:
1. DataProvider COB collection (used in our test)
2. COBIntegration direct access (used in cob_realtime_dashboard.py)
To understand why cob_realtime_dashboard.py gets more stable data.
"""
import asyncio
import logging
import time
from collections import deque
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from core.data_provider import DataProvider, MarketTick
from core.config import get_config
# Try to import COBIntegration like cob_realtime_dashboard does
try:
from core.cob_integration import COBIntegration
COB_INTEGRATION_AVAILABLE = True
except ImportError:
COB_INTEGRATION_AVAILABLE = False
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class COBComparisonTester:
def __init__(self, symbol='ETH/USDT', duration_seconds=15):
self.symbol = symbol
self.duration = timedelta(seconds=duration_seconds)
# Data storage for both methods
self.dp_ticks = deque() # DataProvider ticks
self.cob_data = deque() # COBIntegration data
# Initialize DataProvider (method 1)
logger.info("Initializing DataProvider...")
self.data_provider = DataProvider()
self.dp_cob_received = 0
# Initialize COBIntegration (method 2)
self.cob_integration = None
self.cob_received = 0
if COB_INTEGRATION_AVAILABLE:
logger.info("Initializing COBIntegration...")
self.cob_integration = COBIntegration(symbols=[self.symbol])
else:
logger.warning("COBIntegration not available - will only test DataProvider")
self.start_time = None
self.subscriber_id = None
def _dp_cob_callback(self, symbol: str, cob_data: dict):
"""Callback for DataProvider COB data"""
self.dp_cob_received += 1
if 'stats' in cob_data and 'mid_price' in cob_data['stats']:
mid_price = cob_data['stats']['mid_price']
if mid_price > 0:
synthetic_tick = MarketTick(
symbol=symbol,
timestamp=cob_data.get('timestamp', datetime.now()),
price=mid_price,
volume=cob_data.get('stats', {}).get('total_volume', 0),
quantity=0,
side='dp_cob',
trade_id=f"dp_{self.dp_cob_received}",
is_buyer_maker=False,
raw_data=cob_data
)
self.dp_ticks.append(synthetic_tick)
if self.dp_cob_received % 20 == 0:
logger.info(f"[DataProvider] Update #{self.dp_cob_received}: {symbol} @ ${mid_price:.2f}")
def _cob_integration_callback(self, symbol: str, data: dict):
"""Callback for COBIntegration data"""
self.cob_received += 1
# Store COBIntegration data directly
cob_record = {
'symbol': symbol,
'timestamp': datetime.now(),
'data': data,
'source': 'cob_integration'
}
self.cob_data.append(cob_record)
if self.cob_received % 20 == 0:
stats = data.get('stats', {})
mid_price = stats.get('mid_price', 0)
logger.info(f"[COBIntegration] Update #{self.cob_received}: {symbol} @ ${mid_price:.2f}")
async def run_comparison_test(self):
"""Run the comparison test"""
logger.info(f"Starting COB comparison test for {self.symbol} for {self.duration.total_seconds()} seconds...")
# Start DataProvider COB collection
try:
logger.info("Starting DataProvider COB collection...")
self.data_provider.start_cob_collection()
self.data_provider.subscribe_to_cob(self._dp_cob_callback)
await self.data_provider.start_real_time_streaming()
logger.info("DataProvider streaming started")
except Exception as e:
logger.error(f"Failed to start DataProvider: {e}")
# Start COBIntegration if available
if self.cob_integration:
try:
logger.info("Starting COBIntegration...")
self.cob_integration.add_dashboard_callback(self._cob_integration_callback)
await self.cob_integration.start()
logger.info("COBIntegration started")
except Exception as e:
logger.error(f"Failed to start COBIntegration: {e}")
# Collect data for specified duration
self.start_time = datetime.now()
while datetime.now() - self.start_time < self.duration:
await asyncio.sleep(1)
logger.info(f"DataProvider: {len(self.dp_ticks)} ticks | COBIntegration: {len(self.cob_data)} updates")
# Stop data collection
try:
await self.data_provider.stop_real_time_streaming()
if self.cob_integration:
await self.cob_integration.stop()
except Exception as e:
logger.error(f"Error stopping data collection: {e}")
logger.info(f"Comparison complete:")
logger.info(f" DataProvider: {len(self.dp_ticks)} ticks received")
logger.info(f" COBIntegration: {len(self.cob_data)} updates received")
# Analyze and plot the differences
self.analyze_differences()
self.create_comparison_plots()
def analyze_differences(self):
"""Analyze the differences between the two data sources"""
logger.info("Analyzing data quality differences...")
# Analyze DataProvider data
dp_order_book_count = 0
dp_mid_prices = []
for tick in self.dp_ticks:
if hasattr(tick, 'raw_data') and tick.raw_data:
if 'bids' in tick.raw_data and 'asks' in tick.raw_data:
dp_order_book_count += 1
if 'stats' in tick.raw_data and 'mid_price' in tick.raw_data['stats']:
dp_mid_prices.append(tick.raw_data['stats']['mid_price'])
# Analyze COBIntegration data
cob_order_book_count = 0
cob_mid_prices = []
for record in self.cob_data:
data = record['data']
if 'bids' in data and 'asks' in data:
cob_order_book_count += 1
if 'stats' in data and 'mid_price' in data['stats']:
cob_mid_prices.append(data['stats']['mid_price'])
logger.info("Data Quality Analysis:")
logger.info(f" DataProvider:")
logger.info(f" Total updates: {len(self.dp_ticks)}")
logger.info(f" With order book data: {dp_order_book_count}")
logger.info(f" Mid prices collected: {len(dp_mid_prices)}")
if dp_mid_prices:
logger.info(f" Price range: ${min(dp_mid_prices):.2f} - ${max(dp_mid_prices):.2f}")
logger.info(f" COBIntegration:")
logger.info(f" Total updates: {len(self.cob_data)}")
logger.info(f" With order book data: {cob_order_book_count}")
logger.info(f" Mid prices collected: {len(cob_mid_prices)}")
if cob_mid_prices:
logger.info(f" Price range: ${min(cob_mid_prices):.2f} - ${max(cob_mid_prices):.2f}")
def create_comparison_plots(self):
"""Create comparison plots showing the difference"""
logger.info("Creating comparison plots...")
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12))
# Plot 1: Price comparison
dp_times = []
dp_prices = []
for tick in self.dp_ticks:
if tick.price > 0:
dp_times.append(tick.timestamp)
dp_prices.append(tick.price)
cob_times = []
cob_prices = []
for record in self.cob_data:
data = record['data']
if 'stats' in data and 'mid_price' in data['stats']:
cob_times.append(record['timestamp'])
cob_prices.append(data['stats']['mid_price'])
if dp_times:
ax1.plot(pd.to_datetime(dp_times), dp_prices, 'b-', alpha=0.7, label='DataProvider COB', linewidth=1)
if cob_times:
ax1.plot(pd.to_datetime(cob_times), cob_prices, 'r-', alpha=0.7, label='COBIntegration', linewidth=1)
ax1.set_title('Price Comparison: DataProvider vs COBIntegration')
ax1.set_ylabel('Price (USDT)')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot 2: Data quality comparison (order book depth)
dp_bid_counts = []
dp_ask_counts = []
dp_ob_times = []
for tick in self.dp_ticks:
if hasattr(tick, 'raw_data') and tick.raw_data:
if 'bids' in tick.raw_data and 'asks' in tick.raw_data:
dp_bid_counts.append(len(tick.raw_data['bids']))
dp_ask_counts.append(len(tick.raw_data['asks']))
dp_ob_times.append(tick.timestamp)
cob_bid_counts = []
cob_ask_counts = []
cob_ob_times = []
for record in self.cob_data:
data = record['data']
if 'bids' in data and 'asks' in data:
cob_bid_counts.append(len(data['bids']))
cob_ask_counts.append(len(data['asks']))
cob_ob_times.append(record['timestamp'])
if dp_ob_times:
ax2.plot(pd.to_datetime(dp_ob_times), dp_bid_counts, 'b--', alpha=0.7, label='DP Bid Levels')
ax2.plot(pd.to_datetime(dp_ob_times), dp_ask_counts, 'b:', alpha=0.7, label='DP Ask Levels')
if cob_ob_times:
ax2.plot(pd.to_datetime(cob_ob_times), cob_bid_counts, 'r--', alpha=0.7, label='COB Bid Levels')
ax2.plot(pd.to_datetime(cob_ob_times), cob_ask_counts, 'r:', alpha=0.7, label='COB Ask Levels')
ax2.set_title('Order Book Depth Comparison')
ax2.set_ylabel('Number of Levels')
ax2.set_xlabel('Time')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plot_filename = f"cob_comparison_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png"
plt.savefig(plot_filename, dpi=150)
logger.info(f"Comparison plot saved to {plot_filename}")
plt.show()
async def main():
tester = COBComparisonTester()
await tester.run_comparison_test()
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Test interrupted by user.")

View File

@@ -1,502 +0,0 @@
import asyncio
import logging
import time
from collections import deque
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.colors import LogNorm
from core.data_provider import DataProvider, MarketTick
from core.config import get_config
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class COBStabilityTester:
def __init__(self, symbol='ETHUSDT', duration_seconds=10):
self.symbol = symbol
self.duration = timedelta(seconds=duration_seconds)
self.ticks = deque()
# Set granularity (buckets) based on symbol
if 'ETH' in symbol.upper():
self.price_granularity = 1.0 # 1 USD for ETH
elif 'BTC' in symbol.upper():
self.price_granularity = 10.0 # 10 USD for BTC
else:
self.price_granularity = 1.0 # Default 1 USD
logger.info(f"Using price granularity: ${self.price_granularity} for {symbol}")
# Initialize DataProvider the same way as clean_dashboard
logger.info("Initializing DataProvider like in clean_dashboard...")
self.data_provider = DataProvider() # Use default constructor like clean_dashboard
# Initialize COB data collection like clean_dashboard does
self.cob_data_received = 0
self.latest_cob_data = {}
# Store all COB snapshots for heatmap generation
self.cob_snapshots = deque()
self.price_data = [] # For price line chart
self.start_time = None
self.subscriber_id = None
self.last_log_time = None
def _tick_callback(self, tick: MarketTick):
"""Callback function to receive ticks from the DataProvider."""
if self.start_time is None:
self.start_time = datetime.now()
logger.info(f"Started collecting ticks at {self.start_time}")
# Store all ticks
self.ticks.append(tick)
def _cob_data_callback(self, symbol: str, cob_data: dict):
"""Callback function to receive COB data from the DataProvider."""
# Debug: Log first few callbacks to see what symbols we're getting
if self.cob_data_received < 5:
logger.info(f"DEBUG: Received COB data for symbol '{symbol}' (target: '{self.symbol}')")
# Filter to only our requested symbol - handle both formats (ETH/USDT and ETHUSDT)
normalized_symbol = symbol.replace('/', '')
normalized_target = self.symbol.replace('/', '')
if normalized_symbol != normalized_target:
if self.cob_data_received < 5:
logger.info(f"DEBUG: Skipping symbol '{symbol}' (normalized: '{normalized_symbol}' vs target: '{normalized_target}')")
return
self.cob_data_received += 1
self.latest_cob_data[symbol] = cob_data
# Store the complete COB snapshot for heatmap generation
if 'bids' in cob_data and 'asks' in cob_data:
# Debug: Log structure of first few COB snapshots
if len(self.cob_snapshots) < 3:
logger.info(f"DEBUG: COB data structure - bids: {len(cob_data['bids'])} items, asks: {len(cob_data['asks'])} items")
if cob_data['bids']:
logger.info(f"DEBUG: First bid: {cob_data['bids'][0]}")
if cob_data['asks']:
logger.info(f"DEBUG: First ask: {cob_data['asks'][0]}")
# Use current time for timestamp consistency
current_time = datetime.now()
snapshot = {
'timestamp': current_time,
'bids': cob_data['bids'],
'asks': cob_data['asks'],
'stats': cob_data.get('stats', {})
}
self.cob_snapshots.append(snapshot)
# Log bucketed COB data every second
now = datetime.now()
if self.last_log_time is None or (now - self.last_log_time).total_seconds() >= 1.0:
self.last_log_time = now
self._log_bucketed_cob_data(cob_data)
# Convert COB data to tick-like format for analysis
if 'stats' in cob_data and 'mid_price' in cob_data['stats']:
mid_price = cob_data['stats']['mid_price']
if mid_price > 0:
# Filter out extreme price movements (±10% of recent average)
if len(self.price_data) > 5:
recent_prices = [p['price'] for p in self.price_data[-5:]]
avg_recent_price = sum(recent_prices) / len(recent_prices)
price_deviation = abs(mid_price - avg_recent_price) / avg_recent_price
if price_deviation > 0.10: # More than 10% deviation
logger.warning(f"Filtering out extreme price: ${mid_price:.2f} (deviation: {price_deviation:.1%} from avg ${avg_recent_price:.2f})")
return # Skip this data point
# Store price data for line chart with consistent timestamp
current_time = datetime.now()
self.price_data.append({
'timestamp': current_time,
'price': mid_price
})
# Create a synthetic tick from COB data with consistent timestamp
current_time = datetime.now()
synthetic_tick = MarketTick(
symbol=symbol,
timestamp=current_time,
price=mid_price,
volume=cob_data.get('stats', {}).get('total_volume', 0),
quantity=0, # Not available in COB data
side='unknown', # COB data doesn't have side info
trade_id=f"cob_{self.cob_data_received}",
is_buyer_maker=False,
raw_data=cob_data
)
self.ticks.append(synthetic_tick)
if self.cob_data_received % 10 == 0: # Log every 10th update
logger.info(f"COB update #{self.cob_data_received}: {symbol} @ ${mid_price:.2f}")
def _log_bucketed_cob_data(self, cob_data: dict):
"""Log bucketed COB data every second"""
try:
if 'bids' not in cob_data or 'asks' not in cob_data:
logger.info("COB-1s: No order book data available")
return
if 'stats' not in cob_data or 'mid_price' not in cob_data['stats']:
logger.info("COB-1s: No mid price available")
return
mid_price = cob_data['stats']['mid_price']
if mid_price <= 0:
return
# Bucket the order book data
bid_buckets = {}
ask_buckets = {}
# Process bids (top 10)
for bid in cob_data['bids'][:10]:
try:
if isinstance(bid, dict):
price = float(bid['price'])
size = float(bid['size'])
elif isinstance(bid, (list, tuple)) and len(bid) >= 2:
price = float(bid[0])
size = float(bid[1])
else:
continue
bucketed_price = round(price / self.price_granularity) * self.price_granularity
bid_buckets[bucketed_price] = bid_buckets.get(bucketed_price, 0) + size
except (ValueError, TypeError, IndexError):
continue
# Process asks (top 10)
for ask in cob_data['asks'][:10]:
try:
if isinstance(ask, dict):
price = float(ask['price'])
size = float(ask['size'])
elif isinstance(ask, (list, tuple)) and len(ask) >= 2:
price = float(ask[0])
size = float(ask[1])
else:
continue
bucketed_price = round(price / self.price_granularity) * self.price_granularity
ask_buckets[bucketed_price] = ask_buckets.get(bucketed_price, 0) + size
except (ValueError, TypeError, IndexError):
continue
# Format for log output
bid_str = ", ".join([f"${p:.0f}:{s:.3f}" for p, s in sorted(bid_buckets.items(), reverse=True)])
ask_str = ", ".join([f"${p:.0f}:{s:.3f}" for p, s in sorted(ask_buckets.items())])
logger.info(f"COB-1s @ ${mid_price:.2f} | BIDS: {bid_str} | ASKS: {ask_str}")
except Exception as e:
logger.warning(f"Error logging bucketed COB data: {e}")
async def run_test(self):
"""Run the data collection and plotting test."""
logger.info(f"Starting COB stability test for {self.symbol} for {self.duration.total_seconds()} seconds...")
# Initialize COB collection like clean_dashboard does
try:
logger.info("Starting COB collection in data provider...")
self.data_provider.start_cob_collection()
logger.info("Started COB collection in data provider")
# Subscribe to COB updates
logger.info("Subscribing to COB data updates...")
self.data_provider.subscribe_to_cob(self._cob_data_callback)
logger.info("Subscribed to COB data updates from data provider")
except Exception as e:
logger.error(f"Failed to start COB collection or subscribe: {e}")
# Subscribe to ticks as fallback
try:
self.subscriber_id = self.data_provider.subscribe_to_ticks(self._tick_callback, symbols=[self.symbol])
logger.info("Subscribed to tick data as fallback")
except Exception as e:
logger.warning(f"Failed to subscribe to ticks: {e}")
# Start the data provider's real-time streaming
try:
await self.data_provider.start_real_time_streaming()
logger.info("Started real-time streaming")
except Exception as e:
logger.error(f"Failed to start real-time streaming: {e}")
# Collect data for the specified duration
self.start_time = datetime.now()
while datetime.now() - self.start_time < self.duration:
await asyncio.sleep(1)
logger.info(f"Collected {len(self.ticks)} ticks so far...")
# Stop streaming and unsubscribe
await self.data_provider.stop_real_time_streaming()
self.data_provider.unsubscribe_from_ticks(self.subscriber_id)
logger.info(f"Finished collecting data. Total ticks: {len(self.ticks)}")
# Plot the results
if self.price_data and self.cob_snapshots:
self.create_price_heatmap_chart()
elif self.ticks:
self._create_simple_price_chart()
else:
logger.warning("No data was collected. Cannot generate plot.")
def create_price_heatmap_chart(self):
"""Create a visualization with price chart and order book scatter plot."""
if not self.price_data or not self.cob_snapshots:
logger.warning("Insufficient data to plot.")
return
logger.info(f"Creating price and order book chart...")
logger.info(f"Data summary: {len(self.price_data)} price points, {len(self.cob_snapshots)} COB snapshots")
# Prepare price data
price_df = pd.DataFrame(self.price_data)
price_df['timestamp'] = pd.to_datetime(price_df['timestamp'])
logger.info(f"Price data time range: {price_df['timestamp'].min()} to {price_df['timestamp'].max()}")
logger.info(f"Price range: ${price_df['price'].min():.2f} to ${price_df['price'].max():.2f}")
# Create figure with subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 12), height_ratios=[3, 2])
# Top plot: Price chart with order book levels
ax1.plot(price_df['timestamp'], price_df['price'], 'yellow', linewidth=2, label='Mid Price', zorder=10)
# Plot order book levels as scatter points
bid_times, bid_prices, bid_sizes = [], [], []
ask_times, ask_prices, ask_sizes = [], [], []
# Calculate average price for filtering
avg_price = price_df['price'].mean() if not price_df.empty else 3500 # Fallback price
price_lower = avg_price * 0.9 # -10%
price_upper = avg_price * 1.1 # +10%
logger.info(f"Filtering order book data to price range: ${price_lower:.2f} - ${price_upper:.2f} (±10% of ${avg_price:.2f})")
for snapshot in list(self.cob_snapshots)[-50:]: # Use last 50 snapshots for clarity
timestamp = pd.to_datetime(snapshot['timestamp'])
# Process bids (top 10)
for order in snapshot.get('bids', [])[:10]:
try:
if isinstance(order, dict):
price = float(order['price'])
size = float(order['size'])
elif isinstance(order, (list, tuple)) and len(order) >= 2:
price = float(order[0])
size = float(order[1])
else:
continue
# Filter out prices outside ±10% range
if price < price_lower or price > price_upper:
continue
bid_times.append(timestamp)
bid_prices.append(price)
bid_sizes.append(size)
except (ValueError, TypeError, IndexError):
continue
# Process asks (top 10)
for order in snapshot.get('asks', [])[:10]:
try:
if isinstance(order, dict):
price = float(order['price'])
size = float(order['size'])
elif isinstance(order, (list, tuple)) and len(order) >= 2:
price = float(order[0])
size = float(order[1])
else:
continue
# Filter out prices outside ±10% range
if price < price_lower or price > price_upper:
continue
ask_times.append(timestamp)
ask_prices.append(price)
ask_sizes.append(size)
except (ValueError, TypeError, IndexError):
continue
# Plot order book data as scatter with size indicating volume
if bid_times:
bid_sizes_normalized = np.array(bid_sizes) * 3 # Scale for visibility
ax1.scatter(bid_times, bid_prices, s=bid_sizes_normalized, c='green', alpha=0.3, label='Bids')
logger.info(f"Plotted {len(bid_times)} bid levels")
if ask_times:
ask_sizes_normalized = np.array(ask_sizes) * 3 # Scale for visibility
ax1.scatter(ask_times, ask_prices, s=ask_sizes_normalized, c='red', alpha=0.3, label='Asks')
logger.info(f"Plotted {len(ask_times)} ask levels")
ax1.set_title(f'Real-time Price and Order Book - {self.symbol}\nGranularity: ${self.price_granularity} | Duration: {self.duration.total_seconds()}s')
ax1.set_ylabel('Price (USDT)')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Set proper time range (X-axis) - use actual data collection period
time_min = price_df['timestamp'].min()
time_max = price_df['timestamp'].max()
actual_duration = (time_max - time_min).total_seconds()
logger.info(f"Actual data collection duration: {actual_duration:.1f} seconds")
ax1.set_xlim(time_min, time_max)
# Set tight price range (Y-axis) - use ±2% of price range for better visibility
price_min = price_df['price'].min()
price_max = price_df['price'].max()
price_center = (price_min + price_max) / 2
price_range = price_max - price_min
# If price range is very small, use a minimum range of $5
if price_range < 5:
price_range = 5
# Add 20% padding to the price range for better visualization
y_padding = price_range * 0.2
y_min = price_min - y_padding
y_max = price_max + y_padding
ax1.set_ylim(y_min, y_max)
logger.info(f"Chart Y-axis range: ${y_min:.2f} - ${y_max:.2f} (center: ${price_center:.2f}, range: ${price_range:.2f})")
# Bottom plot: Order book depth over time (aggregated)
time_buckets = []
bid_depths = []
ask_depths = []
# Create time buckets (every few snapshots)
snapshots_list = list(self.cob_snapshots)
bucket_size = max(1, len(snapshots_list) // 20) # ~20 buckets
for i in range(0, len(snapshots_list), bucket_size):
bucket_snapshots = snapshots_list[i:i+bucket_size]
if not bucket_snapshots:
continue
# Use middle timestamp of bucket
mid_snapshot = bucket_snapshots[len(bucket_snapshots)//2]
time_buckets.append(pd.to_datetime(mid_snapshot['timestamp']))
# Calculate average depths
total_bid_depth = 0
total_ask_depth = 0
snapshot_count = 0
for snapshot in bucket_snapshots:
bid_depth = sum([float(order[1]) if isinstance(order, (list, tuple)) else float(order.get('size', 0))
for order in snapshot.get('bids', [])[:10]])
ask_depth = sum([float(order[1]) if isinstance(order, (list, tuple)) else float(order.get('size', 0))
for order in snapshot.get('asks', [])[:10]])
total_bid_depth += bid_depth
total_ask_depth += ask_depth
snapshot_count += 1
if snapshot_count > 0:
bid_depths.append(total_bid_depth / snapshot_count)
ask_depths.append(total_ask_depth / snapshot_count)
else:
bid_depths.append(0)
ask_depths.append(0)
if time_buckets:
ax2.plot(time_buckets, bid_depths, 'green', linewidth=2, label='Bid Depth', alpha=0.7)
ax2.plot(time_buckets, ask_depths, 'red', linewidth=2, label='Ask Depth', alpha=0.7)
ax2.fill_between(time_buckets, bid_depths, alpha=0.3, color='green')
ax2.fill_between(time_buckets, ask_depths, alpha=0.3, color='red')
ax2.set_title('Order Book Depth Over Time')
ax2.set_xlabel('Time')
ax2.set_ylabel('Depth (Volume)')
ax2.legend()
ax2.grid(True, alpha=0.3)
# Set same time range for bottom chart
ax2.set_xlim(time_min, time_max)
# Format time axes
fig.autofmt_xdate()
plt.tight_layout()
plot_filename = f"price_heatmap_chart_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png"
plt.savefig(plot_filename, dpi=150, bbox_inches='tight')
logger.info(f"Price and order book chart saved to {plot_filename}")
plt.show()
def _create_simple_price_chart(self):
"""Create a simple price chart as fallback"""
logger.info("Creating simple price chart as fallback...")
prices = []
times = []
for tick in self.ticks:
if tick.price > 0:
prices.append(tick.price)
times.append(tick.timestamp)
if not prices:
logger.warning("No price data to plot")
return
fig, ax = plt.subplots(figsize=(15, 8))
ax.plot(pd.to_datetime(times), prices, 'cyan', linewidth=1)
ax.set_title(f'Price Chart - {self.symbol}')
ax.set_xlabel('Time')
ax.set_ylabel('Price (USDT)')
fig.autofmt_xdate()
plot_filename = f"cob_price_chart_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png"
plt.savefig(plot_filename)
logger.info(f"Price chart saved to {plot_filename}")
plt.show()
async def main(symbol='ETHUSDT', duration_seconds=10):
"""Main function to run the COB test with configurable parameters.
Args:
symbol: Trading symbol (default: ETHUSDT)
duration_seconds: Test duration in seconds (default: 10)
"""
logger.info(f"Starting COB test with symbol={symbol}, duration={duration_seconds}s")
tester = COBStabilityTester(symbol=symbol, duration_seconds=duration_seconds)
await tester.run_test()
if __name__ == "__main__":
import sys
# Parse command line arguments
symbol = 'ETHUSDT' # Default
duration = 10 # Default
if len(sys.argv) > 1:
symbol = sys.argv[1]
if len(sys.argv) > 2:
try:
duration = int(sys.argv[2])
except ValueError:
logger.warning(f"Invalid duration '{sys.argv[2]}', using default 10 seconds")
logger.info(f"Configuration: Symbol={symbol}, Duration={duration}s")
logger.info(f"Granularity: {'1 USD for ETH' if 'ETH' in symbol.upper() else '10 USD for BTC' if 'BTC' in symbol.upper() else '1 USD default'}")
try:
asyncio.run(main(symbol, duration))
except KeyboardInterrupt:
logger.info("Test interrupted by user.")

View File

@@ -1,310 +0,0 @@
#!/usr/bin/env python3
"""
Test Training Script for AI Trading Models
This script tests the training functionality of our CNN and RL models
and demonstrates the learning capabilities.
"""
import logging
import sys
import asyncio
from pathlib import Path
from datetime import datetime, timedelta
from safe_logging import setup_safe_logging
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import setup_logging
from core.data_provider import DataProvider
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
from models import get_model_registry, CNNModelWrapper, RLAgentWrapper
# Setup logging
setup_logging()
logger = logging.getLogger(__name__)
def test_model_loading():
"""Test that models load correctly"""
logger.info("=== TESTING MODEL LOADING ===")
try:
# Get model registry
registry = get_model_registry()
# Check loaded models
logger.info(f"Loaded models: {list(registry.models.keys())}")
# Test each model
for name, model in registry.models.items():
logger.info(f"Testing {name} model...")
# Test prediction
import numpy as np
test_features = np.random.random((20, 5)) # 20 timesteps, 5 features
try:
predictions, confidence = model.predict(test_features)
logger.info(f"{name} prediction: {predictions} (confidence: {confidence:.3f})")
except Exception as e:
logger.error(f"{name} prediction failed: {e}")
# Memory stats
stats = registry.get_memory_stats()
logger.info(f"Memory usage: {stats['total_used_mb']:.1f}MB / {stats['total_limit_mb']:.1f}MB")
return True
except Exception as e:
logger.error(f"Model loading test failed: {e}")
return False
async def test_orchestrator_integration():
"""Test orchestrator integration with models"""
logger.info("=== TESTING ORCHESTRATOR INTEGRATION ===")
try:
# Initialize components
data_provider = DataProvider()
orchestrator = EnhancedTradingOrchestrator(data_provider)
# Test coordinated decisions
logger.info("Testing coordinated decision making...")
decisions = await orchestrator.make_coordinated_decisions()
if decisions:
for symbol, decision in decisions.items():
if decision:
logger.info(f"{symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
else:
logger.info(f" ⏸️ {symbol}: No decision (waiting)")
else:
logger.warning(" ❌ No decisions made")
# Test RL evaluation
logger.info("Testing RL evaluation...")
await orchestrator.evaluate_actions_with_rl()
return True
except Exception as e:
logger.error(f"Orchestrator integration test failed: {e}")
return False
def test_rl_learning():
"""Test RL learning functionality"""
logger.info("=== TESTING RL LEARNING ===")
try:
registry = get_model_registry()
rl_agent = registry.get_model('RL')
if not rl_agent:
logger.error("RL agent not found")
return False
# Simulate some experiences
import numpy as np
logger.info("Simulating trading experiences...")
for i in range(50):
state = np.random.random(10)
action = np.random.randint(0, 3)
reward = np.random.uniform(-0.1, 0.1) # Random P&L
next_state = np.random.random(10)
done = False
# Store experience
rl_agent.remember(state, action, reward, next_state, done)
logger.info(f"Stored {len(rl_agent.experience_buffer)} experiences")
# Test replay training
logger.info("Testing replay training...")
loss = rl_agent.replay()
if loss is not None:
logger.info(f" ✅ Training loss: {loss:.4f}")
else:
logger.info(" ⏸️ Not enough experiences for training")
return True
except Exception as e:
logger.error(f"RL learning test failed: {e}")
return False
def test_cnn_training():
"""Test CNN training functionality"""
logger.info("=== TESTING CNN TRAINING ===")
try:
registry = get_model_registry()
cnn_model = registry.get_model('CNN')
if not cnn_model:
logger.error("CNN model not found")
return False
# Test training with mock perfect moves
training_data = {
'perfect_moves': [],
'market_data': {},
'symbols': ['ETH/USDT', 'BTC/USDT'],
'timeframes': ['1m', '1h']
}
# Mock some perfect moves
for i in range(10):
perfect_move = {
'symbol': 'ETH/USDT',
'timeframe': '1m',
'timestamp': datetime.now() - timedelta(hours=i),
'optimal_action': 'BUY' if i % 2 == 0 else 'SELL',
'confidence_should_have_been': 0.8 + i * 0.01,
'actual_outcome': 0.02 if i % 2 == 0 else -0.015
}
training_data['perfect_moves'].append(perfect_move)
logger.info(f"Testing training with {len(training_data['perfect_moves'])} perfect moves...")
# Test training
result = cnn_model.train(training_data)
if result and result.get('status') == 'training_simulated':
logger.info(f" ✅ Training completed: {result}")
else:
logger.warning(f" ⚠️ Training result: {result}")
return True
except Exception as e:
logger.error(f"CNN training test failed: {e}")
return False
def test_prediction_tracking():
"""Test prediction tracking and learning feedback"""
logger.info("=== TESTING PREDICTION TRACKING ===")
try:
# Initialize components
data_provider = DataProvider()
orchestrator = EnhancedTradingOrchestrator(data_provider)
# Get some market data for testing
test_data = data_provider.get_historical_data('ETH/USDT', '1m', limit=100)
if test_data is None or test_data.empty:
logger.warning("No market data available for testing")
return True
logger.info(f"Testing with {len(test_data)} candles of ETH/USDT 1m data")
# Simulate some predictions and outcomes
correct_predictions = 0
total_predictions = 0
for i in range(min(10, len(test_data) - 5)):
# Get a slice of data
current_data = test_data.iloc[i:i+20]
future_data = test_data.iloc[i+20:i+25]
if len(current_data) < 20 or len(future_data) < 5:
continue
# Make prediction
current_price = current_data['close'].iloc[-1]
future_price = future_data['close'].iloc[-1]
actual_change = (future_price - current_price) / current_price
# Simulate model prediction
predicted_action = 'BUY' if actual_change > 0.001 else 'SELL' if actual_change < -0.001 else 'HOLD'
# Check if prediction was correct
if predicted_action == 'BUY' and actual_change > 0:
correct_predictions += 1
logger.info(f" ✅ Correct BUY prediction: {actual_change:.4f}")
elif predicted_action == 'SELL' and actual_change < 0:
correct_predictions += 1
logger.info(f" ✅ Correct SELL prediction: {actual_change:.4f}")
elif predicted_action == 'HOLD' and abs(actual_change) < 0.001:
correct_predictions += 1
logger.info(f" ✅ Correct HOLD prediction: {actual_change:.4f}")
else:
logger.info(f" ❌ Wrong {predicted_action} prediction: {actual_change:.4f}")
total_predictions += 1
if total_predictions > 0:
accuracy = correct_predictions / total_predictions
logger.info(f"Prediction accuracy: {accuracy:.1%} ({correct_predictions}/{total_predictions})")
return True
except Exception as e:
logger.error(f"Prediction tracking test failed: {e}")
return False
async def main():
"""Main test function"""
logger.info("🧪 STARTING AI TRADING MODEL TESTS")
logger.info("Testing model loading, training, and learning capabilities")
tests = [
("Model Loading", test_model_loading),
("Orchestrator Integration", test_orchestrator_integration),
("RL Learning", test_rl_learning),
("CNN Training", test_cnn_training),
("Prediction Tracking", test_prediction_tracking)
]
results = {}
for test_name, test_func in tests:
logger.info(f"\n{'='*50}")
logger.info(f"Running: {test_name}")
logger.info(f"{'='*50}")
try:
if asyncio.iscoroutinefunction(test_func):
result = await test_func()
else:
result = test_func()
results[test_name] = result
if result:
logger.info(f"{test_name}: PASSED")
else:
logger.error(f"{test_name}: FAILED")
except Exception as e:
logger.error(f"{test_name}: ERROR - {e}")
results[test_name] = False
# Summary
logger.info(f"\n{'='*50}")
logger.info("TEST SUMMARY")
logger.info(f"{'='*50}")
passed = sum(1 for result in results.values() if result)
total = len(results)
for test_name, result in results.items():
status = "✅ PASSED" if result else "❌ FAILED"
logger.info(f"{test_name}: {status}")
logger.info(f"\nOverall: {passed}/{total} tests passed ({passed/total:.1%})")
if passed == total:
logger.info("🎉 All tests passed! The AI trading system is working correctly.")
else:
logger.warning(f"⚠️ {total-passed} tests failed. Please check the logs above.")
return 0 if passed == total else 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)

View File

@@ -1,204 +0,0 @@
#!/usr/bin/env python3
"""
Test Training Integration with Dashboard
This script tests the enhanced dashboard's ability to:
1. Stream training data to CNN and DQN models
2. Display real-time training metrics and progress
3. Show model learning curves and performance
4. Integrate with the continuous training system
"""
import sys
import logging
import time
import asyncio
from datetime import datetime, timedelta
from pathlib import Path
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_training_integration():
"""Test the training integration functionality"""
try:
print("="*60)
print("TESTING TRAINING INTEGRATION WITH DASHBOARD")
print("="*60)
# Import dashboard
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
from core.data_provider import DataProvider
from core.orchestrator import TradingOrchestrator
# Create components
data_provider = DataProvider()
orchestrator = TradingOrchestrator(data_provider)
dashboard = TradingDashboard(data_provider, orchestrator)
print(f"✓ Dashboard created with training integration")
print(f"✓ Continuous training active: {getattr(dashboard, 'training_active', False)}")
# Test 1: Simulate tick data for training
print("\n📊 TEST 1: Simulating Tick Data")
print("-" * 40)
# Add simulated tick data to cache
base_price = 3500.0
for i in range(1000):
tick_data = {
'timestamp': datetime.now() - timedelta(seconds=1000-i),
'price': base_price + (i % 100) * 0.1,
'volume': 100 + (i % 50),
'side': 'buy' if i % 2 == 0 else 'sell'
}
dashboard.tick_cache.append(tick_data)
print(f"✓ Added {len(dashboard.tick_cache)} ticks to cache")
# Test 2: Prepare training data
print("\n🔄 TEST 2: Preparing Training Data")
print("-" * 40)
training_data = dashboard._prepare_training_data()
if training_data:
print(f"✓ Training data prepared successfully")
print(f" - OHLCV bars: {len(training_data['ohlcv'])}")
print(f" - Features: {training_data['features']}")
print(f" - Symbol: {training_data['symbol']}")
else:
print("❌ Failed to prepare training data")
# Test 3: Format data for CNN
print("\n🧠 TEST 3: CNN Data Formatting")
print("-" * 40)
if training_data:
cnn_data = dashboard._format_data_for_cnn(training_data)
if cnn_data and 'sequences' in cnn_data:
print(f"✓ CNN data formatted successfully")
print(f" - Sequences shape: {cnn_data['sequences'].shape}")
print(f" - Targets shape: {cnn_data['targets'].shape}")
print(f" - Sequence length: {cnn_data['sequence_length']}")
else:
print("❌ Failed to format CNN data")
# Test 4: Format data for RL
print("\n🤖 TEST 4: RL Data Formatting")
print("-" * 40)
if training_data:
rl_experiences = dashboard._format_data_for_rl(training_data)
if rl_experiences:
print(f"✓ RL experiences formatted successfully")
print(f" - Number of experiences: {len(rl_experiences)}")
print(f" - Experience format: (state, action, reward, next_state, done)")
print(f" - Sample experience shapes: {[len(exp) for exp in rl_experiences[:3]]}")
else:
print("❌ Failed to format RL experiences")
# Test 5: Send training data to models
print("\n📤 TEST 5: Sending Training Data to Models")
print("-" * 40)
success = dashboard.send_training_data_to_models()
print(f"✓ Training data sent: {success}")
if hasattr(dashboard, 'training_stats'):
stats = dashboard.training_stats
print(f" - Total training sessions: {stats.get('total_training_sessions', 0)}")
print(f" - CNN training count: {stats.get('cnn_training_count', 0)}")
print(f" - RL training count: {stats.get('rl_training_count', 0)}")
print(f" - Training data points: {stats.get('training_data_points', 0)}")
# Test 6: Training metrics display
print("\n📈 TEST 6: Training Metrics Display")
print("-" * 40)
training_metrics = dashboard._create_training_metrics()
print(f"✓ Training metrics created: {len(training_metrics)} components")
# Test 7: Model training status
print("\n🔍 TEST 7: Model Training Status")
print("-" * 40)
training_status = dashboard._get_model_training_status()
print(f"✓ Training status retrieved")
print(f" - CNN status: {training_status['cnn']['status']}")
print(f" - CNN accuracy: {training_status['cnn']['accuracy']:.1%}")
print(f" - RL status: {training_status['rl']['status']}")
print(f" - RL win rate: {training_status['rl']['win_rate']:.1%}")
# Test 8: Training events log
print("\n📝 TEST 8: Training Events Log")
print("-" * 40)
training_events = dashboard._get_recent_training_events()
print(f"✓ Training events retrieved: {len(training_events)} events")
# Test 9: Mini training chart
print("\n📊 TEST 9: Mini Training Chart")
print("-" * 40)
try:
training_chart = dashboard._create_mini_training_chart(training_status)
print(f"✓ Mini training chart created")
print(f" - Chart type: {type(training_chart)}")
except Exception as e:
print(f"❌ Error creating training chart: {e}")
# Test 10: Continuous training loop
print("\n🔄 TEST 10: Continuous Training Loop")
print("-" * 40)
print(f"✓ Continuous training active: {getattr(dashboard, 'training_active', False)}")
if hasattr(dashboard, 'training_thread'):
print(f"✓ Training thread alive: {dashboard.training_thread.is_alive()}")
# Test 11: Integration with existing continuous training system
print("\n🔗 TEST 11: Integration with Continuous Training System")
print("-" * 40)
try:
# Check if we can get tick cache for external training
tick_cache = dashboard.get_tick_cache_for_training()
print(f"✓ Tick cache accessible: {len(tick_cache)} ticks")
# Check if we can get 1-second bars
one_second_bars = dashboard.get_one_second_bars()
print(f"✓ 1-second bars accessible: {len(one_second_bars)} bars")
except Exception as e:
print(f"❌ Error accessing training data: {e}")
print("\n" + "="*60)
print("TRAINING INTEGRATION TEST COMPLETED")
print("="*60)
# Summary
print("\n📋 SUMMARY:")
print(f"✓ Dashboard with training integration: WORKING")
print(f"✓ Training data preparation: WORKING")
print(f"✓ CNN data formatting: WORKING")
print(f"✓ RL data formatting: WORKING")
print(f"✓ Training metrics display: WORKING")
print(f"✓ Continuous training: ACTIVE")
print(f"✓ Model status tracking: WORKING")
print(f"✓ Training events logging: WORKING")
return True
except Exception as e:
logger.error(f"Training integration test failed: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = test_training_integration()
if success:
print("\n🎉 All training integration tests passed!")
else:
print("\n❌ Some training integration tests failed!")
sys.exit(1)

View File

@@ -1,59 +0,0 @@
#!/usr/bin/env python3
"""
Test script to check training status functionality
"""
import logging
logging.basicConfig(level=logging.INFO)
print("Testing training status functionality...")
try:
from web.old_archived.scalping_dashboard import create_scalping_dashboard
from core.data_provider import DataProvider
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
print("✅ Imports successful")
# Create components
data_provider = DataProvider()
orchestrator = EnhancedTradingOrchestrator(data_provider)
dashboard = create_scalping_dashboard(data_provider, orchestrator)
print("✅ Dashboard created successfully")
# Test training status
training_status = dashboard._get_model_training_status()
print("\n📊 Training Status:")
print(f"CNN Status: {training_status['cnn']['status']}")
print(f"CNN Accuracy: {training_status['cnn']['accuracy']:.1%}")
print(f"CNN Loss: {training_status['cnn']['loss']:.4f}")
print(f"CNN Epochs: {training_status['cnn']['epochs']}")
print(f"RL Status: {training_status['rl']['status']}")
print(f"RL Win Rate: {training_status['rl']['win_rate']:.1%}")
print(f"RL Episodes: {training_status['rl']['episodes']}")
print(f"RL Memory: {training_status['rl']['memory_size']}")
# Test extrema stats
if hasattr(orchestrator, 'get_extrema_stats'):
extrema_stats = orchestrator.get_extrema_stats()
print(f"\n🎯 Extrema Stats:")
print(f"Total extrema detected: {extrema_stats.get('total_extrema_detected', 0)}")
print(f"Training queue size: {extrema_stats.get('training_queue_size', 0)}")
print("✅ Extrema stats available")
else:
print("❌ Extrema stats not available")
# Test tick cache
print(f"\n📈 Training Data:")
print(f"Tick cache size: {len(dashboard.tick_cache)}")
print(f"1s bars cache size: {len(dashboard.one_second_bars)}")
print(f"Streaming status: {dashboard.is_streaming}")
print("\n✅ All tests completed successfully!")
except Exception as e:
print(f"❌ Error: {e}")
import traceback
traceback.print_exc()

View File

@@ -1,262 +0,0 @@
#!/usr/bin/env python3
"""
Test Universal Data Format Compliance
This script verifies that our enhanced trading system properly feeds
the 5 required timeseries streams to all models:
- ETH/USDT: ticks (1s), 1m, 1h, 1d
- BTC/USDT: ticks (1s) as reference
This is our universal trading system input format.
"""
import asyncio
import logging
import sys
from pathlib import Path
import numpy as np
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import get_config
from core.data_provider import DataProvider
from core.universal_data_adapter import UniversalDataAdapter, UniversalDataStream
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
from training.enhanced_cnn_trainer import EnhancedCNNTrainer
from training.enhanced_rl_trainer import EnhancedRLTrainer
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
async def test_universal_data_format():
"""Test that all components properly use the universal 5-timeseries format"""
logger.info("="*80)
logger.info("🧪 TESTING UNIVERSAL DATA FORMAT COMPLIANCE")
logger.info("="*80)
try:
# Initialize components
config = get_config()
data_provider = DataProvider(config)
# Test 1: Universal Data Adapter
logger.info("\n📊 TEST 1: Universal Data Adapter")
logger.info("-" * 40)
adapter = UniversalDataAdapter(data_provider)
universal_stream = adapter.get_universal_data_stream()
if universal_stream is None:
logger.error("❌ Failed to get universal data stream")
return False
# Validate format
is_valid, issues = adapter.validate_universal_format(universal_stream)
if not is_valid:
logger.error(f"❌ Universal format validation failed: {issues}")
return False
logger.info("✅ Universal Data Adapter: PASSED")
logger.info(f" ETH ticks: {len(universal_stream.eth_ticks)} samples")
logger.info(f" ETH 1m: {len(universal_stream.eth_1m)} candles")
logger.info(f" ETH 1h: {len(universal_stream.eth_1h)} candles")
logger.info(f" ETH 1d: {len(universal_stream.eth_1d)} candles")
logger.info(f" BTC reference: {len(universal_stream.btc_ticks)} samples")
logger.info(f" Data quality: {universal_stream.metadata['data_quality']['overall_score']:.2f}")
# Test 2: Enhanced Orchestrator
logger.info("\n🎯 TEST 2: Enhanced Orchestrator")
logger.info("-" * 40)
orchestrator = EnhancedTradingOrchestrator(data_provider)
# Test that orchestrator uses universal adapter
if not hasattr(orchestrator, 'universal_adapter'):
logger.error("❌ Orchestrator missing universal_adapter")
return False
# Test coordinated decisions
decisions = await orchestrator.make_coordinated_decisions()
logger.info("✅ Enhanced Orchestrator: PASSED")
logger.info(f" Generated {len(decisions)} decisions")
logger.info(f" Universal adapter: {type(orchestrator.universal_adapter).__name__}")
for symbol, decision in decisions.items():
if decision:
logger.info(f" {symbol}: {decision.action} (confidence: {decision.confidence:.2f})")
# Test 3: CNN Model Data Format
logger.info("\n🧠 TEST 3: CNN Model Data Format")
logger.info("-" * 40)
# Format data for CNN
cnn_data = adapter.format_for_model(universal_stream, 'cnn')
required_cnn_keys = ['eth_ticks', 'eth_1m', 'eth_1h', 'eth_1d', 'btc_ticks']
missing_keys = [key for key in required_cnn_keys if key not in cnn_data]
if missing_keys:
logger.error(f"❌ CNN data missing keys: {missing_keys}")
return False
logger.info("✅ CNN Model Data Format: PASSED")
for key, data in cnn_data.items():
if isinstance(data, np.ndarray):
logger.info(f" {key}: shape {data.shape}")
else:
logger.info(f" {key}: {type(data)}")
# Test 4: RL Model Data Format
logger.info("\n🤖 TEST 4: RL Model Data Format")
logger.info("-" * 40)
# Format data for RL
rl_data = adapter.format_for_model(universal_stream, 'rl')
if 'state_vector' not in rl_data:
logger.error("❌ RL data missing state_vector")
return False
state_vector = rl_data['state_vector']
if not isinstance(state_vector, np.ndarray):
logger.error("❌ RL state_vector is not numpy array")
return False
logger.info("✅ RL Model Data Format: PASSED")
logger.info(f" State vector shape: {state_vector.shape}")
logger.info(f" State vector size: {len(state_vector)} features")
# Test 5: CNN Trainer Integration
logger.info("\n🎓 TEST 5: CNN Trainer Integration")
logger.info("-" * 40)
try:
cnn_trainer = EnhancedCNNTrainer(config, orchestrator)
logger.info("✅ CNN Trainer Integration: PASSED")
logger.info(f" Model timeframes: {cnn_trainer.model.timeframes}")
logger.info(f" Model device: {cnn_trainer.model.device}")
except Exception as e:
logger.error(f"❌ CNN Trainer Integration failed: {e}")
return False
# Test 6: RL Trainer Integration
logger.info("\n🎮 TEST 6: RL Trainer Integration")
logger.info("-" * 40)
try:
rl_trainer = EnhancedRLTrainer(config, orchestrator)
logger.info("✅ RL Trainer Integration: PASSED")
logger.info(f" RL agents: {len(rl_trainer.agents)}")
for symbol, agent in rl_trainer.agents.items():
logger.info(f" {symbol} agent: {type(agent).__name__}")
except Exception as e:
logger.error(f"❌ RL Trainer Integration failed: {e}")
return False
# Test 7: Data Flow Verification
logger.info("\n🔄 TEST 7: Data Flow Verification")
logger.info("-" * 40)
# Verify that models receive the correct data format
test_predictions = await orchestrator._get_enhanced_predictions_universal(
'ETH/USDT',
list(orchestrator.market_states['ETH/USDT'])[-1] if orchestrator.market_states['ETH/USDT'] else None,
universal_stream
)
if test_predictions:
logger.info("✅ Data Flow Verification: PASSED")
for pred in test_predictions:
logger.info(f" Model: {pred.model_name}")
logger.info(f" Action: {pred.overall_action}")
logger.info(f" Confidence: {pred.overall_confidence:.2f}")
logger.info(f" Timeframes: {len(pred.timeframe_predictions)}")
else:
logger.warning("⚠️ No predictions generated (may be normal if no models loaded)")
# Test 8: Configuration Compliance
logger.info("\n⚙️ TEST 8: Configuration Compliance")
logger.info("-" * 40)
# Check that config matches universal format
expected_symbols = ['ETH/USDT', 'BTC/USDT']
expected_timeframes = ['1s', '1m', '1h', '1d']
config_symbols = config.symbols
config_timeframes = config.timeframes
symbols_match = all(symbol in config_symbols for symbol in expected_symbols)
timeframes_match = all(tf in config_timeframes for tf in expected_timeframes)
if not symbols_match:
logger.warning(f"⚠️ Config symbols may not match universal format")
logger.warning(f" Expected: {expected_symbols}")
logger.warning(f" Config: {config_symbols}")
if not timeframes_match:
logger.warning(f"⚠️ Config timeframes may not match universal format")
logger.warning(f" Expected: {expected_timeframes}")
logger.warning(f" Config: {config_timeframes}")
if symbols_match and timeframes_match:
logger.info("✅ Configuration Compliance: PASSED")
else:
logger.info("⚠️ Configuration Compliance: PARTIAL")
logger.info(f" Symbols: {config_symbols}")
logger.info(f" Timeframes: {config_timeframes}")
# Final Summary
logger.info("\n" + "="*80)
logger.info("🎉 UNIVERSAL DATA FORMAT TEST SUMMARY")
logger.info("="*80)
logger.info("✅ All core tests PASSED!")
logger.info("")
logger.info("📋 VERIFIED COMPLIANCE:")
logger.info(" ✓ Universal Data Adapter working")
logger.info(" ✓ Enhanced Orchestrator using universal format")
logger.info(" ✓ CNN models receive 5 timeseries streams")
logger.info(" ✓ RL models receive combined state vector")
logger.info(" ✓ Trainers properly integrated")
logger.info(" ✓ Data flow verified")
logger.info("")
logger.info("🎯 UNIVERSAL FORMAT ACTIVE:")
logger.info(" 1. ETH/USDT ticks (1s) ✓")
logger.info(" 2. ETH/USDT 1m ✓")
logger.info(" 3. ETH/USDT 1h ✓")
logger.info(" 4. ETH/USDT 1d ✓")
logger.info(" 5. BTC/USDT reference ticks ✓")
logger.info("")
logger.info("🚀 Your enhanced trading system is ready with universal data format!")
logger.info("="*80)
return True
except Exception as e:
logger.error(f"❌ Universal data format test failed: {e}")
import traceback
logger.error(traceback.format_exc())
return False
async def main():
"""Main test function"""
logger.info("🚀 Starting Universal Data Format Compliance Test...")
success = await test_universal_data_format()
if success:
logger.info("\n🎉 All tests passed! Universal data format is properly implemented.")
logger.info("Your enhanced trading system respects the 5-timeseries input format.")
else:
logger.error("\n💥 Tests failed! Please check the universal data format implementation.")
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,177 +0,0 @@
#!/usr/bin/env python3
"""
Test Universal Data Stream Integration with Dashboard
This script validates that:
1. CleanTradingDashboard properly subscribes to UnifiedDataStream
2. All 5 timeseries are properly received and processed
3. Data flows correctly from provider -> adapter -> stream -> dashboard
4. Consumer callback functions work as expected
"""
import asyncio
import logging
import sys
import time
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import get_config
from core.data_provider import DataProvider
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
from core.trading_executor import TradingExecutor
from web.clean_dashboard import CleanTradingDashboard
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
async def test_universal_stream_integration():
"""Test Universal Data Stream integration with dashboard"""
logger.info("="*80)
logger.info("🧪 TESTING UNIVERSAL DATA STREAM INTEGRATION")
logger.info("="*80)
try:
# Initialize components
logger.info("\n📦 STEP 1: Initialize Components")
logger.info("-" * 40)
config = get_config()
data_provider = DataProvider()
orchestrator = EnhancedTradingOrchestrator(
data_provider=data_provider,
symbols=['ETH/USDT', 'BTC/USDT'],
enhanced_rl_training=True
)
trading_executor = TradingExecutor()
logger.info("✅ Core components initialized")
# Initialize dashboard with Universal Data Stream
logger.info("\n📊 STEP 2: Initialize Dashboard with Universal Stream")
logger.info("-" * 40)
dashboard = CleanTradingDashboard(
data_provider=data_provider,
orchestrator=orchestrator,
trading_executor=trading_executor
)
# Check Universal Stream initialization
if hasattr(dashboard, 'unified_stream') and dashboard.unified_stream:
logger.info("✅ Universal Data Stream initialized successfully")
logger.info(f"📋 Consumer ID: {dashboard.stream_consumer_id}")
else:
logger.error("❌ Universal Data Stream not initialized")
return False
# Test consumer registration
logger.info("\n🔗 STEP 3: Validate Consumer Registration")
logger.info("-" * 40)
stream_stats = dashboard.unified_stream.get_stream_stats()
logger.info(f"📊 Stream Stats: {stream_stats}")
if stream_stats['total_consumers'] > 0:
logger.info(f"{stream_stats['total_consumers']} consumers registered")
else:
logger.warning("⚠️ No consumers registered")
# Test data callback
logger.info("\n📡 STEP 4: Test Data Callback")
logger.info("-" * 40)
# Create test data packet
test_data = {
'timestamp': time.time(),
'consumer_id': dashboard.stream_consumer_id,
'consumer_name': 'CleanTradingDashboard',
'ticks': [
{'symbol': 'ETHUSDT', 'price': 3000.0, 'volume': 1.5, 'timestamp': time.time()},
{'symbol': 'ETHUSDT', 'price': 3001.0, 'volume': 2.0, 'timestamp': time.time()},
],
'ohlcv': {'one_second_bars': [], 'multi_timeframe': {
'ETH/USDT': {
'1s': [{'timestamp': time.time(), 'open': 3000, 'high': 3002, 'low': 2999, 'close': 3001, 'volume': 10}],
'1m': [{'timestamp': time.time(), 'open': 2990, 'high': 3010, 'low': 2985, 'close': 3001, 'volume': 100}],
'1h': [{'timestamp': time.time(), 'open': 2900, 'high': 3050, 'low': 2880, 'close': 3001, 'volume': 1000}],
'1d': [{'timestamp': time.time(), 'open': 2800, 'high': 3200, 'low': 2750, 'close': 3001, 'volume': 10000}]
},
'BTC/USDT': {
'1s': [{'timestamp': time.time(), 'open': 65000, 'high': 65020, 'low': 64980, 'close': 65010, 'volume': 0.5}]
}
}},
'training_data': {'market_state': 'test', 'features': []},
'ui_data': {'formatted_data': 'test_ui_data'}
}
# Test callback manually
try:
dashboard._handle_unified_stream_data(test_data)
logger.info("✅ Data callback executed successfully")
# Check if data was processed
if hasattr(dashboard, 'current_prices') and 'ETH/USDT' in dashboard.current_prices:
logger.info(f"✅ Price updated: ETH/USDT = ${dashboard.current_prices['ETH/USDT']}")
else:
logger.warning("⚠️ Prices not updated in dashboard")
except Exception as e:
logger.error(f"❌ Data callback failed: {e}")
return False
# Test Universal Data Adapter
logger.info("\n🔄 STEP 5: Test Universal Data Adapter")
logger.info("-" * 40)
if hasattr(orchestrator, 'universal_adapter'):
universal_stream = orchestrator.universal_adapter.get_universal_data_stream()
if universal_stream:
logger.info("✅ Universal Data Adapter working")
logger.info(f"📊 ETH ticks: {len(universal_stream.eth_ticks)} samples")
logger.info(f"📊 ETH 1m: {len(universal_stream.eth_1m)} candles")
logger.info(f"📊 ETH 1h: {len(universal_stream.eth_1h)} candles")
logger.info(f"📊 ETH 1d: {len(universal_stream.eth_1d)} candles")
logger.info(f"📊 BTC ticks: {len(universal_stream.btc_ticks)} samples")
# Validate format
is_valid, issues = orchestrator.universal_adapter.validate_universal_format(universal_stream)
if is_valid:
logger.info("✅ Universal format validation passed")
else:
logger.warning(f"⚠️ Format issues: {issues}")
else:
logger.error("❌ Universal Data Adapter failed to get stream")
return False
else:
logger.error("❌ Universal Data Adapter not found in orchestrator")
return False
# Summary
logger.info("\n🎯 SUMMARY")
logger.info("-" * 40)
logger.info("✅ Universal Data Stream properly integrated")
logger.info("✅ Dashboard subscribes as consumer")
logger.info("✅ All 5 timeseries format validated")
logger.info("✅ Data callback processing works")
logger.info("✅ Universal Data Adapter functional")
logger.info("\n🏆 INTEGRATION TEST PASSED")
return True
except Exception as e:
logger.error(f"❌ Integration test failed: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = asyncio.run(test_universal_stream_integration())
sys.exit(0 if success else 1)