cleanup, cob ladder still broken
This commit is contained in:
@ -1,93 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to check Binance data availability
|
||||
"""
|
||||
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_binance_data():
|
||||
"""Test Binance data fetching"""
|
||||
print("="*60)
|
||||
print("BINANCE DATA TEST")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
print("1. Testing DataProvider import...")
|
||||
from core.data_provider import DataProvider
|
||||
print(" ✅ DataProvider imported successfully")
|
||||
|
||||
print("\n2. Creating DataProvider instance...")
|
||||
dp = DataProvider()
|
||||
print(f" ✅ DataProvider created")
|
||||
print(f" Symbols: {dp.symbols}")
|
||||
print(f" Timeframes: {dp.timeframes}")
|
||||
|
||||
print("\n3. Testing historical data fetch...")
|
||||
try:
|
||||
data = dp.get_historical_data('ETH/USDT', '1m', 10)
|
||||
if data is not None:
|
||||
print(f" ✅ Historical data fetched: {data.shape}")
|
||||
print(f" Latest price: ${data['close'].iloc[-1]:.2f}")
|
||||
print(f" Data range: {data.index[0]} to {data.index[-1]}")
|
||||
else:
|
||||
print(" ❌ No historical data returned")
|
||||
except Exception as e:
|
||||
print(f" ❌ Error fetching historical data: {e}")
|
||||
|
||||
print("\n4. Testing current price...")
|
||||
try:
|
||||
price = dp.get_current_price('ETH/USDT')
|
||||
if price:
|
||||
print(f" ✅ Current price: ${price:.2f}")
|
||||
else:
|
||||
print(" ❌ No current price available")
|
||||
except Exception as e:
|
||||
print(f" ❌ Error getting current price: {e}")
|
||||
|
||||
print("\n5. Testing real-time streaming setup...")
|
||||
try:
|
||||
# Check if streaming can be initialized
|
||||
print(f" Streaming status: {dp.is_streaming}")
|
||||
print(" ✅ Real-time streaming setup ready")
|
||||
except Exception as e:
|
||||
print(f" ❌ Real-time streaming error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to import or create DataProvider: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def test_dashboard_connection():
|
||||
"""Test if dashboard can connect to data"""
|
||||
print("\n" + "="*60)
|
||||
print("DASHBOARD CONNECTION TEST")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
print("1. Testing dashboard imports...")
|
||||
from web.old_archived.scalping_dashboard import ScalpingDashboard
|
||||
print(" ✅ ScalpingDashboard imported")
|
||||
|
||||
print("\n2. Testing data provider connection...")
|
||||
# Check if the dashboard can create a data provider
|
||||
dashboard = ScalpingDashboard()
|
||||
if hasattr(dashboard, 'data_provider'):
|
||||
print(" ✅ Dashboard has data_provider")
|
||||
print(f" Data provider symbols: {dashboard.data_provider.symbols}")
|
||||
else:
|
||||
print(" ❌ Dashboard missing data_provider")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Dashboard connection error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_binance_data()
|
||||
test_dashboard_connection()
|
@ -1,221 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test callback registration to identify the issue
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
import dash
|
||||
from dash import dcc, html, Input, Output
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_simple_callback():
|
||||
"""Test a simple callback registration"""
|
||||
logger.info("Testing simple callback registration...")
|
||||
|
||||
app = dash.Dash(__name__)
|
||||
|
||||
app.layout = html.Div([
|
||||
html.H1("Callback Registration Test"),
|
||||
html.Div(id="output", children="Initial"),
|
||||
dcc.Interval(id="interval", interval=1000, n_intervals=0)
|
||||
])
|
||||
|
||||
@app.callback(
|
||||
Output('output', 'children'),
|
||||
Input('interval', 'n_intervals')
|
||||
)
|
||||
def update_output(n_intervals):
|
||||
logger.info(f"Callback triggered: {n_intervals}")
|
||||
return f"Update #{n_intervals}"
|
||||
|
||||
logger.info("Simple callback registered successfully")
|
||||
|
||||
# Check if callback is in the callback map
|
||||
logger.info(f"Callback map keys: {list(app.callback_map.keys())}")
|
||||
|
||||
return app
|
||||
|
||||
def test_complex_callback():
|
||||
"""Test a complex callback like the dashboard"""
|
||||
logger.info("Testing complex callback registration...")
|
||||
|
||||
app = dash.Dash(__name__)
|
||||
|
||||
app.layout = html.Div([
|
||||
html.H1("Complex Callback Test"),
|
||||
html.Div(id="current-balance", children="$100.00"),
|
||||
html.Div(id="session-duration", children="00:00:00"),
|
||||
html.Div(id="status", children="Starting"),
|
||||
dcc.Graph(id="chart"),
|
||||
dcc.Interval(id="ultra-fast-interval", interval=1000, n_intervals=0)
|
||||
])
|
||||
|
||||
@app.callback(
|
||||
[
|
||||
Output('current-balance', 'children'),
|
||||
Output('session-duration', 'children'),
|
||||
Output('status', 'children'),
|
||||
Output('chart', 'figure')
|
||||
],
|
||||
[Input('ultra-fast-interval', 'n_intervals')]
|
||||
)
|
||||
def update_dashboard(n_intervals):
|
||||
logger.info(f"Complex callback triggered: {n_intervals}")
|
||||
|
||||
import plotly.graph_objects as go
|
||||
fig = go.Figure()
|
||||
fig.add_trace(go.Scatter(x=[1, 2, 3], y=[1, 2, 3], mode='lines'))
|
||||
fig.update_layout(template="plotly_dark")
|
||||
|
||||
return f"${100 + n_intervals:.2f}", f"00:00:{n_intervals:02d}", "Running", fig
|
||||
|
||||
logger.info("Complex callback registered successfully")
|
||||
|
||||
# Check if callback is in the callback map
|
||||
logger.info(f"Callback map keys: {list(app.callback_map.keys())}")
|
||||
|
||||
return app
|
||||
|
||||
def test_dashboard_callback():
|
||||
"""Test the exact dashboard callback structure"""
|
||||
logger.info("Testing dashboard callback structure...")
|
||||
|
||||
try:
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
|
||||
app = dash.Dash(__name__)
|
||||
|
||||
# Minimal layout with dashboard elements
|
||||
app.layout = html.Div([
|
||||
html.H1("Dashboard Callback Test"),
|
||||
html.Div(id="current-balance", children="$100.00"),
|
||||
html.Div(id="session-duration", children="00:00:00"),
|
||||
html.Div(id="open-positions", children="0"),
|
||||
html.Div(id="live-pnl", children="$0.00"),
|
||||
html.Div(id="win-rate", children="0%"),
|
||||
html.Div(id="total-trades", children="0"),
|
||||
html.Div(id="last-action", children="WAITING"),
|
||||
html.Div(id="eth-price", children="Loading..."),
|
||||
html.Div(id="btc-price", children="Loading..."),
|
||||
dcc.Graph(id="main-eth-1s-chart"),
|
||||
dcc.Graph(id="eth-1m-chart"),
|
||||
dcc.Graph(id="eth-1h-chart"),
|
||||
dcc.Graph(id="eth-1d-chart"),
|
||||
dcc.Graph(id="btc-1s-chart"),
|
||||
html.Div(id="actions-log", children="No actions yet"),
|
||||
html.Div(id="debug-status", children="Debug info"),
|
||||
dcc.Interval(id="ultra-fast-interval", interval=1000, n_intervals=0)
|
||||
])
|
||||
|
||||
@app.callback(
|
||||
[
|
||||
Output('current-balance', 'children'),
|
||||
Output('session-duration', 'children'),
|
||||
Output('open-positions', 'children'),
|
||||
Output('live-pnl', 'children'),
|
||||
Output('win-rate', 'children'),
|
||||
Output('total-trades', 'children'),
|
||||
Output('last-action', 'children'),
|
||||
Output('eth-price', 'children'),
|
||||
Output('btc-price', 'children'),
|
||||
Output('main-eth-1s-chart', 'figure'),
|
||||
Output('eth-1m-chart', 'figure'),
|
||||
Output('eth-1h-chart', 'figure'),
|
||||
Output('eth-1d-chart', 'figure'),
|
||||
Output('btc-1s-chart', 'figure'),
|
||||
Output('actions-log', 'children'),
|
||||
Output('debug-status', 'children')
|
||||
],
|
||||
[Input('ultra-fast-interval', 'n_intervals')]
|
||||
)
|
||||
def update_dashboard_test(n_intervals):
|
||||
logger.info(f"Dashboard callback triggered: {n_intervals}")
|
||||
|
||||
import plotly.graph_objects as go
|
||||
from datetime import datetime
|
||||
|
||||
# Create empty figure
|
||||
empty_fig = go.Figure()
|
||||
empty_fig.update_layout(template="plotly_dark")
|
||||
|
||||
debug_status = html.Div([
|
||||
html.P(f"Test Callback #{n_intervals} at {datetime.now().strftime('%H:%M:%S')}")
|
||||
])
|
||||
|
||||
return (
|
||||
f"${100 + n_intervals:.2f}", # current-balance
|
||||
f"00:00:{n_intervals:02d}", # session-duration
|
||||
"0", # open-positions
|
||||
f"${n_intervals:+.2f}", # live-pnl
|
||||
"75%", # win-rate
|
||||
str(n_intervals), # total-trades
|
||||
"TEST", # last-action
|
||||
"$3500.00", # eth-price
|
||||
"$65000.00", # btc-price
|
||||
empty_fig, # main-eth-1s-chart
|
||||
empty_fig, # eth-1m-chart
|
||||
empty_fig, # eth-1h-chart
|
||||
empty_fig, # eth-1d-chart
|
||||
empty_fig, # btc-1s-chart
|
||||
f"Test action #{n_intervals}", # actions-log
|
||||
debug_status # debug-status
|
||||
)
|
||||
|
||||
logger.info("Dashboard callback registered successfully")
|
||||
logger.info(f"Callback map keys: {list(app.callback_map.keys())}")
|
||||
|
||||
return app
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing dashboard callback: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
logger.info("Starting callback registration tests...")
|
||||
|
||||
# Test 1: Simple callback
|
||||
try:
|
||||
simple_app = test_simple_callback()
|
||||
logger.info("✅ Simple callback test passed")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Simple callback test failed: {e}")
|
||||
|
||||
# Test 2: Complex callback
|
||||
try:
|
||||
complex_app = test_complex_callback()
|
||||
logger.info("✅ Complex callback test passed")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Complex callback test failed: {e}")
|
||||
|
||||
# Test 3: Dashboard callback
|
||||
try:
|
||||
dashboard_app = test_dashboard_callback()
|
||||
if dashboard_app:
|
||||
logger.info("✅ Dashboard callback test passed")
|
||||
|
||||
# Run the dashboard test
|
||||
logger.info("Starting dashboard test server on port 8054...")
|
||||
dashboard_app.run(host='127.0.0.1', port=8054, debug=True)
|
||||
else:
|
||||
logger.error("❌ Dashboard callback test failed")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Dashboard callback test failed: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,22 +0,0 @@
|
||||
import requests
|
||||
import json
|
||||
|
||||
def test_callback():
|
||||
try:
|
||||
url = 'http://127.0.0.1:8051/_dash-update-component'
|
||||
data = {
|
||||
"output": "current-balance.children",
|
||||
"inputs": [{"id": "ultra-fast-interval", "property": "n_intervals", "value": 1}],
|
||||
"changedPropIds": ["ultra-fast-interval.n_intervals"],
|
||||
"state": []
|
||||
}
|
||||
|
||||
response = requests.post(url, json=data, timeout=10)
|
||||
print(f"Status: {response.status_code}")
|
||||
print(f"Response: {response.text[:1000]}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_callback()
|
@ -1,75 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test callback structure to verify it works
|
||||
"""
|
||||
|
||||
import dash
|
||||
from dash import dcc, html, Input, Output
|
||||
import plotly.graph_objects as go
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create Dash app
|
||||
app = dash.Dash(__name__)
|
||||
|
||||
# Simple layout matching the enhanced dashboard structure
|
||||
app.layout = html.Div([
|
||||
html.H1("Callback Structure Test"),
|
||||
html.Div(id="test-output-1"),
|
||||
html.Div(id="test-output-2"),
|
||||
html.Div(id="test-output-3"),
|
||||
dcc.Graph(id="test-chart"),
|
||||
dcc.Interval(id='test-interval', interval=3000, n_intervals=0)
|
||||
])
|
||||
|
||||
# Callback using the EXACT same structure as enhanced dashboard
|
||||
@app.callback(
|
||||
[
|
||||
Output('test-output-1', 'children'),
|
||||
Output('test-output-2', 'children'),
|
||||
Output('test-output-3', 'children'),
|
||||
Output('test-chart', 'figure')
|
||||
],
|
||||
[Input('test-interval', 'n_intervals')]
|
||||
)
|
||||
def update_test_dashboard(n_intervals):
|
||||
"""Test callback with same structure as enhanced dashboard"""
|
||||
try:
|
||||
logger.info(f"Test callback triggered: {n_intervals}")
|
||||
|
||||
# Simple outputs
|
||||
output1 = f"Output 1: {n_intervals}"
|
||||
output2 = f"Output 2: {datetime.now().strftime('%H:%M:%S')}"
|
||||
output3 = f"Output 3: Working"
|
||||
|
||||
# Simple chart
|
||||
fig = go.Figure()
|
||||
fig.add_trace(go.Scatter(
|
||||
x=[1, 2, 3, 4, 5],
|
||||
y=[n_intervals, n_intervals+1, n_intervals+2, n_intervals+1, n_intervals],
|
||||
mode='lines',
|
||||
name='Test Data'
|
||||
))
|
||||
fig.update_layout(
|
||||
title=f"Test Chart - Update {n_intervals}",
|
||||
template="plotly_dark"
|
||||
)
|
||||
|
||||
logger.info(f"Returning: {output1}, {output2}, {output3}, <Figure>")
|
||||
return output1, output2, output3, fig
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in test callback: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
# Return safe fallback
|
||||
return f"Error: {str(e)}", "Error", "Error", go.Figure()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting callback structure test on port 8053...")
|
||||
app.run(host='127.0.0.1', port=8053, debug=True)
|
@ -1,101 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Dashboard Callback - Simple test to verify Dash callbacks work
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
import dash
|
||||
from dash import dcc, html, Input, Output
|
||||
import plotly.graph_objects as go
|
||||
from datetime import datetime
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def create_test_dashboard():
|
||||
"""Create a simple test dashboard to verify callbacks work"""
|
||||
|
||||
app = dash.Dash(__name__)
|
||||
|
||||
app.layout = html.Div([
|
||||
html.H1("🧪 Test Dashboard - Callback Verification", className="text-center"),
|
||||
html.Div([
|
||||
html.H3(id="current-time", className="text-center"),
|
||||
html.H4(id="counter", className="text-center"),
|
||||
dcc.Graph(id="test-chart")
|
||||
]),
|
||||
dcc.Interval(
|
||||
id='test-interval',
|
||||
interval=1000, # 1 second
|
||||
n_intervals=0
|
||||
)
|
||||
])
|
||||
|
||||
@app.callback(
|
||||
[
|
||||
Output('current-time', 'children'),
|
||||
Output('counter', 'children'),
|
||||
Output('test-chart', 'figure')
|
||||
],
|
||||
[Input('test-interval', 'n_intervals')]
|
||||
)
|
||||
def update_test_dashboard(n_intervals):
|
||||
"""Test callback function"""
|
||||
try:
|
||||
logger.info(f"🔄 Test callback triggered, interval: {n_intervals}")
|
||||
|
||||
current_time = datetime.now().strftime("%H:%M:%S")
|
||||
counter = f"Updates: {n_intervals}"
|
||||
|
||||
# Create simple test chart
|
||||
fig = go.Figure()
|
||||
fig.add_trace(go.Scatter(
|
||||
x=list(range(n_intervals + 1)),
|
||||
y=[i**2 for i in range(n_intervals + 1)],
|
||||
mode='lines+markers',
|
||||
name='Test Data'
|
||||
))
|
||||
fig.update_layout(
|
||||
title=f"Test Chart - Update #{n_intervals}",
|
||||
template="plotly_dark"
|
||||
)
|
||||
|
||||
return current_time, counter, fig
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in test callback: {e}")
|
||||
return "Error", "Error", {}
|
||||
|
||||
return app
|
||||
|
||||
def main():
|
||||
"""Run the test dashboard"""
|
||||
logger.info("🧪 Starting test dashboard...")
|
||||
|
||||
try:
|
||||
app = create_test_dashboard()
|
||||
logger.info("✅ Test dashboard created")
|
||||
|
||||
logger.info("🚀 Starting test dashboard on http://127.0.0.1:8052")
|
||||
logger.info("If you see updates every second, callbacks are working!")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
app.run(host='127.0.0.1', port=8052, debug=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Test dashboard stopped by user")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,110 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to make direct requests to the dashboard's callback endpoint
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
|
||||
def test_dashboard_callback():
|
||||
"""Test the dashboard callback endpoint directly"""
|
||||
|
||||
dashboard_url = "http://127.0.0.1:8054"
|
||||
callback_url = f"{dashboard_url}/_dash-update-component"
|
||||
|
||||
print(f"Testing dashboard at {dashboard_url}")
|
||||
|
||||
# First, check if dashboard is running
|
||||
try:
|
||||
response = requests.get(dashboard_url, timeout=5)
|
||||
print(f"Dashboard status: {response.status_code}")
|
||||
if response.status_code != 200:
|
||||
print("Dashboard not responding properly")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error connecting to dashboard: {e}")
|
||||
return
|
||||
|
||||
# Test callback request for dashboard test
|
||||
callback_data = {
|
||||
"output": "current-balance.children",
|
||||
"outputs": [
|
||||
{"id": "current-balance", "property": "children"},
|
||||
{"id": "session-duration", "property": "children"},
|
||||
{"id": "open-positions", "property": "children"},
|
||||
{"id": "live-pnl", "property": "children"},
|
||||
{"id": "win-rate", "property": "children"},
|
||||
{"id": "total-trades", "property": "children"},
|
||||
{"id": "last-action", "property": "children"},
|
||||
{"id": "eth-price", "property": "children"},
|
||||
{"id": "btc-price", "property": "children"},
|
||||
{"id": "main-eth-1s-chart", "property": "figure"},
|
||||
{"id": "eth-1m-chart", "property": "figure"},
|
||||
{"id": "eth-1h-chart", "property": "figure"},
|
||||
{"id": "eth-1d-chart", "property": "figure"},
|
||||
{"id": "btc-1s-chart", "property": "figure"},
|
||||
{"id": "actions-log", "property": "children"},
|
||||
{"id": "debug-status", "property": "children"}
|
||||
],
|
||||
"inputs": [
|
||||
{"id": "ultra-fast-interval", "property": "n_intervals", "value": 1}
|
||||
],
|
||||
"changedPropIds": ["ultra-fast-interval.n_intervals"],
|
||||
"state": []
|
||||
}
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
print("\nTesting callback request...")
|
||||
try:
|
||||
response = requests.post(
|
||||
callback_url,
|
||||
data=json.dumps(callback_data),
|
||||
headers=headers,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
print(f"Callback response status: {response.status_code}")
|
||||
print(f"Response headers: {dict(response.headers)}")
|
||||
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
response_data = response.json()
|
||||
print(f"Response data keys: {list(response_data.keys()) if isinstance(response_data, dict) else 'Not a dict'}")
|
||||
print(f"Response data type: {type(response_data)}")
|
||||
|
||||
if isinstance(response_data, dict) and 'response' in response_data:
|
||||
print(f"Response contains {len(response_data['response'])} items")
|
||||
for i, item in enumerate(response_data['response'][:3]): # Show first 3 items
|
||||
print(f" Item {i}: {type(item)} - {str(item)[:100]}...")
|
||||
else:
|
||||
print(f"Full response: {str(response_data)[:500]}...")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing JSON response: {e}")
|
||||
print(f"Raw response: {response.text[:500]}...")
|
||||
else:
|
||||
print(f"Error response: {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error making callback request: {e}")
|
||||
|
||||
def monitor_dashboard():
|
||||
"""Monitor dashboard callback requests"""
|
||||
print("Monitoring dashboard callback requests...")
|
||||
print("Press Ctrl+C to stop")
|
||||
|
||||
try:
|
||||
for i in range(10): # Test 10 times
|
||||
print(f"\n--- Test {i+1} ---")
|
||||
test_dashboard_callback()
|
||||
time.sleep(2)
|
||||
except KeyboardInterrupt:
|
||||
print("\nMonitoring stopped")
|
||||
|
||||
if __name__ == "__main__":
|
||||
monitor_dashboard()
|
@ -1,103 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple Dashboard Test - Isolate dashboard startup issues
|
||||
"""
|
||||
|
||||
import os
|
||||
# Fix OpenMP library conflicts before importing other modules
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
||||
os.environ['OMP_NUM_THREADS'] = '4'
|
||||
|
||||
import sys
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Setup basic logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_dashboard_startup():
|
||||
"""Test dashboard creation and startup"""
|
||||
try:
|
||||
logger.info("=" * 50)
|
||||
logger.info("TESTING DASHBOARD STARTUP")
|
||||
logger.info("=" * 50)
|
||||
|
||||
# Test imports first
|
||||
logger.info("Step 1: Testing imports...")
|
||||
from core.config import get_config, setup_logging
|
||||
from core.data_provider import DataProvider
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.trading_executor import TradingExecutor
|
||||
logger.info("✓ Core imports successful")
|
||||
|
||||
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
|
||||
logger.info("✓ Dashboard import successful")
|
||||
|
||||
# Test configuration
|
||||
logger.info("Step 2: Testing configuration...")
|
||||
setup_logging()
|
||||
config = get_config()
|
||||
logger.info("✓ Configuration loaded")
|
||||
|
||||
# Test core component creation
|
||||
logger.info("Step 3: Testing core component creation...")
|
||||
data_provider = DataProvider()
|
||||
logger.info("✓ DataProvider created")
|
||||
|
||||
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
||||
logger.info("✓ TradingOrchestrator created")
|
||||
|
||||
trading_executor = TradingExecutor()
|
||||
logger.info("✓ TradingExecutor created")
|
||||
|
||||
# Test dashboard creation
|
||||
logger.info("Step 4: Testing dashboard creation...")
|
||||
dashboard = TradingDashboard(
|
||||
data_provider=data_provider,
|
||||
orchestrator=orchestrator,
|
||||
trading_executor=trading_executor
|
||||
)
|
||||
logger.info("✓ TradingDashboard created successfully")
|
||||
|
||||
# Test dashboard startup
|
||||
logger.info("Step 5: Testing dashboard server startup...")
|
||||
logger.info("Dashboard will start on http://127.0.0.1:8052")
|
||||
logger.info("Press Ctrl+C to stop the test")
|
||||
|
||||
# Run the dashboard
|
||||
dashboard.app.run(
|
||||
host='127.0.0.1',
|
||||
port=8052,
|
||||
debug=False,
|
||||
use_reloader=False
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Dashboard test failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
success = test_dashboard_startup()
|
||||
if success:
|
||||
logger.info("✓ Dashboard test completed successfully")
|
||||
else:
|
||||
logger.error("❌ Dashboard test failed")
|
||||
sys.exit(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Dashboard test interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in dashboard test: {e}")
|
||||
sys.exit(1)
|
@ -1,66 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Dashboard Startup - Debug the scalping dashboard startup issue
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_dashboard_startup():
|
||||
"""Test dashboard startup with detailed error reporting"""
|
||||
try:
|
||||
logger.info("Testing dashboard startup...")
|
||||
|
||||
# Test imports
|
||||
logger.info("Testing imports...")
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from web.old_archived.scalping_dashboard import create_scalping_dashboard
|
||||
logger.info("✅ All imports successful")
|
||||
|
||||
# Test data provider
|
||||
logger.info("Creating data provider...")
|
||||
dp = DataProvider()
|
||||
logger.info("✅ Data provider created")
|
||||
|
||||
# Test orchestrator
|
||||
logger.info("Creating orchestrator...")
|
||||
orch = EnhancedTradingOrchestrator(dp)
|
||||
logger.info("✅ Orchestrator created")
|
||||
|
||||
# Test dashboard creation
|
||||
logger.info("Creating dashboard...")
|
||||
dashboard = create_scalping_dashboard(dp, orch)
|
||||
logger.info("✅ Dashboard created successfully")
|
||||
|
||||
# Test data fetching
|
||||
logger.info("Testing data fetching...")
|
||||
test_data = dp.get_historical_data('ETH/USDT', '1m', limit=5)
|
||||
if test_data is not None and not test_data.empty:
|
||||
logger.info(f"✅ Data fetching works: {len(test_data)} candles")
|
||||
else:
|
||||
logger.warning("⚠️ No data returned from data provider")
|
||||
|
||||
# Start dashboard
|
||||
logger.info("Starting dashboard on http://127.0.0.1:8051")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
dashboard.run(host='127.0.0.1', port=8051, debug=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Dashboard stopped by user")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dashboard_startup()
|
@ -1,201 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Enhanced COB Integration with RL and CNN Models
|
||||
|
||||
This script tests the integration of Consolidated Order Book (COB) data
|
||||
with the real-time RL and CNN training pipeline.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import setup_logging
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.cob_integration import COBIntegration
|
||||
|
||||
# Setup logging
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class COBMLIntegrationTester:
|
||||
"""Test COB integration with ML models"""
|
||||
|
||||
def __init__(self):
|
||||
self.symbols = ['BTC/USDT', 'ETH/USDT']
|
||||
self.data_provider = DataProvider()
|
||||
self.test_results = {}
|
||||
|
||||
async def test_cob_ml_integration(self):
|
||||
"""Test full COB integration with ML pipeline"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("TESTING COB INTEGRATION WITH RL AND CNN MODELS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
# Initialize enhanced orchestrator with COB integration
|
||||
logger.info("1. Initializing Enhanced Trading Orchestrator with COB...")
|
||||
orchestrator = EnhancedTradingOrchestrator(
|
||||
data_provider=self.data_provider,
|
||||
symbols=self.symbols,
|
||||
enhanced_rl_training=True,
|
||||
model_registry={}
|
||||
)
|
||||
|
||||
# Start COB integration
|
||||
logger.info("2. Starting COB Integration...")
|
||||
await orchestrator.start_cob_integration()
|
||||
await asyncio.sleep(5) # Allow startup and data collection
|
||||
|
||||
# Test COB feature generation
|
||||
logger.info("3. Testing COB feature generation...")
|
||||
await self._test_cob_features(orchestrator)
|
||||
|
||||
# Test market state with COB data
|
||||
logger.info("4. Testing market state with COB data...")
|
||||
await self._test_market_state_cob(orchestrator)
|
||||
|
||||
# Test real-time COB callbacks
|
||||
logger.info("5. Testing real-time COB callbacks...")
|
||||
await self._test_realtime_callbacks(orchestrator)
|
||||
|
||||
# Stop COB integration
|
||||
await orchestrator.stop_cob_integration()
|
||||
|
||||
# Print results
|
||||
self._print_test_results()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in COB ML integration test: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _test_cob_features(self, orchestrator):
|
||||
"""Test COB feature availability"""
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
# Check if COB features are available
|
||||
cob_features = orchestrator.latest_cob_features.get(symbol)
|
||||
cob_state = orchestrator.latest_cob_state.get(symbol)
|
||||
|
||||
if cob_features is not None:
|
||||
logger.info(f"✅ {symbol}: COB CNN features available - shape: {cob_features.shape}")
|
||||
self.test_results[f'{symbol}_cob_cnn_features'] = True
|
||||
else:
|
||||
logger.warning(f"⚠️ {symbol}: COB CNN features not available")
|
||||
self.test_results[f'{symbol}_cob_cnn_features'] = False
|
||||
|
||||
if cob_state is not None:
|
||||
logger.info(f"✅ {symbol}: COB DQN state available - shape: {cob_state.shape}")
|
||||
self.test_results[f'{symbol}_cob_dqn_state'] = True
|
||||
else:
|
||||
logger.warning(f"⚠️ {symbol}: COB DQN state not available")
|
||||
self.test_results[f'{symbol}_cob_dqn_state'] = False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing COB features: {e}")
|
||||
|
||||
async def _test_market_state_cob(self, orchestrator):
|
||||
"""Test market state includes COB data"""
|
||||
try:
|
||||
# Generate market states with COB data
|
||||
from core.universal_data_adapter import UniversalDataAdapter
|
||||
adapter = UniversalDataAdapter(self.data_provider)
|
||||
universal_stream = await adapter.get_universal_stream(['BTC/USDT', 'ETH/USDT'])
|
||||
|
||||
market_states = await orchestrator._get_all_market_states_universal(universal_stream)
|
||||
|
||||
for symbol in self.symbols:
|
||||
if symbol in market_states:
|
||||
state = market_states[symbol]
|
||||
|
||||
# Check COB integration in market state
|
||||
tests = [
|
||||
('cob_features', state.cob_features is not None),
|
||||
('cob_state', state.cob_state is not None),
|
||||
('order_book_imbalance', hasattr(state, 'order_book_imbalance')),
|
||||
('liquidity_depth', hasattr(state, 'liquidity_depth')),
|
||||
('exchange_diversity', hasattr(state, 'exchange_diversity')),
|
||||
('market_impact_estimate', hasattr(state, 'market_impact_estimate'))
|
||||
]
|
||||
|
||||
for test_name, passed in tests:
|
||||
status = "✅" if passed else "❌"
|
||||
logger.info(f"{status} {symbol}: {test_name} - {passed}")
|
||||
self.test_results[f'{symbol}_market_state_{test_name}'] = passed
|
||||
|
||||
# Log COB metrics if available
|
||||
if hasattr(state, 'order_book_imbalance'):
|
||||
logger.info(f"📊 {symbol} COB Metrics:")
|
||||
logger.info(f" Order Book Imbalance: {state.order_book_imbalance:.4f}")
|
||||
logger.info(f" Liquidity Depth: ${state.liquidity_depth:,.0f}")
|
||||
logger.info(f" Exchange Diversity: {state.exchange_diversity}")
|
||||
logger.info(f" Market Impact (10k): {state.market_impact_estimate:.4f}%")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing market state COB: {e}")
|
||||
|
||||
async def _test_realtime_callbacks(self, orchestrator):
|
||||
"""Test real-time COB callbacks"""
|
||||
try:
|
||||
# Monitor COB callbacks for 10 seconds
|
||||
initial_features = {s: len(orchestrator.cob_feature_history[s]) for s in self.symbols}
|
||||
|
||||
logger.info("Monitoring COB callbacks for 10 seconds...")
|
||||
await asyncio.sleep(10)
|
||||
|
||||
final_features = {s: len(orchestrator.cob_feature_history[s]) for s in self.symbols}
|
||||
|
||||
for symbol in self.symbols:
|
||||
updates = final_features[symbol] - initial_features[symbol]
|
||||
if updates > 0:
|
||||
logger.info(f"✅ {symbol}: Received {updates} COB feature updates")
|
||||
self.test_results[f'{symbol}_realtime_callbacks'] = True
|
||||
else:
|
||||
logger.warning(f"⚠️ {symbol}: No COB feature updates received")
|
||||
self.test_results[f'{symbol}_realtime_callbacks'] = False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing realtime callbacks: {e}")
|
||||
|
||||
def _print_test_results(self):
|
||||
"""Print comprehensive test results"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("COB ML INTEGRATION TEST RESULTS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
passed = sum(1 for result in self.test_results.values() if result)
|
||||
total = len(self.test_results)
|
||||
|
||||
logger.info(f"Overall: {passed}/{total} tests passed ({passed/total*100:.1f}%)")
|
||||
logger.info("")
|
||||
|
||||
for test_name, result in self.test_results.items():
|
||||
status = "✅ PASS" if result else "❌ FAIL"
|
||||
logger.info(f"{status}: {test_name}")
|
||||
|
||||
logger.info("=" * 60)
|
||||
|
||||
if passed == total:
|
||||
logger.info("🎉 ALL TESTS PASSED - COB ML INTEGRATION WORKING!")
|
||||
elif passed > total * 0.8:
|
||||
logger.info("⚠️ MOSTLY WORKING - Some minor issues detected")
|
||||
else:
|
||||
logger.warning("🚨 INTEGRATION ISSUES - Significant problems detected")
|
||||
|
||||
async def main():
|
||||
"""Run COB ML integration tests"""
|
||||
tester = COBMLIntegrationTester()
|
||||
await tester.test_cob_ml_integration()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@ -1,83 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for enhanced trading dashboard with WebSocket support
|
||||
"""
|
||||
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_dashboard():
|
||||
"""Test the enhanced dashboard functionality"""
|
||||
try:
|
||||
print("="*60)
|
||||
print("TESTING ENHANCED TRADING DASHBOARD")
|
||||
print("="*60)
|
||||
|
||||
# Import dashboard
|
||||
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
|
||||
WEBSOCKET_AVAILABLE = True
|
||||
|
||||
print(f"✓ Dashboard module imported successfully")
|
||||
print(f"✓ WebSocket support available: {WEBSOCKET_AVAILABLE}")
|
||||
|
||||
# Create dashboard instance
|
||||
dashboard = TradingDashboard()
|
||||
|
||||
print(f"✓ Dashboard instance created")
|
||||
print(f"✓ Tick cache capacity: {dashboard.tick_cache.maxlen} ticks (15 min)")
|
||||
print(f"✓ 1s bars capacity: {dashboard.one_second_bars.maxlen} bars (15 min)")
|
||||
print(f"✓ WebSocket streaming: {dashboard.is_streaming}")
|
||||
print(f"✓ Min confidence threshold: {dashboard.min_confidence_threshold}")
|
||||
print(f"✓ Signal cooldown: {dashboard.signal_cooldown}s")
|
||||
|
||||
# Test tick cache methods
|
||||
tick_cache = dashboard.get_tick_cache_for_training(minutes=5)
|
||||
print(f"✓ Tick cache method works: {len(tick_cache)} ticks")
|
||||
|
||||
# Test 1s bars method
|
||||
bars_df = dashboard.get_one_second_bars(count=100)
|
||||
print(f"✓ 1s bars method works: {len(bars_df)} bars")
|
||||
|
||||
# Test chart creation
|
||||
try:
|
||||
chart = dashboard._create_price_chart("ETH/USDT")
|
||||
print(f"✓ Price chart creation works")
|
||||
except Exception as e:
|
||||
print(f"⚠ Price chart creation: {e}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("ENHANCED DASHBOARD FEATURES:")
|
||||
print("="*60)
|
||||
print("✓ Real-time WebSocket tick streaming (when websocket-client installed)")
|
||||
print("✓ 1-second bar charts with volume")
|
||||
print("✓ 15-minute tick cache for model training")
|
||||
print("✓ Confidence-based signal execution")
|
||||
print("✓ Clear signal vs execution distinction")
|
||||
print("✓ Real-time unrealized P&L display")
|
||||
print("✓ Compact layout with system status icon")
|
||||
print("✓ Scalping-optimized signal generation")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("TO START THE DASHBOARD:")
|
||||
print("="*60)
|
||||
print("1. Install WebSocket support: pip install websocket-client")
|
||||
print("2. Run: python -c \"from web.dashboard import TradingDashboard; TradingDashboard().run()\"")
|
||||
print("3. Open browser: http://127.0.0.1:8050")
|
||||
print("="*60)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing dashboard: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_dashboard()
|
||||
sys.exit(0 if success else 1)
|
@ -1,305 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Enhanced Dashboard Integration with RL Training Pipeline
|
||||
|
||||
This script tests the integration between the dashboard and the enhanced RL training pipeline
|
||||
to verify that:
|
||||
1. Unified data stream is properly initialized
|
||||
2. Dashboard receives training data from the enhanced pipeline
|
||||
3. Data flows correctly between components
|
||||
4. Enhanced RL training receives comprehensive data
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('test_enhanced_dashboard_integration.log'),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import components
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.universal_data_adapter import UniversalDataAdapter, UniversalDataStream
|
||||
from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard
|
||||
|
||||
class EnhancedDashboardIntegrationTest:
|
||||
"""Test enhanced dashboard integration with RL training pipeline"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize test components"""
|
||||
self.config = get_config()
|
||||
self.data_provider = None
|
||||
self.orchestrator = None
|
||||
self.unified_stream = None
|
||||
self.dashboard = None
|
||||
|
||||
# Test results
|
||||
self.test_results = {
|
||||
'data_provider_init': False,
|
||||
'orchestrator_init': False,
|
||||
'unified_stream_init': False,
|
||||
'dashboard_init': False,
|
||||
'data_flow_test': False,
|
||||
'training_integration_test': False,
|
||||
'ui_data_test': False,
|
||||
'stream_stats_test': False
|
||||
}
|
||||
|
||||
logger.info("Enhanced Dashboard Integration Test initialized")
|
||||
|
||||
async def run_tests(self):
|
||||
"""Run all integration tests"""
|
||||
logger.info("Starting enhanced dashboard integration tests...")
|
||||
|
||||
try:
|
||||
# Test 1: Initialize components
|
||||
await self.test_component_initialization()
|
||||
|
||||
# Test 2: Test data flow
|
||||
await self.test_data_flow()
|
||||
|
||||
# Test 3: Test training integration
|
||||
await self.test_training_integration()
|
||||
|
||||
# Test 4: Test UI data flow
|
||||
await self.test_ui_data_flow()
|
||||
|
||||
# Test 5: Test stream statistics
|
||||
await self.test_stream_statistics()
|
||||
|
||||
# Generate test report
|
||||
self.generate_test_report()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Test execution failed: {e}")
|
||||
raise
|
||||
|
||||
async def test_component_initialization(self):
|
||||
"""Test component initialization"""
|
||||
logger.info("Testing component initialization...")
|
||||
|
||||
try:
|
||||
# Initialize data provider
|
||||
self.data_provider = DataProvider(
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
timeframes=['1s', '1m', '1h', '1d']
|
||||
)
|
||||
self.test_results['data_provider_init'] = True
|
||||
logger.info("✓ Data provider initialized")
|
||||
|
||||
# Initialize orchestrator
|
||||
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
||||
self.test_results['orchestrator_init'] = True
|
||||
logger.info("✓ Enhanced orchestrator initialized")
|
||||
|
||||
# Initialize unified stream
|
||||
self.unified_stream = UnifiedDataStream(self.data_provider, self.orchestrator)
|
||||
self.test_results['unified_stream_init'] = True
|
||||
logger.info("✓ Unified data stream initialized")
|
||||
|
||||
# Initialize dashboard
|
||||
self.dashboard = RealTimeScalpingDashboard(
|
||||
data_provider=self.data_provider,
|
||||
orchestrator=self.orchestrator
|
||||
)
|
||||
self.test_results['dashboard_init'] = True
|
||||
logger.info("✓ Dashboard initialized with unified stream integration")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Component initialization failed: {e}")
|
||||
raise
|
||||
|
||||
async def test_data_flow(self):
|
||||
"""Test data flow through unified stream"""
|
||||
logger.info("Testing data flow through unified stream...")
|
||||
|
||||
try:
|
||||
# Start unified streaming
|
||||
await self.unified_stream.start_streaming()
|
||||
|
||||
# Wait for data collection
|
||||
logger.info("Waiting for data collection...")
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# Check if data is flowing
|
||||
stream_stats = self.unified_stream.get_stream_stats()
|
||||
|
||||
if stream_stats['tick_cache_size'] > 0:
|
||||
logger.info(f"✓ Tick data flowing: {stream_stats['tick_cache_size']} ticks")
|
||||
self.test_results['data_flow_test'] = True
|
||||
else:
|
||||
logger.warning("⚠ No tick data detected")
|
||||
|
||||
if stream_stats['one_second_bars_count'] > 0:
|
||||
logger.info(f"✓ 1s bars generated: {stream_stats['one_second_bars_count']} bars")
|
||||
else:
|
||||
logger.warning("⚠ No 1s bars generated")
|
||||
|
||||
logger.info(f"Stream statistics: {stream_stats}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data flow test failed: {e}")
|
||||
raise
|
||||
|
||||
async def test_training_integration(self):
|
||||
"""Test training data integration"""
|
||||
logger.info("Testing training data integration...")
|
||||
|
||||
try:
|
||||
# Get latest training data
|
||||
training_data = self.unified_stream.get_latest_training_data()
|
||||
|
||||
if training_data:
|
||||
logger.info("✓ Training data packet available")
|
||||
logger.info(f" Tick cache: {len(training_data.tick_cache)} ticks")
|
||||
logger.info(f" 1s bars: {len(training_data.one_second_bars)} bars")
|
||||
logger.info(f" Multi-timeframe data: {len(training_data.multi_timeframe_data)} symbols")
|
||||
logger.info(f" CNN features: {'Available' if training_data.cnn_features else 'Not available'}")
|
||||
logger.info(f" CNN predictions: {'Available' if training_data.cnn_predictions else 'Not available'}")
|
||||
logger.info(f" Market state: {'Available' if training_data.market_state else 'Not available'}")
|
||||
logger.info(f" Universal stream: {'Available' if training_data.universal_stream else 'Not available'}")
|
||||
|
||||
# Check if dashboard can access training data
|
||||
if hasattr(self.dashboard, 'latest_training_data') and self.dashboard.latest_training_data:
|
||||
logger.info("✓ Dashboard has access to training data")
|
||||
self.test_results['training_integration_test'] = True
|
||||
else:
|
||||
logger.warning("⚠ Dashboard does not have training data access")
|
||||
else:
|
||||
logger.warning("⚠ No training data available")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training integration test failed: {e}")
|
||||
raise
|
||||
|
||||
async def test_ui_data_flow(self):
|
||||
"""Test UI data flow"""
|
||||
logger.info("Testing UI data flow...")
|
||||
|
||||
try:
|
||||
# Get latest UI data
|
||||
ui_data = self.unified_stream.get_latest_ui_data()
|
||||
|
||||
if ui_data:
|
||||
logger.info("✓ UI data packet available")
|
||||
logger.info(f" Current prices: {ui_data.current_prices}")
|
||||
logger.info(f" Tick cache size: {ui_data.tick_cache_size}")
|
||||
logger.info(f" 1s bars count: {ui_data.one_second_bars_count}")
|
||||
logger.info(f" Streaming status: {ui_data.streaming_status}")
|
||||
logger.info(f" Training data available: {ui_data.training_data_available}")
|
||||
|
||||
# Check if dashboard can access UI data
|
||||
if hasattr(self.dashboard, 'latest_ui_data') and self.dashboard.latest_ui_data:
|
||||
logger.info("✓ Dashboard has access to UI data")
|
||||
self.test_results['ui_data_test'] = True
|
||||
else:
|
||||
logger.warning("⚠ Dashboard does not have UI data access")
|
||||
else:
|
||||
logger.warning("⚠ No UI data available")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"UI data flow test failed: {e}")
|
||||
raise
|
||||
|
||||
async def test_stream_statistics(self):
|
||||
"""Test stream statistics"""
|
||||
logger.info("Testing stream statistics...")
|
||||
|
||||
try:
|
||||
# Get comprehensive stream stats
|
||||
stream_stats = self.unified_stream.get_stream_stats()
|
||||
|
||||
logger.info("Stream Statistics:")
|
||||
logger.info(f" Total ticks processed: {stream_stats.get('total_ticks_processed', 0)}")
|
||||
logger.info(f" Total packets sent: {stream_stats.get('total_packets_sent', 0)}")
|
||||
logger.info(f" Consumers served: {stream_stats.get('consumers_served', 0)}")
|
||||
logger.info(f" Active consumers: {stream_stats.get('active_consumers', 0)}")
|
||||
logger.info(f" Total consumers: {stream_stats.get('total_consumers', 0)}")
|
||||
logger.info(f" Processing errors: {stream_stats.get('processing_errors', 0)}")
|
||||
logger.info(f" Data quality score: {stream_stats.get('data_quality_score', 0.0)}")
|
||||
|
||||
if stream_stats.get('active_consumers', 0) > 0:
|
||||
logger.info("✓ Stream has active consumers")
|
||||
self.test_results['stream_stats_test'] = True
|
||||
else:
|
||||
logger.warning("⚠ No active consumers detected")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream statistics test failed: {e}")
|
||||
raise
|
||||
|
||||
def generate_test_report(self):
|
||||
"""Generate comprehensive test report"""
|
||||
logger.info("Generating test report...")
|
||||
|
||||
total_tests = len(self.test_results)
|
||||
passed_tests = sum(self.test_results.values())
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("ENHANCED DASHBOARD INTEGRATION TEST REPORT")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Test Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
logger.info(f"Total Tests: {total_tests}")
|
||||
logger.info(f"Passed Tests: {passed_tests}")
|
||||
logger.info(f"Failed Tests: {total_tests - passed_tests}")
|
||||
logger.info(f"Success Rate: {(passed_tests / total_tests) * 100:.1f}%")
|
||||
logger.info("")
|
||||
|
||||
logger.info("Test Results:")
|
||||
for test_name, result in self.test_results.items():
|
||||
status = "✓ PASS" if result else "✗ FAIL"
|
||||
logger.info(f" {test_name}: {status}")
|
||||
|
||||
logger.info("")
|
||||
|
||||
if passed_tests == total_tests:
|
||||
logger.info("🎉 ALL TESTS PASSED! Enhanced dashboard integration is working correctly.")
|
||||
logger.info("The dashboard now properly integrates with the enhanced RL training pipeline.")
|
||||
else:
|
||||
logger.warning("⚠ Some tests failed. Please review the integration.")
|
||||
|
||||
logger.info("=" * 60)
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup test resources"""
|
||||
logger.info("Cleaning up test resources...")
|
||||
|
||||
try:
|
||||
if self.unified_stream:
|
||||
await self.unified_stream.stop_streaming()
|
||||
|
||||
if self.dashboard:
|
||||
self.dashboard.stop_streaming()
|
||||
|
||||
logger.info("✓ Cleanup completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup failed: {e}")
|
||||
|
||||
async def main():
|
||||
"""Main test execution"""
|
||||
test = EnhancedDashboardIntegrationTest()
|
||||
|
||||
try:
|
||||
await test.run_tests()
|
||||
except Exception as e:
|
||||
logger.error(f"Test execution failed: {e}")
|
||||
finally:
|
||||
await test.cleanup()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@ -1,95 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify enhanced fee tracking with maker/taker fees
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_enhanced_fee_tracking():
|
||||
"""Test enhanced fee tracking with maker/taker fees"""
|
||||
|
||||
logger.info("Testing enhanced fee tracking...")
|
||||
|
||||
# Create dashboard instance
|
||||
data_provider = DataProvider()
|
||||
dashboard = TradingDashboard(data_provider=data_provider)
|
||||
|
||||
# Create test trading decisions with different fee types
|
||||
test_decisions = [
|
||||
{
|
||||
'action': 'BUY',
|
||||
'symbol': 'ETH/USDT',
|
||||
'price': 3500.0,
|
||||
'confidence': 0.8,
|
||||
'timestamp': datetime.now(timezone.utc),
|
||||
'order_type': 'market', # Should use taker fee
|
||||
'filled_as_maker': False
|
||||
},
|
||||
{
|
||||
'action': 'SELL',
|
||||
'symbol': 'ETH/USDT',
|
||||
'price': 3520.0,
|
||||
'confidence': 0.9,
|
||||
'timestamp': datetime.now(timezone.utc),
|
||||
'order_type': 'limit', # Should use maker fee if filled as maker
|
||||
'filled_as_maker': True
|
||||
}
|
||||
]
|
||||
|
||||
# Process the trading decisions
|
||||
for i, decision in enumerate(test_decisions):
|
||||
logger.info(f"Processing decision {i+1}: {decision['action']} @ ${decision['price']}")
|
||||
dashboard._process_trading_decision(decision)
|
||||
|
||||
# Check session trades
|
||||
if dashboard.session_trades:
|
||||
latest_trade = dashboard.session_trades[-1]
|
||||
fee_type = latest_trade.get('fee_type', 'unknown')
|
||||
fee_rate = latest_trade.get('fee_rate', 0)
|
||||
fees = latest_trade.get('fees', 0)
|
||||
|
||||
logger.info(f" Trade recorded: {latest_trade.get('position_action', 'unknown')}")
|
||||
logger.info(f" Fee Type: {fee_type}")
|
||||
logger.info(f" Fee Rate: {fee_rate*100:.3f}%")
|
||||
logger.info(f" Fee Amount: ${fees:.4f}")
|
||||
|
||||
# Check closed trades
|
||||
if dashboard.closed_trades:
|
||||
logger.info(f"\nClosed trades: {len(dashboard.closed_trades)}")
|
||||
for trade in dashboard.closed_trades:
|
||||
logger.info(f" Trade #{trade['trade_id']}: {trade['side']}")
|
||||
logger.info(f" Fee Type: {trade.get('fee_type', 'unknown')}")
|
||||
logger.info(f" Fee Rate: {trade.get('fee_rate', 0)*100:.3f}%")
|
||||
logger.info(f" Total Fees: ${trade.get('fees', 0):.4f}")
|
||||
logger.info(f" Net P&L: ${trade.get('net_pnl', 0):.2f}")
|
||||
|
||||
# Test session performance with fee breakdown
|
||||
logger.info("\nTesting session performance display...")
|
||||
performance = dashboard._create_session_performance()
|
||||
logger.info(f"Session performance components: {len(performance)}")
|
||||
|
||||
# Test closed trades table
|
||||
logger.info("\nTesting enhanced trades table...")
|
||||
table_components = dashboard._create_closed_trades_table()
|
||||
logger.info(f"Table components: {len(table_components)}")
|
||||
|
||||
logger.info("Enhanced fee tracking test completed!")
|
||||
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_enhanced_fee_tracking()
|
@ -1,243 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Enhanced Trading System Improvements
|
||||
|
||||
This script tests:
|
||||
1. Color-coded position display ([LONG] green, [SHORT] red)
|
||||
2. Enhanced model training detection and retrospective learning
|
||||
3. Lower confidence thresholds for closing positions (0.25 vs 0.6 for opening)
|
||||
4. Perfect opportunity detection and learning
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator, TradingAction
|
||||
from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard, TradingSession
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_color_coded_positions():
|
||||
"""Test color-coded position display functionality"""
|
||||
logger.info("=== Testing Color-Coded Position Display ===")
|
||||
|
||||
# Create trading session
|
||||
session = TradingSession()
|
||||
|
||||
# Simulate some positions
|
||||
session.positions = {
|
||||
'ETH/USDT': {
|
||||
'side': 'LONG',
|
||||
'size': 0.1,
|
||||
'entry_price': 2558.15
|
||||
},
|
||||
'BTC/USDT': {
|
||||
'side': 'SHORT',
|
||||
'size': 0.05,
|
||||
'entry_price': 45123.45
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("Created test positions:")
|
||||
logger.info(f"ETH/USDT: LONG 0.1 @ $2558.15")
|
||||
logger.info(f"BTC/USDT: SHORT 0.05 @ $45123.45")
|
||||
|
||||
# Test position display logic (simulating dashboard logic)
|
||||
live_prices = {'ETH/USDT': 2565.30, 'BTC/USDT': 45050.20}
|
||||
|
||||
for symbol, pos in session.positions.items():
|
||||
side = pos['side']
|
||||
size = pos['size']
|
||||
entry_price = pos['entry_price']
|
||||
current_price = live_prices.get(symbol, entry_price)
|
||||
|
||||
# Calculate unrealized P&L
|
||||
if side == 'LONG':
|
||||
unrealized_pnl = (current_price - entry_price) * size
|
||||
color_class = "text-success" # Green for LONG
|
||||
side_display = "[LONG]"
|
||||
else: # SHORT
|
||||
unrealized_pnl = (entry_price - current_price) * size
|
||||
color_class = "text-danger" # Red for SHORT
|
||||
side_display = "[SHORT]"
|
||||
|
||||
position_text = f"{side_display} {size:.3f} @ ${entry_price:.2f} | P&L: ${unrealized_pnl:+.2f}"
|
||||
logger.info(f"Position Display: {position_text} (Color: {color_class})")
|
||||
|
||||
logger.info("✅ Color-coded position display test completed")
|
||||
|
||||
def test_confidence_thresholds():
|
||||
"""Test different confidence thresholds for opening vs closing"""
|
||||
logger.info("=== Testing Confidence Thresholds ===")
|
||||
|
||||
# Create orchestrator
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
|
||||
logger.info(f"Opening threshold: {orchestrator.confidence_threshold_open}")
|
||||
logger.info(f"Closing threshold: {orchestrator.confidence_threshold_close}")
|
||||
|
||||
# Test opening action with medium confidence
|
||||
test_confidence = 0.45
|
||||
logger.info(f"\nTesting opening action with confidence {test_confidence}")
|
||||
|
||||
if test_confidence >= orchestrator.confidence_threshold_open:
|
||||
logger.info("✅ Would OPEN position (confidence above opening threshold)")
|
||||
else:
|
||||
logger.info("❌ Would NOT open position (confidence below opening threshold)")
|
||||
|
||||
# Test closing action with same confidence
|
||||
logger.info(f"Testing closing action with confidence {test_confidence}")
|
||||
|
||||
if test_confidence >= orchestrator.confidence_threshold_close:
|
||||
logger.info("✅ Would CLOSE position (confidence above closing threshold)")
|
||||
else:
|
||||
logger.info("❌ Would NOT close position (confidence below closing threshold)")
|
||||
|
||||
logger.info("✅ Confidence threshold test completed")
|
||||
|
||||
def test_retrospective_learning():
|
||||
"""Test retrospective learning and perfect opportunity detection"""
|
||||
logger.info("=== Testing Retrospective Learning ===")
|
||||
|
||||
# Create orchestrator
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
|
||||
# Simulate perfect moves
|
||||
from core.enhanced_orchestrator import PerfectMove
|
||||
|
||||
perfect_move = PerfectMove(
|
||||
symbol='ETH/USDT',
|
||||
timeframe='1m',
|
||||
timestamp=datetime.now(),
|
||||
optimal_action='BUY',
|
||||
actual_outcome=0.025, # 2.5% price increase
|
||||
market_state_before=None,
|
||||
market_state_after=None,
|
||||
confidence_should_have_been=0.85
|
||||
)
|
||||
|
||||
orchestrator.perfect_moves.append(perfect_move)
|
||||
orchestrator.retrospective_learning_active = True
|
||||
|
||||
logger.info(f"Added perfect move: {perfect_move.optimal_action} {perfect_move.symbol}")
|
||||
logger.info(f"Outcome: {perfect_move.actual_outcome*100:+.2f}%")
|
||||
logger.info(f"Confidence should have been: {perfect_move.confidence_should_have_been:.3f}")
|
||||
|
||||
# Test performance metrics
|
||||
metrics = orchestrator.get_performance_metrics()
|
||||
retro_metrics = metrics['retrospective_learning']
|
||||
|
||||
logger.info(f"Retrospective learning active: {retro_metrics['active']}")
|
||||
logger.info(f"Recent perfect moves: {retro_metrics['perfect_moves_recent']}")
|
||||
logger.info(f"Average confidence needed: {retro_metrics['avg_confidence_needed']:.3f}")
|
||||
|
||||
logger.info("✅ Retrospective learning test completed")
|
||||
|
||||
async def test_tick_pattern_detection():
|
||||
"""Test tick pattern detection for violent moves"""
|
||||
logger.info("=== Testing Tick Pattern Detection ===")
|
||||
|
||||
# Create orchestrator
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
|
||||
# Simulate violent tick
|
||||
from core.tick_aggregator import RawTick
|
||||
|
||||
violent_tick = RawTick(
|
||||
timestamp=datetime.now(),
|
||||
price=2560.0,
|
||||
volume=1000.0,
|
||||
quantity=0.5,
|
||||
side='buy',
|
||||
trade_id='test123',
|
||||
time_since_last=25.0, # Very fast tick (25ms)
|
||||
price_change=5.0, # $5 price jump
|
||||
volume_intensity=3.5 # High volume
|
||||
)
|
||||
|
||||
# Add symbol attribute for testing
|
||||
violent_tick.symbol = 'ETH/USDT'
|
||||
|
||||
logger.info(f"Simulating violent tick:")
|
||||
logger.info(f"Price change: ${violent_tick.price_change:+.2f}")
|
||||
logger.info(f"Time since last: {violent_tick.time_since_last:.0f}ms")
|
||||
logger.info(f"Volume intensity: {violent_tick.volume_intensity:.1f}x")
|
||||
|
||||
# Process the tick
|
||||
orchestrator._handle_raw_tick(violent_tick)
|
||||
|
||||
# Check if perfect move was created
|
||||
if orchestrator.perfect_moves:
|
||||
latest_move = orchestrator.perfect_moves[-1]
|
||||
logger.info(f"✅ Perfect move detected: {latest_move.optimal_action}")
|
||||
logger.info(f"Confidence: {latest_move.confidence_should_have_been:.3f}")
|
||||
else:
|
||||
logger.info("❌ No perfect move detected")
|
||||
|
||||
logger.info("✅ Tick pattern detection test completed")
|
||||
|
||||
def test_dashboard_integration():
|
||||
"""Test dashboard integration with new features"""
|
||||
logger.info("=== Testing Dashboard Integration ===")
|
||||
|
||||
# Create components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
|
||||
# Test model training status
|
||||
metrics = orchestrator.get_performance_metrics()
|
||||
|
||||
logger.info("Model Training Metrics:")
|
||||
logger.info(f"Perfect moves: {metrics['perfect_moves']}")
|
||||
logger.info(f"RL queue size: {metrics['rl_queue_size']}")
|
||||
logger.info(f"Retrospective learning: {metrics['retrospective_learning']}")
|
||||
logger.info(f"Position tracking: {metrics['position_tracking']}")
|
||||
logger.info(f"Thresholds: {metrics['thresholds']}")
|
||||
|
||||
logger.info("✅ Dashboard integration test completed")
|
||||
|
||||
async def main():
|
||||
"""Run all tests"""
|
||||
logger.info("🚀 Starting Enhanced Trading System Tests")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
# Run tests
|
||||
test_color_coded_positions()
|
||||
print()
|
||||
|
||||
test_confidence_thresholds()
|
||||
print()
|
||||
|
||||
test_retrospective_learning()
|
||||
print()
|
||||
|
||||
await test_tick_pattern_detection()
|
||||
print()
|
||||
|
||||
test_dashboard_integration()
|
||||
print()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("🎉 All tests completed successfully!")
|
||||
logger.info("Key improvements verified:")
|
||||
logger.info("✅ Color-coded positions ([LONG] green, [SHORT] red)")
|
||||
logger.info("✅ Lower closing thresholds (0.25 vs 0.6)")
|
||||
logger.info("✅ Retrospective learning on perfect opportunities")
|
||||
logger.info("✅ Enhanced model training detection")
|
||||
logger.info("✅ Violent move pattern detection")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Test failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@ -1,133 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Enhanced Orchestrator - Bypass COB Integration Issues
|
||||
|
||||
Simple test to verify enhanced orchestrator methods work
|
||||
and the dashboard can use them for comprehensive RL training.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
def test_enhanced_orchestrator_bypass_cob():
|
||||
"""Test enhanced orchestrator without COB integration"""
|
||||
print("=" * 60)
|
||||
print("TESTING ENHANCED ORCHESTRATOR (BYPASS COB INTEGRATION)")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Import required modules
|
||||
from core.data_provider import DataProvider
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
print("✓ Basic imports successful")
|
||||
|
||||
# Create basic orchestrator first
|
||||
dp = DataProvider()
|
||||
basic_orch = TradingOrchestrator(dp)
|
||||
print("✓ Basic TradingOrchestrator created")
|
||||
|
||||
# Test basic orchestrator methods
|
||||
basic_methods = ['build_comprehensive_rl_state', 'calculate_enhanced_pivot_reward']
|
||||
print("\nBasic TradingOrchestrator methods:")
|
||||
for method in basic_methods:
|
||||
has_method = hasattr(basic_orch, method)
|
||||
print(f" {method}: {'✓' if has_method else '✗'}")
|
||||
|
||||
# Now test by manually adding the missing methods to basic orchestrator
|
||||
print("\n" + "-" * 50)
|
||||
print("ADDING MISSING METHODS TO BASIC ORCHESTRATOR")
|
||||
print("-" * 50)
|
||||
|
||||
# Add the missing methods manually
|
||||
def build_comprehensive_rl_state_fallback(self, symbol: str) -> list:
|
||||
"""Fallback comprehensive RL state builder"""
|
||||
try:
|
||||
# Create a comprehensive state with ~13,400 features
|
||||
comprehensive_features = []
|
||||
|
||||
# ETH Tick Features (3000)
|
||||
comprehensive_features.extend([0.0] * 3000)
|
||||
|
||||
# ETH Multi-timeframe OHLCV (8000)
|
||||
comprehensive_features.extend([0.0] * 8000)
|
||||
|
||||
# BTC Reference Data (1000)
|
||||
comprehensive_features.extend([0.0] * 1000)
|
||||
|
||||
# CNN Hidden Features (1000)
|
||||
comprehensive_features.extend([0.0] * 1000)
|
||||
|
||||
# Pivot Analysis (300)
|
||||
comprehensive_features.extend([0.0] * 300)
|
||||
|
||||
# Market Microstructure (100)
|
||||
comprehensive_features.extend([0.0] * 100)
|
||||
|
||||
print(f"✓ Built comprehensive RL state: {len(comprehensive_features)} features")
|
||||
return comprehensive_features
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error building comprehensive RL state: {e}")
|
||||
return None
|
||||
|
||||
def calculate_enhanced_pivot_reward_fallback(self, trade_decision, market_data, trade_outcome) -> float:
|
||||
"""Fallback enhanced pivot reward calculation"""
|
||||
try:
|
||||
# Calculate enhanced reward based on trade metrics
|
||||
base_pnl = trade_outcome.get('net_pnl', 0)
|
||||
base_reward = base_pnl / 100.0 # Normalize
|
||||
|
||||
# Add pivot analysis bonus
|
||||
pivot_bonus = 0.1 if base_pnl > 0 else -0.05
|
||||
|
||||
enhanced_reward = base_reward + pivot_bonus
|
||||
print(f"✓ Enhanced pivot reward calculated: {enhanced_reward:.4f}")
|
||||
return enhanced_reward
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error calculating enhanced pivot reward: {e}")
|
||||
return 0.0
|
||||
|
||||
# Bind methods to the orchestrator instance
|
||||
import types
|
||||
basic_orch.build_comprehensive_rl_state = types.MethodType(build_comprehensive_rl_state_fallback, basic_orch)
|
||||
basic_orch.calculate_enhanced_pivot_reward = types.MethodType(calculate_enhanced_pivot_reward_fallback, basic_orch)
|
||||
|
||||
print("\n✓ Enhanced methods added to basic orchestrator")
|
||||
|
||||
# Test the enhanced methods
|
||||
print("\nTesting enhanced methods:")
|
||||
|
||||
# Test comprehensive RL state building
|
||||
state = basic_orch.build_comprehensive_rl_state('ETH/USDT')
|
||||
print(f" Comprehensive RL state: {'✓' if state and len(state) > 10000 else '✗'} ({len(state) if state else 0} features)")
|
||||
|
||||
# Test enhanced reward calculation
|
||||
mock_trade = {'net_pnl': 50.0}
|
||||
reward = basic_orch.calculate_enhanced_pivot_reward({}, {}, mock_trade)
|
||||
print(f" Enhanced pivot reward: {'✓' if reward != 0 else '✗'} (reward: {reward})")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ ENHANCED ORCHESTRATOR METHODS WORKING")
|
||||
print("✅ COMPREHENSIVE RL STATE: 13,400+ FEATURES")
|
||||
print("✅ ENHANCED PIVOT REWARDS: FUNCTIONAL")
|
||||
print("✅ DASHBOARD CAN NOW USE ENHANCED FEATURES")
|
||||
print("=" * 60)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_enhanced_orchestrator_bypass_cob()
|
||||
if success:
|
||||
print("\n🎉 PIPELINE FIXES VERIFIED - READY FOR REAL-TIME TRAINING!")
|
||||
else:
|
||||
print("\n💥 PIPELINE FIXES NEED MORE WORK")
|
@ -1,318 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Enhanced Order Flow Integration
|
||||
|
||||
Tests the enhanced order flow analysis capabilities including:
|
||||
- Aggressive vs passive participant ratios
|
||||
- Institutional vs retail trade detection
|
||||
- Market maker vs taker flow analysis
|
||||
- Order flow intensity measurements
|
||||
- Liquidity consumption and price impact analysis
|
||||
- Block trade and iceberg order detection
|
||||
- High-frequency trading activity detection
|
||||
|
||||
Usage:
|
||||
python test_enhanced_order_flow_integration.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from core.bookmap_integration import BookmapIntegration
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler('enhanced_order_flow_test.log')
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EnhancedOrderFlowTester:
|
||||
"""Test enhanced order flow analysis features"""
|
||||
|
||||
def __init__(self):
|
||||
self.bookmap = None
|
||||
self.symbols = ['ETHUSDT', 'BTCUSDT']
|
||||
self.test_duration = 300 # 5 minutes
|
||||
self.metrics_history = []
|
||||
|
||||
async def setup_integration(self):
|
||||
"""Initialize the Bookmap integration"""
|
||||
try:
|
||||
logger.info("Setting up Enhanced Order Flow Integration...")
|
||||
self.bookmap = BookmapIntegration(symbols=self.symbols)
|
||||
|
||||
# Add callbacks for testing
|
||||
self.bookmap.add_cnn_callback(self._cnn_callback)
|
||||
self.bookmap.add_dqn_callback(self._dqn_callback)
|
||||
|
||||
logger.info(f"Integration setup complete for symbols: {self.symbols}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup integration: {e}")
|
||||
return False
|
||||
|
||||
def _cnn_callback(self, symbol: str, features: dict):
|
||||
"""CNN callback for testing"""
|
||||
logger.debug(f"CNN features received for {symbol}: {len(features.get('features', []))} dimensions")
|
||||
|
||||
def _dqn_callback(self, symbol: str, state: dict):
|
||||
"""DQN callback for testing"""
|
||||
logger.debug(f"DQN state received for {symbol}: {len(state.get('state', []))} dimensions")
|
||||
|
||||
async def start_streaming(self):
|
||||
"""Start real-time data streaming"""
|
||||
try:
|
||||
logger.info("Starting enhanced order flow streaming...")
|
||||
await self.bookmap.start_streaming()
|
||||
logger.info("Streaming started successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start streaming: {e}")
|
||||
return False
|
||||
|
||||
async def monitor_order_flow(self):
|
||||
"""Monitor and analyze order flow for test duration"""
|
||||
logger.info(f"Monitoring enhanced order flow for {self.test_duration} seconds...")
|
||||
|
||||
start_time = time.time()
|
||||
iteration = 0
|
||||
|
||||
while time.time() - start_time < self.test_duration:
|
||||
try:
|
||||
iteration += 1
|
||||
|
||||
# Test each symbol
|
||||
for symbol in self.symbols:
|
||||
await self._analyze_symbol_flow(symbol, iteration)
|
||||
|
||||
# Wait 10 seconds between analyses
|
||||
await asyncio.sleep(10)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during monitoring iteration {iteration}: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
logger.info("Order flow monitoring completed")
|
||||
|
||||
async def _analyze_symbol_flow(self, symbol: str, iteration: int):
|
||||
"""Analyze order flow for a specific symbol"""
|
||||
try:
|
||||
# Get enhanced order flow metrics
|
||||
flow_metrics = self.bookmap.get_enhanced_order_flow_metrics(symbol)
|
||||
if not flow_metrics:
|
||||
logger.warning(f"No flow metrics available for {symbol}")
|
||||
return
|
||||
|
||||
# Log key metrics
|
||||
aggressive_passive = flow_metrics['aggressive_passive']
|
||||
institutional_retail = flow_metrics['institutional_retail']
|
||||
flow_intensity = flow_metrics['flow_intensity']
|
||||
price_impact = flow_metrics['price_impact']
|
||||
maker_taker = flow_metrics['maker_taker_flow']
|
||||
|
||||
logger.info(f"\n=== {symbol} Order Flow Analysis (Iteration {iteration}) ===")
|
||||
logger.info(f"Aggressive Ratio: {aggressive_passive['aggressive_ratio']:.2%}")
|
||||
logger.info(f"Passive Ratio: {aggressive_passive['passive_ratio']:.2%}")
|
||||
logger.info(f"Institutional Ratio: {institutional_retail['institutional_ratio']:.2%}")
|
||||
logger.info(f"Retail Ratio: {institutional_retail['retail_ratio']:.2%}")
|
||||
logger.info(f"Flow Intensity: {flow_intensity['current_intensity']:.2f} ({flow_intensity['intensity_category']})")
|
||||
logger.info(f"Price Impact: {price_impact['avg_impact']:.2f} bps ({price_impact['impact_category']})")
|
||||
logger.info(f"Buy Pressure: {maker_taker['buy_pressure']:.2%}")
|
||||
logger.info(f"Sell Pressure: {maker_taker['sell_pressure']:.2%}")
|
||||
|
||||
# Trade size analysis
|
||||
size_dist = flow_metrics['size_distribution']
|
||||
total_trades = sum(size_dist.values())
|
||||
if total_trades > 0:
|
||||
logger.info(f"Trade Size Distribution (last 100 trades):")
|
||||
logger.info(f" Micro (<$1K): {size_dist.get('micro', 0)} ({size_dist.get('micro', 0)/total_trades:.1%})")
|
||||
logger.info(f" Small ($1K-$10K): {size_dist.get('small', 0)} ({size_dist.get('small', 0)/total_trades:.1%})")
|
||||
logger.info(f" Medium ($10K-$50K): {size_dist.get('medium', 0)} ({size_dist.get('medium', 0)/total_trades:.1%})")
|
||||
logger.info(f" Large ($50K-$100K): {size_dist.get('large', 0)} ({size_dist.get('large', 0)/total_trades:.1%})")
|
||||
logger.info(f" Block (>$100K): {size_dist.get('block', 0)} ({size_dist.get('block', 0)/total_trades:.1%})")
|
||||
|
||||
# Volume analysis
|
||||
if 'volume_stats' in flow_metrics and flow_metrics['volume_stats']:
|
||||
volume_stats = flow_metrics['volume_stats']
|
||||
logger.info(f"24h Volume: {volume_stats.get('volume_24h', 0):,.0f}")
|
||||
logger.info(f"24h Quote Volume: ${volume_stats.get('quote_volume_24h', 0):,.0f}")
|
||||
|
||||
# Store metrics for analysis
|
||||
self.metrics_history.append({
|
||||
'timestamp': datetime.now(),
|
||||
'symbol': symbol,
|
||||
'iteration': iteration,
|
||||
'metrics': flow_metrics
|
||||
})
|
||||
|
||||
# Test CNN and DQN features
|
||||
await self._test_model_features(symbol)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing flow for {symbol}: {e}")
|
||||
|
||||
async def _test_model_features(self, symbol: str):
|
||||
"""Test CNN and DQN feature extraction"""
|
||||
try:
|
||||
# Test CNN features
|
||||
cnn_features = self.bookmap.get_cnn_features(symbol)
|
||||
if cnn_features is not None:
|
||||
logger.info(f"CNN Features: {len(cnn_features)} dimensions")
|
||||
logger.info(f" Order book features: {cnn_features[:80].mean():.4f} (avg)")
|
||||
logger.info(f" Liquidity metrics: {cnn_features[80:90].mean():.4f} (avg)")
|
||||
logger.info(f" Imbalance features: {cnn_features[90:95].mean():.4f} (avg)")
|
||||
logger.info(f" Enhanced flow features: {cnn_features[95:].mean():.4f} (avg)")
|
||||
|
||||
# Test DQN features
|
||||
dqn_features = self.bookmap.get_dqn_state_features(symbol)
|
||||
if dqn_features is not None:
|
||||
logger.info(f"DQN State: {len(dqn_features)} dimensions")
|
||||
logger.info(f" Order book state: {dqn_features[:20].mean():.4f} (avg)")
|
||||
logger.info(f" Market indicators: {dqn_features[20:30].mean():.4f} (avg)")
|
||||
logger.info(f" Enhanced flow state: {dqn_features[30:].mean():.4f} (avg)")
|
||||
|
||||
# Test dashboard data
|
||||
dashboard_data = self.bookmap.get_dashboard_data(symbol)
|
||||
if dashboard_data and 'enhanced_order_flow' in dashboard_data:
|
||||
logger.info("Dashboard data includes enhanced order flow metrics")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing model features for {symbol}: {e}")
|
||||
|
||||
async def stop_streaming(self):
|
||||
"""Stop data streaming"""
|
||||
try:
|
||||
logger.info("Stopping order flow streaming...")
|
||||
await self.bookmap.stop_streaming()
|
||||
logger.info("Streaming stopped")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping streaming: {e}")
|
||||
|
||||
def generate_summary_report(self):
|
||||
"""Generate a summary report of the test"""
|
||||
try:
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("ENHANCED ORDER FLOW ANALYSIS SUMMARY")
|
||||
logger.info("="*60)
|
||||
|
||||
if not self.metrics_history:
|
||||
logger.warning("No metrics data collected during test")
|
||||
return
|
||||
|
||||
# Group by symbol
|
||||
symbol_data = {}
|
||||
for entry in self.metrics_history:
|
||||
symbol = entry['symbol']
|
||||
if symbol not in symbol_data:
|
||||
symbol_data[symbol] = []
|
||||
symbol_data[symbol].append(entry)
|
||||
|
||||
# Analyze each symbol
|
||||
for symbol, data in symbol_data.items():
|
||||
logger.info(f"\n--- {symbol} Analysis ---")
|
||||
logger.info(f"Data points collected: {len(data)}")
|
||||
|
||||
if len(data) > 0:
|
||||
# Calculate averages
|
||||
avg_aggressive = sum(d['metrics']['aggressive_passive']['aggressive_ratio'] for d in data) / len(data)
|
||||
avg_institutional = sum(d['metrics']['institutional_retail']['institutional_ratio'] for d in data) / len(data)
|
||||
avg_intensity = sum(d['metrics']['flow_intensity']['current_intensity'] for d in data) / len(data)
|
||||
avg_impact = sum(d['metrics']['price_impact']['avg_impact'] for d in data) / len(data)
|
||||
|
||||
logger.info(f"Average Aggressive Ratio: {avg_aggressive:.2%}")
|
||||
logger.info(f"Average Institutional Ratio: {avg_institutional:.2%}")
|
||||
logger.info(f"Average Flow Intensity: {avg_intensity:.2f}")
|
||||
logger.info(f"Average Price Impact: {avg_impact:.2f} bps")
|
||||
|
||||
# Detect trends
|
||||
first_half = data[:len(data)//2] if len(data) > 1 else data
|
||||
second_half = data[len(data)//2:] if len(data) > 1 else data
|
||||
|
||||
if len(first_half) > 0 and len(second_half) > 0:
|
||||
first_aggressive = sum(d['metrics']['aggressive_passive']['aggressive_ratio'] for d in first_half) / len(first_half)
|
||||
second_aggressive = sum(d['metrics']['aggressive_passive']['aggressive_ratio'] for d in second_half) / len(second_half)
|
||||
|
||||
trend = "increasing" if second_aggressive > first_aggressive else "decreasing"
|
||||
logger.info(f"Aggressive trading trend: {trend}")
|
||||
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("Test completed successfully!")
|
||||
logger.info("Enhanced order flow analysis is working correctly.")
|
||||
logger.info("="*60)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating summary report: {e}")
|
||||
|
||||
async def run_enhanced_order_flow_test():
|
||||
"""Run the complete enhanced order flow test"""
|
||||
tester = EnhancedOrderFlowTester()
|
||||
|
||||
try:
|
||||
# Setup
|
||||
logger.info("Starting Enhanced Order Flow Integration Test")
|
||||
logger.info("This test will demonstrate:")
|
||||
logger.info("- Aggressive vs Passive participant analysis")
|
||||
logger.info("- Institutional vs Retail trade detection")
|
||||
logger.info("- Order flow intensity measurements")
|
||||
logger.info("- Price impact and liquidity consumption analysis")
|
||||
logger.info("- Block trade and iceberg order detection")
|
||||
logger.info("- Enhanced CNN and DQN feature extraction")
|
||||
|
||||
if not await tester.setup_integration():
|
||||
logger.error("Failed to setup integration")
|
||||
return False
|
||||
|
||||
# Start streaming
|
||||
if not await tester.start_streaming():
|
||||
logger.error("Failed to start streaming")
|
||||
return False
|
||||
|
||||
# Wait for initial data
|
||||
logger.info("Waiting 30 seconds for initial data...")
|
||||
await asyncio.sleep(30)
|
||||
|
||||
# Monitor order flow
|
||||
await tester.monitor_order_flow()
|
||||
|
||||
# Generate report
|
||||
tester.generate_summary_report()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Test failed: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
try:
|
||||
await tester.stop_streaming()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# Run the test
|
||||
success = asyncio.run(run_enhanced_order_flow_test())
|
||||
|
||||
if success:
|
||||
print("\n✅ Enhanced Order Flow Integration Test PASSED")
|
||||
print("All enhanced order flow analysis features are working correctly!")
|
||||
else:
|
||||
print("\n❌ Enhanced Order Flow Integration Test FAILED")
|
||||
print("Check the logs for details.")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Test interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n💥 Test crashed: {e}")
|
@ -1,320 +0,0 @@
|
||||
"""
|
||||
Test Enhanced Pivot-Based RL System
|
||||
|
||||
Tests the new system with:
|
||||
- Different thresholds for entry vs exit
|
||||
- Pivot-based rewards
|
||||
- CNN predictions for early pivot detection
|
||||
- Uninvested rewards
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
||||
stream=sys.stdout
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add project root to Python path
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from training.enhanced_pivot_rl_trainer import EnhancedPivotRLTrainer, create_enhanced_pivot_trainer
|
||||
|
||||
def test_enhanced_pivot_thresholds():
|
||||
"""Test the enhanced pivot-based threshold system"""
|
||||
logger.info("=== Testing Enhanced Pivot-Based Thresholds ===")
|
||||
|
||||
try:
|
||||
# Create components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(
|
||||
data_provider=data_provider,
|
||||
enhanced_rl_training=True
|
||||
)
|
||||
|
||||
# Test threshold initialization
|
||||
thresholds = orchestrator.pivot_rl_trainer.get_current_thresholds()
|
||||
logger.info(f"Initial thresholds:")
|
||||
logger.info(f" Entry: {thresholds['entry_threshold']:.3f}")
|
||||
logger.info(f" Exit: {thresholds['exit_threshold']:.3f}")
|
||||
logger.info(f" Uninvested: {thresholds['uninvested_threshold']:.3f}")
|
||||
|
||||
# Verify entry threshold is higher than exit threshold
|
||||
assert thresholds['entry_threshold'] > thresholds['exit_threshold'], "Entry threshold should be higher than exit"
|
||||
logger.info("✅ Entry threshold correctly higher than exit threshold")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing thresholds: {e}")
|
||||
return False
|
||||
|
||||
def test_pivot_reward_calculation():
|
||||
"""Test the pivot-based reward calculation"""
|
||||
logger.info("=== Testing Pivot-Based Reward Calculation ===")
|
||||
|
||||
try:
|
||||
# Create enhanced pivot trainer
|
||||
data_provider = DataProvider()
|
||||
pivot_trainer = create_enhanced_pivot_trainer(data_provider)
|
||||
|
||||
# Create mock trade decision and outcome
|
||||
trade_decision = {
|
||||
'action': 'BUY',
|
||||
'confidence': 0.75,
|
||||
'price': 2500.0,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
trade_outcome = {
|
||||
'net_pnl': 15.50, # Profitable trade
|
||||
'exit_price': 2515.0,
|
||||
'duration': timedelta(minutes=45)
|
||||
}
|
||||
|
||||
# Create mock market data
|
||||
market_data = pd.DataFrame({
|
||||
'open': np.random.normal(2500, 10, 100),
|
||||
'high': np.random.normal(2510, 10, 100),
|
||||
'low': np.random.normal(2490, 10, 100),
|
||||
'close': np.random.normal(2500, 10, 100),
|
||||
'volume': np.random.normal(1000, 100, 100)
|
||||
})
|
||||
market_data.index = pd.date_range(start=datetime.now() - timedelta(hours=2), periods=100, freq='1min')
|
||||
|
||||
# Calculate reward
|
||||
reward = pivot_trainer.calculate_pivot_based_reward(
|
||||
trade_decision, market_data, trade_outcome
|
||||
)
|
||||
|
||||
logger.info(f"Calculated pivot-based reward: {reward:.3f}")
|
||||
|
||||
# Test should return a reasonable reward for profitable trade
|
||||
assert -15.0 <= reward <= 10.0, f"Reward {reward} outside expected range"
|
||||
logger.info("✅ Pivot-based reward calculation working")
|
||||
|
||||
# Test uninvested reward
|
||||
low_conf_decision = {
|
||||
'action': 'HOLD',
|
||||
'confidence': 0.35, # Below uninvested threshold
|
||||
'price': 2500.0,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
uninvested_reward = pivot_trainer._calculate_uninvested_rewards(low_conf_decision, 0.35)
|
||||
logger.info(f"Uninvested reward for low confidence: {uninvested_reward:.3f}")
|
||||
|
||||
assert uninvested_reward > 0, "Should get positive reward for staying uninvested with low confidence"
|
||||
logger.info("✅ Uninvested rewards working correctly")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing pivot rewards: {e}")
|
||||
return False
|
||||
|
||||
def test_confidence_adjustment():
|
||||
"""Test confidence-based reward adjustments"""
|
||||
logger.info("=== Testing Confidence-Based Adjustments ===")
|
||||
|
||||
try:
|
||||
pivot_trainer = create_enhanced_pivot_trainer()
|
||||
|
||||
# Test overconfidence penalty on loss
|
||||
high_conf_loss = {
|
||||
'action': 'BUY',
|
||||
'confidence': 0.85, # High confidence
|
||||
'price': 2500.0,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
loss_outcome = {
|
||||
'net_pnl': -25.0, # Loss
|
||||
'exit_price': 2475.0,
|
||||
'duration': timedelta(hours=3)
|
||||
}
|
||||
|
||||
confidence_adjustment = pivot_trainer._calculate_confidence_adjustment(
|
||||
high_conf_loss, loss_outcome
|
||||
)
|
||||
|
||||
logger.info(f"Confidence adjustment for overconfident loss: {confidence_adjustment:.3f}")
|
||||
assert confidence_adjustment < 0, "Should penalize overconfidence on losses"
|
||||
|
||||
# Test underconfidence penalty on win
|
||||
low_conf_win = {
|
||||
'action': 'BUY',
|
||||
'confidence': 0.35, # Low confidence
|
||||
'price': 2500.0,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
win_outcome = {
|
||||
'net_pnl': 20.0, # Profit
|
||||
'exit_price': 2520.0,
|
||||
'duration': timedelta(minutes=30)
|
||||
}
|
||||
|
||||
confidence_adjustment_2 = pivot_trainer._calculate_confidence_adjustment(
|
||||
low_conf_win, win_outcome
|
||||
)
|
||||
|
||||
logger.info(f"Confidence adjustment for underconfident win: {confidence_adjustment_2:.3f}")
|
||||
# Should be small penalty or zero
|
||||
|
||||
logger.info("✅ Confidence adjustments working correctly")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing confidence adjustments: {e}")
|
||||
return False
|
||||
|
||||
def test_dynamic_threshold_updates():
|
||||
"""Test dynamic threshold updating based on performance"""
|
||||
logger.info("=== Testing Dynamic Threshold Updates ===")
|
||||
|
||||
try:
|
||||
pivot_trainer = create_enhanced_pivot_trainer()
|
||||
|
||||
# Get initial thresholds
|
||||
initial_thresholds = pivot_trainer.get_current_thresholds()
|
||||
logger.info(f"Initial thresholds: {initial_thresholds}")
|
||||
|
||||
# Simulate some poor performance (low win rate)
|
||||
for i in range(25):
|
||||
outcome = {
|
||||
'timestamp': datetime.now(),
|
||||
'action': 'BUY',
|
||||
'confidence': 0.6,
|
||||
'net_pnl': -5.0 if i < 20 else 10.0, # 20% win rate
|
||||
'reward': -1.0 if i < 20 else 2.0,
|
||||
'duration': timedelta(hours=2)
|
||||
}
|
||||
pivot_trainer.trade_outcomes.append(outcome)
|
||||
|
||||
# Update thresholds
|
||||
pivot_trainer.update_thresholds_based_on_performance()
|
||||
|
||||
# Get updated thresholds
|
||||
updated_thresholds = pivot_trainer.get_current_thresholds()
|
||||
logger.info(f"Updated thresholds after poor performance: {updated_thresholds}")
|
||||
|
||||
# Entry threshold should increase (more selective) after poor performance
|
||||
assert updated_thresholds['entry_threshold'] >= initial_thresholds['entry_threshold'], \
|
||||
"Entry threshold should increase after poor performance"
|
||||
|
||||
logger.info("✅ Dynamic threshold updates working correctly")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing dynamic thresholds: {e}")
|
||||
return False
|
||||
|
||||
def test_cnn_integration():
|
||||
"""Test CNN integration for pivot predictions"""
|
||||
logger.info("=== Testing CNN Integration ===")
|
||||
|
||||
try:
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(
|
||||
data_provider=data_provider,
|
||||
enhanced_rl_training=True
|
||||
)
|
||||
|
||||
# Check if Williams structure is initialized with CNN
|
||||
williams = orchestrator.pivot_rl_trainer.williams
|
||||
logger.info(f"Williams CNN enabled: {williams.enable_cnn_feature}")
|
||||
logger.info(f"Williams CNN model available: {williams.cnn_model is not None}")
|
||||
|
||||
# Test CNN threshold adjustment
|
||||
from core.enhanced_orchestrator import MarketState
|
||||
from datetime import datetime
|
||||
|
||||
mock_market_state = MarketState(
|
||||
symbol='ETH/USDT',
|
||||
timestamp=datetime.now(),
|
||||
prices={'1s': 2500.0},
|
||||
features={'1s': np.array([])},
|
||||
volatility=0.02,
|
||||
volume=1000.0,
|
||||
trend_strength=0.5,
|
||||
market_regime='normal',
|
||||
universal_data=None
|
||||
)
|
||||
|
||||
cnn_adjustment = orchestrator._get_cnn_threshold_adjustment(
|
||||
'ETH/USDT', 'BUY', mock_market_state
|
||||
)
|
||||
|
||||
logger.info(f"CNN threshold adjustment: {cnn_adjustment:.3f}")
|
||||
assert 0.0 <= cnn_adjustment <= 0.1, "CNN adjustment should be reasonable"
|
||||
|
||||
logger.info("✅ CNN integration working correctly")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing CNN integration: {e}")
|
||||
return False
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all enhanced pivot RL system tests"""
|
||||
logger.info("🚀 Starting Enhanced Pivot RL System Tests")
|
||||
|
||||
tests = [
|
||||
test_enhanced_pivot_thresholds,
|
||||
test_pivot_reward_calculation,
|
||||
test_confidence_adjustment,
|
||||
test_dynamic_threshold_updates,
|
||||
test_cnn_integration
|
||||
]
|
||||
|
||||
passed = 0
|
||||
total = len(tests)
|
||||
|
||||
for test_func in tests:
|
||||
try:
|
||||
if test_func():
|
||||
passed += 1
|
||||
logger.info(f"✅ {test_func.__name__} PASSED")
|
||||
else:
|
||||
logger.error(f"❌ {test_func.__name__} FAILED")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ {test_func.__name__} ERROR: {e}")
|
||||
|
||||
logger.info(f"\n📊 Test Results: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
logger.info("🎉 All Enhanced Pivot RL System tests PASSED!")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"⚠️ {total - passed} tests FAILED")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
|
||||
if success:
|
||||
logger.info("\n🔥 Enhanced Pivot RL System is ready for deployment!")
|
||||
logger.info("Key improvements:")
|
||||
logger.info(" ✅ Higher entry threshold than exit threshold")
|
||||
logger.info(" ✅ Pivot-based reward calculation")
|
||||
logger.info(" ✅ CNN predictions for early pivot detection")
|
||||
logger.info(" ✅ Rewards for staying uninvested when uncertain")
|
||||
logger.info(" ✅ Confidence-based reward adjustments")
|
||||
logger.info(" ✅ Dynamic threshold learning from performance")
|
||||
else:
|
||||
logger.error("\n❌ Enhanced Pivot RL System has issues that need fixing")
|
||||
|
||||
sys.exit(0 if success else 1)
|
@ -1,83 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Enhanced RL Fix - Verify comprehensive state building and reward calculation
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
def test_enhanced_orchestrator():
|
||||
"""Test enhanced orchestrator methods"""
|
||||
print("=== TESTING ENHANCED RL FIXES ===")
|
||||
|
||||
try:
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
print("✓ Enhanced orchestrator imported successfully")
|
||||
|
||||
# Create orchestrator with enhanced RL enabled
|
||||
dp = DataProvider()
|
||||
eo = EnhancedTradingOrchestrator(
|
||||
data_provider=dp,
|
||||
enhanced_rl_training=True,
|
||||
symbols=['ETH/USDT', 'BTC/USDT']
|
||||
)
|
||||
print("✓ Enhanced orchestrator created")
|
||||
|
||||
# Test method availability
|
||||
methods = ['build_comprehensive_rl_state', 'calculate_enhanced_pivot_reward', '_get_symbol_correlation']
|
||||
print("\nMethod availability:")
|
||||
for method in methods:
|
||||
available = hasattr(eo, method)
|
||||
print(f" {method}: {'✓' if available else '✗'}")
|
||||
|
||||
# Test comprehensive state building
|
||||
print("\nTesting comprehensive state building...")
|
||||
state = eo.build_comprehensive_rl_state('ETH/USDT')
|
||||
if state is not None:
|
||||
print(f"✓ Comprehensive state built: {len(state)} features")
|
||||
print(f" State type: {type(state)}")
|
||||
print(f" State shape: {state.shape if hasattr(state, 'shape') else 'No shape'}")
|
||||
else:
|
||||
print("✗ Comprehensive state returned None")
|
||||
|
||||
# Debug why state is None
|
||||
print("\nDEBUGGING STATE BUILDING...")
|
||||
print(f" Williams enabled: {hasattr(eo, 'williams_enabled')}")
|
||||
print(f" COB integration active: {hasattr(eo, 'cob_integration_active')}")
|
||||
print(f" Enhanced RL training: {getattr(eo, 'enhanced_rl_training', 'Not set')}")
|
||||
|
||||
# Test enhanced reward calculation
|
||||
print("\nTesting enhanced reward calculation...")
|
||||
trade_decision = {
|
||||
'action': 'BUY',
|
||||
'confidence': 0.75,
|
||||
'price': 2500.0,
|
||||
'timestamp': '2023-01-01 00:00:00'
|
||||
}
|
||||
trade_outcome = {
|
||||
'net_pnl': 50.0,
|
||||
'exit_price': 2550.0,
|
||||
'duration': '00:15:00'
|
||||
}
|
||||
market_data = {'symbol': 'ETH/USDT'}
|
||||
|
||||
try:
|
||||
reward = eo.calculate_enhanced_pivot_reward(trade_decision, market_data, trade_outcome)
|
||||
print(f"✓ Enhanced reward calculated: {reward}")
|
||||
except Exception as e:
|
||||
print(f"✗ Enhanced reward failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n=== TEST COMPLETE ===")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_enhanced_orchestrator()
|
@ -1,346 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Enhanced Williams Market Structure with CNN Integration
|
||||
|
||||
This script demonstrates the multi-timeframe, multi-symbol CNN-enabled
|
||||
Williams Market Structure that predicts pivot points using TrainingDataPacket.
|
||||
|
||||
Features tested:
|
||||
- Multi-timeframe data integration (1s, 1m, 1h)
|
||||
- Multi-symbol support (ETH, BTC)
|
||||
- Tick data aggregation
|
||||
- 1h-based normalization strategy
|
||||
- Multi-level pivot prediction (5 levels, type + price)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Mock TrainingDataPacket for testing
|
||||
@dataclass
|
||||
class MockTrainingDataPacket:
|
||||
"""Mock TrainingDataPacket for testing CNN integration"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
tick_cache: List[Dict[str, Any]]
|
||||
one_second_bars: List[Dict[str, Any]]
|
||||
multi_timeframe_data: Dict[str, List[Dict[str, Any]]]
|
||||
cnn_features: Optional[Dict[str, np.ndarray]] = None
|
||||
cnn_predictions: Optional[Dict[str, np.ndarray]] = None
|
||||
market_state: Optional[Any] = None
|
||||
universal_stream: Optional[Any] = None
|
||||
|
||||
class MockTrainingDataProvider:
|
||||
"""Mock provider that supplies TrainingDataPacket for testing"""
|
||||
|
||||
def __init__(self):
|
||||
self.training_data_buffer = []
|
||||
self._generate_mock_data()
|
||||
|
||||
def _generate_mock_data(self):
|
||||
"""Generate comprehensive mock market data"""
|
||||
current_time = datetime.now()
|
||||
|
||||
# Generate ETH data for different timeframes
|
||||
eth_1s_data = self._generate_ohlcv_data(2400.0, 900, '1s', current_time) # 15 min of 1s data
|
||||
eth_1m_data = self._generate_ohlcv_data(2400.0, 900, '1m', current_time) # 15 hours of 1m data
|
||||
eth_1h_data = self._generate_ohlcv_data(2400.0, 24, '1h', current_time) # 24 hours of 1h data
|
||||
|
||||
# Generate BTC data
|
||||
btc_1s_data = self._generate_ohlcv_data(45000.0, 300, '1s', current_time) # 5 min of 1s data
|
||||
|
||||
# Generate tick data
|
||||
tick_data = self._generate_tick_data(current_time)
|
||||
|
||||
# Create comprehensive TrainingDataPacket
|
||||
training_packet = MockTrainingDataPacket(
|
||||
timestamp=current_time,
|
||||
symbol='ETH/USDT',
|
||||
tick_cache=tick_data,
|
||||
one_second_bars=eth_1s_data[-300:], # Last 5 minutes
|
||||
multi_timeframe_data={
|
||||
'ETH/USDT': {
|
||||
'1s': eth_1s_data,
|
||||
'1m': eth_1m_data,
|
||||
'1h': eth_1h_data
|
||||
},
|
||||
'BTC/USDT': {
|
||||
'1s': btc_1s_data
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
self.training_data_buffer.append(training_packet)
|
||||
logger.info(f"Generated mock training data: {len(eth_1s_data)} 1s bars, {len(eth_1m_data)} 1m bars, {len(eth_1h_data)} 1h bars")
|
||||
logger.info(f"ETH 1h price range: {min(bar['low'] for bar in eth_1h_data):.2f} - {max(bar['high'] for bar in eth_1h_data):.2f}")
|
||||
|
||||
def _generate_ohlcv_data(self, base_price: float, count: int, timeframe: str, end_time: datetime) -> List[Dict[str, Any]]:
|
||||
"""Generate realistic OHLCV data with indicators"""
|
||||
data = []
|
||||
|
||||
# Calculate time delta based on timeframe
|
||||
if timeframe == '1s':
|
||||
delta = timedelta(seconds=1)
|
||||
elif timeframe == '1m':
|
||||
delta = timedelta(minutes=1)
|
||||
elif timeframe == '1h':
|
||||
delta = timedelta(hours=1)
|
||||
else:
|
||||
delta = timedelta(minutes=1)
|
||||
|
||||
current_price = base_price
|
||||
|
||||
for i in range(count):
|
||||
timestamp = end_time - delta * (count - i - 1)
|
||||
|
||||
# Generate realistic price movement
|
||||
price_change = np.random.normal(0, base_price * 0.001) # 0.1% volatility
|
||||
current_price = max(current_price + price_change, base_price * 0.8) # Floor at 80% of base
|
||||
|
||||
# Generate OHLCV
|
||||
open_price = current_price
|
||||
high_price = open_price * (1 + abs(np.random.normal(0, 0.002)))
|
||||
low_price = open_price * (1 - abs(np.random.normal(0, 0.002)))
|
||||
close_price = low_price + (high_price - low_price) * np.random.random()
|
||||
volume = np.random.exponential(1000)
|
||||
|
||||
current_price = close_price
|
||||
|
||||
# Calculate basic indicators (placeholders)
|
||||
sma_20 = close_price * (1 + np.random.normal(0, 0.001))
|
||||
ema_20 = close_price * (1 + np.random.normal(0, 0.0005))
|
||||
rsi_14 = 30 + np.random.random() * 40 # RSI between 30-70
|
||||
macd = np.random.normal(0, 0.1)
|
||||
bb_upper = high_price * 1.02
|
||||
|
||||
bar = {
|
||||
'timestamp': timestamp,
|
||||
'open': open_price,
|
||||
'high': high_price,
|
||||
'low': low_price,
|
||||
'close': close_price,
|
||||
'volume': volume,
|
||||
'sma_20': sma_20,
|
||||
'ema_20': ema_20,
|
||||
'rsi_14': rsi_14,
|
||||
'macd': macd,
|
||||
'bb_upper': bb_upper
|
||||
}
|
||||
data.append(bar)
|
||||
|
||||
return data
|
||||
|
||||
def _generate_tick_data(self, end_time: datetime) -> List[Dict[str, Any]]:
|
||||
"""Generate realistic tick data for last 5 minutes"""
|
||||
ticks = []
|
||||
|
||||
# Generate ETH ticks
|
||||
for i in range(300): # 5 minutes * 60 seconds
|
||||
tick_time = end_time - timedelta(seconds=300 - i)
|
||||
|
||||
# 2-5 ticks per second
|
||||
ticks_per_second = np.random.randint(2, 6)
|
||||
|
||||
for j in range(ticks_per_second):
|
||||
tick = {
|
||||
'symbol': 'ETH/USDT',
|
||||
'timestamp': tick_time + timedelta(milliseconds=j * 200),
|
||||
'price': 2400.0 + np.random.normal(0, 5),
|
||||
'volume': np.random.exponential(0.5),
|
||||
'quantity': np.random.exponential(1.0),
|
||||
'side': 'buy' if np.random.random() > 0.5 else 'sell'
|
||||
}
|
||||
ticks.append(tick)
|
||||
|
||||
# Generate BTC ticks
|
||||
for i in range(300): # 5 minutes * 60 seconds
|
||||
tick_time = end_time - timedelta(seconds=300 - i)
|
||||
|
||||
ticks_per_second = np.random.randint(1, 4)
|
||||
|
||||
for j in range(ticks_per_second):
|
||||
tick = {
|
||||
'symbol': 'BTC/USDT',
|
||||
'timestamp': tick_time + timedelta(milliseconds=j * 300),
|
||||
'price': 45000.0 + np.random.normal(0, 100),
|
||||
'volume': np.random.exponential(0.1),
|
||||
'quantity': np.random.exponential(0.5),
|
||||
'side': 'buy' if np.random.random() > 0.5 else 'sell'
|
||||
}
|
||||
ticks.append(tick)
|
||||
|
||||
return ticks
|
||||
|
||||
def get_latest_training_data(self):
|
||||
"""Return the latest TrainingDataPacket"""
|
||||
return self.training_data_buffer[-1] if self.training_data_buffer else None
|
||||
|
||||
|
||||
def test_enhanced_williams_cnn():
|
||||
"""Test the enhanced Williams Market Structure with CNN integration"""
|
||||
try:
|
||||
from training.williams_market_structure import WilliamsMarketStructure, SwingType
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("TESTING ENHANCED WILLIAMS MARKET STRUCTURE WITH CNN INTEGRATION")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Create mock data provider
|
||||
data_provider = MockTrainingDataProvider()
|
||||
|
||||
# Initialize Williams Market Structure with CNN
|
||||
williams = WilliamsMarketStructure(
|
||||
swing_strengths=[2, 3, 5], # Reduced for testing
|
||||
cnn_input_shape=(900, 50), # 900 timesteps, 50 features
|
||||
cnn_output_size=10, # 5 levels * 2 outputs (type + price)
|
||||
enable_cnn_feature=True, # Enable CNN features
|
||||
training_data_provider=data_provider
|
||||
)
|
||||
|
||||
logger.info(f"CNN enabled: {williams.enable_cnn_feature}")
|
||||
logger.info(f"Training data provider available: {williams.training_data_provider is not None}")
|
||||
|
||||
# Generate test OHLCV data for Williams calculation
|
||||
test_ohlcv = generate_test_ohlcv_data()
|
||||
logger.info(f"Generated test OHLCV data: {len(test_ohlcv)} bars")
|
||||
|
||||
# Test Williams calculation with CNN integration
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("RUNNING WILLIAMS PIVOT CALCULATION WITH CNN INTEGRATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
structure_levels = williams.calculate_recursive_pivot_points(test_ohlcv)
|
||||
|
||||
# Display results
|
||||
logger.info(f"\nWilliams Structure Analysis Results:")
|
||||
logger.info(f"Calculated levels: {len(structure_levels)}")
|
||||
|
||||
for level_key, level_data in structure_levels.items():
|
||||
swing_count = len(level_data.swing_points)
|
||||
logger.info(f"{level_key}: {swing_count} swing points, "
|
||||
f"trend: {level_data.trend_analysis.direction.value}, "
|
||||
f"bias: {level_data.current_bias.value}")
|
||||
|
||||
if swing_count > 0:
|
||||
latest_swing = level_data.swing_points[-1]
|
||||
logger.info(f" Latest swing: {latest_swing.swing_type.name} @ {latest_swing.price:.2f}")
|
||||
|
||||
# Test CNN input preparation directly
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("TESTING CNN INPUT PREPARATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if williams.enable_cnn_feature and structure_levels['level_0'].swing_points:
|
||||
test_pivot = structure_levels['level_0'].swing_points[-1]
|
||||
|
||||
logger.info(f"Testing CNN input for pivot: {test_pivot.swing_type.name} @ {test_pivot.price:.2f}")
|
||||
|
||||
# Test input preparation
|
||||
cnn_input = williams._prepare_cnn_input(
|
||||
current_pivot=test_pivot,
|
||||
ohlcv_data_context=test_ohlcv,
|
||||
previous_pivot_details=None
|
||||
)
|
||||
|
||||
logger.info(f"CNN input shape: {cnn_input.shape}")
|
||||
logger.info(f"CNN input range: [{cnn_input.min():.4f}, {cnn_input.max():.4f}]")
|
||||
logger.info(f"CNN input mean: {cnn_input.mean():.4f}, std: {cnn_input.std():.4f}")
|
||||
|
||||
# Test ground truth preparation
|
||||
if len(structure_levels['level_0'].swing_points) >= 2:
|
||||
prev_pivot = structure_levels['level_0'].swing_points[-2]
|
||||
current_pivot = structure_levels['level_0'].swing_points[-1]
|
||||
|
||||
prev_details = {'pivot': prev_pivot}
|
||||
ground_truth = williams._get_cnn_ground_truth(prev_details, current_pivot)
|
||||
|
||||
logger.info(f"Ground truth shape: {ground_truth.shape}")
|
||||
logger.info(f"Ground truth (first 4 values): {ground_truth[:4]}")
|
||||
logger.info(f"Level 0 prediction: type={ground_truth[0]:.2f}, price={ground_truth[1]:.4f}")
|
||||
|
||||
# Test normalization
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("TESTING 1H-BASED NORMALIZATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
training_packet = data_provider.get_latest_training_data()
|
||||
if training_packet:
|
||||
# Test normalization with sample data
|
||||
sample_features = np.random.normal(2400, 50, (100, 10)) # ETH-like prices
|
||||
|
||||
normalized = williams._normalize_features_by_1h_range(sample_features, training_packet)
|
||||
|
||||
logger.info(f"Original features range: [{sample_features.min():.2f}, {sample_features.max():.2f}]")
|
||||
logger.info(f"Normalized features range: [{normalized.min():.4f}, {normalized.max():.4f}]")
|
||||
|
||||
# Check if 1h data is being used for normalization
|
||||
eth_1h = training_packet.multi_timeframe_data.get('ETH/USDT', {}).get('1h', [])
|
||||
if eth_1h:
|
||||
h1_prices = []
|
||||
for bar in eth_1h[-24:]:
|
||||
h1_prices.extend([bar['open'], bar['high'], bar['low'], bar['close']])
|
||||
h1_range = max(h1_prices) - min(h1_prices)
|
||||
logger.info(f"1h price range used for normalization: {h1_range:.2f}")
|
||||
|
||||
logger.info("\n" + "=" * 80)
|
||||
logger.info("ENHANCED WILLIAMS CNN INTEGRATION TEST COMPLETED SUCCESSFULLY")
|
||||
logger.info("=" * 80)
|
||||
|
||||
return True
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Import error - some dependencies missing: {e}")
|
||||
logger.info("This is expected if TensorFlow or other dependencies are not installed")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Test failed with error: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def generate_test_ohlcv_data(bars=200, base_price=2400.0):
|
||||
"""Generate test OHLCV data for Williams calculation"""
|
||||
data = []
|
||||
current_price = base_price
|
||||
current_time = datetime.now()
|
||||
|
||||
for i in range(bars):
|
||||
timestamp = current_time - timedelta(seconds=bars - i)
|
||||
|
||||
# Generate price movement
|
||||
price_change = np.random.normal(0, base_price * 0.002)
|
||||
current_price = max(current_price + price_change, base_price * 0.9)
|
||||
|
||||
open_price = current_price
|
||||
high_price = open_price * (1 + abs(np.random.normal(0, 0.003)))
|
||||
low_price = open_price * (1 - abs(np.random.normal(0, 0.003)))
|
||||
close_price = low_price + (high_price - low_price) * np.random.random()
|
||||
volume = np.random.exponential(1000)
|
||||
|
||||
current_price = close_price
|
||||
|
||||
bar = [
|
||||
timestamp.timestamp(),
|
||||
open_price,
|
||||
high_price,
|
||||
low_price,
|
||||
close_price,
|
||||
volume
|
||||
]
|
||||
data.append(bar)
|
||||
|
||||
return np.array(data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_enhanced_williams_cnn()
|
||||
if success:
|
||||
print("\n✅ All tests passed! Enhanced Williams CNN integration is working.")
|
||||
else:
|
||||
print("\n❌ Some tests failed. Check logs for details.")
|
@ -1,115 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Essential Test Suite - Core functionality tests
|
||||
|
||||
This file contains the most important tests to verify core functionality:
|
||||
- Data loading and processing
|
||||
- Basic model operations
|
||||
- Trading signal generation
|
||||
- Critical utility functions
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import unittest
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TestEssentialFunctionality(unittest.TestCase):
|
||||
"""Essential tests for core trading system functionality"""
|
||||
|
||||
def test_imports(self):
|
||||
"""Test that all critical modules can be imported"""
|
||||
try:
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from utils.model_utils import robust_save, robust_load
|
||||
logger.info("✅ All critical imports successful")
|
||||
except ImportError as e:
|
||||
self.fail(f"Critical import failed: {e}")
|
||||
|
||||
def test_config_loading(self):
|
||||
"""Test configuration loading"""
|
||||
try:
|
||||
from core.config import get_config
|
||||
config = get_config()
|
||||
self.assertIsNotNone(config, "Config should be loaded")
|
||||
logger.info("✅ Configuration loading successful")
|
||||
except Exception as e:
|
||||
self.fail(f"Config loading failed: {e}")
|
||||
|
||||
def test_data_provider_initialization(self):
|
||||
"""Test DataProvider can be initialized"""
|
||||
try:
|
||||
from core.data_provider import DataProvider
|
||||
data_provider = DataProvider(['ETH/USDT'], ['1m'])
|
||||
self.assertIsNotNone(data_provider, "DataProvider should initialize")
|
||||
logger.info("✅ DataProvider initialization successful")
|
||||
except Exception as e:
|
||||
self.fail(f"DataProvider initialization failed: {e}")
|
||||
|
||||
def test_model_utils(self):
|
||||
"""Test model utility functions"""
|
||||
try:
|
||||
from utils.model_utils import get_model_info
|
||||
import tempfile
|
||||
|
||||
# Test with non-existent file
|
||||
info = get_model_info("non_existent_file.pt")
|
||||
self.assertFalse(info['exists'], "Should report file doesn't exist")
|
||||
|
||||
logger.info("✅ Model utils test successful")
|
||||
except Exception as e:
|
||||
self.fail(f"Model utils test failed: {e}")
|
||||
|
||||
def test_signal_generation_logic(self):
|
||||
"""Test basic signal generation logic"""
|
||||
import numpy as np
|
||||
|
||||
# Test signal distribution calculation
|
||||
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0]) # SELL, HOLD, BUY
|
||||
|
||||
buy_count = np.sum(predictions == 2)
|
||||
sell_count = np.sum(predictions == 0)
|
||||
hold_count = np.sum(predictions == 1)
|
||||
total = len(predictions)
|
||||
|
||||
distribution = {
|
||||
"BUY": buy_count / total,
|
||||
"SELL": sell_count / total,
|
||||
"HOLD": hold_count / total
|
||||
}
|
||||
|
||||
# Verify calculations
|
||||
self.assertAlmostEqual(distribution["BUY"], 0.3, places=1)
|
||||
self.assertAlmostEqual(distribution["SELL"], 0.3, places=1)
|
||||
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=1)
|
||||
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=1)
|
||||
|
||||
logger.info("✅ Signal generation logic test successful")
|
||||
|
||||
def run_essential_tests():
|
||||
"""Run essential tests only"""
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestEssentialFunctionality)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(suite)
|
||||
return result.wasSuccessful()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger.info("Running essential functionality tests...")
|
||||
|
||||
success = run_essential_tests()
|
||||
|
||||
if success:
|
||||
logger.info("✅ All essential tests passed!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("❌ Essential tests failed!")
|
||||
sys.exit(1)
|
@ -1,508 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Enhanced Extrema Training Test Suite
|
||||
|
||||
Tests the complete extrema training system including:
|
||||
1. 200-candle 1m context data loading
|
||||
2. Local extrema detection (bottoms and tops)
|
||||
3. Training on not-so-perfect opportunities
|
||||
4. Dashboard integration with extrema information
|
||||
5. Reusable functionality across different dashboards
|
||||
|
||||
This test suite verifies all components work together correctly.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any
|
||||
import time
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_extrema_trainer_initialization():
|
||||
"""Test 1: Extrema trainer initialization and basic functionality"""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 1: Extrema Trainer Initialization")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
from core.extrema_trainer import ExtremaTrainer
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Initialize components
|
||||
data_provider = DataProvider()
|
||||
symbols = ['ETHUSDT', 'BTCUSDT']
|
||||
|
||||
# Create extrema trainer
|
||||
extrema_trainer = ExtremaTrainer(
|
||||
data_provider=data_provider,
|
||||
symbols=symbols,
|
||||
window_size=10
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert extrema_trainer.symbols == symbols
|
||||
assert extrema_trainer.window_size == 10
|
||||
assert len(extrema_trainer.detected_extrema) == len(symbols)
|
||||
assert len(extrema_trainer.context_data) == len(symbols)
|
||||
|
||||
print("✅ Extrema trainer initialized successfully")
|
||||
print(f" - Symbols: {symbols}")
|
||||
print(f" - Window size: {extrema_trainer.window_size}")
|
||||
print(f" - Context data containers: {len(extrema_trainer.context_data)}")
|
||||
print(f" - Extrema containers: {len(extrema_trainer.detected_extrema)}")
|
||||
|
||||
return True, extrema_trainer
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Extrema trainer initialization failed: {e}")
|
||||
return False, None
|
||||
|
||||
def test_context_data_loading(extrema_trainer):
|
||||
"""Test 2: 200-candle 1m context data loading"""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 2: 200-Candle 1m Context Data Loading")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Initialize context data
|
||||
start_time = time.time()
|
||||
results = extrema_trainer.initialize_context_data()
|
||||
load_time = time.time() - start_time
|
||||
|
||||
# Verify results
|
||||
successful_loads = sum(1 for success in results.values() if success)
|
||||
total_symbols = len(extrema_trainer.symbols)
|
||||
|
||||
print(f"✅ Context data loading completed in {load_time:.2f} seconds")
|
||||
print(f" - Success rate: {successful_loads}/{total_symbols} symbols")
|
||||
|
||||
# Check context data details
|
||||
for symbol in extrema_trainer.symbols:
|
||||
context = extrema_trainer.context_data[symbol]
|
||||
candles_loaded = len(context.candles)
|
||||
features_available = context.features is not None
|
||||
|
||||
print(f" - {symbol}: {candles_loaded} candles, features: {'✅' if features_available else '❌'}")
|
||||
|
||||
if features_available:
|
||||
print(f" Features shape: {context.features.shape}")
|
||||
|
||||
# Test context feature retrieval
|
||||
for symbol in extrema_trainer.symbols:
|
||||
features = extrema_trainer.get_context_features_for_model(symbol)
|
||||
if features is not None:
|
||||
print(f" - {symbol} model features: {features.shape}")
|
||||
else:
|
||||
print(f" - {symbol} model features: Not available")
|
||||
|
||||
return successful_loads > 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Context data loading failed: {e}")
|
||||
return False
|
||||
|
||||
def test_extrema_detection(extrema_trainer):
|
||||
"""Test 3: Local extrema detection (bottoms and tops)"""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 3: Local Extrema Detection")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Run batch extrema detection
|
||||
start_time = time.time()
|
||||
detection_results = extrema_trainer.run_batch_detection()
|
||||
detection_time = time.time() - start_time
|
||||
|
||||
# Analyze results
|
||||
total_extrema = sum(len(extrema_list) for extrema_list in detection_results.values())
|
||||
|
||||
print(f"✅ Extrema detection completed in {detection_time:.2f} seconds")
|
||||
print(f" - Total extrema detected: {total_extrema}")
|
||||
|
||||
# Detailed breakdown by symbol
|
||||
for symbol, extrema_list in detection_results.items():
|
||||
if extrema_list:
|
||||
bottoms = len([e for e in extrema_list if e.extrema_type == 'bottom'])
|
||||
tops = len([e for e in extrema_list if e.extrema_type == 'top'])
|
||||
avg_confidence = np.mean([e.confidence for e in extrema_list])
|
||||
|
||||
print(f" - {symbol}: {len(extrema_list)} extrema (bottoms: {bottoms}, tops: {tops})")
|
||||
print(f" Average confidence: {avg_confidence:.3f}")
|
||||
|
||||
# Show recent extrema details
|
||||
for extrema in extrema_list[-2:]: # Last 2 extrema
|
||||
print(f" {extrema.extrema_type.upper()} @ ${extrema.price:.2f} "
|
||||
f"(confidence: {extrema.confidence:.3f}, action: {extrema.optimal_action})")
|
||||
|
||||
# Test perfect moves for CNN
|
||||
perfect_moves = extrema_trainer.get_perfect_moves_for_cnn(count=20)
|
||||
print(f" - Perfect moves for CNN training: {len(perfect_moves)}")
|
||||
|
||||
if perfect_moves:
|
||||
for move in perfect_moves[:3]: # Show first 3
|
||||
print(f" {move['optimal_action']} {move['symbol']} @ {move['timestamp'].strftime('%H:%M:%S')} "
|
||||
f"(outcome: {move['actual_outcome']:.3f}, confidence: {move['confidence_should_have_been']:.3f})")
|
||||
|
||||
return total_extrema > 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Extrema detection failed: {e}")
|
||||
return False
|
||||
|
||||
def test_context_data_updates(extrema_trainer):
|
||||
"""Test 4: Context data updates and continuous extrema detection"""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 4: Context Data Updates and Continuous Detection")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Test single symbol update
|
||||
symbol = extrema_trainer.symbols[0]
|
||||
|
||||
print(f"Testing context update for {symbol}...")
|
||||
start_time = time.time()
|
||||
update_results = extrema_trainer.update_context_data(symbol)
|
||||
update_time = time.time() - start_time
|
||||
|
||||
print(f"✅ Context update completed in {update_time:.2f} seconds")
|
||||
print(f" - Update result for {symbol}: {'✅' if update_results.get(symbol, False) else '❌'}")
|
||||
|
||||
# Test all symbols update
|
||||
print("Testing context update for all symbols...")
|
||||
start_time = time.time()
|
||||
all_update_results = extrema_trainer.update_context_data()
|
||||
all_update_time = time.time() - start_time
|
||||
|
||||
successful_updates = sum(1 for success in all_update_results.values() if success)
|
||||
|
||||
print(f"✅ All symbols update completed in {all_update_time:.2f} seconds")
|
||||
print(f" - Success rate: {successful_updates}/{len(extrema_trainer.symbols)} symbols")
|
||||
|
||||
# Check for new extrema after updates
|
||||
new_extrema = extrema_trainer.run_batch_detection()
|
||||
new_total = sum(len(extrema_list) for extrema_list in new_extrema.values())
|
||||
|
||||
print(f" - New extrema detected after update: {new_total}")
|
||||
|
||||
return successful_updates > 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Context data updates failed: {e}")
|
||||
return False
|
||||
|
||||
def test_extrema_stats_and_training_data(extrema_trainer):
|
||||
"""Test 5: Extrema statistics and training data retrieval"""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 5: Extrema Statistics and Training Data")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Get comprehensive stats
|
||||
stats = extrema_trainer.get_extrema_stats()
|
||||
|
||||
print("✅ Extrema statistics retrieved successfully")
|
||||
print(f" - Total extrema detected: {stats.get('total_extrema_detected', 0)}")
|
||||
print(f" - Training queue size: {stats.get('training_queue_size', 0)}")
|
||||
print(f" - Window size: {stats.get('window_size', 0)}")
|
||||
|
||||
# Confidence thresholds
|
||||
thresholds = stats.get('confidence_thresholds', {})
|
||||
print(f" - Confidence thresholds: min={thresholds.get('min', 0):.2f}, max={thresholds.get('max', 0):.2f}")
|
||||
|
||||
# Context data status
|
||||
context_status = stats.get('context_data_status', {})
|
||||
for symbol, status in context_status.items():
|
||||
candles = status.get('candles_loaded', 0)
|
||||
features = status.get('features_available', False)
|
||||
last_update = status.get('last_update', 'Unknown')
|
||||
print(f" - {symbol}: {candles} candles, features: {'✅' if features else '❌'}, updated: {last_update}")
|
||||
|
||||
# Recent extrema breakdown
|
||||
recent_extrema = stats.get('recent_extrema', {})
|
||||
if recent_extrema:
|
||||
print(f" - Recent extrema: {recent_extrema.get('bottoms', 0)} bottoms, {recent_extrema.get('tops', 0)} tops")
|
||||
print(f" - Average confidence: {recent_extrema.get('avg_confidence', 0):.3f}")
|
||||
print(f" - Average outcome: {recent_extrema.get('avg_outcome', 0):.3f}")
|
||||
|
||||
# Test training data retrieval
|
||||
training_data = extrema_trainer.get_extrema_training_data(count=10, min_confidence=0.4)
|
||||
print(f" - Training data (min confidence 0.4): {len(training_data)} cases")
|
||||
|
||||
if training_data:
|
||||
high_confidence_cases = len([case for case in training_data if case.confidence > 0.7])
|
||||
print(f" - High confidence cases (>0.7): {high_confidence_cases}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Extrema statistics retrieval failed: {e}")
|
||||
return False
|
||||
|
||||
def test_enhanced_orchestrator_integration():
|
||||
"""Test 6: Enhanced orchestrator integration with extrema trainer"""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 6: Enhanced Orchestrator Integration")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Initialize orchestrator (should include extrema trainer)
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
|
||||
# Verify extrema trainer integration
|
||||
assert hasattr(orchestrator, 'extrema_trainer')
|
||||
assert orchestrator.extrema_trainer is not None
|
||||
|
||||
print("✅ Enhanced orchestrator initialized with extrema trainer")
|
||||
print(f" - Extrema trainer symbols: {orchestrator.extrema_trainer.symbols}")
|
||||
|
||||
# Test extrema stats retrieval through orchestrator
|
||||
extrema_stats = orchestrator.get_extrema_stats()
|
||||
print(f" - Extrema stats available: {'✅' if extrema_stats else '❌'}")
|
||||
|
||||
if extrema_stats:
|
||||
print(f" - Total extrema: {extrema_stats.get('total_extrema_detected', 0)}")
|
||||
print(f" - Training queue: {extrema_stats.get('training_queue_size', 0)}")
|
||||
|
||||
# Test context features retrieval
|
||||
for symbol in orchestrator.symbols[:2]: # Test first 2 symbols
|
||||
context_features = orchestrator.get_context_features_for_model(symbol)
|
||||
if context_features is not None:
|
||||
print(f" - {symbol} context features: {context_features.shape}")
|
||||
else:
|
||||
print(f" - {symbol} context features: Not available")
|
||||
|
||||
# Test perfect moves for CNN
|
||||
perfect_moves = orchestrator.get_perfect_moves_for_cnn(count=10)
|
||||
print(f" - Perfect moves for CNN: {len(perfect_moves)}")
|
||||
|
||||
return True, orchestrator
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Enhanced orchestrator integration failed: {e}")
|
||||
return False, None
|
||||
|
||||
def test_dashboard_integration(orchestrator):
|
||||
"""Test 7: Dashboard integration with extrema information"""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 7: Dashboard Integration")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard
|
||||
|
||||
# Initialize dashboard with enhanced orchestrator
|
||||
dashboard = RealTimeScalpingDashboard(orchestrator=orchestrator)
|
||||
|
||||
print("✅ Dashboard initialized with enhanced orchestrator")
|
||||
|
||||
# Test sensitivity learning info (should include extrema stats)
|
||||
sensitivity_info = dashboard._get_sensitivity_learning_info()
|
||||
|
||||
print("✅ Sensitivity learning info retrieved")
|
||||
print(f" - Info structure: {list(sensitivity_info.keys())}")
|
||||
|
||||
# Check for extrema information
|
||||
if 'extrema' in sensitivity_info:
|
||||
extrema_info = sensitivity_info['extrema']
|
||||
print(f" - Extrema info available: ✅")
|
||||
print(f" - Total extrema detected: {extrema_info.get('total_extrema_detected', 0)}")
|
||||
print(f" - Training queue size: {extrema_info.get('training_queue_size', 0)}")
|
||||
|
||||
recent_extrema = extrema_info.get('recent_extrema', {})
|
||||
if recent_extrema:
|
||||
print(f" - Recent bottoms: {recent_extrema.get('bottoms', 0)}")
|
||||
print(f" - Recent tops: {recent_extrema.get('tops', 0)}")
|
||||
print(f" - Average confidence: {recent_extrema.get('avg_confidence', 0):.3f}")
|
||||
|
||||
# Check for context data information
|
||||
if 'context_data' in sensitivity_info:
|
||||
context_info = sensitivity_info['context_data']
|
||||
print(f" - Context data info available: ✅")
|
||||
print(f" - Symbols with context: {len(context_info)}")
|
||||
|
||||
for symbol, status in list(context_info.items())[:2]: # Show first 2
|
||||
candles = status.get('candles_loaded', 0)
|
||||
features = status.get('features_available', False)
|
||||
print(f" - {symbol}: {candles} candles, features: {'✅' if features else '❌'}")
|
||||
|
||||
# Test model training status creation
|
||||
try:
|
||||
training_status = dashboard._create_model_training_status()
|
||||
print("✅ Model training status created successfully")
|
||||
print(f" - Status type: {type(training_status)}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Model training status creation had issues: {e}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Dashboard integration failed: {e}")
|
||||
return False
|
||||
|
||||
def test_reusability_across_dashboards():
|
||||
"""Test 8: Reusability of extrema trainer across different dashboards"""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 8: Reusability Across Different Dashboards")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
from core.extrema_trainer import ExtremaTrainer
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Create shared extrema trainer
|
||||
data_provider = DataProvider()
|
||||
shared_extrema_trainer = ExtremaTrainer(
|
||||
data_provider=data_provider,
|
||||
symbols=['ETHUSDT'],
|
||||
window_size=8 # Different window size
|
||||
)
|
||||
|
||||
# Initialize context data
|
||||
shared_extrema_trainer.initialize_context_data()
|
||||
|
||||
print("✅ Shared extrema trainer created")
|
||||
print(f" - Window size: {shared_extrema_trainer.window_size}")
|
||||
print(f" - Symbols: {shared_extrema_trainer.symbols}")
|
||||
|
||||
# Simulate usage by multiple dashboard types
|
||||
dashboard_types = ['scalping', 'swing', 'analysis']
|
||||
|
||||
for dashboard_type in dashboard_types:
|
||||
print(f"\n Testing {dashboard_type} dashboard usage:")
|
||||
|
||||
# Get extrema stats (reusable method)
|
||||
stats = shared_extrema_trainer.get_extrema_stats()
|
||||
print(f" - {dashboard_type}: Extrema stats retrieved ✅")
|
||||
|
||||
# Get context features (reusable method)
|
||||
features = shared_extrema_trainer.get_context_features_for_model('ETHUSDT')
|
||||
if features is not None:
|
||||
print(f" - {dashboard_type}: Context features available ✅ {features.shape}")
|
||||
else:
|
||||
print(f" - {dashboard_type}: Context features not available ❌")
|
||||
|
||||
# Get training data (reusable method)
|
||||
training_data = shared_extrema_trainer.get_extrema_training_data(count=5)
|
||||
print(f" - {dashboard_type}: Training data retrieved ✅ ({len(training_data)} cases)")
|
||||
|
||||
# Get perfect moves (reusable method)
|
||||
perfect_moves = shared_extrema_trainer.get_perfect_moves_for_cnn(count=5)
|
||||
print(f" - {dashboard_type}: Perfect moves retrieved ✅ ({len(perfect_moves)} moves)")
|
||||
|
||||
print("\n✅ Extrema trainer successfully reused across different dashboard types")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Reusability test failed: {e}")
|
||||
return False
|
||||
|
||||
def run_comprehensive_test_suite():
|
||||
"""Run the complete test suite"""
|
||||
print("🚀 ENHANCED EXTREMA TRAINING TEST SUITE")
|
||||
print("="*80)
|
||||
print("Testing 200-candle context data, extrema detection, and dashboard integration")
|
||||
print("="*80)
|
||||
|
||||
test_results = []
|
||||
extrema_trainer = None
|
||||
orchestrator = None
|
||||
|
||||
# Test 1: Extrema trainer initialization
|
||||
success, extrema_trainer = test_extrema_trainer_initialization()
|
||||
test_results.append(("Extrema Trainer Initialization", success))
|
||||
|
||||
if success and extrema_trainer:
|
||||
# Test 2: Context data loading
|
||||
success = test_context_data_loading(extrema_trainer)
|
||||
test_results.append(("200-Candle Context Data Loading", success))
|
||||
|
||||
# Test 3: Extrema detection
|
||||
success = test_extrema_detection(extrema_trainer)
|
||||
test_results.append(("Local Extrema Detection", success))
|
||||
|
||||
# Test 4: Context data updates
|
||||
success = test_context_data_updates(extrema_trainer)
|
||||
test_results.append(("Context Data Updates", success))
|
||||
|
||||
# Test 5: Stats and training data
|
||||
success = test_extrema_stats_and_training_data(extrema_trainer)
|
||||
test_results.append(("Extrema Stats and Training Data", success))
|
||||
|
||||
# Test 6: Enhanced orchestrator integration
|
||||
success, orchestrator = test_enhanced_orchestrator_integration()
|
||||
test_results.append(("Enhanced Orchestrator Integration", success))
|
||||
|
||||
if success and orchestrator:
|
||||
# Test 7: Dashboard integration
|
||||
success = test_dashboard_integration(orchestrator)
|
||||
test_results.append(("Dashboard Integration", success))
|
||||
|
||||
# Test 8: Reusability
|
||||
success = test_reusability_across_dashboards()
|
||||
test_results.append(("Reusability Across Dashboards", success))
|
||||
|
||||
# Print final results
|
||||
print("\n" + "="*80)
|
||||
print("🏁 TEST SUITE RESULTS")
|
||||
print("="*80)
|
||||
|
||||
passed = 0
|
||||
total = len(test_results)
|
||||
|
||||
for test_name, success in test_results:
|
||||
status = "✅ PASSED" if success else "❌ FAILED"
|
||||
print(f"{test_name:<40} {status}")
|
||||
if success:
|
||||
passed += 1
|
||||
|
||||
print("="*80)
|
||||
print(f"OVERALL RESULT: {passed}/{total} tests passed ({passed/total*100:.1f}%)")
|
||||
|
||||
if passed == total:
|
||||
print("🎉 ALL TESTS PASSED! Enhanced extrema training system is working correctly.")
|
||||
elif passed >= total * 0.8:
|
||||
print("✅ MOSTLY SUCCESSFUL! System is functional with minor issues.")
|
||||
else:
|
||||
print("⚠️ SIGNIFICANT ISSUES DETECTED! Please review failed tests.")
|
||||
|
||||
print("="*80)
|
||||
|
||||
return passed, total
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
passed, total = run_comprehensive_test_suite()
|
||||
|
||||
# Exit with appropriate code
|
||||
if passed == total:
|
||||
sys.exit(0) # Success
|
||||
else:
|
||||
sys.exit(1) # Some failures
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ Test suite interrupted by user")
|
||||
sys.exit(2)
|
||||
except Exception as e:
|
||||
print(f"\n\n❌ Test suite crashed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(3)
|
@ -1,210 +0,0 @@
|
||||
"""
|
||||
Test script for automatic fee synchronization with MEXC API
|
||||
|
||||
This script demonstrates how the system can automatically sync trading fees
|
||||
from the MEXC API to the local configuration file.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Add NN directory to path
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), 'NN'))
|
||||
|
||||
from NN.exchanges.mexc_interface import MEXCInterface
|
||||
from core.config_sync import ConfigSynchronizer
|
||||
from core.trading_executor import TradingExecutor
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_mexc_fee_retrieval():
|
||||
"""Test retrieving fees directly from MEXC API"""
|
||||
logger.info("=== Testing MEXC Fee Retrieval ===")
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize MEXC interface
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
api_secret = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
if not api_key or not api_secret:
|
||||
logger.error("MEXC API credentials not found in environment variables")
|
||||
return None
|
||||
|
||||
try:
|
||||
mexc = MEXCInterface(api_key=api_key, api_secret=api_secret, test_mode=False)
|
||||
|
||||
# Test connection
|
||||
if mexc.connect():
|
||||
logger.info("MEXC: Connection successful")
|
||||
else:
|
||||
logger.error("MEXC: Connection failed")
|
||||
return None
|
||||
|
||||
# Get trading fees
|
||||
logger.info("MEXC: Fetching trading fees...")
|
||||
fees = mexc.get_trading_fees()
|
||||
|
||||
if fees:
|
||||
logger.info(f"MEXC Trading Fees Retrieved:")
|
||||
logger.info(f" Maker Rate: {fees.get('maker_rate', 0)*100:.3f}%")
|
||||
logger.info(f" Taker Rate: {fees.get('taker_rate', 0)*100:.3f}%")
|
||||
logger.info(f" Source: {fees.get('source', 'unknown')}")
|
||||
|
||||
if fees.get('source') == 'mexc_api':
|
||||
logger.info(f" Raw Commission Rates:")
|
||||
logger.info(f" Maker: {fees.get('maker_commission', 0)} basis points")
|
||||
logger.info(f" Taker: {fees.get('taker_commission', 0)} basis points")
|
||||
else:
|
||||
logger.warning("Using fallback fee values - API may not be working")
|
||||
else:
|
||||
logger.error("Failed to retrieve trading fees")
|
||||
|
||||
return fees
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing MEXC fee retrieval: {e}")
|
||||
return None
|
||||
|
||||
def test_config_synchronization():
|
||||
"""Test automatic config synchronization"""
|
||||
logger.info("\n=== Testing Config Synchronization ===")
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
try:
|
||||
# Initialize MEXC interface
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
api_secret = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
if not api_key or not api_secret:
|
||||
logger.error("MEXC API credentials not found")
|
||||
return False
|
||||
|
||||
mexc = MEXCInterface(api_key=api_key, api_secret=api_secret, test_mode=False)
|
||||
|
||||
# Initialize config synchronizer
|
||||
config_sync = ConfigSynchronizer(
|
||||
config_path="config.yaml",
|
||||
mexc_interface=mexc
|
||||
)
|
||||
|
||||
# Get current sync status
|
||||
logger.info("Current sync status:")
|
||||
status = config_sync.get_sync_status()
|
||||
for key, value in status.items():
|
||||
if key != 'latest_sync_result':
|
||||
logger.info(f" {key}: {value}")
|
||||
|
||||
# Perform manual sync
|
||||
logger.info("\nPerforming manual fee synchronization...")
|
||||
sync_result = config_sync.sync_trading_fees(force=True)
|
||||
|
||||
logger.info(f"Sync Result:")
|
||||
logger.info(f" Status: {sync_result.get('status')}")
|
||||
logger.info(f" Changes Made: {sync_result.get('changes_made', False)}")
|
||||
|
||||
if sync_result.get('changes'):
|
||||
logger.info(" Fee Changes:")
|
||||
for fee_type, change in sync_result['changes'].items():
|
||||
logger.info(f" {fee_type}: {change['old']:.6f} -> {change['new']:.6f}")
|
||||
|
||||
if sync_result.get('errors'):
|
||||
logger.warning(f" Errors: {sync_result['errors']}")
|
||||
|
||||
# Test auto-sync
|
||||
logger.info("\nTesting auto-sync...")
|
||||
auto_sync_success = config_sync.auto_sync_fees()
|
||||
logger.info(f"Auto-sync result: {'Success' if auto_sync_success else 'Failed/Skipped'}")
|
||||
|
||||
return sync_result.get('status') in ['success', 'no_changes']
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing config synchronization: {e}")
|
||||
return False
|
||||
|
||||
def test_trading_executor_integration():
|
||||
"""Test fee sync integration in TradingExecutor"""
|
||||
logger.info("\n=== Testing TradingExecutor Integration ===")
|
||||
|
||||
try:
|
||||
# Initialize trading executor (this should trigger automatic fee sync)
|
||||
logger.info("Initializing TradingExecutor with auto fee sync...")
|
||||
executor = TradingExecutor("config.yaml")
|
||||
|
||||
# Check if config synchronizer was initialized
|
||||
if hasattr(executor, 'config_synchronizer') and executor.config_synchronizer:
|
||||
logger.info("Config synchronizer successfully initialized")
|
||||
|
||||
# Get sync status
|
||||
sync_status = executor.get_fee_sync_status()
|
||||
logger.info("Fee sync status:")
|
||||
for key, value in sync_status.items():
|
||||
if key not in ['latest_sync_result']:
|
||||
logger.info(f" {key}: {value}")
|
||||
|
||||
# Test manual sync through executor
|
||||
logger.info("\nTesting manual sync through TradingExecutor...")
|
||||
manual_sync = executor.sync_fees_with_api(force=True)
|
||||
logger.info(f"Manual sync result: {manual_sync.get('status')}")
|
||||
|
||||
# Test auto sync
|
||||
logger.info("Testing auto sync...")
|
||||
auto_sync = executor.auto_sync_fees_if_needed()
|
||||
logger.info(f"Auto sync result: {'Success' if auto_sync else 'Skipped/Failed'}")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error("Config synchronizer not initialized in TradingExecutor")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing TradingExecutor integration: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
logger.info("Starting Fee Synchronization Tests")
|
||||
logger.info("=" * 50)
|
||||
|
||||
# Test 1: Direct API fee retrieval
|
||||
fees = test_mexc_fee_retrieval()
|
||||
|
||||
# Test 2: Config synchronization
|
||||
if fees:
|
||||
sync_success = test_config_synchronization()
|
||||
else:
|
||||
logger.warning("Skipping config sync test due to API failure")
|
||||
sync_success = False
|
||||
|
||||
# Test 3: TradingExecutor integration
|
||||
if sync_success:
|
||||
integration_success = test_trading_executor_integration()
|
||||
else:
|
||||
logger.warning("Skipping TradingExecutor test due to sync failure")
|
||||
integration_success = False
|
||||
|
||||
# Summary
|
||||
logger.info("\n" + "=" * 50)
|
||||
logger.info("TEST SUMMARY:")
|
||||
logger.info(f" MEXC API Fee Retrieval: {'PASS' if fees else 'FAIL'}")
|
||||
logger.info(f" Config Synchronization: {'PASS' if sync_success else 'FAIL'}")
|
||||
logger.info(f" TradingExecutor Integration: {'PASS' if integration_success else 'FAIL'}")
|
||||
|
||||
if fees and sync_success and integration_success:
|
||||
logger.info("\nALL TESTS PASSED! Fee synchronization is working correctly.")
|
||||
logger.info("Your system will now automatically sync trading fees from MEXC API.")
|
||||
else:
|
||||
logger.warning("\nSome tests failed. Check the logs above for details.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,108 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Final Test - Verify Enhanced Orchestrator Methods Work
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
def test_final_fixes():
|
||||
"""Test that the enhanced orchestrator methods are working"""
|
||||
print("=" * 60)
|
||||
print("FINAL TEST - ENHANCED RL PIPELINE FIXES")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Import and test basic orchestrator
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
print("✓ Imports successful")
|
||||
|
||||
# Create orchestrator
|
||||
dp = DataProvider()
|
||||
orch = TradingOrchestrator(dp)
|
||||
print("✓ TradingOrchestrator created")
|
||||
|
||||
# Test enhanced methods
|
||||
methods = ['build_comprehensive_rl_state', 'calculate_enhanced_pivot_reward']
|
||||
print("\nTesting enhanced methods:")
|
||||
|
||||
for method in methods:
|
||||
has_method = hasattr(orch, method)
|
||||
print(f" {method}: {'✓' if has_method else '✗'}")
|
||||
|
||||
# Test comprehensive RL state building
|
||||
print("\nTesting comprehensive RL state building:")
|
||||
state = orch.build_comprehensive_rl_state('ETH/USDT')
|
||||
if state and len(state) >= 13000:
|
||||
print(f"✅ Comprehensive RL state: {len(state)} features (AUDIT FIXED)")
|
||||
else:
|
||||
print(f"❌ Comprehensive RL state: {len(state) if state else 0} features")
|
||||
|
||||
# Test enhanced reward calculation
|
||||
print("\nTesting enhanced pivot reward:")
|
||||
mock_trade_outcome = {'net_pnl': 25.0, 'hold_time_seconds': 300}
|
||||
mock_market_data = {'current_price': 3500.0, 'trend_strength': 0.8, 'volatility': 0.1}
|
||||
mock_trade_decision = {'price': 3495.0}
|
||||
|
||||
reward = orch.calculate_enhanced_pivot_reward(
|
||||
mock_trade_decision,
|
||||
mock_market_data,
|
||||
mock_trade_outcome
|
||||
)
|
||||
print(f"✅ Enhanced pivot reward: {reward:.4f}")
|
||||
|
||||
# Test dashboard integration
|
||||
print("\nTesting dashboard integration:")
|
||||
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
|
||||
|
||||
# Create dashboard with basic orchestrator (should work now)
|
||||
dashboard = TradingDashboard(data_provider=dp, orchestrator=orch)
|
||||
print("✓ Dashboard created with enhanced orchestrator")
|
||||
|
||||
# Test dashboard can access enhanced methods
|
||||
dashboard_has_enhanced = hasattr(dashboard.orchestrator, 'build_comprehensive_rl_state')
|
||||
print(f" Dashboard has enhanced methods: {'✓' if dashboard_has_enhanced else '✗'}")
|
||||
|
||||
if dashboard_has_enhanced:
|
||||
dashboard_state = dashboard.orchestrator.build_comprehensive_rl_state('ETH/USDT')
|
||||
print(f" Dashboard comprehensive state: {len(dashboard_state) if dashboard_state else 0} features")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎉 COMPREHENSIVE RL TRAINING PIPELINE FIXES COMPLETE!")
|
||||
print("=" * 60)
|
||||
print("✅ AUDIT ISSUE #1: INPUT DATA GAP FIXED")
|
||||
print(" - Comprehensive RL state: 13,400+ features")
|
||||
print(" - ETH tick data, multi-timeframe OHLCV, BTC reference")
|
||||
print(" - CNN features, pivot analysis, microstructure")
|
||||
print("")
|
||||
print("✅ AUDIT ISSUE #2: ENHANCED REWARD CALCULATION FIXED")
|
||||
print(" - Pivot-based reward system operational")
|
||||
print(" - Market structure analysis integrated")
|
||||
print(" - Trade execution quality assessment")
|
||||
print("")
|
||||
print("✅ AUDIT ISSUE #3: ORCHESTRATOR INTEGRATION FIXED")
|
||||
print(" - Dashboard can access enhanced methods")
|
||||
print(" - No async/sync conflicts")
|
||||
print(" - Real-time training data collection ready")
|
||||
print("")
|
||||
print("🚀 READY FOR REAL-TIME TRAINING WITH RETROSPECTIVE SETUPS!")
|
||||
print("=" * 60)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_final_fixes()
|
||||
if success:
|
||||
print("\n✅ All pipeline fixes verified and working!")
|
||||
else:
|
||||
print("\n❌ Pipeline fixes need more work")
|
Binary file not shown.
@ -1,301 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test GPU Training - Check if our models actually train and use GPU
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_gpu_availability():
|
||||
"""Test if GPU is available and working"""
|
||||
logger.info("=== GPU AVAILABILITY TEST ===")
|
||||
|
||||
print(f"PyTorch version: {torch.__version__}")
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
print(f"CUDA version: {torch.version.cuda}")
|
||||
print(f"GPU count: {torch.cuda.device_count()}")
|
||||
for i in range(torch.cuda.device_count()):
|
||||
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
||||
print(f" Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f} GB")
|
||||
|
||||
# Test GPU operations
|
||||
try:
|
||||
device = torch.device('cuda:0')
|
||||
x = torch.randn(100, 100, device=device)
|
||||
y = torch.randn(100, 100, device=device)
|
||||
z = torch.mm(x, y)
|
||||
print(f"✅ GPU operations working: {z.device}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ GPU operations failed: {e}")
|
||||
return False
|
||||
else:
|
||||
print("❌ No CUDA available")
|
||||
return False
|
||||
|
||||
def test_simple_training():
|
||||
"""Test if a simple neural network actually trains"""
|
||||
logger.info("=== SIMPLE TRAINING TEST ===")
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Create a simple model
|
||||
class SimpleNet(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(10, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, 3)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
model = SimpleNet().to(device)
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# Generate some dummy data
|
||||
X = torch.randn(1000, 10, device=device)
|
||||
y = torch.randint(0, 3, (1000,), device=device)
|
||||
|
||||
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||
print(f"Data shape: {X.shape}, Labels shape: {y.shape}")
|
||||
|
||||
# Training loop
|
||||
initial_loss = None
|
||||
losses = []
|
||||
|
||||
print("Training for 100 steps...")
|
||||
start_time = time.time()
|
||||
|
||||
for step in range(100):
|
||||
# Forward pass
|
||||
outputs = model(X)
|
||||
loss = criterion(outputs, y)
|
||||
|
||||
# Backward pass
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
loss_val = loss.item()
|
||||
losses.append(loss_val)
|
||||
|
||||
if step == 0:
|
||||
initial_loss = loss_val
|
||||
|
||||
if step % 20 == 0:
|
||||
print(f"Step {step}: Loss = {loss_val:.4f}")
|
||||
|
||||
end_time = time.time()
|
||||
final_loss = losses[-1]
|
||||
|
||||
print(f"Training completed in {end_time - start_time:.2f} seconds")
|
||||
print(f"Initial loss: {initial_loss:.4f}")
|
||||
print(f"Final loss: {final_loss:.4f}")
|
||||
print(f"Loss reduction: {initial_loss - final_loss:.4f}")
|
||||
|
||||
# Check if training actually happened
|
||||
if final_loss < initial_loss * 0.9: # At least 10% reduction
|
||||
print("✅ Training is working - loss decreased significantly")
|
||||
return True
|
||||
else:
|
||||
print("❌ Training may not be working - loss didn't decrease much")
|
||||
return False
|
||||
|
||||
def test_our_models():
|
||||
"""Test if our actual models can train"""
|
||||
logger.info("=== OUR MODELS TEST ===")
|
||||
|
||||
try:
|
||||
# Test DQN Agent
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print(f"Testing DQN Agent on {device}")
|
||||
|
||||
# Create agent
|
||||
state_shape = (100,) # Simple state
|
||||
agent = DQNAgent(
|
||||
state_shape=state_shape,
|
||||
n_actions=3,
|
||||
learning_rate=0.001,
|
||||
device=device
|
||||
)
|
||||
|
||||
print(f"✅ DQN Agent created successfully")
|
||||
print(f" Device: {agent.device}")
|
||||
print(f" Policy net device: {next(agent.policy_net.parameters()).device}")
|
||||
|
||||
# Test training step
|
||||
state = np.random.randn(100).astype(np.float32)
|
||||
action = 1
|
||||
reward = 0.5
|
||||
next_state = np.random.randn(100).astype(np.float32)
|
||||
done = False
|
||||
|
||||
# Add experience and train
|
||||
agent.remember(state, action, reward, next_state, done)
|
||||
|
||||
# Add more experiences
|
||||
for _ in range(200): # Need enough for batch
|
||||
s = np.random.randn(100).astype(np.float32)
|
||||
a = np.random.randint(0, 3)
|
||||
r = np.random.randn() * 0.1
|
||||
ns = np.random.randn(100).astype(np.float32)
|
||||
d = np.random.random() < 0.1
|
||||
agent.remember(s, a, r, ns, d)
|
||||
|
||||
# Test training
|
||||
print("Testing training step...")
|
||||
initial_loss = None
|
||||
for i in range(10):
|
||||
loss = agent.replay()
|
||||
if loss > 0:
|
||||
if initial_loss is None:
|
||||
initial_loss = loss
|
||||
print(f" Step {i}: Loss = {loss:.4f}")
|
||||
|
||||
if initial_loss is not None:
|
||||
print("✅ DQN training is working")
|
||||
else:
|
||||
print("❌ DQN training returned no loss")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing our models: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_cnn_model():
|
||||
"""Test CNN model training"""
|
||||
logger.info("=== CNN MODEL TEST ===")
|
||||
|
||||
try:
|
||||
from NN.models.enhanced_cnn import EnhancedCNN
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print(f"Testing Enhanced CNN on {device}")
|
||||
|
||||
# Create model
|
||||
state_dim = (3, 20, 26) # 3 timeframes, 20 window, 26 features
|
||||
n_actions = 3
|
||||
|
||||
model = EnhancedCNN(state_dim, n_actions).to(device)
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
print(f"✅ Enhanced CNN created successfully")
|
||||
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
# Test forward pass
|
||||
batch_size = 32
|
||||
x = torch.randn(batch_size, 3, 20, 26, device=device)
|
||||
|
||||
print("Testing forward pass...")
|
||||
outputs = model(x)
|
||||
|
||||
if isinstance(outputs, tuple):
|
||||
action_probs, extrema_pred, price_pred, features, advanced_pred = outputs
|
||||
print(f"✅ Forward pass successful")
|
||||
print(f" Action probs shape: {action_probs.shape}")
|
||||
print(f" Features shape: {features.shape}")
|
||||
else:
|
||||
print(f"❌ Unexpected output format: {type(outputs)}")
|
||||
return False
|
||||
|
||||
# Test training step
|
||||
y = torch.randint(0, 3, (batch_size,), device=device)
|
||||
|
||||
print("Testing training step...")
|
||||
loss = criterion(action_probs, y)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print(f"✅ CNN training step successful, loss: {loss.item():.4f}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing CNN model: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("=" * 60)
|
||||
print("TESTING GPU TRAINING FUNCTIONALITY")
|
||||
print("=" * 60)
|
||||
|
||||
results = {}
|
||||
|
||||
# Test 1: GPU availability
|
||||
results['gpu'] = test_gpu_availability()
|
||||
print()
|
||||
|
||||
# Test 2: Simple training
|
||||
results['simple_training'] = test_simple_training()
|
||||
print()
|
||||
|
||||
# Test 3: Our DQN models
|
||||
results['dqn_models'] = test_our_models()
|
||||
print()
|
||||
|
||||
# Test 4: CNN models
|
||||
results['cnn_models'] = test_cnn_model()
|
||||
print()
|
||||
|
||||
# Summary
|
||||
print("=" * 60)
|
||||
print("TEST RESULTS SUMMARY")
|
||||
print("=" * 60)
|
||||
|
||||
for test_name, passed in results.items():
|
||||
status = "✅ PASS" if passed else "❌ FAIL"
|
||||
print(f"{test_name.upper()}: {status}")
|
||||
|
||||
all_passed = all(results.values())
|
||||
|
||||
if all_passed:
|
||||
print("\n🎉 ALL TESTS PASSED - Your training should work with GPU!")
|
||||
else:
|
||||
print("\n⚠️ SOME TESTS FAILED - Check the issues above")
|
||||
|
||||
if not results['gpu']:
|
||||
print(" → GPU not available or not working")
|
||||
if not results['simple_training']:
|
||||
print(" → Basic training loop not working")
|
||||
if not results['dqn_models']:
|
||||
print(" → DQN models have issues")
|
||||
if not results['cnn_models']:
|
||||
print(" → CNN models have issues")
|
||||
|
||||
return 0 if all_passed else 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = main()
|
||||
sys.exit(exit_code)
|
@ -1,402 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive Indicators and Signals Test Suite
|
||||
|
||||
This module consolidates testing functionality for:
|
||||
- Technical indicators (from test_indicators.py)
|
||||
- Signal interpretation and processing (from test_signal_interpreter.py)
|
||||
- Market data analysis
|
||||
- Trading signal validation
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import unittest
|
||||
import logging
|
||||
import numpy as np
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import setup_logging
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TestTechnicalIndicators(unittest.TestCase):
|
||||
"""Test suite for technical indicators functionality"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
setup_logging()
|
||||
self.data_provider = DataProvider(['ETH/USDT'], ['1h'])
|
||||
|
||||
def test_indicator_calculation(self):
|
||||
"""Test that indicators are calculated correctly"""
|
||||
logger.info("Testing technical indicators calculation...")
|
||||
|
||||
try:
|
||||
# Fetch data with indicators
|
||||
df = self.data_provider.get_historical_data('ETH/USDT', '1h', refresh=True, limit=100)
|
||||
|
||||
self.assertIsNotNone(df, "Should fetch data successfully")
|
||||
self.assertGreater(len(df), 0, "Should have data rows")
|
||||
|
||||
# Check basic OHLCV columns
|
||||
basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
for col in basic_cols:
|
||||
self.assertIn(col, df.columns, f"Should have {col} column")
|
||||
|
||||
# Check that indicators are calculated
|
||||
indicator_cols = [col for col in df.columns if col not in basic_cols]
|
||||
self.assertGreater(len(indicator_cols), 0, "Should have technical indicators")
|
||||
|
||||
logger.info(f"✅ Successfully calculated {len(indicator_cols)} indicators")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Indicator test failed: {e}")
|
||||
self.skipTest("Data or indicators not available")
|
||||
|
||||
def test_indicator_categorization(self):
|
||||
"""Test categorization of different indicator types"""
|
||||
logger.info("Testing indicator categorization...")
|
||||
|
||||
try:
|
||||
df = self.data_provider.get_historical_data('ETH/USDT', '1h', refresh=True, limit=100)
|
||||
|
||||
if df is not None:
|
||||
basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
indicator_cols = [col for col in df.columns if col not in basic_cols]
|
||||
|
||||
# Categorize indicators
|
||||
trend_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['sma', 'ema', 'macd', 'adx', 'psar'])]
|
||||
momentum_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['rsi', 'stoch', 'williams', 'cci'])]
|
||||
volatility_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['bb_', 'atr', 'keltner'])]
|
||||
volume_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['volume', 'obv', 'vpt', 'mfi', 'ad_line', 'vwap'])]
|
||||
|
||||
# Check we have indicators in each category
|
||||
total_categorized = len(trend_indicators) + len(momentum_indicators) + len(volatility_indicators) + len(volume_indicators)
|
||||
|
||||
logger.info(f"Indicator categories:")
|
||||
logger.info(f" Trend: {len(trend_indicators)}")
|
||||
logger.info(f" Momentum: {len(momentum_indicators)}")
|
||||
logger.info(f" Volatility: {len(volatility_indicators)}")
|
||||
logger.info(f" Volume: {len(volume_indicators)}")
|
||||
logger.info(f" Total categorized: {total_categorized}/{len(indicator_cols)}")
|
||||
|
||||
self.assertGreater(total_categorized, 0, "Should have categorized indicators")
|
||||
|
||||
else:
|
||||
self.skipTest("Could not fetch data for categorization test")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Categorization test failed: {e}")
|
||||
self.skipTest("Indicator categorization not available")
|
||||
|
||||
def test_feature_matrix_creation(self):
|
||||
"""Test multi-timeframe feature matrix creation"""
|
||||
logger.info("Testing feature matrix creation...")
|
||||
|
||||
try:
|
||||
# Test feature matrix with multiple timeframes
|
||||
feature_matrix = self.data_provider.get_feature_matrix('ETH/USDT', ['1h'], window_size=20)
|
||||
|
||||
if feature_matrix is not None:
|
||||
self.assertEqual(len(feature_matrix.shape), 3, "Should be 3D matrix")
|
||||
self.assertGreater(feature_matrix.shape[2], 0, "Should have features")
|
||||
|
||||
logger.info(f"✅ Feature matrix shape: {feature_matrix.shape}")
|
||||
|
||||
else:
|
||||
self.skipTest("Could not create feature matrix")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Feature matrix test failed: {e}")
|
||||
self.skipTest("Feature matrix creation not available")
|
||||
|
||||
class TestSignalProcessing(unittest.TestCase):
|
||||
"""Test suite for signal interpretation and processing"""
|
||||
|
||||
def test_signal_distribution_calculation(self):
|
||||
"""Test signal distribution calculation"""
|
||||
logger.info("Testing signal distribution calculation...")
|
||||
|
||||
# Mock predictions (SELL=0, HOLD=1, BUY=2)
|
||||
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0])
|
||||
|
||||
buy_count = np.sum(predictions == 2)
|
||||
sell_count = np.sum(predictions == 0)
|
||||
hold_count = np.sum(predictions == 1)
|
||||
total = len(predictions)
|
||||
|
||||
distribution = {
|
||||
"BUY": buy_count / total,
|
||||
"SELL": sell_count / total,
|
||||
"HOLD": hold_count / total
|
||||
}
|
||||
|
||||
# Verify calculations
|
||||
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
|
||||
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
|
||||
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
|
||||
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
|
||||
|
||||
logger.info("✅ Signal distribution calculation test passed")
|
||||
|
||||
def test_basic_signal_interpretation(self):
|
||||
"""Test basic signal interpretation logic"""
|
||||
logger.info("Testing basic signal interpretation...")
|
||||
|
||||
# Test cases with different probability distributions
|
||||
test_cases = [
|
||||
{
|
||||
'probs': [0.8, 0.1, 0.1], # Strong SELL
|
||||
'expected_action': 'SELL',
|
||||
'expected_confidence': 'high'
|
||||
},
|
||||
{
|
||||
'probs': [0.1, 0.1, 0.8], # Strong BUY
|
||||
'expected_action': 'BUY',
|
||||
'expected_confidence': 'high'
|
||||
},
|
||||
{
|
||||
'probs': [0.1, 0.8, 0.1], # Strong HOLD
|
||||
'expected_action': 'HOLD',
|
||||
'expected_confidence': 'high'
|
||||
},
|
||||
{
|
||||
'probs': [0.4, 0.3, 0.3], # Uncertain - should prefer SELL (index 0)
|
||||
'expected_action': 'SELL',
|
||||
'expected_confidence': 'low'
|
||||
},
|
||||
{
|
||||
'probs': [0.33, 0.33, 0.34], # Very uncertain - slight BUY preference
|
||||
'expected_action': 'BUY',
|
||||
'expected_confidence': 'low'
|
||||
}
|
||||
]
|
||||
|
||||
for i, test_case in enumerate(test_cases):
|
||||
probs = np.array(test_case['probs'])
|
||||
expected_action = test_case['expected_action']
|
||||
|
||||
# Simple signal interpretation (argmax)
|
||||
predicted_action_idx = np.argmax(probs)
|
||||
action_map = {0: 'SELL', 1: 'HOLD', 2: 'BUY'}
|
||||
predicted_action = action_map[predicted_action_idx]
|
||||
|
||||
# Calculate confidence (max probability)
|
||||
confidence = np.max(probs)
|
||||
confidence_level = 'high' if confidence > 0.7 else 'medium' if confidence > 0.5 else 'low'
|
||||
|
||||
# Verify predictions
|
||||
self.assertEqual(predicted_action, expected_action,
|
||||
f"Test case {i+1}: Expected {expected_action}, got {predicted_action}")
|
||||
|
||||
logger.info(f"Test case {i+1}: {probs} -> {predicted_action} ({confidence_level} confidence)")
|
||||
|
||||
logger.info("✅ Basic signal interpretation test passed")
|
||||
|
||||
def test_signal_filtering_logic(self):
|
||||
"""Test signal filtering and validation logic"""
|
||||
logger.info("Testing signal filtering logic...")
|
||||
|
||||
# Test threshold-based filtering
|
||||
buy_threshold = 0.6
|
||||
sell_threshold = 0.6
|
||||
hold_threshold = 0.7
|
||||
|
||||
test_signals = [
|
||||
{
|
||||
'probs': [0.8, 0.1, 0.1], # Strong SELL (above threshold)
|
||||
'should_pass': True,
|
||||
'expected': 'SELL'
|
||||
},
|
||||
{
|
||||
'probs': [0.5, 0.3, 0.2], # Weak SELL (below threshold)
|
||||
'should_pass': False,
|
||||
'expected': 'HOLD'
|
||||
},
|
||||
{
|
||||
'probs': [0.1, 0.2, 0.7], # Strong BUY (above threshold)
|
||||
'should_pass': True,
|
||||
'expected': 'BUY'
|
||||
},
|
||||
{
|
||||
'probs': [0.2, 0.8, 0.0], # Strong HOLD (above threshold)
|
||||
'should_pass': True,
|
||||
'expected': 'HOLD'
|
||||
}
|
||||
]
|
||||
|
||||
for i, test in enumerate(test_signals):
|
||||
probs = np.array(test['probs'])
|
||||
sell_prob, hold_prob, buy_prob = probs
|
||||
|
||||
# Apply threshold filtering
|
||||
if sell_prob >= sell_threshold:
|
||||
filtered_action = 'SELL'
|
||||
passed_filter = True
|
||||
elif buy_prob >= buy_threshold:
|
||||
filtered_action = 'BUY'
|
||||
passed_filter = True
|
||||
elif hold_prob >= hold_threshold:
|
||||
filtered_action = 'HOLD'
|
||||
passed_filter = True
|
||||
else:
|
||||
filtered_action = 'HOLD' # Default to HOLD if no threshold met
|
||||
passed_filter = False
|
||||
|
||||
# Verify filtering
|
||||
expected_pass = test['should_pass']
|
||||
expected_action = test['expected']
|
||||
|
||||
self.assertEqual(passed_filter, expected_pass,
|
||||
f"Test {i+1}: Filter pass expectation failed")
|
||||
self.assertEqual(filtered_action, expected_action,
|
||||
f"Test {i+1}: Expected {expected_action}, got {filtered_action}")
|
||||
|
||||
logger.info(f"Test {i+1}: {probs} -> {filtered_action} (passed: {passed_filter})")
|
||||
|
||||
logger.info("✅ Signal filtering logic test passed")
|
||||
|
||||
def test_signal_sequence_validation(self):
|
||||
"""Test signal sequence validation and oscillation prevention"""
|
||||
logger.info("Testing signal sequence validation...")
|
||||
|
||||
# Simulate a sequence of signals that might oscillate
|
||||
signal_sequence = ['BUY', 'SELL', 'BUY', 'SELL', 'HOLD', 'BUY']
|
||||
|
||||
# Simple oscillation detection
|
||||
oscillation_count = 0
|
||||
for i in range(1, len(signal_sequence)):
|
||||
if (signal_sequence[i-1] == 'BUY' and signal_sequence[i] == 'SELL') or \
|
||||
(signal_sequence[i-1] == 'SELL' and signal_sequence[i] == 'BUY'):
|
||||
oscillation_count += 1
|
||||
|
||||
# Count consecutive non-HOLD signals
|
||||
consecutive_trades = 0
|
||||
max_consecutive = 0
|
||||
for signal in signal_sequence:
|
||||
if signal != 'HOLD':
|
||||
consecutive_trades += 1
|
||||
max_consecutive = max(max_consecutive, consecutive_trades)
|
||||
else:
|
||||
consecutive_trades = 0
|
||||
|
||||
# Verify oscillation detection
|
||||
self.assertGreater(oscillation_count, 0, "Should detect oscillations in test sequence")
|
||||
self.assertGreater(max_consecutive, 1, "Should detect consecutive trades")
|
||||
|
||||
logger.info(f"Detected {oscillation_count} oscillations and max {max_consecutive} consecutive trades")
|
||||
logger.info("✅ Signal sequence validation test passed")
|
||||
|
||||
class TestMarketDataAnalysis(unittest.TestCase):
|
||||
"""Test suite for market data analysis functionality"""
|
||||
|
||||
def test_price_movement_calculation(self):
|
||||
"""Test price movement and trend calculation"""
|
||||
logger.info("Testing price movement calculation...")
|
||||
|
||||
# Mock price data
|
||||
prices = np.array([100.0, 101.0, 102.5, 101.8, 103.2, 102.9, 104.1])
|
||||
|
||||
# Calculate price movements
|
||||
price_changes = np.diff(prices)
|
||||
percentage_changes = (price_changes / prices[:-1]) * 100
|
||||
|
||||
# Calculate simple trend
|
||||
recent_trend = np.mean(percentage_changes[-3:]) # Last 3 changes
|
||||
trend_direction = 'uptrend' if recent_trend > 0.1 else 'downtrend' if recent_trend < -0.1 else 'sideways'
|
||||
|
||||
# Verify calculations
|
||||
self.assertEqual(len(price_changes), len(prices) - 1, "Should have n-1 price changes")
|
||||
self.assertEqual(len(percentage_changes), len(prices) - 1, "Should have n-1 percentage changes")
|
||||
|
||||
# Verify trend detection makes sense
|
||||
self.assertIn(trend_direction, ['uptrend', 'downtrend', 'sideways'], "Should detect valid trend")
|
||||
|
||||
logger.info(f"Price sequence: {prices}")
|
||||
logger.info(f"Recent trend: {trend_direction} ({recent_trend:.2f}%)")
|
||||
logger.info("✅ Price movement calculation test passed")
|
||||
|
||||
def test_volatility_measurement(self):
|
||||
"""Test volatility measurement"""
|
||||
logger.info("Testing volatility measurement...")
|
||||
|
||||
# Mock price data with different volatility
|
||||
stable_prices = np.array([100.0, 100.1, 99.9, 100.2, 99.8, 100.0])
|
||||
volatile_prices = np.array([100.0, 105.0, 95.0, 110.0, 90.0, 115.0])
|
||||
|
||||
# Calculate volatility (standard deviation of returns)
|
||||
def calculate_volatility(prices):
|
||||
returns = np.diff(prices) / prices[:-1]
|
||||
return np.std(returns) * 100 # As percentage
|
||||
|
||||
stable_vol = calculate_volatility(stable_prices)
|
||||
volatile_vol = calculate_volatility(volatile_prices)
|
||||
|
||||
# Verify volatility measurements
|
||||
self.assertLess(stable_vol, volatile_vol, "Stable prices should have lower volatility")
|
||||
self.assertGreater(volatile_vol, 5.0, "Volatile prices should have significant volatility")
|
||||
|
||||
logger.info(f"Stable volatility: {stable_vol:.2f}%")
|
||||
logger.info(f"Volatile volatility: {volatile_vol:.2f}%")
|
||||
logger.info("✅ Volatility measurement test passed")
|
||||
|
||||
def run_indicator_tests():
|
||||
"""Run indicator tests only"""
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestTechnicalIndicators)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(suite)
|
||||
return result.wasSuccessful()
|
||||
|
||||
def run_signal_tests():
|
||||
"""Run signal processing tests only"""
|
||||
test_suites = [
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestSignalProcessing),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestMarketDataAnalysis),
|
||||
]
|
||||
|
||||
combined_suite = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(combined_suite)
|
||||
return result.wasSuccessful()
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all indicator and signal tests"""
|
||||
test_suites = [
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestTechnicalIndicators),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestSignalProcessing),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestMarketDataAnalysis),
|
||||
]
|
||||
|
||||
combined_suite = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(combined_suite)
|
||||
return result.wasSuccessful()
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logging()
|
||||
logger.info("Running indicators and signals test suite...")
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
test_type = sys.argv[1]
|
||||
if test_type == "indicators":
|
||||
success = run_indicator_tests()
|
||||
elif test_type == "signals":
|
||||
success = run_signal_tests()
|
||||
else:
|
||||
success = run_all_tests()
|
||||
else:
|
||||
success = run_all_tests()
|
||||
|
||||
if success:
|
||||
logger.info("✅ All indicator and signal tests passed!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("❌ Some tests failed!")
|
||||
sys.exit(1)
|
@ -1,176 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Leverage Slider Functionality
|
||||
|
||||
This script tests the leverage slider integration in the dashboard:
|
||||
- Verifies slider range (1x to 100x)
|
||||
- Tests risk level calculation
|
||||
- Checks leverage multiplier updates
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import setup_logging
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
|
||||
|
||||
# Setup logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_leverage_calculations():
|
||||
"""Test leverage risk calculations"""
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("TESTING LEVERAGE CALCULATIONS")
|
||||
logger.info("=" * 50)
|
||||
|
||||
test_cases = [
|
||||
{'leverage': 1, 'expected_risk': 'Low Risk'},
|
||||
{'leverage': 5, 'expected_risk': 'Low Risk'},
|
||||
{'leverage': 10, 'expected_risk': 'Medium Risk'},
|
||||
{'leverage': 25, 'expected_risk': 'Medium Risk'},
|
||||
{'leverage': 30, 'expected_risk': 'High Risk'},
|
||||
{'leverage': 50, 'expected_risk': 'High Risk'},
|
||||
{'leverage': 75, 'expected_risk': 'Extreme Risk'},
|
||||
{'leverage': 100, 'expected_risk': 'Extreme Risk'},
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
leverage = test_case['leverage']
|
||||
expected_risk = test_case['expected_risk']
|
||||
|
||||
# Calculate risk level using same logic as dashboard
|
||||
if leverage <= 5:
|
||||
actual_risk = "Low Risk"
|
||||
elif leverage <= 25:
|
||||
actual_risk = "Medium Risk"
|
||||
elif leverage <= 50:
|
||||
actual_risk = "High Risk"
|
||||
else:
|
||||
actual_risk = "Extreme Risk"
|
||||
|
||||
status = "PASS" if actual_risk == expected_risk else "FAIL"
|
||||
logger.info(f" {leverage:3d}x leverage -> {actual_risk:13s} (expected: {expected_risk:13s}) [{status}]")
|
||||
|
||||
if status == "FAIL":
|
||||
logger.error(f"Test failed for {leverage}x leverage!")
|
||||
return False
|
||||
|
||||
logger.info("All leverage calculation tests PASSED!")
|
||||
return True
|
||||
|
||||
def test_leverage_reward_amplification():
|
||||
"""Test how different leverage levels amplify rewards"""
|
||||
|
||||
logger.info("\n" + "=" * 50)
|
||||
logger.info("TESTING LEVERAGE REWARD AMPLIFICATION")
|
||||
logger.info("=" * 50)
|
||||
|
||||
base_price = 3000.0
|
||||
price_changes = [0.001, 0.002, 0.005, 0.01] # 0.1%, 0.2%, 0.5%, 1.0%
|
||||
leverage_levels = [1, 5, 10, 25, 50, 100]
|
||||
|
||||
logger.info("Price Change | " + " | ".join([f"{lev:3d}x" for lev in leverage_levels]))
|
||||
logger.info("-" * 70)
|
||||
|
||||
for price_change_pct in price_changes:
|
||||
results = []
|
||||
for leverage in leverage_levels:
|
||||
# Calculate amplified return
|
||||
amplified_return = price_change_pct * leverage * 100 # Convert to percentage
|
||||
results.append(f"{amplified_return:6.1f}%")
|
||||
|
||||
logger.info(f" {price_change_pct*100:4.1f}% | " + " | ".join(results))
|
||||
|
||||
logger.info("\nKey insights:")
|
||||
logger.info("- 1x leverage: Traditional trading returns")
|
||||
logger.info("- 50x leverage: Our current default for enhanced learning")
|
||||
logger.info("- 100x leverage: Maximum risk/reward amplification")
|
||||
|
||||
return True
|
||||
|
||||
def test_dashboard_integration():
|
||||
"""Test dashboard integration"""
|
||||
|
||||
logger.info("\n" + "=" * 50)
|
||||
logger.info("TESTING DASHBOARD INTEGRATION")
|
||||
logger.info("=" * 50)
|
||||
|
||||
try:
|
||||
# Initialize components
|
||||
logger.info("Creating data provider...")
|
||||
data_provider = DataProvider()
|
||||
|
||||
logger.info("Creating enhanced orchestrator...")
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
|
||||
logger.info("Creating trading dashboard...")
|
||||
dashboard = TradingDashboard(data_provider, orchestrator)
|
||||
|
||||
# Test leverage settings
|
||||
logger.info(f"Initial leverage: {dashboard.leverage_multiplier}x")
|
||||
logger.info(f"Leverage range: {dashboard.min_leverage}x to {dashboard.max_leverage}x")
|
||||
logger.info(f"Leverage step: {dashboard.leverage_step}x")
|
||||
|
||||
# Test leverage updates
|
||||
test_leverages = [10, 25, 50, 75]
|
||||
for test_leverage in test_leverages:
|
||||
dashboard.leverage_multiplier = test_leverage
|
||||
logger.info(f"Set leverage to {dashboard.leverage_multiplier}x")
|
||||
|
||||
logger.info("Dashboard integration test PASSED!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dashboard integration test FAILED: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run all leverage tests"""
|
||||
|
||||
logger.info("LEVERAGE SLIDER FUNCTIONALITY TEST")
|
||||
logger.info("Testing the 50x leverage system with adjustable slider")
|
||||
|
||||
all_passed = True
|
||||
|
||||
# Test 1: Leverage calculations
|
||||
if not test_leverage_calculations():
|
||||
all_passed = False
|
||||
|
||||
# Test 2: Reward amplification
|
||||
if not test_leverage_reward_amplification():
|
||||
all_passed = False
|
||||
|
||||
# Test 3: Dashboard integration
|
||||
if not test_dashboard_integration():
|
||||
all_passed = False
|
||||
|
||||
# Final result
|
||||
logger.info("\n" + "=" * 50)
|
||||
if all_passed:
|
||||
logger.info("ALL TESTS PASSED!")
|
||||
logger.info("Leverage slider functionality is working correctly.")
|
||||
logger.info("\nTo use:")
|
||||
logger.info("1. Run: python run_clean_dashboard.py")
|
||||
logger.info("2. Open: http://127.0.0.1:8050")
|
||||
logger.info("3. Find the leverage slider in the System & Leverage panel")
|
||||
logger.info("4. Adjust leverage from 1x to 100x")
|
||||
logger.info("5. Watch risk levels update automatically")
|
||||
else:
|
||||
logger.error("SOME TESTS FAILED!")
|
||||
logger.error("Check the error messages above.")
|
||||
|
||||
return 0 if all_passed else 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = main()
|
||||
sys.exit(exit_code)
|
@ -1,88 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for manual trading buttons functionality
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
def test_manual_trading():
|
||||
"""Test the manual trading buttons functionality"""
|
||||
print("Testing manual trading buttons...")
|
||||
|
||||
# Check if dashboard is running
|
||||
try:
|
||||
response = requests.get("http://127.0.0.1:8050", timeout=5)
|
||||
if response.status_code == 200:
|
||||
print("✅ Dashboard is running on port 8050")
|
||||
else:
|
||||
print(f"❌ Dashboard returned status code: {response.status_code}")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"❌ Dashboard not accessible: {e}")
|
||||
return
|
||||
|
||||
# Check if trades file exists
|
||||
try:
|
||||
with open('closed_trades_history.json', 'r') as f:
|
||||
trades = json.load(f)
|
||||
print(f"📊 Current trades in history: {len(trades)}")
|
||||
if trades:
|
||||
latest_trade = trades[-1]
|
||||
print(f" Latest trade: {latest_trade.get('side')} at ${latest_trade.get('exit_price', latest_trade.get('entry_price'))}")
|
||||
except FileNotFoundError:
|
||||
print("📊 No trades history file found (this is normal for fresh start)")
|
||||
except Exception as e:
|
||||
print(f"❌ Error reading trades file: {e}")
|
||||
|
||||
print("\n🎯 Manual Trading Test Instructions:")
|
||||
print("1. Open dashboard at http://127.0.0.1:8050")
|
||||
print("2. Look for the 'MANUAL BUY' and 'MANUAL SELL' buttons")
|
||||
print("3. Click 'MANUAL BUY' to create a test long position")
|
||||
print("4. Wait a few seconds, then click 'MANUAL SELL' to close and create short")
|
||||
print("5. Check the chart for green triangles showing trade entry/exit points")
|
||||
print("6. Check the 'Closed Trades' table for trade records")
|
||||
|
||||
print("\n📈 Expected Results:")
|
||||
print("- Green triangles should appear on the price chart at trade execution times")
|
||||
print("- Dashed lines should connect entry and exit points")
|
||||
print("- Trade records should appear in the closed trades table")
|
||||
print("- Session P&L should update with trade profits/losses")
|
||||
|
||||
print("\n🔍 Monitoring trades file...")
|
||||
initial_count = 0
|
||||
try:
|
||||
with open('closed_trades_history.json', 'r') as f:
|
||||
initial_count = len(json.load(f))
|
||||
except:
|
||||
pass
|
||||
|
||||
print(f"Initial trade count: {initial_count}")
|
||||
print("Watching for new trades... (Press Ctrl+C to stop)")
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(2)
|
||||
try:
|
||||
with open('closed_trades_history.json', 'r') as f:
|
||||
current_trades = json.load(f)
|
||||
current_count = len(current_trades)
|
||||
|
||||
if current_count > initial_count:
|
||||
new_trades = current_trades[initial_count:]
|
||||
for trade in new_trades:
|
||||
print(f"🆕 NEW TRADE: {trade.get('side')} | Entry: ${trade.get('entry_price'):.2f} | Exit: ${trade.get('exit_price'):.2f} | P&L: ${trade.get('net_pnl'):.2f}")
|
||||
initial_count = current_count
|
||||
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Error monitoring trades: {e}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n✅ Test monitoring stopped")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_manual_trading()
|
@ -1,64 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import sys
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from NN.exchanges.mexc_interface import MEXCInterface
|
||||
from core.config import get_config
|
||||
|
||||
def test_mexc_private_api():
|
||||
"""Test MEXC private API endpoints"""
|
||||
# Load configuration
|
||||
config = get_config('config.yaml')
|
||||
mexc_config = config.get('mexc_trading', {})
|
||||
|
||||
# Get API credentials
|
||||
api_key = os.getenv('MEXC_API_KEY', mexc_config.get('api_key', ''))
|
||||
api_secret = os.getenv('MEXC_SECRET_KEY', mexc_config.get('api_secret', ''))
|
||||
|
||||
if not api_key or not api_secret:
|
||||
logger.error("API key or secret not found. Please set MEXC_API_KEY and MEXC_SECRET_KEY environment variables.")
|
||||
return
|
||||
|
||||
# Initialize MEXC interface in test mode
|
||||
mexc = MEXCInterface(api_key=api_key, api_secret=api_secret, test_mode=True, trading_mode='simulation')
|
||||
|
||||
# Test connection
|
||||
if not mexc.connect():
|
||||
logger.error("Failed to connect to MEXC API")
|
||||
return
|
||||
|
||||
# Test getting account information
|
||||
logger.info("Testing account information retrieval...")
|
||||
account_info = mexc.get_account_info()
|
||||
if account_info:
|
||||
logger.info(f"Account info retrieved: {account_info}")
|
||||
else:
|
||||
logger.error("Failed to retrieve account info")
|
||||
|
||||
# Test getting balance for a specific asset
|
||||
asset = "USDT"
|
||||
logger.info(f"Testing balance retrieval for {asset}...")
|
||||
balance = mexc.get_balance(asset)
|
||||
logger.info(f"Balance for {asset}: {balance}")
|
||||
|
||||
# Test placing a simulated order (in test mode)
|
||||
symbol = "ETH/USDT"
|
||||
side = "buy"
|
||||
order_type = "market"
|
||||
quantity = 0.01 # Small quantity for testing
|
||||
logger.info(f"Testing order placement for {symbol} ({side}, {order_type}, qty: {quantity})...")
|
||||
order_result = mexc.place_order(symbol=symbol, side=side, order_type=order_type, quantity=quantity)
|
||||
if order_result:
|
||||
logger.info(f"Order placed successfully: {order_result}")
|
||||
else:
|
||||
logger.error("Failed to place order")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mexc_private_api()
|
@ -1,59 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from NN.exchanges.mexc_interface import MEXCInterface
|
||||
|
||||
# Set up logging to see debug info
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Load API credentials from environment variables or a configuration file
|
||||
# For testing, prioritize environment variables for CI/CD or sensitive data
|
||||
# Fallback to a placeholder or configuration reading if env vars are not set
|
||||
api_key = os.getenv('MEXC_API_KEY', '')
|
||||
api_secret = os.getenv('MEXC_SECRET_KEY', '')
|
||||
|
||||
# If using a config file, you might do something like:
|
||||
# from core.config import get_config
|
||||
# config = get_config('config.yaml')
|
||||
# mexc_config = config.get('mexc_trading', {})
|
||||
# api_key = mexc_config.get('api_key', api_key)
|
||||
# api_secret = mexc_config.get('api_secret', api_secret)
|
||||
|
||||
if not api_key or not api_secret:
|
||||
logging.error("API keys are not set. Please set MEXC_API_KEY and MEXC_SECRET_KEY environment variables or configure config.yaml")
|
||||
exit(1)
|
||||
|
||||
# Create interface with API credentials
|
||||
mexc = MEXCInterface(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
trading_mode='simulation'
|
||||
)
|
||||
|
||||
print('MEXC Interface created successfully')
|
||||
|
||||
# Test signature generation
|
||||
import time
|
||||
timestamp = int(time.time() * 1000)
|
||||
test_params = 'quantity=1&price=11&symbol=BTCUSDT&side=BUY&type=LIMIT×tamp=' + str(timestamp)
|
||||
signature = mexc._generate_signature(timestamp, test_params)
|
||||
print(f'Generated signature: {signature}')
|
||||
|
||||
# Test account info
|
||||
print('Testing account info...')
|
||||
account_info = mexc.get_account_info()
|
||||
print(f'Account info result: {account_info}')
|
||||
|
||||
# Test ticker data
|
||||
print('Testing ticker data...')
|
||||
ticker = mexc.get_ticker('ETH/USDT')
|
||||
print(f'ETH/USDT ticker: {ticker}')
|
||||
|
||||
# Test balance retrieval
|
||||
print('Testing balance retrieval...')
|
||||
usdt_balance = mexc.get_balance('USDT')
|
||||
print(f'USDT balance: {usdt_balance}')
|
||||
|
||||
# Test a small order placement (simulation mode)
|
||||
print('Testing order placement in simulation mode...')
|
||||
order_result = mexc.place_order('ETH/USDT', 'buy', 'market', 0.001)
|
||||
print(f'Order result: {order_result}')
|
@ -1,222 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for MEXC balance retrieval and $1 order execution
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.abspath('.'))
|
||||
|
||||
from core.trading_executor import TradingExecutor
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_mexc_balance():
|
||||
"""Test MEXC balance retrieval"""
|
||||
print("="*60)
|
||||
print("TESTING MEXC BALANCE RETRIEVAL")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Initialize trading executor
|
||||
executor = TradingExecutor()
|
||||
|
||||
# Check if trading is enabled
|
||||
print(f"Trading enabled: {executor.trading_enabled}")
|
||||
print(f"Dry run mode: {executor.dry_run}")
|
||||
|
||||
if not executor.trading_enabled:
|
||||
print("❌ Trading not enabled - check config.yaml and API keys")
|
||||
return False
|
||||
|
||||
# Test balance retrieval
|
||||
print("\n📊 Retrieving account balance...")
|
||||
balances = executor.get_account_balance()
|
||||
|
||||
if not balances:
|
||||
print("❌ No balances retrieved - check API connectivity")
|
||||
return False
|
||||
|
||||
print(f"✅ Retrieved balances for {len(balances)} assets:")
|
||||
for asset, balance_info in balances.items():
|
||||
free = balance_info['free']
|
||||
locked = balance_info['locked']
|
||||
total = balance_info['total']
|
||||
print(f" {asset}: Free: {free:.6f}, Locked: {locked:.6f}, Total: {total:.6f}")
|
||||
|
||||
# Check USDT balance specifically
|
||||
if 'USDT' in balances:
|
||||
usdt_free = balances['USDT']['free']
|
||||
print(f"\n💰 USDT available for trading: ${usdt_free:.2f}")
|
||||
|
||||
if usdt_free >= 2.0: # Need at least $2 for testing
|
||||
print("✅ Sufficient USDT balance for $1 order testing")
|
||||
return True
|
||||
else:
|
||||
print(f"⚠️ Insufficient USDT balance for testing (need $2+, have ${usdt_free:.2f})")
|
||||
return False
|
||||
else:
|
||||
print("❌ No USDT balance found")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing MEXC balance: {e}")
|
||||
return False
|
||||
|
||||
def test_mexc_order_execution():
|
||||
"""Test $1 order execution (dry run)"""
|
||||
print("\n" + "="*60)
|
||||
print("TESTING $1 ORDER EXECUTION (DRY RUN)")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Initialize components
|
||||
executor = TradingExecutor()
|
||||
data_provider = DataProvider()
|
||||
|
||||
if not executor.trading_enabled:
|
||||
print("❌ Trading not enabled - cannot test order execution")
|
||||
return False
|
||||
|
||||
# Test symbol
|
||||
symbol = "ETH/USDT"
|
||||
|
||||
# Get current price
|
||||
print(f"\n📈 Getting current price for {symbol}...")
|
||||
ticker_data = data_provider.get_historical_data(symbol, '1m', limit=1, refresh=True)
|
||||
|
||||
if ticker_data is None or ticker_data.empty:
|
||||
print(f"❌ Could not get price data for {symbol}")
|
||||
return False
|
||||
|
||||
current_price = float(ticker_data['close'].iloc[-1])
|
||||
print(f"✅ Current {symbol} price: ${current_price:.2f}")
|
||||
|
||||
# Calculate order size for $1
|
||||
usd_amount = 1.0
|
||||
crypto_amount = usd_amount / current_price
|
||||
print(f"💱 $1 USD = {crypto_amount:.6f} ETH")
|
||||
|
||||
# Test buy signal execution
|
||||
print(f"\n🛒 Testing BUY signal execution...")
|
||||
buy_success = executor.execute_signal(
|
||||
symbol=symbol,
|
||||
action='BUY',
|
||||
confidence=0.75,
|
||||
current_price=current_price
|
||||
)
|
||||
|
||||
if buy_success:
|
||||
print("✅ BUY signal executed successfully")
|
||||
|
||||
# Check position
|
||||
positions = executor.get_positions()
|
||||
if symbol in positions:
|
||||
position = positions[symbol]
|
||||
print(f"📍 Position opened: {position.quantity:.6f} {symbol} @ ${position.entry_price:.2f}")
|
||||
|
||||
# Test sell signal execution
|
||||
print(f"\n💰 Testing SELL signal execution...")
|
||||
sell_success = executor.execute_signal(
|
||||
symbol=symbol,
|
||||
action='SELL',
|
||||
confidence=0.80,
|
||||
current_price=current_price * 1.001 # Simulate small price increase
|
||||
)
|
||||
|
||||
if sell_success:
|
||||
print("✅ SELL signal executed successfully")
|
||||
|
||||
# Check trade history
|
||||
trades = executor.get_trade_history()
|
||||
if trades:
|
||||
last_trade = trades[-1]
|
||||
print(f"📊 Trade completed: P&L = ${last_trade.pnl:.4f}")
|
||||
|
||||
return True
|
||||
else:
|
||||
print("❌ SELL signal failed")
|
||||
return False
|
||||
else:
|
||||
print("❌ No position found after BUY signal")
|
||||
return False
|
||||
else:
|
||||
print("❌ BUY signal failed")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing order execution: {e}")
|
||||
return False
|
||||
|
||||
def test_dashboard_balance_integration():
|
||||
"""Test dashboard balance integration"""
|
||||
print("\n" + "="*60)
|
||||
print("TESTING DASHBOARD BALANCE INTEGRATION")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
|
||||
|
||||
# Create dashboard with trading executor
|
||||
executor = TradingExecutor()
|
||||
dashboard = TradingDashboard(trading_executor=executor)
|
||||
|
||||
print(f"Dashboard starting balance: ${dashboard.starting_balance:.2f}")
|
||||
|
||||
if dashboard.starting_balance > 0:
|
||||
print("✅ Dashboard successfully retrieved starting balance")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ Dashboard using default balance (MEXC not connected)")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing dashboard integration: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("🚀 MEXC INTEGRATION TESTING")
|
||||
print("Testing balance retrieval and $1 order execution")
|
||||
|
||||
# Test 1: Balance retrieval
|
||||
balance_test = test_mexc_balance()
|
||||
|
||||
# Test 2: Order execution (only if balance test passes)
|
||||
if balance_test:
|
||||
order_test = test_mexc_order_execution()
|
||||
else:
|
||||
print("\n⏭️ Skipping order execution test (balance test failed)")
|
||||
order_test = False
|
||||
|
||||
# Test 3: Dashboard integration
|
||||
dashboard_test = test_dashboard_balance_integration()
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*60)
|
||||
print("TEST SUMMARY")
|
||||
print("="*60)
|
||||
print(f"Balance Retrieval: {'✅ PASS' if balance_test else '❌ FAIL'}")
|
||||
print(f"Order Execution: {'✅ PASS' if order_test else '❌ FAIL'}")
|
||||
print(f"Dashboard Integration: {'✅ PASS' if dashboard_test else '❌ FAIL'}")
|
||||
|
||||
if balance_test and order_test and dashboard_test:
|
||||
print("\n🎉 ALL TESTS PASSED - Ready for live $1 testing!")
|
||||
return True
|
||||
else:
|
||||
print("\n⚠️ Some tests failed - check configuration and API keys")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
@ -1 +0,0 @@
|
||||
|
@ -1,117 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for MEXC API with new credentials
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
def test_api_credentials():
|
||||
"""Test MEXC API credentials step by step"""
|
||||
print("="*60)
|
||||
print("MEXC API CREDENTIALS TEST")
|
||||
print("="*60)
|
||||
|
||||
# Check environment variables
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
api_secret = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
print(f"1. Environment Variables:")
|
||||
print(f" API Key: {api_key[:5]}...{api_key[-5:] if api_key else 'None'}")
|
||||
print(f" API Secret: {api_secret[:5]}...{api_secret[-5:] if api_secret else 'None'}")
|
||||
print(f" API Key Length: {len(api_key) if api_key else 0}")
|
||||
print(f" API Secret Length: {len(api_secret) if api_secret else 0}")
|
||||
|
||||
if not api_key or not api_secret:
|
||||
print("❌ API credentials not found in environment")
|
||||
return False
|
||||
|
||||
# Test public API first
|
||||
print(f"\n2. Testing Public API (no authentication):")
|
||||
try:
|
||||
from NN.exchanges.mexc_interface import MEXCInterface
|
||||
api = MEXCInterface('dummy', 'dummy', test_mode=False)
|
||||
|
||||
ticker = api.get_ticker('ETHUSDT')
|
||||
if ticker:
|
||||
print(f" ✅ Public API works: ETH/USDT = ${ticker.get('last', 'N/A')}")
|
||||
else:
|
||||
print(f" ❌ Public API failed")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f" ❌ Public API error: {e}")
|
||||
return False
|
||||
|
||||
# Test private API with actual credentials
|
||||
print(f"\n3. Testing Private API (with authentication):")
|
||||
try:
|
||||
api_auth = MEXCInterface(api_key, api_secret, test_mode=False)
|
||||
|
||||
# Try to get account info
|
||||
account_info = api_auth.get_account_info()
|
||||
if account_info:
|
||||
print(f" ✅ Private API works: Account info retrieved")
|
||||
print(f" 📊 Account Type: {account_info.get('accountType', 'N/A')}")
|
||||
# Try to get USDT balance
|
||||
usdt_balance = api_auth.get_balance('USDT')
|
||||
print(f" 💰 USDT Balance: {usdt_balance}")
|
||||
return True
|
||||
else:
|
||||
print(f" ❌ Private API failed: Could not get account info")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Private API error: {e}")
|
||||
return False
|
||||
|
||||
def test_api_permissions():
|
||||
"""Test specific API permissions"""
|
||||
print(f"\n4. Testing API Permissions:")
|
||||
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
api_secret = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
try:
|
||||
from NN.exchanges.mexc_interface import MEXCInterface
|
||||
api = MEXCInterface(api_key, api_secret, test_mode=False)
|
||||
|
||||
# Test spot trading permissions
|
||||
print(" Testing spot trading permissions...")
|
||||
|
||||
# Try to get open orders (requires spot trading permission)
|
||||
try:
|
||||
orders = api.get_open_orders('ETHUSDT')
|
||||
print(" ✅ Spot trading permission: OK")
|
||||
except Exception as e:
|
||||
print(f" ❌ Spot trading permission: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Permission test error: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
success = test_api_credentials()
|
||||
|
||||
if success:
|
||||
test_api_permissions()
|
||||
print(f"\n✅ MEXC API SETUP COMPLETE")
|
||||
print("The trading system should now work with live MEXC spot trading")
|
||||
else:
|
||||
print(f"\n❌ MEXC API SETUP FAILED")
|
||||
print("Possible issues:")
|
||||
print("1. API key or secret incorrect")
|
||||
print("2. API key not activated yet")
|
||||
print("3. Insufficient permissions (need spot trading)")
|
||||
print("4. IP address not whitelisted")
|
||||
print("5. Account verification incomplete")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,362 +0,0 @@
|
||||
"""
|
||||
MEXC Order Execution Debug Script
|
||||
|
||||
This script tests MEXC order execution step by step to identify any issues
|
||||
with the trading integration.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Add paths for imports
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), 'core'))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), 'NN'))
|
||||
|
||||
from core.trading_executor import TradingExecutor
|
||||
from core.data_provider import DataProvider
|
||||
from NN.exchanges.mexc_interface import MEXCInterface
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler("mexc_order_debug.log"),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger("mexc_debug")
|
||||
|
||||
class MEXCOrderDebugger:
|
||||
"""Debug MEXC order execution step by step"""
|
||||
|
||||
def __init__(self):
|
||||
self.test_symbol = 'ETH/USDC' # ETH with USDC (supported by MEXC API)
|
||||
self.test_quantity = 0.15 # $0.15 worth of ETH for testing (within our balance)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
self.api_key = os.getenv('MEXC_API_KEY')
|
||||
self.api_secret = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
def run_comprehensive_test(self):
|
||||
"""Run comprehensive MEXC order execution test"""
|
||||
print("="*80)
|
||||
print("MEXC ORDER EXECUTION DEBUG TEST")
|
||||
print("="*80)
|
||||
|
||||
# Step 1: Test environment variables
|
||||
print("\n1. Testing Environment Variables...")
|
||||
if not self.test_environment_variables():
|
||||
return False
|
||||
|
||||
# Step 2: Test MEXC interface creation
|
||||
print("\n2. Testing MEXC Interface Creation...")
|
||||
mexc = self.test_mexc_interface_creation()
|
||||
if not mexc:
|
||||
return False
|
||||
|
||||
# Step 3: Test connection
|
||||
print("\n3. Testing MEXC Connection...")
|
||||
if not self.test_mexc_connection(mexc):
|
||||
return False
|
||||
|
||||
# Step 4: Test account info
|
||||
print("\n4. Testing Account Information...")
|
||||
if not self.test_account_info(mexc):
|
||||
return False
|
||||
|
||||
# Step 5: Test ticker data
|
||||
print("\n5. Testing Ticker Data...")
|
||||
current_price = self.test_ticker_data(mexc)
|
||||
if not current_price:
|
||||
return False
|
||||
|
||||
# Step 6: Test trading executor
|
||||
print("\n6. Testing Trading Executor...")
|
||||
executor = self.test_trading_executor_creation()
|
||||
if not executor:
|
||||
return False
|
||||
|
||||
# Step 7: Test order placement (simulation)
|
||||
print("\n7. Testing Order Placement...")
|
||||
if not self.test_order_placement(executor, current_price):
|
||||
return False
|
||||
|
||||
# Step 8: Test order parameters
|
||||
print("\n8. Testing Order Parameters...")
|
||||
if not self.test_order_parameters(mexc, current_price):
|
||||
return False
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("✅ ALL TESTS COMPLETED SUCCESSFULLY!")
|
||||
print("MEXC order execution system appears to be working correctly.")
|
||||
print("="*80)
|
||||
|
||||
return True
|
||||
|
||||
def test_environment_variables(self) -> bool:
|
||||
"""Test environment variables"""
|
||||
try:
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
api_secret = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
if not api_key:
|
||||
print("❌ MEXC_API_KEY environment variable not set")
|
||||
return False
|
||||
|
||||
if not api_secret:
|
||||
print("❌ MEXC_SECRET_KEY environment variable not set")
|
||||
return False
|
||||
|
||||
print(f"✅ MEXC_API_KEY: {api_key[:8]}...{api_key[-4:]} (length: {len(api_key)})")
|
||||
print(f"✅ MEXC_SECRET_KEY: {api_secret[:8]}...{api_secret[-4:]} (length: {len(api_secret)})")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error checking environment variables: {e}")
|
||||
return False
|
||||
|
||||
def test_mexc_interface_creation(self) -> MEXCInterface:
|
||||
"""Test MEXC interface creation"""
|
||||
try:
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
api_secret = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
mexc = MEXCInterface(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
test_mode=True # Use testnet for safety
|
||||
)
|
||||
|
||||
print(f"✅ MEXC Interface created successfully")
|
||||
print(f" - Test mode: {mexc.test_mode}")
|
||||
print(f" - Base URL: {mexc.base_url}")
|
||||
|
||||
return mexc
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating MEXC interface: {e}")
|
||||
return None
|
||||
|
||||
def test_mexc_connection(self, mexc: MEXCInterface) -> bool:
|
||||
"""Test MEXC connection"""
|
||||
try:
|
||||
# Test ping
|
||||
ping_result = mexc.ping()
|
||||
print(f"✅ MEXC Ping successful: {ping_result}")
|
||||
|
||||
# Test server time
|
||||
server_time = mexc.get_server_time()
|
||||
print(f"✅ MEXC Server time: {server_time}")
|
||||
|
||||
# Test connection method
|
||||
connected = mexc.connect()
|
||||
print(f"✅ MEXC Connection: {connected}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing MEXC connection: {e}")
|
||||
logger.error(f"MEXC connection error: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def test_account_info(self, mexc: MEXCInterface) -> bool:
|
||||
"""Test account information retrieval"""
|
||||
try:
|
||||
account_info = mexc.get_account_info()
|
||||
print(f"✅ Account info retrieved successfully")
|
||||
print(f" - Can trade: {account_info.get('canTrade', 'Unknown')}")
|
||||
print(f" - Can withdraw: {account_info.get('canWithdraw', 'Unknown')}")
|
||||
print(f" - Can deposit: {account_info.get('canDeposit', 'Unknown')}")
|
||||
print(f" - Account type: {account_info.get('accountType', 'Unknown')}")
|
||||
|
||||
# Test balance retrieval
|
||||
balances = account_info.get('balances', [])
|
||||
usdc_balance = 0
|
||||
usdt_balance = 0
|
||||
for balance in balances:
|
||||
if balance.get('asset') == 'USDC':
|
||||
usdc_balance = float(balance.get('free', 0))
|
||||
elif balance.get('asset') == 'USDT':
|
||||
usdt_balance = float(balance.get('free', 0))
|
||||
|
||||
print(f" - USDC Balance: {usdc_balance}")
|
||||
print(f" - USDT Balance: {usdt_balance}")
|
||||
|
||||
if usdc_balance < self.test_quantity:
|
||||
print(f"⚠️ Warning: USDC balance ({usdc_balance}) is less than test amount ({self.test_quantity})")
|
||||
if usdt_balance >= self.test_quantity:
|
||||
print(f"💡 Note: You have sufficient USDT ({usdt_balance}), but we need USDC for ETH/USDC trading")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error retrieving account info: {e}")
|
||||
logger.error(f"Account info error: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def test_ticker_data(self, mexc: MEXCInterface) -> float:
|
||||
"""Test ticker data retrieval"""
|
||||
try:
|
||||
ticker = mexc.get_ticker(self.test_symbol)
|
||||
if not ticker:
|
||||
print(f"❌ Failed to get ticker for {self.test_symbol}")
|
||||
return None
|
||||
|
||||
current_price = ticker['last']
|
||||
print(f"✅ Ticker data retrieved for {self.test_symbol}")
|
||||
print(f" - Last price: ${current_price:.2f}")
|
||||
print(f" - Bid: ${ticker.get('bid', 0):.2f}")
|
||||
print(f" - Ask: ${ticker.get('ask', 0):.2f}")
|
||||
print(f" - Volume: {ticker.get('volume', 0)}")
|
||||
|
||||
return current_price
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error retrieving ticker data: {e}")
|
||||
logger.error(f"Ticker data error: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def test_trading_executor_creation(self) -> TradingExecutor:
|
||||
"""Test trading executor creation"""
|
||||
try:
|
||||
executor = TradingExecutor()
|
||||
print(f"✅ Trading Executor created successfully")
|
||||
print(f" - Trading enabled: {executor.trading_enabled}")
|
||||
print(f" - Trading mode: {executor.trading_mode}")
|
||||
print(f" - Simulation mode: {executor.simulation_mode}")
|
||||
|
||||
return executor
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating trading executor: {e}")
|
||||
logger.error(f"Trading executor error: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def test_order_placement(self, executor: TradingExecutor, current_price: float) -> bool:
|
||||
"""Test order placement through executor"""
|
||||
try:
|
||||
print(f"Testing BUY signal execution...")
|
||||
|
||||
# Test BUY signal
|
||||
buy_success = executor.execute_signal(
|
||||
symbol=self.test_symbol,
|
||||
action='BUY',
|
||||
confidence=0.75,
|
||||
current_price=current_price
|
||||
)
|
||||
|
||||
print(f"✅ BUY signal execution: {'SUCCESS' if buy_success else 'FAILED'}")
|
||||
|
||||
if buy_success:
|
||||
# Check positions
|
||||
positions = executor.get_positions()
|
||||
if self.test_symbol in positions:
|
||||
position = positions[self.test_symbol]
|
||||
print(f" - Position created: {position.side} {position.quantity:.6f} @ ${position.entry_price:.2f}")
|
||||
|
||||
# Test SELL signal
|
||||
print(f"Testing SELL signal execution...")
|
||||
sell_success = executor.execute_signal(
|
||||
symbol=self.test_symbol,
|
||||
action='SELL',
|
||||
confidence=0.80,
|
||||
current_price=current_price * 1.001 # Simulate small price increase
|
||||
)
|
||||
|
||||
print(f"✅ SELL signal execution: {'SUCCESS' if sell_success else 'FAILED'}")
|
||||
|
||||
if sell_success:
|
||||
# Check trade history
|
||||
trades = executor.get_trade_history()
|
||||
if trades:
|
||||
last_trade = trades[-1]
|
||||
print(f" - Trade P&L: ${last_trade.pnl:.4f}")
|
||||
|
||||
return sell_success
|
||||
else:
|
||||
print("❌ No position found after BUY signal")
|
||||
return False
|
||||
|
||||
return buy_success
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing order placement: {e}")
|
||||
logger.error(f"Order placement error: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def test_order_parameters(self, mexc: MEXCInterface, current_price: float) -> bool:
|
||||
"""Test order parameters and validation"""
|
||||
try:
|
||||
print("Testing order parameter calculation...")
|
||||
|
||||
# Calculate test order size
|
||||
crypto_quantity = self.test_quantity / current_price
|
||||
print(f" - USD amount: ${self.test_quantity}")
|
||||
print(f" - Current price: ${current_price:.2f}")
|
||||
print(f" - Crypto quantity: {crypto_quantity:.6f} ETH")
|
||||
|
||||
# Test order parameters formatting
|
||||
mexc_symbol = self.test_symbol.replace('/', '')
|
||||
print(f" - MEXC symbol format: {mexc_symbol}")
|
||||
|
||||
order_params = {
|
||||
'symbol': mexc_symbol,
|
||||
'side': 'BUY',
|
||||
'type': 'MARKET',
|
||||
'quantity': str(crypto_quantity),
|
||||
'recvWindow': 5000
|
||||
}
|
||||
|
||||
print(f" - Order parameters: {order_params}")
|
||||
|
||||
# Test signature generation (without actually placing order)
|
||||
print("Testing signature generation...")
|
||||
test_params = order_params.copy()
|
||||
test_params['timestamp'] = int(time.time() * 1000)
|
||||
|
||||
try:
|
||||
signature = mexc._generate_signature(test_params)
|
||||
print(f"✅ Signature generated successfully (length: {len(signature)})")
|
||||
except Exception as e:
|
||||
print(f"❌ Signature generation failed: {e}")
|
||||
return False
|
||||
|
||||
print("✅ Order parameters validation successful")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing order parameters: {e}")
|
||||
logger.error(f"Order parameters error: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
try:
|
||||
debugger = MEXCOrderDebugger()
|
||||
success = debugger.run_comprehensive_test()
|
||||
|
||||
if success:
|
||||
print("\n🎉 MEXC order execution system is working correctly!")
|
||||
print("You can now safely execute live trades.")
|
||||
else:
|
||||
print("\n🚨 MEXC order execution has issues that need to be resolved.")
|
||||
print("Check the logs above for specific error details.")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main test: {e}", exc_info=True)
|
||||
print(f"\n❌ Critical error during testing: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,205 +0,0 @@
|
||||
"""
|
||||
Test MEXC Order Size Requirements
|
||||
|
||||
This script tests different order sizes to identify minimum order requirements
|
||||
and understand why order placement is failing.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import time
|
||||
|
||||
# Add paths for imports
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), 'NN'))
|
||||
|
||||
from NN.exchanges.mexc_interface import MEXCInterface
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("mexc_order_test")
|
||||
|
||||
def test_order_sizes():
|
||||
"""Test different order sizes to find minimum requirements"""
|
||||
print("="*60)
|
||||
print("MEXC ORDER SIZE REQUIREMENTS TEST")
|
||||
print("="*60)
|
||||
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
api_secret = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
if not api_key or not api_secret:
|
||||
print("❌ Missing API credentials")
|
||||
return False
|
||||
|
||||
mexc = MEXCInterface(api_key=api_key, api_secret=api_secret, test_mode=True)
|
||||
|
||||
# Get current ETH price
|
||||
ticker = mexc.get_ticker('ETH/USDT')
|
||||
if not ticker:
|
||||
print("❌ Failed to get ETH price")
|
||||
return False
|
||||
|
||||
current_price = ticker['last']
|
||||
print(f"Current ETH price: ${current_price:.2f}")
|
||||
|
||||
# Test different USD amounts
|
||||
test_amounts_usd = [0.1, 0.5, 1.0, 5.0, 10.0, 20.0]
|
||||
|
||||
print(f"\nTesting different order sizes...")
|
||||
print(f"{'USD Amount':<12} {'ETH Quantity':<15} {'Min ETH?':<10} {'Min USD?':<10}")
|
||||
print("-" * 50)
|
||||
|
||||
for usd_amount in test_amounts_usd:
|
||||
eth_quantity = usd_amount / current_price
|
||||
|
||||
# Check if quantity meets common minimum requirements
|
||||
min_eth_ok = eth_quantity >= 0.001 # 0.001 ETH common minimum
|
||||
min_usd_ok = usd_amount >= 5.0 # $5 common minimum
|
||||
|
||||
print(f"${usd_amount:<11.2f} {eth_quantity:<15.6f} {'✅' if min_eth_ok else '❌':<9} {'✅' if min_usd_ok else '❌':<9}")
|
||||
|
||||
# Test actual order parameter validation
|
||||
print(f"\nTesting order parameter validation...")
|
||||
|
||||
# Test small order (likely to fail)
|
||||
small_usd = 1.0
|
||||
small_eth = small_usd / current_price
|
||||
|
||||
print(f"\n1. Testing small order: ${small_usd} (${small_eth:.6f} ETH)")
|
||||
success = test_order_validation(mexc, 'ETHUSDT', 'BUY', 'MARKET', small_eth)
|
||||
|
||||
# Test medium order (might work)
|
||||
medium_usd = 10.0
|
||||
medium_eth = medium_usd / current_price
|
||||
|
||||
print(f"\n2. Testing medium order: ${medium_usd} (${medium_eth:.6f} ETH)")
|
||||
success = test_order_validation(mexc, 'ETHUSDT', 'BUY', 'MARKET', medium_eth)
|
||||
|
||||
# Test with rounded quantities
|
||||
print(f"\n3. Testing with rounded quantities...")
|
||||
|
||||
# Test 0.001 ETH (common minimum)
|
||||
print(f" Testing 0.001 ETH (${0.001 * current_price:.2f})")
|
||||
success = test_order_validation(mexc, 'ETHUSDT', 'BUY', 'MARKET', 0.001)
|
||||
|
||||
# Test 0.01 ETH
|
||||
print(f" Testing 0.01 ETH (${0.01 * current_price:.2f})")
|
||||
success = test_order_validation(mexc, 'ETHUSDT', 'BUY', 'MARKET', 0.01)
|
||||
|
||||
return True
|
||||
|
||||
def test_order_validation(mexc: MEXCInterface, symbol: str, side: str, order_type: str, quantity: float) -> bool:
|
||||
"""Test order parameter validation without actually placing the order"""
|
||||
try:
|
||||
# Prepare order parameters
|
||||
params = {
|
||||
'symbol': symbol,
|
||||
'side': side,
|
||||
'type': order_type,
|
||||
'quantity': str(quantity),
|
||||
'recvWindow': 5000,
|
||||
'timestamp': int(time.time() * 1000)
|
||||
}
|
||||
|
||||
# Generate signature
|
||||
signature = mexc._generate_signature(params)
|
||||
params['signature'] = signature
|
||||
|
||||
print(f" Params: {params}")
|
||||
|
||||
# Try to validate parameters by making the request but catching the specific error
|
||||
headers = {'X-MEXC-APIKEY': mexc.api_key}
|
||||
url = f"{mexc.base_url}/{mexc.api_version}/order"
|
||||
|
||||
import requests
|
||||
|
||||
# Make the request to see what specific error we get
|
||||
response = requests.post(url, params=params, headers=headers, timeout=30)
|
||||
|
||||
if response.status_code == 200:
|
||||
print(" ✅ Order would be accepted (parameters valid)")
|
||||
return True
|
||||
else:
|
||||
response_data = response.json() if response.headers.get('content-type', '').startswith('application/json') else {'msg': response.text}
|
||||
error_code = response_data.get('code', 'Unknown')
|
||||
error_msg = response_data.get('msg', 'Unknown error')
|
||||
|
||||
print(f" ❌ Error {error_code}: {error_msg}")
|
||||
|
||||
# Analyze specific error codes
|
||||
if error_code == 400001:
|
||||
print(" → Invalid parameter format")
|
||||
elif error_code == 700002:
|
||||
print(" → Invalid signature")
|
||||
elif error_code == 70016:
|
||||
print(" → Order size too small")
|
||||
elif error_code == 70015:
|
||||
print(" → Insufficient balance")
|
||||
elif 'LOT_SIZE' in error_msg:
|
||||
print(" → Lot size violation (quantity precision/minimum)")
|
||||
elif 'MIN_NOTIONAL' in error_msg:
|
||||
print(" → Minimum notional value not met")
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Exception: {e}")
|
||||
return False
|
||||
|
||||
def get_symbol_info():
|
||||
"""Get symbol trading rules and limits"""
|
||||
print("\nGetting symbol trading rules...")
|
||||
|
||||
try:
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
api_secret = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
mexc = MEXCInterface(api_key=api_key, api_secret=api_secret, test_mode=True)
|
||||
|
||||
# Try to get exchange info
|
||||
import requests
|
||||
|
||||
url = f"{mexc.base_url}/{mexc.api_version}/exchangeInfo"
|
||||
response = requests.get(url, timeout=30)
|
||||
|
||||
if response.status_code == 200:
|
||||
exchange_info = response.json()
|
||||
|
||||
# Find ETHUSDT symbol info
|
||||
for symbol_info in exchange_info.get('symbols', []):
|
||||
if symbol_info.get('symbol') == 'ETHUSDT':
|
||||
print(f"Found ETHUSDT trading rules:")
|
||||
print(f" Status: {symbol_info.get('status')}")
|
||||
print(f" Base asset: {symbol_info.get('baseAsset')}")
|
||||
print(f" Quote asset: {symbol_info.get('quoteAsset')}")
|
||||
|
||||
# Check filters
|
||||
for filter_info in symbol_info.get('filters', []):
|
||||
filter_type = filter_info.get('filterType')
|
||||
if filter_type == 'LOT_SIZE':
|
||||
print(f" Lot Size Filter:")
|
||||
print(f" Min Qty: {filter_info.get('minQty')}")
|
||||
print(f" Max Qty: {filter_info.get('maxQty')}")
|
||||
print(f" Step Size: {filter_info.get('stepSize')}")
|
||||
elif filter_type == 'MIN_NOTIONAL':
|
||||
print(f" Min Notional Filter:")
|
||||
print(f" Min Notional: {filter_info.get('minNotional')}")
|
||||
elif filter_type == 'PRICE_FILTER':
|
||||
print(f" Price Filter:")
|
||||
print(f" Min Price: {filter_info.get('minPrice')}")
|
||||
print(f" Max Price: {filter_info.get('maxPrice')}")
|
||||
print(f" Tick Size: {filter_info.get('tickSize')}")
|
||||
|
||||
break
|
||||
else:
|
||||
print("❌ ETHUSDT symbol not found in exchange info")
|
||||
else:
|
||||
print(f"❌ Failed to get exchange info: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error getting symbol info: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
get_symbol_info()
|
||||
test_order_sizes()
|
@ -1,71 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for MEXC public API endpoints
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.abspath('.'))
|
||||
|
||||
from NN.exchanges.mexc_interface import MEXCInterface
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_mexc_public_api():
|
||||
"""Test MEXC public API endpoints"""
|
||||
print("="*60)
|
||||
print("TESTING MEXC PUBLIC API")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Initialize MEXC interface without API keys (public access only)
|
||||
mexc = MEXCInterface()
|
||||
|
||||
print("\n1. Testing server connectivity...")
|
||||
try:
|
||||
# Test ping
|
||||
ping_result = mexc.ping()
|
||||
print(f"✅ Ping successful: {ping_result}")
|
||||
except Exception as e:
|
||||
print(f"❌ Ping failed: {e}")
|
||||
|
||||
print("\n2. Testing server time...")
|
||||
try:
|
||||
# Test server time
|
||||
time_result = mexc.get_server_time()
|
||||
print(f"✅ Server time: {time_result}")
|
||||
except Exception as e:
|
||||
print(f"❌ Server time failed: {e}")
|
||||
|
||||
print("\n3. Testing ticker data...")
|
||||
symbols_to_test = ['BTC/USDT', 'ETH/USDT']
|
||||
|
||||
for symbol in symbols_to_test:
|
||||
try:
|
||||
ticker = mexc.get_ticker(symbol)
|
||||
if ticker:
|
||||
print(f"✅ {symbol}: ${ticker['last']:.2f} (bid: ${ticker['bid']:.2f}, ask: ${ticker['ask']:.2f})")
|
||||
else:
|
||||
print(f"❌ {symbol}: No data returned")
|
||||
except Exception as e:
|
||||
print(f"❌ {symbol}: Error - {e}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("PUBLIC API TEST COMPLETED")
|
||||
print("="*60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error initializing MEXC interface: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mexc_public_api()
|
@ -1,321 +0,0 @@
|
||||
"""
|
||||
Test MEXC Signature Generation
|
||||
|
||||
This script tests the MEXC signature generation to ensure it's correct
|
||||
according to the MEXC API documentation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
import hmac
|
||||
from urllib.parse import urlencode
|
||||
import time
|
||||
import requests
|
||||
|
||||
# Add paths for imports
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), 'NN'))
|
||||
|
||||
from NN.exchanges.mexc_interface import MEXCInterface
|
||||
|
||||
def test_signature_generation():
|
||||
"""Test MEXC signature generation with known examples"""
|
||||
print("="*60)
|
||||
print("MEXC SIGNATURE GENERATION TEST")
|
||||
print("="*60)
|
||||
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
api_secret = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
if not api_key or not api_secret:
|
||||
print("❌ Missing API credentials")
|
||||
return False
|
||||
|
||||
print(f"API Key: {api_key[:8]}...{api_key[-4:]}")
|
||||
print(f"API Secret: {api_secret[:8]}...{api_secret[-4:]}")
|
||||
|
||||
# Test 1: Simple signature generation
|
||||
print("\n1. Testing basic signature generation...")
|
||||
|
||||
test_params = {
|
||||
'symbol': 'ETHUSDT',
|
||||
'side': 'BUY',
|
||||
'type': 'MARKET',
|
||||
'quantity': '0.001',
|
||||
'timestamp': 1640995200000,
|
||||
'recvWindow': 5000
|
||||
}
|
||||
|
||||
# Generate signature manually
|
||||
sorted_params = sorted(test_params.items())
|
||||
query_string = urlencode(sorted_params)
|
||||
expected_signature = hmac.new(
|
||||
api_secret.encode('utf-8'),
|
||||
query_string.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
print(f"Query string: {query_string}")
|
||||
print(f"Expected signature: {expected_signature}")
|
||||
|
||||
# Generate signature using MEXC interface
|
||||
mexc = MEXCInterface(api_key=api_key, api_secret=api_secret, test_mode=True)
|
||||
actual_signature = mexc._generate_signature(test_params)
|
||||
|
||||
print(f"Actual signature: {actual_signature}")
|
||||
|
||||
if expected_signature == actual_signature:
|
||||
print("✅ Signature generation matches expected")
|
||||
else:
|
||||
print("❌ Signature generation mismatch")
|
||||
return False
|
||||
|
||||
# Test 2: Real order parameters
|
||||
print("\n2. Testing with real order parameters...")
|
||||
|
||||
current_timestamp = int(time.time() * 1000)
|
||||
real_params = {
|
||||
'symbol': 'ETHUSDT',
|
||||
'side': 'BUY',
|
||||
'type': 'MARKET',
|
||||
'quantity': '0.001',
|
||||
'timestamp': current_timestamp,
|
||||
'recvWindow': 5000
|
||||
}
|
||||
|
||||
real_signature = mexc._generate_signature(real_params)
|
||||
sorted_real_params = sorted(real_params.items())
|
||||
real_query_string = urlencode(sorted_real_params)
|
||||
|
||||
print(f"Real timestamp: {current_timestamp}")
|
||||
print(f"Real query string: {real_query_string}")
|
||||
print(f"Real signature: {real_signature}")
|
||||
|
||||
# Test 3: Verify parameter ordering
|
||||
print("\n3. Testing parameter ordering sensitivity...")
|
||||
|
||||
# Test with parameters in different order
|
||||
unordered_params = {
|
||||
'timestamp': current_timestamp,
|
||||
'symbol': 'ETHUSDT',
|
||||
'recvWindow': 5000,
|
||||
'type': 'MARKET',
|
||||
'side': 'BUY',
|
||||
'quantity': '0.001'
|
||||
}
|
||||
|
||||
unordered_signature = mexc._generate_signature(unordered_params)
|
||||
|
||||
if real_signature == unordered_signature:
|
||||
print("✅ Parameter ordering handled correctly")
|
||||
else:
|
||||
print("❌ Parameter ordering issue")
|
||||
return False
|
||||
|
||||
# Test 4: Check for common issues
|
||||
print("\n4. Checking for common signature issues...")
|
||||
|
||||
# Check if any parameters need special encoding
|
||||
special_params = {
|
||||
'symbol': 'ETHUSDT',
|
||||
'side': 'BUY',
|
||||
'type': 'MARKET',
|
||||
'quantity': '0.0028417810009889397', # Full precision from error log
|
||||
'timestamp': current_timestamp,
|
||||
'recvWindow': 5000
|
||||
}
|
||||
|
||||
special_signature = mexc._generate_signature(special_params)
|
||||
special_sorted = sorted(special_params.items())
|
||||
special_query = urlencode(special_sorted)
|
||||
|
||||
print(f"Special quantity: {special_params['quantity']}")
|
||||
print(f"Special query: {special_query}")
|
||||
print(f"Special signature: {special_signature}")
|
||||
|
||||
# Test 5: Compare with error log signature
|
||||
print("\n5. Comparing with error log...")
|
||||
|
||||
# From the error log, we have this signature:
|
||||
error_log_signature = "2a52436039e24b593ab0ab20ac1a67e2323654dc14190ee2c2cde341930d27d4"
|
||||
error_timestamp = 1748349875981
|
||||
|
||||
error_params = {
|
||||
'symbol': 'ETHUSDT',
|
||||
'side': 'BUY',
|
||||
'type': 'MARKET',
|
||||
'quantity': '0.0028417810009889397',
|
||||
'recvWindow': 5000,
|
||||
'timestamp': error_timestamp
|
||||
}
|
||||
|
||||
recreated_signature = mexc._generate_signature(error_params)
|
||||
|
||||
print(f"Error log signature: {error_log_signature}")
|
||||
print(f"Recreated signature: {recreated_signature}")
|
||||
|
||||
if error_log_signature == recreated_signature:
|
||||
print("✅ Signature recreation matches error log")
|
||||
else:
|
||||
print("❌ Signature recreation doesn't match - potential algorithm issue")
|
||||
|
||||
# Debug the query string
|
||||
debug_sorted = sorted(error_params.items())
|
||||
debug_query = urlencode(debug_sorted)
|
||||
print(f"Debug query string: {debug_query}")
|
||||
|
||||
# Try manual HMAC
|
||||
manual_sig = hmac.new(
|
||||
api_secret.encode('utf-8'),
|
||||
debug_query.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
print(f"Manual signature: {manual_sig}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("SIGNATURE TEST COMPLETED")
|
||||
print("="*60)
|
||||
|
||||
return True
|
||||
|
||||
def test_mexc_api_call():
|
||||
"""Test a simple authenticated API call to verify signature works"""
|
||||
print("\n6. Testing authenticated API call...")
|
||||
|
||||
try:
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
api_secret = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
mexc = MEXCInterface(api_key=api_key, api_secret=api_secret, test_mode=True)
|
||||
|
||||
# Test account info (this should work if signature is correct)
|
||||
account_info = mexc.get_account_info()
|
||||
print("✅ Account info call successful - signature is working")
|
||||
print(f" Account type: {account_info.get('accountType', 'Unknown')}")
|
||||
print(f" Can trade: {account_info.get('canTrade', 'Unknown')}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Account info call failed: {e}")
|
||||
return False
|
||||
|
||||
# Test exact signature generation for MEXC order placement
|
||||
def test_mexc_order_signature():
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
api_secret = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
if not api_key or not api_secret:
|
||||
print("❌ MEXC API keys not found")
|
||||
return
|
||||
|
||||
print("=== MEXC ORDER SIGNATURE DEBUG ===")
|
||||
print(f"API Key: {api_key[:8]}...{api_key[-4:]}")
|
||||
print(f"Secret Key: {api_secret[:8]}...{api_secret[-4:]}")
|
||||
print()
|
||||
|
||||
# Get server time first
|
||||
try:
|
||||
time_resp = requests.get('https://api.mexc.com/api/v3/time')
|
||||
server_time = time_resp.json()['serverTime']
|
||||
print(f"✅ Server time: {server_time}")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to get server time: {e}")
|
||||
server_time = int(time.time() * 1000)
|
||||
print(f"Using local time: {server_time}")
|
||||
|
||||
# Test order parameters (from MEXC documentation example)
|
||||
params = {
|
||||
'symbol': 'MXUSDT', # Changed to API-supported symbol
|
||||
'side': 'BUY',
|
||||
'type': 'MARKET',
|
||||
'quantity': '1', # Small test quantity (1 MX token)
|
||||
'timestamp': server_time
|
||||
}
|
||||
|
||||
print("\n=== Testing Different Signature Methods ===")
|
||||
|
||||
# Method 1: Sorted parameters with & separator (current approach)
|
||||
print("\n1. Current approach (sorted with &):")
|
||||
sorted_params = sorted(params.items())
|
||||
query_string1 = '&'.join([f"{key}={value}" for key, value in sorted_params])
|
||||
signature1 = hmac.new(
|
||||
api_secret.encode('utf-8'),
|
||||
query_string1.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
print(f"Query: {query_string1}")
|
||||
print(f"Signature: {signature1}")
|
||||
|
||||
# Method 2: URL encoded (like account info that works)
|
||||
print("\n2. URL encoded approach:")
|
||||
sorted_params = sorted(params.items())
|
||||
query_string2 = urlencode(sorted_params)
|
||||
signature2 = hmac.new(
|
||||
api_secret.encode('utf-8'),
|
||||
query_string2.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
print(f"Query: {query_string2}")
|
||||
print(f"Signature: {signature2}")
|
||||
|
||||
# Method 3: MEXC documentation example format
|
||||
print("\n3. MEXC docs example format:")
|
||||
# From MEXC docs: symbol=BTCUSDT&side=BUY&type=LIMIT&quantity=1&price=11&recvWindow=5000×tamp=1644489390087
|
||||
query_string3 = f"symbol={params['symbol']}&side={params['side']}&type={params['type']}&quantity={params['quantity']}×tamp={params['timestamp']}"
|
||||
signature3 = hmac.new(
|
||||
api_secret.encode('utf-8'),
|
||||
query_string3.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
print(f"Query: {query_string3}")
|
||||
print(f"Signature: {signature3}")
|
||||
|
||||
# Test all methods by making actual requests
|
||||
print("\n=== Testing Actual Requests ===")
|
||||
|
||||
methods = [
|
||||
("Current approach", signature1, params),
|
||||
("URL encoded", signature2, params),
|
||||
("MEXC docs format", signature3, params)
|
||||
]
|
||||
|
||||
for method_name, signature, test_params in methods:
|
||||
print(f"\n{method_name}:")
|
||||
test_params_copy = test_params.copy()
|
||||
test_params_copy['signature'] = signature
|
||||
|
||||
headers = {'X-MEXC-APIKEY': api_key}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
'https://api.mexc.com/api/v3/order',
|
||||
params=test_params_copy,
|
||||
headers=headers,
|
||||
timeout=10
|
||||
)
|
||||
print(f"Status: {response.status_code}")
|
||||
print(f"Response: {response.text[:200]}...")
|
||||
|
||||
if response.status_code == 200:
|
||||
print("✅ SUCCESS!")
|
||||
break
|
||||
elif "Signature for this request is not valid" in response.text:
|
||||
print("❌ Invalid signature")
|
||||
else:
|
||||
print(f"❌ Other error: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Request failed: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_signature_generation()
|
||||
|
||||
if success:
|
||||
success = test_mexc_api_call()
|
||||
|
||||
if success:
|
||||
test_mexc_order_signature()
|
||||
print("\n🎉 All signature tests passed!")
|
||||
else:
|
||||
print("\n🚨 Signature tests failed - check the output above")
|
@ -1,185 +0,0 @@
|
||||
"""
|
||||
MEXC Timestamp and Signature Debug
|
||||
|
||||
This script tests different timestamp and recvWindow combinations to fix the signature validation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import hashlib
|
||||
import hmac
|
||||
from urllib.parse import urlencode
|
||||
import requests
|
||||
|
||||
# Add paths for imports
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), 'NN'))
|
||||
|
||||
def test_mexc_timestamp_debug():
|
||||
"""Test different timestamp strategies"""
|
||||
print("="*60)
|
||||
print("MEXC TIMESTAMP AND SIGNATURE DEBUG")
|
||||
print("="*60)
|
||||
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
api_secret = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
if not api_key or not api_secret:
|
||||
print("❌ Missing API credentials")
|
||||
return False
|
||||
|
||||
base_url = "https://api.mexc.com"
|
||||
api_version = "api/v3"
|
||||
|
||||
# Test 1: Get server time directly
|
||||
print("1. Getting server time...")
|
||||
|
||||
try:
|
||||
response = requests.get(f"{base_url}/{api_version}/time", timeout=10)
|
||||
if response.status_code == 200:
|
||||
server_time_data = response.json()
|
||||
server_time = server_time_data['serverTime']
|
||||
local_time = int(time.time() * 1000)
|
||||
time_diff = server_time - local_time
|
||||
|
||||
print(f" Server time: {server_time}")
|
||||
print(f" Local time: {local_time}")
|
||||
print(f" Difference: {time_diff}ms")
|
||||
|
||||
else:
|
||||
print(f" ❌ Failed to get server time: {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f" ❌ Error getting server time: {e}")
|
||||
return False
|
||||
|
||||
# Test 2: Try different timestamp strategies
|
||||
strategies = [
|
||||
("Server time exactly", server_time),
|
||||
("Server time - 500ms", server_time - 500),
|
||||
("Server time - 1000ms", server_time - 1000),
|
||||
("Server time - 2000ms", server_time - 2000),
|
||||
("Local time", local_time),
|
||||
("Local time - 1000ms", local_time - 1000),
|
||||
]
|
||||
|
||||
# Test with different recvWindow values
|
||||
recv_windows = [5000, 10000, 30000, 60000]
|
||||
|
||||
print(f"\n2. Testing different timestamp strategies and recvWindow values...")
|
||||
|
||||
for strategy_name, timestamp in strategies:
|
||||
print(f"\n Strategy: {strategy_name} (timestamp: {timestamp})")
|
||||
|
||||
for recv_window in recv_windows:
|
||||
print(f" Testing recvWindow: {recv_window}ms")
|
||||
|
||||
# Test account info request
|
||||
params = {
|
||||
'timestamp': timestamp,
|
||||
'recvWindow': recv_window
|
||||
}
|
||||
|
||||
# Generate signature
|
||||
sorted_params = sorted(params.items())
|
||||
query_string = urlencode(sorted_params)
|
||||
signature = hmac.new(
|
||||
api_secret.encode('utf-8'),
|
||||
query_string.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
params['signature'] = signature
|
||||
|
||||
# Make request
|
||||
headers = {'X-MEXC-APIKEY': api_key}
|
||||
url = f"{base_url}/{api_version}/account"
|
||||
|
||||
try:
|
||||
response = requests.get(url, params=params, headers=headers, timeout=10)
|
||||
|
||||
if response.status_code == 200:
|
||||
print(f" ✅ SUCCESS")
|
||||
account_data = response.json()
|
||||
print(f" Account type: {account_data.get('accountType', 'Unknown')}")
|
||||
return True # Found working combination
|
||||
else:
|
||||
error_data = response.json() if 'application/json' in response.headers.get('content-type', '') else {'msg': response.text}
|
||||
error_code = error_data.get('code', 'Unknown')
|
||||
error_msg = error_data.get('msg', 'Unknown')
|
||||
print(f" ❌ Error {error_code}: {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Exception: {e}")
|
||||
|
||||
print(f"\n❌ No working timestamp/recvWindow combination found")
|
||||
return False
|
||||
|
||||
def test_minimal_signature():
|
||||
"""Test with minimal parameters to isolate signature issues"""
|
||||
print(f"\n3. Testing minimal signature generation...")
|
||||
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
api_secret = os.getenv('MEXC_SECRET_KEY')
|
||||
|
||||
base_url = "https://api.mexc.com"
|
||||
api_version = "api/v3"
|
||||
|
||||
# Get fresh server time
|
||||
try:
|
||||
response = requests.get(f"{base_url}/{api_version}/time", timeout=10)
|
||||
server_time = response.json()['serverTime']
|
||||
print(f" Fresh server time: {server_time}")
|
||||
except Exception as e:
|
||||
print(f" ❌ Failed to get server time: {e}")
|
||||
return False
|
||||
|
||||
# Test with absolute minimal parameters
|
||||
minimal_params = {
|
||||
'timestamp': server_time
|
||||
}
|
||||
|
||||
# Generate signature with minimal params
|
||||
sorted_params = sorted(minimal_params.items())
|
||||
query_string = urlencode(sorted_params)
|
||||
signature = hmac.new(
|
||||
api_secret.encode('utf-8'),
|
||||
query_string.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
minimal_params['signature'] = signature
|
||||
|
||||
print(f" Minimal params: {minimal_params}")
|
||||
print(f" Query string: {query_string}")
|
||||
print(f" Signature: {signature}")
|
||||
|
||||
# Test account request with minimal params
|
||||
headers = {'X-MEXC-APIKEY': api_key}
|
||||
url = f"{base_url}/{api_version}/account"
|
||||
|
||||
try:
|
||||
response = requests.get(url, params=minimal_params, headers=headers, timeout=10)
|
||||
|
||||
if response.status_code == 200:
|
||||
print(f" ✅ Minimal signature works!")
|
||||
return True
|
||||
else:
|
||||
error_data = response.json() if 'application/json' in response.headers.get('content-type', '') else {'msg': response.text}
|
||||
print(f" ❌ Minimal signature failed: {error_data.get('code', 'Unknown')} - {error_data.get('msg', 'Unknown')}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Exception with minimal signature: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_mexc_timestamp_debug()
|
||||
|
||||
if not success:
|
||||
success = test_minimal_signature()
|
||||
|
||||
if success:
|
||||
print(f"\n🎉 Found working MEXC configuration!")
|
||||
else:
|
||||
print(f"\n🚨 MEXC signature/timestamp issue persists")
|
@ -1,384 +0,0 @@
|
||||
"""
|
||||
Test MEXC Trading Integration
|
||||
|
||||
This script tests the integration between the enhanced orchestrator and MEXC trading executor.
|
||||
It verifies that trading signals can be executed through the MEXC API with proper risk management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# Add core directory to path
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), 'core'))
|
||||
|
||||
from core.trading_executor import TradingExecutor, Position, TradeRecord
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
from core.config import get_config
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler("test_mexc_trading.log"),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger("mexc_trading_test")
|
||||
|
||||
class TradingIntegrationTest:
|
||||
"""Test class for MEXC trading integration"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the test environment"""
|
||||
self.config = get_config()
|
||||
self.data_provider = DataProvider()
|
||||
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
||||
self.trading_executor = TradingExecutor()
|
||||
|
||||
# Test configuration
|
||||
self.test_symbol = 'ETH/USDT'
|
||||
self.test_confidence = 0.75
|
||||
|
||||
def test_trading_executor_initialization(self):
|
||||
"""Test that the trading executor initializes correctly"""
|
||||
logger.info("Testing trading executor initialization...")
|
||||
|
||||
try:
|
||||
# Check configuration
|
||||
assert self.trading_executor.mexc_config is not None
|
||||
logger.info("✅ MEXC configuration loaded")
|
||||
|
||||
# Check dry run mode
|
||||
assert self.trading_executor.dry_run == True
|
||||
logger.info("✅ Dry run mode enabled for safety")
|
||||
|
||||
# Check position limits
|
||||
max_position_value = self.trading_executor.mexc_config.get('max_position_value_usd', 1.0)
|
||||
assert max_position_value == 1.0
|
||||
logger.info(f"✅ Max position value set to ${max_position_value}")
|
||||
|
||||
# Check safety features
|
||||
assert self.trading_executor.mexc_config.get('emergency_stop', False) == False
|
||||
logger.info("✅ Emergency stop not active")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Trading executor initialization test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_exchange_connection(self):
|
||||
"""Test connection to MEXC exchange"""
|
||||
logger.info("Testing MEXC exchange connection...")
|
||||
|
||||
try:
|
||||
# Test ticker retrieval
|
||||
ticker = self.trading_executor.exchange.get_ticker(self.test_symbol)
|
||||
|
||||
if ticker:
|
||||
logger.info(f"✅ Successfully retrieved ticker for {self.test_symbol}")
|
||||
logger.info(f" Current price: ${ticker['last']:.2f}")
|
||||
logger.info(f" Bid: ${ticker['bid']:.2f}, Ask: ${ticker['ask']:.2f}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"❌ Failed to retrieve ticker for {self.test_symbol}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Exchange connection test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_position_size_calculation(self):
|
||||
"""Test position size calculation with different confidence levels"""
|
||||
logger.info("Testing position size calculation...")
|
||||
|
||||
try:
|
||||
test_price = 2500.0
|
||||
test_cases = [
|
||||
(0.5, "Medium confidence"),
|
||||
(0.75, "High confidence"),
|
||||
(0.9, "Very high confidence"),
|
||||
(0.3, "Low confidence")
|
||||
]
|
||||
|
||||
for confidence, description in test_cases:
|
||||
position_value = self.trading_executor._calculate_position_size(confidence, test_price)
|
||||
quantity = position_value / test_price
|
||||
|
||||
logger.info(f" {description} ({confidence:.1f}): ${position_value:.2f} = {quantity:.6f} ETH")
|
||||
|
||||
# Verify position value is within limits
|
||||
max_value = self.trading_executor.mexc_config.get('max_position_value_usd', 1.0)
|
||||
min_value = self.trading_executor.mexc_config.get('min_position_value_usd', 0.1)
|
||||
|
||||
assert min_value <= position_value <= max_value
|
||||
|
||||
logger.info("✅ Position size calculation working correctly")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Position size calculation test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_dry_run_trading(self):
|
||||
"""Test dry run trading execution"""
|
||||
logger.info("Testing dry run trading execution...")
|
||||
|
||||
try:
|
||||
# Get current price
|
||||
ticker = self.trading_executor.exchange.get_ticker(self.test_symbol)
|
||||
if not ticker:
|
||||
logger.error("Cannot get current price for testing")
|
||||
return False
|
||||
|
||||
current_price = ticker['last']
|
||||
|
||||
# Test BUY signal
|
||||
logger.info(f"Testing BUY signal for {self.test_symbol} at ${current_price:.2f}")
|
||||
buy_success = self.trading_executor.execute_signal(
|
||||
symbol=self.test_symbol,
|
||||
action='BUY',
|
||||
confidence=self.test_confidence,
|
||||
current_price=current_price
|
||||
)
|
||||
|
||||
if buy_success:
|
||||
logger.info("✅ BUY signal executed successfully in dry run mode")
|
||||
|
||||
# Check position was created
|
||||
positions = self.trading_executor.get_positions()
|
||||
assert self.test_symbol in positions
|
||||
position = positions[self.test_symbol]
|
||||
logger.info(f" Position created: {position.side} {position.quantity:.6f} @ ${position.entry_price:.2f}")
|
||||
else:
|
||||
logger.error("❌ BUY signal execution failed")
|
||||
return False
|
||||
|
||||
# Wait a moment
|
||||
time.sleep(1)
|
||||
|
||||
# Test SELL signal
|
||||
logger.info(f"Testing SELL signal for {self.test_symbol}")
|
||||
sell_success = self.trading_executor.execute_signal(
|
||||
symbol=self.test_symbol,
|
||||
action='SELL',
|
||||
confidence=self.test_confidence,
|
||||
current_price=current_price * 1.01 # Simulate 1% price increase
|
||||
)
|
||||
|
||||
if sell_success:
|
||||
logger.info("✅ SELL signal executed successfully in dry run mode")
|
||||
|
||||
# Check position was closed
|
||||
positions = self.trading_executor.get_positions()
|
||||
assert self.test_symbol not in positions
|
||||
|
||||
# Check trade history
|
||||
trade_history = self.trading_executor.get_trade_history()
|
||||
assert len(trade_history) > 0
|
||||
|
||||
last_trade = trade_history[-1]
|
||||
logger.info(f" Trade completed: P&L ${last_trade.pnl:.2f}")
|
||||
else:
|
||||
logger.error("❌ SELL signal execution failed")
|
||||
return False
|
||||
|
||||
logger.info("✅ Dry run trading test completed successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Dry run trading test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_safety_conditions(self):
|
||||
"""Test safety condition checks"""
|
||||
logger.info("Testing safety condition checks...")
|
||||
|
||||
try:
|
||||
# Test symbol allowlist
|
||||
disallowed_symbol = 'DOGE/USDT'
|
||||
result = self.trading_executor._check_safety_conditions(disallowed_symbol, 'BUY')
|
||||
if disallowed_symbol not in self.trading_executor.mexc_config.get('allowed_symbols', []):
|
||||
assert result == False
|
||||
logger.info("✅ Symbol allowlist check working")
|
||||
|
||||
# Test trade interval
|
||||
# First trade should succeed
|
||||
current_price = 2500.0
|
||||
self.trading_executor.execute_signal(self.test_symbol, 'BUY', 0.7, current_price)
|
||||
|
||||
# Immediate second trade should fail due to interval
|
||||
result = self.trading_executor._check_safety_conditions(self.test_symbol, 'BUY')
|
||||
# Note: This might pass if interval is very short, which is fine for testing
|
||||
|
||||
logger.info("✅ Safety condition checks working")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Safety condition test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_daily_statistics(self):
|
||||
"""Test daily statistics tracking"""
|
||||
logger.info("Testing daily statistics tracking...")
|
||||
|
||||
try:
|
||||
stats = self.trading_executor.get_daily_stats()
|
||||
|
||||
required_keys = ['daily_trades', 'daily_loss', 'total_pnl', 'winning_trades',
|
||||
'losing_trades', 'win_rate', 'positions_count']
|
||||
|
||||
for key in required_keys:
|
||||
assert key in stats
|
||||
|
||||
logger.info("✅ Daily statistics structure correct")
|
||||
logger.info(f" Daily trades: {stats['daily_trades']}")
|
||||
logger.info(f" Total P&L: ${stats['total_pnl']:.2f}")
|
||||
logger.info(f" Win rate: {stats['win_rate']:.1%}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Daily statistics test failed: {e}")
|
||||
return False
|
||||
|
||||
async def test_orchestrator_integration(self):
|
||||
"""Test integration with enhanced orchestrator"""
|
||||
logger.info("Testing orchestrator integration...")
|
||||
|
||||
try:
|
||||
# Test that orchestrator can make decisions
|
||||
decisions = await self.orchestrator.make_coordinated_decisions()
|
||||
|
||||
logger.info(f"✅ Orchestrator made decisions for {len(decisions)} symbols")
|
||||
|
||||
for symbol, decision in decisions.items():
|
||||
if decision:
|
||||
logger.info(f" {symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
|
||||
|
||||
# Test executing the decision through trading executor
|
||||
if decision.action != 'HOLD':
|
||||
success = self.trading_executor.execute_signal(
|
||||
symbol=symbol,
|
||||
action=decision.action,
|
||||
confidence=decision.confidence,
|
||||
current_price=decision.price
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f" ✅ Successfully executed {decision.action} for {symbol}")
|
||||
else:
|
||||
logger.info(f" ⚠️ Trade execution blocked by safety conditions for {symbol}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Orchestrator integration test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_emergency_stop(self):
|
||||
"""Test emergency stop functionality"""
|
||||
logger.info("Testing emergency stop functionality...")
|
||||
|
||||
try:
|
||||
# Create a test position first
|
||||
current_price = 2500.0
|
||||
self.trading_executor.execute_signal(self.test_symbol, 'BUY', 0.7, current_price)
|
||||
|
||||
# Verify position exists
|
||||
positions_before = self.trading_executor.get_positions()
|
||||
logger.info(f" Positions before emergency stop: {len(positions_before)}")
|
||||
|
||||
# Trigger emergency stop
|
||||
self.trading_executor.emergency_stop()
|
||||
|
||||
# Verify trading is disabled
|
||||
assert self.trading_executor.trading_enabled == False
|
||||
logger.info("✅ Trading disabled after emergency stop")
|
||||
|
||||
# In dry run mode, positions should still be closed
|
||||
positions_after = self.trading_executor.get_positions()
|
||||
logger.info(f" Positions after emergency stop: {len(positions_after)}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Emergency stop test failed: {e}")
|
||||
return False
|
||||
|
||||
async def run_all_tests(self):
|
||||
"""Run all integration tests"""
|
||||
logger.info("🚀 Starting MEXC Trading Integration Tests")
|
||||
logger.info("=" * 60)
|
||||
|
||||
tests = [
|
||||
("Trading Executor Initialization", self.test_trading_executor_initialization),
|
||||
("Exchange Connection", self.test_exchange_connection),
|
||||
("Position Size Calculation", self.test_position_size_calculation),
|
||||
("Dry Run Trading", self.test_dry_run_trading),
|
||||
("Safety Conditions", self.test_safety_conditions),
|
||||
("Daily Statistics", self.test_daily_statistics),
|
||||
("Orchestrator Integration", self.test_orchestrator_integration),
|
||||
("Emergency Stop", self.test_emergency_stop),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test_name, test_func in tests:
|
||||
logger.info(f"\n📋 Running test: {test_name}")
|
||||
logger.info("-" * 40)
|
||||
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(test_func):
|
||||
result = await test_func()
|
||||
else:
|
||||
result = test_func()
|
||||
|
||||
if result:
|
||||
passed += 1
|
||||
logger.info(f"✅ {test_name} PASSED")
|
||||
else:
|
||||
failed += 1
|
||||
logger.error(f"❌ {test_name} FAILED")
|
||||
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.error(f"❌ {test_name} FAILED with exception: {e}")
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("🏁 Test Results Summary")
|
||||
logger.info(f"✅ Passed: {passed}")
|
||||
logger.info(f"❌ Failed: {failed}")
|
||||
logger.info(f"📊 Success Rate: {passed/(passed+failed)*100:.1f}%")
|
||||
|
||||
if failed == 0:
|
||||
logger.info("🎉 All tests passed! MEXC trading integration is ready.")
|
||||
else:
|
||||
logger.warning(f"⚠️ {failed} test(s) failed. Please review and fix issues before live trading.")
|
||||
|
||||
return failed == 0
|
||||
|
||||
async def main():
|
||||
"""Main test function"""
|
||||
test_runner = TradingIntegrationTest()
|
||||
success = await test_runner.run_all_tests()
|
||||
|
||||
if success:
|
||||
logger.info("\n🔧 Next Steps:")
|
||||
logger.info("1. Set up your MEXC API keys in .env file")
|
||||
logger.info("2. Update config.yaml to enable trading (mexc_trading.enabled: true)")
|
||||
logger.info("3. Consider disabling dry_run_mode for live trading")
|
||||
logger.info("4. Start with small position sizes for initial live testing")
|
||||
logger.info("5. Monitor the system closely during initial live trading")
|
||||
|
||||
return success
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@ -1,109 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Minimal Dashboard Test - Debug startup issues
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_imports():
|
||||
"""Test all required imports"""
|
||||
try:
|
||||
logger.info("Testing imports...")
|
||||
|
||||
# Core imports
|
||||
from core.config import get_config
|
||||
logger.info("✓ core.config imported")
|
||||
|
||||
from core.data_provider import DataProvider
|
||||
logger.info("✓ core.data_provider imported")
|
||||
|
||||
# Dashboard imports
|
||||
import dash
|
||||
from dash import dcc, html
|
||||
import plotly.graph_objects as go
|
||||
logger.info("✓ Dash imports successful")
|
||||
|
||||
# Try to import the dashboard
|
||||
from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard
|
||||
logger.info("✓ RealTimeScalpingDashboard imported")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Import error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_dashboard_creation():
|
||||
"""Test dashboard creation"""
|
||||
try:
|
||||
logger.info("Testing dashboard creation...")
|
||||
|
||||
from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Create data provider
|
||||
data_provider = DataProvider()
|
||||
logger.info("✓ DataProvider created")
|
||||
|
||||
# Create dashboard
|
||||
dashboard = RealTimeScalpingDashboard(data_provider=data_provider)
|
||||
logger.info("✓ Dashboard created successfully")
|
||||
|
||||
return dashboard
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dashboard creation error: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def test_dashboard_run():
|
||||
"""Test dashboard run"""
|
||||
try:
|
||||
logger.info("Testing dashboard run...")
|
||||
|
||||
dashboard = test_dashboard_creation()
|
||||
if not dashboard:
|
||||
return False
|
||||
|
||||
# Try to run dashboard
|
||||
logger.info("Starting dashboard on port 8052...")
|
||||
dashboard.run(host='127.0.0.1', port=8052, debug=True)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dashboard run error: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
logger.info("=== MINIMAL DASHBOARD TEST ===")
|
||||
|
||||
# Test 1: Imports
|
||||
if not test_imports():
|
||||
logger.error("Import test failed!")
|
||||
sys.exit(1)
|
||||
|
||||
# Test 2: Dashboard creation
|
||||
dashboard = test_dashboard_creation()
|
||||
if not dashboard:
|
||||
logger.error("Dashboard creation test failed!")
|
||||
sys.exit(1)
|
||||
|
||||
# Test 3: Dashboard run
|
||||
logger.info("All tests passed! Starting dashboard...")
|
||||
test_dashboard_run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,127 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Minimal Trading Test
|
||||
Test basic trading functionality with simplified decision logic
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def test_minimal_trading():
|
||||
"""Test minimal trading with lowered thresholds"""
|
||||
logger.info("=== MINIMAL TRADING TEST ===")
|
||||
|
||||
try:
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.trading_executor import TradingExecutor
|
||||
|
||||
# Initialize with minimal components
|
||||
config = get_config()
|
||||
data_provider = DataProvider()
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
logger.info("✅ Basic components initialized")
|
||||
|
||||
# Test data availability
|
||||
symbol = 'ETH/USDT'
|
||||
data = data_provider.get_historical_data(symbol, '1m', limit=20)
|
||||
|
||||
if data is None or data.empty:
|
||||
logger.error("❌ No data available for minimal test")
|
||||
return
|
||||
|
||||
current_price = float(data['close'].iloc[-1])
|
||||
logger.info(f"✅ Current {symbol} price: ${current_price:.2f}")
|
||||
|
||||
# Generate simple trading signal
|
||||
price_change = data['close'].pct_change().iloc[-5:].mean()
|
||||
|
||||
# Simple momentum signal
|
||||
if price_change > 0.001: # 0.1% positive momentum
|
||||
action = 'BUY'
|
||||
confidence = 0.6 # Above 35% threshold
|
||||
reason = f"Positive momentum: {price_change:.1%}"
|
||||
elif price_change < -0.001: # 0.1% negative momentum
|
||||
action = 'SELL'
|
||||
confidence = 0.6 # Above 35% threshold
|
||||
reason = f"Negative momentum: {price_change:.1%}"
|
||||
else:
|
||||
action = 'HOLD'
|
||||
confidence = 0.3
|
||||
reason = "Neutral momentum"
|
||||
|
||||
logger.info(f"📈 Signal: {action} with {confidence:.1%} confidence - {reason}")
|
||||
|
||||
# Test if we would execute this trade
|
||||
if confidence > 0.35: # Our new threshold
|
||||
logger.info("✅ Signal WOULD trigger trade execution")
|
||||
|
||||
# Simulate position sizing
|
||||
position_size = 0.01 # 0.01 ETH
|
||||
estimated_value = position_size * current_price
|
||||
|
||||
logger.info(f"📊 Would trade {position_size} ETH (~${estimated_value:.2f})")
|
||||
|
||||
# Test trading executor (simulation mode)
|
||||
if hasattr(trading_executor, 'simulation_mode'):
|
||||
trading_executor.simulation_mode = True
|
||||
|
||||
logger.info("🎯 Trading signal meets threshold - system operational")
|
||||
|
||||
else:
|
||||
logger.warning(f"❌ Signal below threshold ({confidence:.1%} < 35%)")
|
||||
|
||||
# Test multiple timeframes
|
||||
logger.info("\n=== MULTI-TIMEFRAME TEST ===")
|
||||
timeframes = ['1m', '5m', '1h']
|
||||
signals = []
|
||||
|
||||
for tf in timeframes:
|
||||
try:
|
||||
tf_data = data_provider.get_historical_data(symbol, tf, limit=10)
|
||||
if tf_data is not None and not tf_data.empty:
|
||||
tf_change = tf_data['close'].pct_change().iloc[-3:].mean()
|
||||
tf_confidence = min(0.8, abs(tf_change) * 100)
|
||||
|
||||
signals.append({
|
||||
'timeframe': tf,
|
||||
'change': tf_change,
|
||||
'confidence': tf_confidence
|
||||
})
|
||||
|
||||
logger.info(f" {tf}: {tf_change:.2%} change, {tf_confidence:.1%} confidence")
|
||||
except Exception as e:
|
||||
logger.warning(f" {tf}: Error - {e}")
|
||||
|
||||
# Combined signal
|
||||
if signals:
|
||||
avg_confidence = np.mean([s['confidence'] for s in signals])
|
||||
logger.info(f"📊 Average multi-timeframe confidence: {avg_confidence:.1%}")
|
||||
|
||||
if avg_confidence > 0.35:
|
||||
logger.info("✅ Multi-timeframe signal would trigger trade")
|
||||
else:
|
||||
logger.warning("❌ Multi-timeframe signal below threshold")
|
||||
|
||||
logger.info("\n=== RECOMMENDATIONS ===")
|
||||
logger.info("1. ✅ Data flow is working correctly")
|
||||
logger.info("2. ✅ Price data is fresh and accurate")
|
||||
logger.info("3. ✅ Confidence thresholds are now more reasonable (35%)")
|
||||
logger.info("4. ⚠️ Complex cross-asset logic has bugs - use simple momentum")
|
||||
logger.info("5. 🎯 System can generate trading signals - test with real orchestrator")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Minimal trading test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_minimal_trading())
|
@ -1,274 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Comprehensive test suite for model persistence and training functionality
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
import tempfile
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from utils.model_utils import robust_save, robust_load, get_model_info, verify_save_load_cycle
|
||||
|
||||
# Configure logging for tests
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MockAgent:
|
||||
"""Mock agent class for testing model persistence"""
|
||||
|
||||
def __init__(self, state_size=64, action_size=4, hidden_size=256):
|
||||
self.state_size = state_size
|
||||
self.action_size = action_size
|
||||
self.hidden_size = hidden_size
|
||||
self.epsilon = 0.1
|
||||
|
||||
# Create simple mock networks
|
||||
self.policy_net = torch.nn.Sequential(
|
||||
torch.nn.Linear(state_size, hidden_size),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(hidden_size, action_size)
|
||||
)
|
||||
|
||||
self.target_net = torch.nn.Sequential(
|
||||
torch.nn.Linear(state_size, hidden_size),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(hidden_size, action_size)
|
||||
)
|
||||
|
||||
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=0.001)
|
||||
|
||||
class TestModelPersistence(unittest.TestCase):
|
||||
"""Test suite for model saving and loading functionality"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.test_agent = MockAgent()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures"""
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_robust_save_basic(self):
|
||||
"""Test basic robust save functionality"""
|
||||
save_path = os.path.join(self.temp_dir, "test_model.pt")
|
||||
|
||||
success = robust_save(self.test_agent, save_path)
|
||||
self.assertTrue(success, "Robust save should succeed")
|
||||
self.assertTrue(os.path.exists(save_path), "Model file should exist")
|
||||
self.assertGreater(os.path.getsize(save_path), 0, "Model file should not be empty")
|
||||
|
||||
def test_robust_save_without_optimizer(self):
|
||||
"""Test robust save without optimizer state"""
|
||||
save_path = os.path.join(self.temp_dir, "test_model_no_opt.pt")
|
||||
|
||||
success = robust_save(self.test_agent, save_path, include_optimizer=False)
|
||||
self.assertTrue(success, "Robust save without optimizer should succeed")
|
||||
|
||||
# Verify that optimizer state is not included
|
||||
checkpoint = torch.load(save_path, map_location='cpu')
|
||||
self.assertNotIn('optimizer', checkpoint, "Optimizer state should not be saved")
|
||||
self.assertIn('policy_net', checkpoint, "Policy network should be saved")
|
||||
|
||||
def test_robust_load_basic(self):
|
||||
"""Test basic robust load functionality"""
|
||||
save_path = os.path.join(self.temp_dir, "test_model.pt")
|
||||
|
||||
# Save first
|
||||
success = robust_save(self.test_agent, save_path)
|
||||
self.assertTrue(success, "Save should succeed")
|
||||
|
||||
# Create new agent and load
|
||||
new_agent = MockAgent()
|
||||
success = robust_load(new_agent, save_path)
|
||||
self.assertTrue(success, "Load should succeed")
|
||||
|
||||
# Verify epsilon was loaded
|
||||
self.assertEqual(new_agent.epsilon, self.test_agent.epsilon, "Epsilon should match")
|
||||
|
||||
def test_get_model_info(self):
|
||||
"""Test model info extraction"""
|
||||
save_path = os.path.join(self.temp_dir, "test_model.pt")
|
||||
|
||||
# Test non-existent file
|
||||
info = get_model_info(save_path)
|
||||
self.assertFalse(info['exists'], "Non-existent file should return exists=False")
|
||||
|
||||
# Save model and test info
|
||||
robust_save(self.test_agent, save_path)
|
||||
info = get_model_info(save_path)
|
||||
|
||||
self.assertTrue(info['exists'], "Existing file should return exists=True")
|
||||
self.assertGreater(info['size_bytes'], 0, "File size should be greater than 0")
|
||||
self.assertTrue(info['has_optimizer'], "Should detect optimizer in checkpoint")
|
||||
self.assertEqual(info['parameters']['state_size'], self.test_agent.state_size)
|
||||
self.assertEqual(info['parameters']['action_size'], self.test_agent.action_size)
|
||||
|
||||
def test_save_load_cycle_verification(self):
|
||||
"""Test save/load cycle verification"""
|
||||
test_path = os.path.join(self.temp_dir, "cycle_test.pt")
|
||||
|
||||
success = verify_save_load_cycle(self.test_agent, test_path)
|
||||
self.assertTrue(success, "Save/load cycle should succeed")
|
||||
|
||||
# File should be cleaned up after verification
|
||||
self.assertFalse(os.path.exists(test_path), "Test file should be cleaned up")
|
||||
|
||||
def test_multiple_save_methods(self):
|
||||
"""Test that different save methods all work"""
|
||||
methods = ['regular', 'no_optimizer', 'pickle2']
|
||||
|
||||
for method in methods:
|
||||
with self.subTest(method=method):
|
||||
save_path = os.path.join(self.temp_dir, f"test_{method}.pt")
|
||||
|
||||
if method == 'regular':
|
||||
success = robust_save(self.test_agent, save_path)
|
||||
elif method == 'no_optimizer':
|
||||
success = robust_save(self.test_agent, save_path, include_optimizer=False)
|
||||
elif method == 'pickle2':
|
||||
# This would be tested by the robust_save fallback mechanism
|
||||
success = robust_save(self.test_agent, save_path)
|
||||
|
||||
self.assertTrue(success, f"{method} save should succeed")
|
||||
self.assertTrue(os.path.exists(save_path), f"{method} save should create file")
|
||||
|
||||
class TestTrainingMetrics(unittest.TestCase):
|
||||
"""Test suite for training metrics and monitoring functionality"""
|
||||
|
||||
def test_signal_distribution_calculation(self):
|
||||
"""Test signal distribution calculation"""
|
||||
# Mock predictions
|
||||
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0]) # SELL, HOLD, BUY
|
||||
|
||||
buy_count = np.sum(predictions == 2)
|
||||
sell_count = np.sum(predictions == 0)
|
||||
hold_count = np.sum(predictions == 1)
|
||||
total = len(predictions)
|
||||
|
||||
distribution = {
|
||||
"BUY": buy_count / total,
|
||||
"SELL": sell_count / total,
|
||||
"HOLD": hold_count / total
|
||||
}
|
||||
|
||||
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
|
||||
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
|
||||
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
|
||||
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
|
||||
|
||||
def test_metrics_tracking_structure(self):
|
||||
"""Test metrics history structure for training monitoring"""
|
||||
metrics_history = {
|
||||
"epoch": [],
|
||||
"train_loss": [],
|
||||
"val_loss": [],
|
||||
"train_acc": [],
|
||||
"val_acc": [],
|
||||
"train_pnl": [],
|
||||
"val_pnl": [],
|
||||
"train_win_rate": [],
|
||||
"val_win_rate": [],
|
||||
"signal_distribution": []
|
||||
}
|
||||
|
||||
# Simulate adding metrics for one epoch
|
||||
metrics_history["epoch"].append(1)
|
||||
metrics_history["train_loss"].append(0.5)
|
||||
metrics_history["val_loss"].append(0.6)
|
||||
metrics_history["train_acc"].append(0.7)
|
||||
metrics_history["val_acc"].append(0.65)
|
||||
metrics_history["train_pnl"].append(0.1)
|
||||
metrics_history["val_pnl"].append(0.08)
|
||||
metrics_history["train_win_rate"].append(0.6)
|
||||
metrics_history["val_win_rate"].append(0.55)
|
||||
metrics_history["signal_distribution"].append({"BUY": 0.3, "SELL": 0.3, "HOLD": 0.4})
|
||||
|
||||
# Verify structure
|
||||
self.assertEqual(len(metrics_history["epoch"]), 1)
|
||||
self.assertEqual(metrics_history["epoch"][0], 1)
|
||||
self.assertIsInstance(metrics_history["signal_distribution"][0], dict)
|
||||
self.assertIn("BUY", metrics_history["signal_distribution"][0])
|
||||
|
||||
class TestModelArchitecture(unittest.TestCase):
|
||||
"""Test suite for model architecture verification"""
|
||||
|
||||
def test_model_parameter_consistency(self):
|
||||
"""Test that model parameters are consistent after save/load"""
|
||||
agent = MockAgent(state_size=32, action_size=3, hidden_size=128)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
save_path = os.path.join(temp_dir, "consistency_test.pt")
|
||||
|
||||
# Save model
|
||||
robust_save(agent, save_path)
|
||||
|
||||
# Load into new model with same architecture
|
||||
new_agent = MockAgent(state_size=32, action_size=3, hidden_size=128)
|
||||
robust_load(new_agent, save_path)
|
||||
|
||||
# Verify parameters match
|
||||
self.assertEqual(new_agent.state_size, agent.state_size)
|
||||
self.assertEqual(new_agent.action_size, agent.action_size)
|
||||
self.assertEqual(new_agent.hidden_size, agent.hidden_size)
|
||||
self.assertEqual(new_agent.epsilon, agent.epsilon)
|
||||
|
||||
def test_model_forward_pass(self):
|
||||
"""Test that model can perform forward pass after load"""
|
||||
agent = MockAgent()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
save_path = os.path.join(temp_dir, "forward_test.pt")
|
||||
|
||||
# Create test input
|
||||
test_input = torch.randn(1, agent.state_size)
|
||||
|
||||
# Get original output
|
||||
original_output = agent.policy_net(test_input)
|
||||
|
||||
# Save and load
|
||||
robust_save(agent, save_path)
|
||||
new_agent = MockAgent()
|
||||
robust_load(new_agent, save_path)
|
||||
|
||||
# Test forward pass works
|
||||
new_output = new_agent.policy_net(test_input)
|
||||
|
||||
self.assertEqual(new_output.shape, original_output.shape)
|
||||
# Outputs should be identical since we loaded the same weights
|
||||
torch.testing.assert_close(new_output, original_output)
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all test suites"""
|
||||
test_suites = [
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestModelPersistence),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestTrainingMetrics),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestModelArchitecture)
|
||||
]
|
||||
|
||||
combined_suite = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(combined_suite)
|
||||
|
||||
return result.wasSuccessful()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Running comprehensive model persistence and training tests...")
|
||||
success = run_all_tests()
|
||||
|
||||
if success:
|
||||
logger.info("All tests passed!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("Some tests failed!")
|
||||
sys.exit(1)
|
@ -1,327 +0,0 @@
|
||||
"""
|
||||
Test Multi-Exchange Consolidated Order Book (COB) Provider
|
||||
|
||||
This script demonstrates the functionality of the new multi-exchange COB data provider:
|
||||
1. Real-time order book aggregation from multiple exchanges
|
||||
2. Fine-grain price bucket generation
|
||||
3. CNN/DQN feature generation
|
||||
4. Dashboard integration
|
||||
5. Market analysis and signal generation
|
||||
|
||||
Run this to test the COB provider with live data streams.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from core.multi_exchange_cob_provider import MultiExchangeCOBProvider
|
||||
from core.cob_integration import COBIntegration
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class COBTester:
|
||||
"""Test harness for Multi-Exchange COB Provider"""
|
||||
|
||||
def __init__(self):
|
||||
self.symbols = ['BTC/USDT', 'ETH/USDT']
|
||||
self.data_provider = None
|
||||
self.cob_integration = None
|
||||
self.test_duration = 300 # 5 minutes
|
||||
|
||||
# Statistics tracking
|
||||
self.stats = {
|
||||
'cob_updates_received': 0,
|
||||
'bucket_updates_received': 0,
|
||||
'cnn_features_generated': 0,
|
||||
'dqn_features_generated': 0,
|
||||
'signals_generated': 0,
|
||||
'start_time': None
|
||||
}
|
||||
|
||||
async def run_test(self):
|
||||
"""Run comprehensive COB provider test"""
|
||||
logger.info("Starting Multi-Exchange COB Provider Test")
|
||||
logger.info(f"Testing symbols: {self.symbols}")
|
||||
logger.info(f"Test duration: {self.test_duration} seconds")
|
||||
|
||||
try:
|
||||
# Initialize components
|
||||
await self._initialize_components()
|
||||
|
||||
# Run test scenarios
|
||||
await self._run_basic_functionality_test()
|
||||
await self._run_feature_generation_test()
|
||||
await self._run_dashboard_integration_test()
|
||||
await self._run_signal_analysis_test()
|
||||
|
||||
# Monitor for specified duration
|
||||
await self._monitor_live_data()
|
||||
|
||||
# Generate final report
|
||||
self._generate_test_report()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Test failed: {e}")
|
||||
finally:
|
||||
await self._cleanup()
|
||||
|
||||
async def _initialize_components(self):
|
||||
"""Initialize COB provider and integration components"""
|
||||
logger.info("Initializing COB components...")
|
||||
|
||||
# Create data provider (optional - for integration testing)
|
||||
self.data_provider = DataProvider(symbols=self.symbols)
|
||||
|
||||
# Create COB integration
|
||||
self.cob_integration = COBIntegration(
|
||||
data_provider=self.data_provider,
|
||||
symbols=self.symbols
|
||||
)
|
||||
|
||||
# Register test callbacks
|
||||
self.cob_integration.add_cnn_callback(self._cnn_callback)
|
||||
self.cob_integration.add_dqn_callback(self._dqn_callback)
|
||||
self.cob_integration.add_dashboard_callback(self._dashboard_callback)
|
||||
|
||||
# Start COB integration
|
||||
await self.cob_integration.start()
|
||||
|
||||
# Allow time for connections
|
||||
await asyncio.sleep(5)
|
||||
|
||||
self.stats['start_time'] = datetime.now()
|
||||
logger.info("COB components initialized successfully")
|
||||
|
||||
async def _run_basic_functionality_test(self):
|
||||
"""Test basic COB provider functionality"""
|
||||
logger.info("Testing basic COB functionality...")
|
||||
|
||||
# Wait for order book data
|
||||
await asyncio.sleep(10)
|
||||
|
||||
for symbol in self.symbols:
|
||||
# Test consolidated order book retrieval
|
||||
cob_snapshot = self.cob_integration.get_cob_snapshot(symbol)
|
||||
if cob_snapshot:
|
||||
logger.info(f"{symbol} COB Status:")
|
||||
logger.info(f" Exchanges active: {cob_snapshot.exchanges_active}")
|
||||
logger.info(f" Volume weighted mid: ${cob_snapshot.volume_weighted_mid:.2f}")
|
||||
logger.info(f" Spread: {cob_snapshot.spread_bps:.2f} bps")
|
||||
logger.info(f" Bid liquidity: ${cob_snapshot.total_bid_liquidity:,.0f}")
|
||||
logger.info(f" Ask liquidity: ${cob_snapshot.total_ask_liquidity:,.0f}")
|
||||
logger.info(f" Liquidity imbalance: {cob_snapshot.liquidity_imbalance:.3f}")
|
||||
|
||||
# Test price buckets
|
||||
price_buckets = self.cob_integration.get_price_buckets(symbol)
|
||||
if price_buckets:
|
||||
bid_buckets = len(price_buckets.get('bids', {}))
|
||||
ask_buckets = len(price_buckets.get('asks', {}))
|
||||
logger.info(f" Price buckets: {bid_buckets} bids, {ask_buckets} asks")
|
||||
|
||||
# Test exchange breakdown
|
||||
exchange_breakdown = self.cob_integration.get_exchange_breakdown(symbol)
|
||||
if exchange_breakdown:
|
||||
logger.info(f" Exchange breakdown:")
|
||||
for exchange, data in exchange_breakdown.items():
|
||||
market_share = data.get('market_share', 0) * 100
|
||||
logger.info(f" {exchange}: {market_share:.1f}% market share")
|
||||
else:
|
||||
logger.warning(f"No COB data available for {symbol}")
|
||||
|
||||
logger.info("Basic functionality test completed")
|
||||
|
||||
async def _run_feature_generation_test(self):
|
||||
"""Test CNN and DQN feature generation"""
|
||||
logger.info("Testing feature generation...")
|
||||
|
||||
for symbol in self.symbols:
|
||||
# Test CNN features
|
||||
cnn_features = self.cob_integration.get_cob_features(symbol)
|
||||
if cnn_features is not None:
|
||||
logger.info(f"{symbol} CNN features: shape={cnn_features.shape}, "
|
||||
f"min={cnn_features.min():.4f}, max={cnn_features.max():.4f}")
|
||||
else:
|
||||
logger.warning(f"No CNN features available for {symbol}")
|
||||
|
||||
# Test market depth analysis
|
||||
depth_analysis = self.cob_integration.get_market_depth_analysis(symbol)
|
||||
if depth_analysis:
|
||||
logger.info(f"{symbol} Market Depth Analysis:")
|
||||
logger.info(f" Depth levels: {depth_analysis['depth_analysis']['bid_levels']} bids, "
|
||||
f"{depth_analysis['depth_analysis']['ask_levels']} asks")
|
||||
|
||||
dominant_exchanges = depth_analysis['depth_analysis'].get('dominant_exchanges', {})
|
||||
logger.info(f" Dominant exchanges: {dominant_exchanges}")
|
||||
|
||||
logger.info("Feature generation test completed")
|
||||
|
||||
async def _run_dashboard_integration_test(self):
|
||||
"""Test dashboard data generation"""
|
||||
logger.info("Testing dashboard integration...")
|
||||
|
||||
# Dashboard integration is tested via callbacks
|
||||
# Statistics are tracked in the callback functions
|
||||
await asyncio.sleep(5)
|
||||
|
||||
logger.info("Dashboard integration test completed")
|
||||
|
||||
async def _run_signal_analysis_test(self):
|
||||
"""Test signal generation and analysis"""
|
||||
logger.info("Testing signal analysis...")
|
||||
|
||||
for symbol in self.symbols:
|
||||
# Get recent signals
|
||||
recent_signals = self.cob_integration.get_recent_signals(symbol, count=10)
|
||||
logger.info(f"{symbol} recent signals: {len(recent_signals)} generated")
|
||||
|
||||
for signal in recent_signals[-3:]: # Show last 3 signals
|
||||
logger.info(f" Signal: {signal.get('type')} - {signal.get('side')} - "
|
||||
f"Confidence: {signal.get('confidence', 0):.3f}")
|
||||
|
||||
logger.info("Signal analysis test completed")
|
||||
|
||||
async def _monitor_live_data(self):
|
||||
"""Monitor live data for the specified duration"""
|
||||
logger.info(f"Monitoring live data for {self.test_duration} seconds...")
|
||||
|
||||
start_time = time.time()
|
||||
last_stats_time = start_time
|
||||
|
||||
while time.time() - start_time < self.test_duration:
|
||||
# Print periodic statistics
|
||||
current_time = time.time()
|
||||
if current_time - last_stats_time >= 30: # Every 30 seconds
|
||||
self._print_periodic_stats()
|
||||
last_stats_time = current_time
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.info("Live data monitoring completed")
|
||||
|
||||
def _print_periodic_stats(self):
|
||||
"""Print periodic statistics during monitoring"""
|
||||
elapsed = (datetime.now() - self.stats['start_time']).total_seconds()
|
||||
|
||||
logger.info("Periodic Statistics:")
|
||||
logger.info(f" Elapsed time: {elapsed:.0f} seconds")
|
||||
logger.info(f" COB updates: {self.stats['cob_updates_received']}")
|
||||
logger.info(f" Bucket updates: {self.stats['bucket_updates_received']}")
|
||||
logger.info(f" CNN features: {self.stats['cnn_features_generated']}")
|
||||
logger.info(f" DQN features: {self.stats['dqn_features_generated']}")
|
||||
logger.info(f" Signals: {self.stats['signals_generated']}")
|
||||
|
||||
# Calculate rates
|
||||
if elapsed > 0:
|
||||
cob_rate = self.stats['cob_updates_received'] / elapsed
|
||||
logger.info(f" COB update rate: {cob_rate:.2f}/sec")
|
||||
|
||||
def _generate_test_report(self):
|
||||
"""Generate final test report"""
|
||||
elapsed = (datetime.now() - self.stats['start_time']).total_seconds()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("MULTI-EXCHANGE COB PROVIDER TEST REPORT")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Test Duration: {elapsed:.0f} seconds")
|
||||
logger.info(f"Symbols Tested: {', '.join(self.symbols)}")
|
||||
logger.info("")
|
||||
|
||||
# Data Reception Statistics
|
||||
logger.info("Data Reception:")
|
||||
logger.info(f" COB Updates Received: {self.stats['cob_updates_received']}")
|
||||
logger.info(f" Bucket Updates Received: {self.stats['bucket_updates_received']}")
|
||||
logger.info(f" Average COB Rate: {self.stats['cob_updates_received'] / elapsed:.2f}/sec")
|
||||
logger.info("")
|
||||
|
||||
# Feature Generation Statistics
|
||||
logger.info("Feature Generation:")
|
||||
logger.info(f" CNN Features Generated: {self.stats['cnn_features_generated']}")
|
||||
logger.info(f" DQN Features Generated: {self.stats['dqn_features_generated']}")
|
||||
logger.info("")
|
||||
|
||||
# Signal Generation Statistics
|
||||
logger.info("Signal Analysis:")
|
||||
logger.info(f" Signals Generated: {self.stats['signals_generated']}")
|
||||
logger.info("")
|
||||
|
||||
# Component Statistics
|
||||
cob_stats = self.cob_integration.get_statistics()
|
||||
logger.info("Component Statistics:")
|
||||
logger.info(f" Active Exchanges: {', '.join(cob_stats.get('active_exchanges', []))}")
|
||||
logger.info(f" Streaming Status: {cob_stats.get('is_streaming', False)}")
|
||||
logger.info(f" Bucket Size: {cob_stats.get('bucket_size_bps', 0)} bps")
|
||||
logger.info(f" Average Processing Time: {cob_stats.get('avg_processing_time_ms', 0):.2f} ms")
|
||||
logger.info("")
|
||||
|
||||
# Per-Symbol Analysis
|
||||
logger.info("Per-Symbol Analysis:")
|
||||
for symbol in self.symbols:
|
||||
cob_snapshot = self.cob_integration.get_cob_snapshot(symbol)
|
||||
if cob_snapshot:
|
||||
logger.info(f" {symbol}:")
|
||||
logger.info(f" Active Exchanges: {len(cob_snapshot.exchanges_active)}")
|
||||
logger.info(f" Spread: {cob_snapshot.spread_bps:.2f} bps")
|
||||
logger.info(f" Total Liquidity: ${(cob_snapshot.total_bid_liquidity + cob_snapshot.total_ask_liquidity):,.0f}")
|
||||
|
||||
recent_signals = self.cob_integration.get_recent_signals(symbol)
|
||||
logger.info(f" Signals Generated: {len(recent_signals)}")
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("Test completed successfully!")
|
||||
|
||||
async def _cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
logger.info("Cleaning up resources...")
|
||||
|
||||
if self.cob_integration:
|
||||
await self.cob_integration.stop()
|
||||
|
||||
if self.data_provider and hasattr(self.data_provider, 'stop_real_time_streaming'):
|
||||
await self.data_provider.stop_real_time_streaming()
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
# Callback functions for testing
|
||||
|
||||
def _cnn_callback(self, symbol: str, data: dict):
|
||||
"""CNN feature callback for testing"""
|
||||
self.stats['cnn_features_generated'] += 1
|
||||
if self.stats['cnn_features_generated'] % 100 == 0:
|
||||
logger.debug(f"CNN features generated: {self.stats['cnn_features_generated']}")
|
||||
|
||||
def _dqn_callback(self, symbol: str, data: dict):
|
||||
"""DQN feature callback for testing"""
|
||||
self.stats['dqn_features_generated'] += 1
|
||||
if self.stats['dqn_features_generated'] % 100 == 0:
|
||||
logger.debug(f"DQN features generated: {self.stats['dqn_features_generated']}")
|
||||
|
||||
def _dashboard_callback(self, symbol: str, data: dict):
|
||||
"""Dashboard data callback for testing"""
|
||||
self.stats['cob_updates_received'] += 1
|
||||
|
||||
# Check for signals in dashboard data
|
||||
signals = data.get('recent_signals', [])
|
||||
self.stats['signals_generated'] += len(signals)
|
||||
|
||||
async def main():
|
||||
"""Main test function"""
|
||||
logger.info("Multi-Exchange COB Provider Test Starting...")
|
||||
|
||||
try:
|
||||
tester = COBTester()
|
||||
await tester.run_test()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Test interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Test failed with error: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@ -1,213 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for negative case training functionality
|
||||
|
||||
This script tests:
|
||||
1. Negative case trainer initialization
|
||||
2. Adding losing trades for intensive training
|
||||
3. Storage in testcases/negative folder
|
||||
4. Simultaneous inference and training
|
||||
5. 500x leverage training case generation
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from core.negative_case_trainer import NegativeCaseTrainer, NegativeCase
|
||||
from core.trading_action import TradingAction
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_negative_case_trainer():
|
||||
"""Test negative case trainer functionality"""
|
||||
print("🔴 Testing Negative Case Trainer for Intensive Training on Losses")
|
||||
print("=" * 70)
|
||||
|
||||
# Test 1: Initialize trainer
|
||||
print("\n1. Initializing Negative Case Trainer...")
|
||||
trainer = NegativeCaseTrainer()
|
||||
print(f"✅ Trainer initialized with storage at: {trainer.storage_dir}")
|
||||
print(f"✅ Background training thread started: {trainer.training_thread.is_alive()}")
|
||||
|
||||
# Test 2: Create a losing trade scenario
|
||||
print("\n2. Creating losing trade scenarios...")
|
||||
|
||||
# Scenario 1: Small loss (1% with 500x leverage = 500% loss)
|
||||
trade_info_1 = {
|
||||
'timestamp': datetime.now(),
|
||||
'symbol': 'ETH/USDT',
|
||||
'action': 'BUY',
|
||||
'price': 3000.0,
|
||||
'size': 0.1,
|
||||
'value': 300.0,
|
||||
'confidence': 0.8,
|
||||
'pnl': -3.0 # $3 loss on $300 position = 1% loss
|
||||
}
|
||||
|
||||
market_data_1 = {
|
||||
'exit_price': 2970.0, # 1% drop
|
||||
'state_before': {
|
||||
'volatility': 2.5,
|
||||
'momentum': 0.5,
|
||||
'volume_ratio': 1.2
|
||||
},
|
||||
'state_after': {
|
||||
'volatility': 3.0,
|
||||
'momentum': -1.0,
|
||||
'volume_ratio': 0.8
|
||||
},
|
||||
'tick_data': [],
|
||||
'technical_indicators': {
|
||||
'rsi': 65,
|
||||
'macd': 0.5
|
||||
}
|
||||
}
|
||||
|
||||
case_id_1 = trainer.add_losing_trade(trade_info_1, market_data_1)
|
||||
print(f"✅ Added small loss case: {case_id_1}")
|
||||
|
||||
# Scenario 2: Large loss (5% with 500x leverage = 2500% loss)
|
||||
trade_info_2 = {
|
||||
'timestamp': datetime.now(),
|
||||
'symbol': 'ETH/USDT',
|
||||
'action': 'SELL',
|
||||
'price': 3000.0,
|
||||
'size': 0.2,
|
||||
'value': 600.0,
|
||||
'confidence': 0.9,
|
||||
'pnl': -30.0 # $30 loss on $600 position = 5% loss
|
||||
}
|
||||
|
||||
market_data_2 = {
|
||||
'exit_price': 3150.0, # 5% rise (bad for short)
|
||||
'state_before': {
|
||||
'volatility': 1.8,
|
||||
'momentum': -0.3,
|
||||
'volume_ratio': 0.9
|
||||
},
|
||||
'state_after': {
|
||||
'volatility': 4.2,
|
||||
'momentum': 2.5,
|
||||
'volume_ratio': 1.8
|
||||
},
|
||||
'tick_data': [],
|
||||
'technical_indicators': {
|
||||
'rsi': 35,
|
||||
'macd': -0.8
|
||||
}
|
||||
}
|
||||
|
||||
case_id_2 = trainer.add_losing_trade(trade_info_2, market_data_2)
|
||||
print(f"✅ Added large loss case: {case_id_2}")
|
||||
|
||||
# Test 3: Check training stats
|
||||
print("\n3. Checking training statistics...")
|
||||
stats = trainer.get_training_stats()
|
||||
print(f"✅ Total negative cases: {stats['total_negative_cases']}")
|
||||
print(f"✅ Cases in training queue: {stats['cases_in_queue']}")
|
||||
print(f"✅ High priority cases: {stats['high_priority_cases']}")
|
||||
print(f"✅ Training active: {stats['training_active']}")
|
||||
print(f"✅ Storage directory: {stats['storage_directory']}")
|
||||
|
||||
# Test 4: Check recent lessons
|
||||
print("\n4. Recent lessons learned...")
|
||||
lessons = trainer.get_recent_lessons(3)
|
||||
for i, lesson in enumerate(lessons, 1):
|
||||
print(f"✅ Lesson {i}: {lesson}")
|
||||
|
||||
# Test 5: Test simultaneous inference capability
|
||||
print("\n5. Testing simultaneous inference and training...")
|
||||
for i in range(5):
|
||||
can_inference = trainer.can_inference_proceed()
|
||||
print(f"✅ Inference check {i+1}: {'ALLOWED' if can_inference else 'BLOCKED'}")
|
||||
time.sleep(0.5)
|
||||
|
||||
# Test 6: Wait for some training to complete
|
||||
print("\n6. Waiting for intensive training to process cases...")
|
||||
time.sleep(3) # Wait for background training
|
||||
|
||||
# Check updated stats
|
||||
updated_stats = trainer.get_training_stats()
|
||||
print(f"✅ Cases processed: {updated_stats['total_cases_processed']}")
|
||||
print(f"✅ Total training time: {updated_stats['total_training_time']:.2f}s")
|
||||
print(f"✅ Avg accuracy improvement: {updated_stats['avg_accuracy_improvement']:.1%}")
|
||||
|
||||
# Test 7: 500x leverage training case analysis
|
||||
print("\n7. 500x Leverage Training Case Analysis...")
|
||||
print("💡 With 0% fees, any move >0.1% is profitable at 500x leverage:")
|
||||
|
||||
test_moves = [0.05, 0.1, 0.15, 0.2, 0.5, 1.0] # Price change percentages
|
||||
for move_pct in test_moves:
|
||||
leverage_profit = move_pct * 500
|
||||
profitable = move_pct >= 0.1
|
||||
status = "✅ PROFITABLE" if profitable else "❌ TOO SMALL"
|
||||
print(f" {move_pct:+.2f}% move = {leverage_profit:+.1f}% @ 500x leverage - {status}")
|
||||
|
||||
print("\n🔴 PRIORITY: Losing trades trigger intensive RL retraining")
|
||||
print("🚀 System optimized for fast trading with 500x leverage and 0% fees")
|
||||
print("⚡ Training cases generated for all moves >0.1% to maximize profit")
|
||||
|
||||
return trainer
|
||||
|
||||
def test_integration_with_enhanced_dashboard():
|
||||
"""Test integration with enhanced dashboard"""
|
||||
print("\n" + "=" * 70)
|
||||
print("🔗 Testing Integration with Enhanced Dashboard")
|
||||
print("=" * 70)
|
||||
|
||||
try:
|
||||
from web.old_archived.enhanced_scalping_dashboard import EnhancedScalpingDashboard
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
|
||||
# Create components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
dashboard = EnhancedScalpingDashboard(data_provider, orchestrator)
|
||||
|
||||
print("✅ Enhanced dashboard created successfully")
|
||||
print(f"✅ Orchestrator has negative case trainer: {hasattr(orchestrator, 'negative_case_trainer')}")
|
||||
print(f"✅ Trading session has orchestrator reference: {hasattr(dashboard.trading_session, 'orchestrator')}")
|
||||
|
||||
# Test negative case trainer access
|
||||
if hasattr(orchestrator, 'negative_case_trainer'):
|
||||
trainer_stats = orchestrator.negative_case_trainer.get_training_stats()
|
||||
print(f"✅ Negative case trainer accessible with {trainer_stats['total_negative_cases']} cases")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Integration test failed: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🔴 NEGATIVE CASE TRAINING TEST SUITE")
|
||||
print("Focus: Learning from losses to prevent future mistakes")
|
||||
print("Features: 500x leverage optimization, 0% fee advantage, intensive retraining")
|
||||
|
||||
try:
|
||||
# Test negative case trainer
|
||||
trainer = test_negative_case_trainer()
|
||||
|
||||
# Test integration
|
||||
integration_success = test_integration_with_enhanced_dashboard()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("📊 TEST SUMMARY")
|
||||
print("=" * 70)
|
||||
print("✅ Negative case trainer: WORKING")
|
||||
print("✅ Intensive training on losses: ACTIVE")
|
||||
print("✅ Storage in testcases/negative: WORKING")
|
||||
print("✅ Simultaneous inference/training: SUPPORTED")
|
||||
print("✅ 500x leverage optimization: IMPLEMENTED")
|
||||
print(f"✅ Enhanced dashboard integration: {'WORKING' if integration_success else 'NEEDS ATTENTION'}")
|
||||
|
||||
print("\n🎯 SYSTEM READY FOR INTENSIVE LOSS-BASED LEARNING")
|
||||
print("💪 Every losing trade makes the system stronger!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test suite failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
@ -1,201 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test NN-Driven Trading System
|
||||
Demonstrates how the system now makes decisions using Neural Networks instead of algorithms
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def test_nn_driven_system():
|
||||
"""Test the NN-driven trading system"""
|
||||
logger.info("=== TESTING NN-DRIVEN TRADING SYSTEM ===")
|
||||
|
||||
try:
|
||||
# Import core components
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.nn_decision_fusion import ModelPrediction, MarketContext
|
||||
|
||||
# Initialize components
|
||||
config = get_config()
|
||||
data_provider = DataProvider()
|
||||
|
||||
# Initialize NN-driven orchestrator
|
||||
orchestrator = EnhancedTradingOrchestrator(
|
||||
data_provider=data_provider,
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
enhanced_rl_training=True
|
||||
)
|
||||
|
||||
logger.info("✅ NN-driven orchestrator initialized")
|
||||
|
||||
# Test 1: Add mock CNN prediction
|
||||
cnn_prediction = ModelPrediction(
|
||||
model_name="williams_cnn",
|
||||
prediction_type="direction",
|
||||
value=0.6, # Bullish signal
|
||||
confidence=0.8,
|
||||
timestamp=datetime.now(),
|
||||
metadata={'timeframe': '1h', 'feature_importance': [0.2, 0.3, 0.5]}
|
||||
)
|
||||
|
||||
orchestrator.neural_fusion.add_prediction(cnn_prediction)
|
||||
logger.info("🔮 Added CNN prediction: BULLISH (0.6) with 80% confidence")
|
||||
|
||||
# Test 2: Add mock RL prediction
|
||||
rl_prediction = ModelPrediction(
|
||||
model_name="dqn_agent",
|
||||
prediction_type="action",
|
||||
value=0.4, # Moderate buy signal
|
||||
confidence=0.7,
|
||||
timestamp=datetime.now(),
|
||||
metadata={'action_probs': [0.4, 0.2, 0.4]} # [BUY, SELL, HOLD]
|
||||
)
|
||||
|
||||
orchestrator.neural_fusion.add_prediction(rl_prediction)
|
||||
logger.info("🔮 Added RL prediction: MODERATE_BUY (0.4) with 70% confidence")
|
||||
|
||||
# Test 3: Add mock COB RL prediction
|
||||
cob_prediction = ModelPrediction(
|
||||
model_name="cob_rl",
|
||||
prediction_type="direction",
|
||||
value=0.3, # Slightly bullish
|
||||
confidence=0.85,
|
||||
timestamp=datetime.now(),
|
||||
metadata={'cob_imbalance': 0.1, 'liquidity_depth': 150000}
|
||||
)
|
||||
|
||||
orchestrator.neural_fusion.add_prediction(cob_prediction)
|
||||
logger.info("🔮 Added COB RL prediction: SLIGHT_BULLISH (0.3) with 85% confidence")
|
||||
|
||||
# Test 4: Create market context
|
||||
market_context = MarketContext(
|
||||
symbol='ETH/USDT',
|
||||
current_price=2441.50,
|
||||
price_change_1m=0.002, # 0.2% up in 1m
|
||||
price_change_5m=0.008, # 0.8% up in 5m
|
||||
volume_ratio=1.2, # 20% above average volume
|
||||
volatility=0.015, # 1.5% volatility
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
logger.info(f"📊 Market Context: ETH/USDT at ${market_context.current_price}")
|
||||
logger.info(f" 📈 Price changes: 1m: {market_context.price_change_1m:.3f}, 5m: {market_context.price_change_5m:.3f}")
|
||||
logger.info(f" 📊 Volume ratio: {market_context.volume_ratio:.2f}, Volatility: {market_context.volatility:.3f}")
|
||||
|
||||
# Test 5: Make NN decision
|
||||
fusion_decision = orchestrator.neural_fusion.make_decision(
|
||||
symbol='ETH/USDT',
|
||||
market_context=market_context,
|
||||
min_confidence=0.25
|
||||
)
|
||||
|
||||
if fusion_decision:
|
||||
logger.info("🧠 === NN DECISION RESULT ===")
|
||||
logger.info(f" Action: {fusion_decision.action}")
|
||||
logger.info(f" Confidence: {fusion_decision.confidence:.3f}")
|
||||
logger.info(f" Expected Return: {fusion_decision.expected_return:.3f}")
|
||||
logger.info(f" Risk Score: {fusion_decision.risk_score:.3f}")
|
||||
logger.info(f" Position Size: {fusion_decision.position_size:.4f} ETH")
|
||||
logger.info(f" Reasoning: {fusion_decision.reasoning}")
|
||||
logger.info(" Model Contributions:")
|
||||
for model, contribution in fusion_decision.model_contributions.items():
|
||||
logger.info(f" - {model}: {contribution:.1%}")
|
||||
else:
|
||||
logger.warning("❌ No NN decision generated")
|
||||
|
||||
# Test 6: Test coordinated decisions
|
||||
logger.info("\n🎯 Testing coordinated NN decisions...")
|
||||
decisions = await orchestrator.make_coordinated_decisions()
|
||||
|
||||
if decisions:
|
||||
logger.info(f"✅ Generated {len(decisions)} NN-driven trading decisions:")
|
||||
for i, decision in enumerate(decisions):
|
||||
logger.info(f" Decision {i+1}: {decision.symbol} {decision.action} "
|
||||
f"({decision.confidence:.3f} confidence, "
|
||||
f"{decision.quantity:.4f} size)")
|
||||
if hasattr(decision, 'metadata') and decision.metadata:
|
||||
if decision.metadata.get('nn_driven'):
|
||||
logger.info(f" 🧠 NN-DRIVEN: {decision.metadata.get('reasoning', 'No reasoning')}")
|
||||
else:
|
||||
logger.info("ℹ️ No trading decisions generated (insufficient confidence)")
|
||||
|
||||
# Test 7: Check NN system status
|
||||
nn_status = orchestrator.neural_fusion.get_status()
|
||||
logger.info("\n📊 NN System Status:")
|
||||
logger.info(f" Device: {nn_status['device']}")
|
||||
logger.info(f" Training Mode: {nn_status['training_mode']}")
|
||||
logger.info(f" Registered Models: {nn_status['registered_models']}")
|
||||
logger.info(f" Recent Predictions: {nn_status['recent_predictions']}")
|
||||
logger.info(f" Model Parameters: {nn_status['model_parameters']:,}")
|
||||
|
||||
# Test 8: Demonstrate different confidence scenarios
|
||||
logger.info("\n🔬 Testing different confidence scenarios...")
|
||||
|
||||
# Low confidence scenario
|
||||
low_conf_prediction = ModelPrediction(
|
||||
model_name="williams_cnn",
|
||||
prediction_type="direction",
|
||||
value=0.1, # Weak signal
|
||||
confidence=0.2, # Low confidence
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
orchestrator.neural_fusion.add_prediction(low_conf_prediction)
|
||||
low_conf_decision = orchestrator.neural_fusion.make_decision(
|
||||
symbol='ETH/USDT',
|
||||
market_context=market_context,
|
||||
min_confidence=0.25
|
||||
)
|
||||
|
||||
if low_conf_decision:
|
||||
logger.info(f" Low confidence result: {low_conf_decision.action} (should be HOLD)")
|
||||
else:
|
||||
logger.info(" ✅ Low confidence correctly resulted in no decision")
|
||||
|
||||
# High confidence scenario
|
||||
high_conf_prediction = ModelPrediction(
|
||||
model_name="williams_cnn",
|
||||
prediction_type="direction",
|
||||
value=0.8, # Strong signal
|
||||
confidence=0.95, # Very high confidence
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
orchestrator.neural_fusion.add_prediction(high_conf_prediction)
|
||||
high_conf_decision = orchestrator.neural_fusion.make_decision(
|
||||
symbol='ETH/USDT',
|
||||
market_context=market_context,
|
||||
min_confidence=0.25
|
||||
)
|
||||
|
||||
if high_conf_decision:
|
||||
logger.info(f" High confidence result: {high_conf_decision.action} "
|
||||
f"(conf: {high_conf_decision.confidence:.3f}, "
|
||||
f"size: {high_conf_decision.position_size:.4f})")
|
||||
|
||||
logger.info("\n✅ NN-DRIVEN TRADING SYSTEM TEST COMPLETE")
|
||||
logger.info("🎯 Key Benefits Demonstrated:")
|
||||
logger.info(" 1. Multiple NN models provide predictions")
|
||||
logger.info(" 2. Central NN fusion makes final decisions")
|
||||
logger.info(" 3. Market context influences decisions")
|
||||
logger.info(" 4. Confidence thresholds prevent bad trades")
|
||||
logger.info(" 5. Position sizing based on NN outputs")
|
||||
logger.info(" 6. Clear reasoning for every decision")
|
||||
logger.info(" 7. Model contribution tracking")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in NN-driven system test: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_nn_driven_system())
|
@ -1,305 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Pivot-Based Normalization System
|
||||
|
||||
This script tests the comprehensive pivot-based normalization system:
|
||||
1. Monthly 1s data collection with pagination
|
||||
2. Williams Market Structure pivot analysis
|
||||
3. Pivot bounds extraction and caching
|
||||
4. Pivot-based feature normalization
|
||||
5. Integration with model training pipeline
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from core.data_provider import DataProvider
|
||||
from core.config import get_config
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_pivot_normalization_system():
|
||||
"""Test the complete pivot-based normalization system"""
|
||||
|
||||
print("="*80)
|
||||
print("TESTING PIVOT-BASED NORMALIZATION SYSTEM")
|
||||
print("="*80)
|
||||
|
||||
# Initialize data provider
|
||||
symbols = ['ETH/USDT'] # Test with ETH only
|
||||
timeframes = ['1s']
|
||||
|
||||
logger.info("Initializing DataProvider with pivot-based normalization...")
|
||||
data_provider = DataProvider(symbols=symbols, timeframes=timeframes)
|
||||
|
||||
# Test 1: Monthly Data Collection
|
||||
print("\n" + "="*60)
|
||||
print("TEST 1: MONTHLY 1S DATA COLLECTION")
|
||||
print("="*60)
|
||||
|
||||
symbol = 'ETH/USDT'
|
||||
|
||||
try:
|
||||
# This will trigger monthly data collection and pivot analysis
|
||||
logger.info(f"Testing monthly data collection for {symbol}...")
|
||||
monthly_data = data_provider._collect_monthly_1m_data(symbol)
|
||||
|
||||
if monthly_data is not None:
|
||||
print(f"✅ Monthly data collection SUCCESS")
|
||||
print(f" 📊 Collected {len(monthly_data):,} 1m candles")
|
||||
print(f" 📅 Period: {monthly_data['timestamp'].min()} to {monthly_data['timestamp'].max()}")
|
||||
print(f" 💰 Price range: ${monthly_data['low'].min():.2f} - ${monthly_data['high'].max():.2f}")
|
||||
print(f" 📈 Volume range: {monthly_data['volume'].min():.2f} - {monthly_data['volume'].max():.2f}")
|
||||
else:
|
||||
print("❌ Monthly data collection FAILED")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Monthly data collection ERROR: {e}")
|
||||
return False
|
||||
|
||||
# Test 2: Pivot Bounds Extraction
|
||||
print("\n" + "="*60)
|
||||
print("TEST 2: PIVOT BOUNDS EXTRACTION")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
logger.info("Testing pivot bounds extraction...")
|
||||
bounds = data_provider._extract_pivot_bounds_from_monthly_data(symbol, monthly_data)
|
||||
|
||||
if bounds is not None:
|
||||
print(f"✅ Pivot bounds extraction SUCCESS")
|
||||
print(f" 💰 Price bounds: ${bounds.price_min:.2f} - ${bounds.price_max:.2f}")
|
||||
print(f" 📊 Volume bounds: {bounds.volume_min:.2f} - {bounds.volume_max:.2f}")
|
||||
print(f" 🔸 Support levels: {len(bounds.pivot_support_levels)}")
|
||||
print(f" 🔹 Resistance levels: {len(bounds.pivot_resistance_levels)}")
|
||||
print(f" 📈 Candles analyzed: {bounds.total_candles_analyzed:,}")
|
||||
print(f" ⏰ Created: {bounds.created_timestamp}")
|
||||
|
||||
# Store bounds for next tests
|
||||
data_provider.pivot_bounds[symbol] = bounds
|
||||
else:
|
||||
print("❌ Pivot bounds extraction FAILED")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Pivot bounds extraction ERROR: {e}")
|
||||
return False
|
||||
|
||||
# Test 3: Pivot Context Features
|
||||
print("\n" + "="*60)
|
||||
print("TEST 3: PIVOT CONTEXT FEATURES")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
logger.info("Testing pivot context features...")
|
||||
|
||||
# Get recent data for testing
|
||||
recent_data = data_provider.get_historical_data(symbol, '1m', limit=100)
|
||||
|
||||
if recent_data is not None and not recent_data.empty:
|
||||
# Add pivot context features
|
||||
with_pivot_features = data_provider._add_pivot_context_features(recent_data, symbol)
|
||||
|
||||
# Check if pivot features were added
|
||||
pivot_features = [col for col in with_pivot_features.columns if 'pivot' in col]
|
||||
|
||||
if pivot_features:
|
||||
print(f"✅ Pivot context features SUCCESS")
|
||||
print(f" 🎯 Added features: {pivot_features}")
|
||||
|
||||
# Show sample values
|
||||
latest_row = with_pivot_features.iloc[-1]
|
||||
print(f" 📊 Latest values:")
|
||||
for feature in pivot_features:
|
||||
print(f" {feature}: {latest_row[feature]:.4f}")
|
||||
else:
|
||||
print("❌ No pivot context features added")
|
||||
return False
|
||||
else:
|
||||
print("❌ Could not get recent data for testing")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Pivot context features ERROR: {e}")
|
||||
return False
|
||||
|
||||
# Test 4: Pivot-Based Normalization
|
||||
print("\n" + "="*60)
|
||||
print("TEST 4: PIVOT-BASED NORMALIZATION")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
logger.info("Testing pivot-based normalization...")
|
||||
|
||||
# Get data with technical indicators
|
||||
data_with_indicators = data_provider.get_historical_data(symbol, '1m', limit=50)
|
||||
|
||||
if data_with_indicators is not None and not data_with_indicators.empty:
|
||||
# Test traditional vs pivot normalization
|
||||
traditional_norm = data_provider._normalize_features(data_with_indicators.tail(10))
|
||||
pivot_norm = data_provider._normalize_features(data_with_indicators.tail(10), symbol=symbol)
|
||||
|
||||
print(f"✅ Pivot-based normalization SUCCESS")
|
||||
print(f" 📊 Traditional normalization shape: {traditional_norm.shape}")
|
||||
print(f" 🎯 Pivot normalization shape: {pivot_norm.shape}")
|
||||
|
||||
# Compare price normalization
|
||||
if 'close' in pivot_norm.columns:
|
||||
trad_close_range = traditional_norm['close'].max() - traditional_norm['close'].min()
|
||||
pivot_close_range = pivot_norm['close'].max() - pivot_norm['close'].min()
|
||||
|
||||
print(f" 💰 Traditional close range: {trad_close_range:.6f}")
|
||||
print(f" 🎯 Pivot close range: {pivot_close_range:.6f}")
|
||||
|
||||
# Pivot normalization should be better bounded
|
||||
if 0 <= pivot_norm['close'].min() and pivot_norm['close'].max() <= 1:
|
||||
print(f" ✅ Pivot normalization properly bounded [0,1]")
|
||||
else:
|
||||
print(f" ⚠️ Pivot normalization outside [0,1] bounds")
|
||||
else:
|
||||
print("❌ Could not get data for normalization testing")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Pivot-based normalization ERROR: {e}")
|
||||
return False
|
||||
|
||||
# Test 5: Feature Matrix with Pivot Normalization
|
||||
print("\n" + "="*60)
|
||||
print("TEST 5: FEATURE MATRIX WITH PIVOT NORMALIZATION")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
logger.info("Testing feature matrix with pivot normalization...")
|
||||
|
||||
# Create feature matrix using pivot normalization
|
||||
feature_matrix = data_provider.get_feature_matrix(symbol, timeframes=['1m'], window_size=20)
|
||||
|
||||
if feature_matrix is not None:
|
||||
print(f"✅ Feature matrix with pivot normalization SUCCESS")
|
||||
print(f" 📊 Matrix shape: {feature_matrix.shape}")
|
||||
print(f" 🎯 Data range: [{feature_matrix.min():.4f}, {feature_matrix.max():.4f}]")
|
||||
print(f" 📈 Mean: {feature_matrix.mean():.4f}")
|
||||
print(f" 📊 Std: {feature_matrix.std():.4f}")
|
||||
|
||||
# Check for proper normalization
|
||||
if feature_matrix.min() >= -5 and feature_matrix.max() <= 5: # Reasonable bounds
|
||||
print(f" ✅ Feature matrix reasonably bounded")
|
||||
else:
|
||||
print(f" ⚠️ Feature matrix may have extreme values")
|
||||
else:
|
||||
print("❌ Feature matrix creation FAILED")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Feature matrix ERROR: {e}")
|
||||
return False
|
||||
|
||||
# Test 6: Caching System
|
||||
print("\n" + "="*60)
|
||||
print("TEST 6: CACHING SYSTEM")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
logger.info("Testing caching system...")
|
||||
|
||||
# Test pivot bounds caching
|
||||
original_bounds = data_provider.pivot_bounds[symbol]
|
||||
data_provider._save_pivot_bounds_to_cache(symbol, original_bounds)
|
||||
|
||||
# Clear from memory and reload
|
||||
del data_provider.pivot_bounds[symbol]
|
||||
loaded_bounds = data_provider._load_pivot_bounds_from_cache(symbol)
|
||||
|
||||
if loaded_bounds is not None:
|
||||
print(f"✅ Pivot bounds caching SUCCESS")
|
||||
print(f" 💾 Original price range: ${original_bounds.price_min:.2f} - ${original_bounds.price_max:.2f}")
|
||||
print(f" 💾 Loaded price range: ${loaded_bounds.price_min:.2f} - ${loaded_bounds.price_max:.2f}")
|
||||
|
||||
# Restore bounds
|
||||
data_provider.pivot_bounds[symbol] = loaded_bounds
|
||||
else:
|
||||
print("❌ Pivot bounds caching FAILED")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Caching system ERROR: {e}")
|
||||
return False
|
||||
|
||||
# Test 7: Public API Methods
|
||||
print("\n" + "="*60)
|
||||
print("TEST 7: PUBLIC API METHODS")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
logger.info("Testing public API methods...")
|
||||
|
||||
# Test get_pivot_bounds
|
||||
api_bounds = data_provider.get_pivot_bounds(symbol)
|
||||
if api_bounds is not None:
|
||||
print(f"✅ get_pivot_bounds() SUCCESS")
|
||||
print(f" 📊 Returned bounds for {api_bounds.symbol}")
|
||||
|
||||
# Test get_pivot_normalized_features
|
||||
test_data = data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||
if test_data is not None:
|
||||
normalized_data = data_provider.get_pivot_normalized_features(symbol, test_data)
|
||||
if normalized_data is not None:
|
||||
print(f"✅ get_pivot_normalized_features() SUCCESS")
|
||||
print(f" 📊 Normalized data shape: {normalized_data.shape}")
|
||||
else:
|
||||
print("❌ get_pivot_normalized_features() FAILED")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Public API methods ERROR: {e}")
|
||||
return False
|
||||
|
||||
# Final Summary
|
||||
print("\n" + "="*80)
|
||||
print("🎉 PIVOT-BASED NORMALIZATION SYSTEM TEST COMPLETE")
|
||||
print("="*80)
|
||||
print("✅ All tests PASSED successfully!")
|
||||
print("\n📋 System Features Verified:")
|
||||
print(" ✅ Monthly 1s data collection with pagination")
|
||||
print(" ✅ Williams Market Structure pivot analysis")
|
||||
print(" ✅ Pivot bounds extraction and validation")
|
||||
print(" ✅ Pivot context features generation")
|
||||
print(" ✅ Pivot-based feature normalization")
|
||||
print(" ✅ Feature matrix creation with pivot bounds")
|
||||
print(" ✅ Comprehensive caching system")
|
||||
print(" ✅ Public API methods")
|
||||
|
||||
print(f"\n🎯 Ready for model training with pivot-normalized features!")
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
success = test_pivot_normalization_system()
|
||||
|
||||
if success:
|
||||
print("\n🚀 Pivot-based normalization system ready for production!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\n❌ Pivot-based normalization system has issues!")
|
||||
sys.exit(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⏹️ Test interrupted by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n💥 Unexpected error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
@ -1,80 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test PnL Tracking System
|
||||
|
||||
This script demonstrates the ultra-fast scalping PnL tracking system
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from run_scalping_dashboard import UltraFastScalpingRunner
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
def test_pnl_tracking():
|
||||
"""Test the PnL tracking system"""
|
||||
print("🔥 TESTING ULTRA-FAST SCALPING PnL TRACKING 🔥")
|
||||
print("="*60)
|
||||
|
||||
# Create runner
|
||||
runner = UltraFastScalpingRunner()
|
||||
|
||||
print(f"💰 Starting Balance: ${runner.balance:.2f}")
|
||||
print(f"📊 Leverage: {runner.leverage}x")
|
||||
print(f"💳 Trading Fee: {runner.trading_fee*100:.3f}% per trade")
|
||||
print("⚡ Starting simulation for 30 seconds...")
|
||||
print("="*60)
|
||||
|
||||
# Start simulation
|
||||
runner.start_ultra_fast_simulation()
|
||||
|
||||
try:
|
||||
# Run for 30 seconds
|
||||
time.sleep(30)
|
||||
except KeyboardInterrupt:
|
||||
print("\n🛑 Stopping simulation...")
|
||||
|
||||
# Stop simulation
|
||||
runner.running = False
|
||||
|
||||
# Wait for threads to finish
|
||||
if runner.simulation_thread:
|
||||
runner.simulation_thread.join(timeout=2)
|
||||
if runner.exit_monitor_thread:
|
||||
runner.exit_monitor_thread.join(timeout=2)
|
||||
|
||||
# Print final results
|
||||
print("\n" + "="*60)
|
||||
print("💼 FINAL PnL TRACKING RESULTS:")
|
||||
print("="*60)
|
||||
print(f"📊 Total Trades: {len(runner.closed_trades)}")
|
||||
print(f"🎯 Total PnL: ${runner.total_pnl:+.2f}")
|
||||
print(f"💳 Total Fees: ${runner.total_fees:.2f}")
|
||||
print(f"🟢 Wins: {runner.win_count} | 🔴 Losses: {runner.loss_count}")
|
||||
if runner.win_count + runner.loss_count > 0:
|
||||
win_rate = runner.win_count / (runner.win_count + runner.loss_count)
|
||||
print(f"📈 Win Rate: {win_rate*100:.1f}%")
|
||||
print(f"💰 Starting Balance: ${runner.balance:.2f}")
|
||||
print(f"💰 Final Balance: ${runner.balance + runner.total_pnl:.2f}")
|
||||
if runner.balance > 0:
|
||||
return_pct = ((runner.balance + runner.total_pnl) / runner.balance - 1) * 100
|
||||
print(f"📊 Return: {return_pct:+.2f}%")
|
||||
print(f"📋 Open Positions: {len(runner.open_positions)}")
|
||||
print("="*60)
|
||||
|
||||
# Show sample of closed trades
|
||||
if runner.closed_trades:
|
||||
print("\n📈 SAMPLE CLOSED TRADES:")
|
||||
print("-" * 40)
|
||||
for i, trade in enumerate(runner.closed_trades[-5:]): # Last 5 trades
|
||||
duration = (trade.exit_time - trade.entry_time).total_seconds()
|
||||
pnl_color = "🟢" if trade.pnl > 0 else "🔴"
|
||||
print(f"{pnl_color} Trade #{trade.trade_id}: {trade.action} {trade.symbol}")
|
||||
print(f" Entry: ${trade.entry_price:.2f} → Exit: ${trade.exit_price:.2f}")
|
||||
print(f" Duration: {duration:.1f}s | PnL: ${trade.pnl:+.2f}")
|
||||
|
||||
print("\n✅ PnL Tracking Test Complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pnl_tracking()
|
@ -1,134 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for enhanced PnL tracking with position flipping and color coding
|
||||
"""
|
||||
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_enhanced_pnl_tracking():
|
||||
"""Test the enhanced PnL tracking with position flipping"""
|
||||
try:
|
||||
print("="*60)
|
||||
print("TESTING ENHANCED PnL TRACKING & POSITION COLOR CODING")
|
||||
print("="*60)
|
||||
|
||||
# Import dashboard
|
||||
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
|
||||
|
||||
# Create dashboard instance
|
||||
dashboard = TradingDashboard()
|
||||
|
||||
print(f"✓ Dashboard created")
|
||||
print(f"✓ Initial position: {dashboard.current_position}")
|
||||
print(f"✓ Initial realized PnL: ${dashboard.total_realized_pnl:.2f}")
|
||||
print(f"✓ Initial session trades: {len(dashboard.session_trades)}")
|
||||
|
||||
# Test sequence of trades with position flipping
|
||||
test_trades = [
|
||||
{'action': 'BUY', 'price': 3000.0, 'size': 0.1, 'confidence': 0.75}, # Open LONG
|
||||
{'action': 'SELL', 'price': 3050.0, 'size': 0.1, 'confidence': 0.80}, # Close LONG (+$5 profit)
|
||||
{'action': 'SELL', 'price': 3040.0, 'size': 0.1, 'confidence': 0.70}, # Open SHORT
|
||||
{'action': 'BUY', 'price': 3020.0, 'size': 0.1, 'confidence': 0.85}, # Close SHORT (+$2 profit) & flip to LONG
|
||||
{'action': 'SELL', 'price': 3010.0, 'size': 0.1, 'confidence': 0.65}, # Close LONG (-$1 loss)
|
||||
]
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("EXECUTING TEST TRADE SEQUENCE:")
|
||||
print("="*60)
|
||||
|
||||
for i, trade in enumerate(test_trades, 1):
|
||||
print(f"\n--- Trade {i}: {trade['action']} @ ${trade['price']:.2f} ---")
|
||||
|
||||
# Add required fields
|
||||
trade['symbol'] = 'ETH/USDT'
|
||||
trade['timestamp'] = datetime.now(timezone.utc)
|
||||
trade['reason'] = f'Test trade {i}'
|
||||
|
||||
# Process the trade
|
||||
dashboard._process_trading_decision(trade)
|
||||
|
||||
# Show results
|
||||
print(f"Current position: {dashboard.current_position}")
|
||||
print(f"Realized PnL: ${dashboard.total_realized_pnl:.2f}")
|
||||
print(f"Total trades: {len(dashboard.session_trades)}")
|
||||
print(f"Recent decisions: {len(dashboard.recent_decisions)}")
|
||||
|
||||
# Test unrealized PnL calculation
|
||||
if dashboard.current_position:
|
||||
current_price = trade['price'] + 5.0 # Simulate price movement
|
||||
unrealized_pnl = dashboard._calculate_unrealized_pnl(current_price)
|
||||
print(f"Unrealized PnL @ ${current_price:.2f}: ${unrealized_pnl:.2f}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("FINAL RESULTS:")
|
||||
print("="*60)
|
||||
print(f"✓ Total realized PnL: ${dashboard.total_realized_pnl:.2f}")
|
||||
print(f"✓ Total fees paid: ${dashboard.total_fees:.2f}")
|
||||
print(f"✓ Total trades executed: {len(dashboard.session_trades)}")
|
||||
print(f"✓ Final position: {dashboard.current_position}")
|
||||
|
||||
# Test session performance calculation
|
||||
print("\n" + "="*60)
|
||||
print("SESSION PERFORMANCE TEST:")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
session_perf = dashboard._create_session_performance()
|
||||
print(f"✓ Session performance component created successfully")
|
||||
print(f"✓ Performance items count: {len(session_perf)}")
|
||||
except Exception as e:
|
||||
print(f"❌ Session performance error: {e}")
|
||||
|
||||
# Test decisions list with PnL info
|
||||
print("\n" + "="*60)
|
||||
print("DECISIONS LIST WITH PnL TEST:")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
decisions_list = dashboard._create_decisions_list()
|
||||
print(f"✓ Decisions list created successfully")
|
||||
print(f"✓ Decisions items count: {len(decisions_list)}")
|
||||
|
||||
# Check for PnL information in closed trades
|
||||
closed_trades = [t for t in dashboard.session_trades if 'pnl' in t]
|
||||
print(f"✓ Closed trades with PnL: {len(closed_trades)}")
|
||||
|
||||
for trade in closed_trades:
|
||||
action = trade.get('position_action', 'UNKNOWN')
|
||||
pnl = trade.get('pnl', 0)
|
||||
entry_price = trade.get('entry_price', 0)
|
||||
exit_price = trade.get('price', 0)
|
||||
print(f" - {action}: Entry ${entry_price:.2f} -> Exit ${exit_price:.2f} = PnL ${pnl:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Decisions list error: {e}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("ENHANCED FEATURES VERIFIED:")
|
||||
print("="*60)
|
||||
print("✓ Position flipping (LONG -> SHORT -> LONG)")
|
||||
print("✓ PnL calculation for closed trades")
|
||||
print("✓ Color coding for positions based on side and P&L")
|
||||
print("✓ Entry/exit price tracking")
|
||||
print("✓ Real-time unrealized PnL calculation")
|
||||
print("✓ ASCII indicators (no Unicode for Windows compatibility)")
|
||||
print("✓ Enhanced trade logging with PnL information")
|
||||
print("✓ Session performance metrics with PnL breakdown")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing enhanced PnL tracking: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_enhanced_pnl_tracking()
|
||||
sys.exit(0 if success else 1)
|
@ -1,92 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for real-time COB functionality
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
async def test_realtime_cob():
|
||||
"""Test real-time COB data streaming"""
|
||||
|
||||
# Test API endpoints
|
||||
base_url = "http://localhost:8053"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
print("Testing COB Dashboard API endpoints...")
|
||||
|
||||
# Test symbols endpoint
|
||||
try:
|
||||
async with session.get(f"{base_url}/api/symbols") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✓ Symbols: {data}")
|
||||
else:
|
||||
print(f"✗ Symbols endpoint failed: {response.status}")
|
||||
except Exception as e:
|
||||
print(f"✗ Error testing symbols endpoint: {e}")
|
||||
|
||||
# Test real-time stats for BTC/USDT
|
||||
try:
|
||||
async with session.get(f"{base_url}/api/realtime/BTC/USDT") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✓ Real-time stats for BTC/USDT:")
|
||||
print(f" Current mid price: {data.get('current', {}).get('mid_price', 'N/A')}")
|
||||
print(f" 1s window updates: {data.get('1s_window', {}).get('update_count', 'N/A')}")
|
||||
print(f" 5s window updates: {data.get('5s_window', {}).get('update_count', 'N/A')}")
|
||||
else:
|
||||
print(f"✗ Real-time stats endpoint failed: {response.status}")
|
||||
error_data = await response.text()
|
||||
print(f" Error: {error_data}")
|
||||
except Exception as e:
|
||||
print(f"✗ Error testing real-time stats endpoint: {e}")
|
||||
|
||||
# Test WebSocket connection
|
||||
print("\nTesting WebSocket connection...")
|
||||
try:
|
||||
async with session.ws_connect(f"{base_url.replace('http', 'ws')}/ws") as ws:
|
||||
print("✓ WebSocket connected")
|
||||
|
||||
# Wait for some data
|
||||
message_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
message_count += 1
|
||||
|
||||
if data.get('type') == 'cob_update':
|
||||
symbol = data.get('data', {}).get('stats', {}).get('symbol', 'Unknown')
|
||||
mid_price = data.get('data', {}).get('stats', {}).get('mid_price', 0)
|
||||
print(f"✓ Received COB update for {symbol}: ${mid_price:.2f}")
|
||||
|
||||
# Check for real-time stats
|
||||
if 'realtime_1s' in data.get('data', {}).get('stats', {}):
|
||||
print(f" ✓ Real-time 1s stats available")
|
||||
if 'realtime_5s' in data.get('data', {}).get('stats', {}):
|
||||
print(f" ✓ Real-time 5s stats available")
|
||||
|
||||
# Stop after 5 messages or 10 seconds
|
||||
if message_count >= 5 or (time.time() - start_time) > 10:
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
print(f"✗ WebSocket error: {ws.exception()}")
|
||||
break
|
||||
|
||||
print(f"✓ Received {message_count} WebSocket messages")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ WebSocket connection failed: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing Real-time COB Dashboard")
|
||||
print("=" * 40)
|
||||
|
||||
asyncio.run(test_realtime_cob())
|
||||
|
||||
print("\nTest completed!")
|
@ -1,555 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Script for Real-time RL COB Trader
|
||||
|
||||
This script tests the real-time reinforcement learning system to ensure:
|
||||
1. Proper model initialization and parameter count (~1B parameters)
|
||||
2. COB data integration and feature extraction
|
||||
3. Real-time inference pipeline
|
||||
4. Signal accumulation and consensus
|
||||
5. Training loop functionality
|
||||
6. Trade execution integration
|
||||
|
||||
Run this before deploying the live system.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import time
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
# Local imports
|
||||
from core.realtime_rl_cob_trader import RealtimeRLCOBTrader, MassiveRLNetwork, PredictionResult
|
||||
from core.trading_executor import TradingExecutor
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RealtimeRLTester:
|
||||
"""
|
||||
Comprehensive tester for Real-time RL COB Trader
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.test_results = {}
|
||||
self.trader = None
|
||||
self.trading_executor = None
|
||||
|
||||
async def run_all_tests(self):
|
||||
"""Run all tests and generate report"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("REAL-TIME RL COB TRADER TESTING SUITE")
|
||||
logger.info("=" * 60)
|
||||
|
||||
tests = [
|
||||
self.test_model_initialization,
|
||||
self.test_model_parameter_count,
|
||||
self.test_feature_extraction,
|
||||
self.test_inference_performance,
|
||||
self.test_signal_accumulation,
|
||||
self.test_training_pipeline,
|
||||
self.test_trading_integration,
|
||||
self.test_performance_monitoring
|
||||
]
|
||||
|
||||
for test in tests:
|
||||
try:
|
||||
await test()
|
||||
except Exception as e:
|
||||
logger.error(f"Test {test.__name__} failed: {e}")
|
||||
self.test_results[test.__name__] = {'status': 'FAILED', 'error': str(e)}
|
||||
|
||||
await self.generate_test_report()
|
||||
|
||||
async def test_model_initialization(self):
|
||||
"""Test model initialization and architecture"""
|
||||
logger.info("🧠 Testing Model Initialization...")
|
||||
|
||||
try:
|
||||
# Test model creation
|
||||
model = MassiveRLNetwork(input_size=2000, hidden_size=4096, num_layers=12)
|
||||
|
||||
# Check if CUDA is available
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model = model.to(device)
|
||||
|
||||
# Test forward pass
|
||||
batch_size = 4
|
||||
test_input = torch.randn(batch_size, 2000).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(test_input)
|
||||
|
||||
# Verify outputs
|
||||
assert 'price_logits' in outputs
|
||||
assert 'value' in outputs
|
||||
assert 'confidence' in outputs
|
||||
assert 'features' in outputs
|
||||
|
||||
assert outputs['price_logits'].shape == (batch_size, 3) # DOWN, SIDEWAYS, UP
|
||||
assert outputs['value'].shape == (batch_size, 1)
|
||||
assert outputs['confidence'].shape == (batch_size, 1)
|
||||
|
||||
self.test_results['test_model_initialization'] = {
|
||||
'status': 'PASSED',
|
||||
'device': str(device),
|
||||
'output_shapes': {k: list(v.shape) for k, v in outputs.items()}
|
||||
}
|
||||
|
||||
logger.info("✅ Model initialization test PASSED")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results['test_model_initialization'] = {'status': 'FAILED', 'error': str(e)}
|
||||
raise
|
||||
|
||||
async def test_model_parameter_count(self):
|
||||
"""Test that model has approximately 400M parameters"""
|
||||
logger.info("🔢 Testing Model Parameter Count...")
|
||||
|
||||
try:
|
||||
model = MassiveRLNetwork(input_size=2000, hidden_size=2048, num_layers=8)
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
logger.info(f"Total parameters: {total_params:,}")
|
||||
logger.info(f"Trainable parameters: {trainable_params:,}")
|
||||
|
||||
# Check if parameters are approximately 400M (350M - 450M range)
|
||||
target_400m = total_params >= 350_000_000 and total_params <= 450_000_000
|
||||
|
||||
self.test_results['test_model_parameter_count'] = {
|
||||
'status': 'PASSED' if target_400m else 'WARNING',
|
||||
'total_parameters': total_params,
|
||||
'trainable_parameters': trainable_params,
|
||||
'parameter_size_gb': (total_params * 4) / (1024**3), # 4 bytes per float32
|
||||
'is_optimized': target_400m, # Around 400M parameters for faster startup
|
||||
'target_range': '350M - 450M parameters'
|
||||
}
|
||||
|
||||
logger.info(f"✅ Model has {total_params:,} parameters ({total_params/1e6:.0f}M)")
|
||||
if target_400m:
|
||||
logger.info("✅ Parameter count within 400M target range for fast startup")
|
||||
else:
|
||||
logger.warning(f"⚠️ Parameter count outside 400M target range: {total_params/1e6:.0f}M")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results['test_model_parameter_count'] = {'status': 'FAILED', 'error': str(e)}
|
||||
raise
|
||||
|
||||
async def test_feature_extraction(self):
|
||||
"""Test feature extraction from COB data"""
|
||||
logger.info("🔍 Testing Feature Extraction...")
|
||||
|
||||
try:
|
||||
# Initialize trader
|
||||
self.trading_executor = TradingExecutor(simulation_mode=True)
|
||||
self.trader = RealtimeRLCOBTrader(
|
||||
symbols=['BTC/USDT'],
|
||||
trading_executor=self.trading_executor,
|
||||
inference_interval_ms=1000 # Slower for testing
|
||||
)
|
||||
|
||||
# Create mock COB data
|
||||
mock_cob_data = {
|
||||
'state': np.random.randn(1500), # Mock state features
|
||||
'timestamp': datetime.now(),
|
||||
'type': 'cob_state'
|
||||
}
|
||||
|
||||
# Test feature extraction
|
||||
features = self.trader._extract_features('BTC/USDT', mock_cob_data)
|
||||
|
||||
assert features is not None
|
||||
assert len(features) == 2000 # Target feature size
|
||||
assert features.dtype == np.float32
|
||||
assert not np.any(np.isnan(features))
|
||||
assert not np.any(np.isinf(features))
|
||||
|
||||
# Test normalization
|
||||
assert np.abs(np.mean(features)) < 1.0 # Roughly normalized
|
||||
assert np.std(features) < 10.0 # Not too spread out
|
||||
|
||||
self.test_results['test_feature_extraction'] = {
|
||||
'status': 'PASSED',
|
||||
'feature_size': len(features),
|
||||
'feature_range': [float(np.min(features)), float(np.max(features))],
|
||||
'feature_stats': {
|
||||
'mean': float(np.mean(features)),
|
||||
'std': float(np.std(features)),
|
||||
'median': float(np.median(features))
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("✅ Feature extraction test PASSED")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results['test_feature_extraction'] = {'status': 'FAILED', 'error': str(e)}
|
||||
raise
|
||||
|
||||
async def test_inference_performance(self):
|
||||
"""Test inference speed and quality"""
|
||||
logger.info("⚡ Testing Inference Performance...")
|
||||
|
||||
try:
|
||||
if not self.trader:
|
||||
self.trading_executor = TradingExecutor(simulation_mode=True)
|
||||
self.trader = RealtimeRLCOBTrader(
|
||||
symbols=['BTC/USDT'],
|
||||
trading_executor=self.trading_executor
|
||||
)
|
||||
|
||||
# Test multiple inferences
|
||||
num_tests = 10
|
||||
inference_times = []
|
||||
|
||||
for i in range(num_tests):
|
||||
# Create test features
|
||||
test_features = np.random.randn(2000).astype(np.float32)
|
||||
test_features = self.trader._normalize_features(test_features)
|
||||
|
||||
# Time inference
|
||||
start_time = time.time()
|
||||
prediction = self.trader._predict('BTC/USDT', test_features)
|
||||
inference_time = (time.time() - start_time) * 1000
|
||||
|
||||
inference_times.append(inference_time)
|
||||
|
||||
# Verify prediction structure
|
||||
assert 'direction' in prediction
|
||||
assert 'confidence' in prediction
|
||||
assert 'change' in prediction
|
||||
assert 'value' in prediction
|
||||
|
||||
assert 0 <= prediction['direction'] <= 2
|
||||
assert 0.0 <= prediction['confidence'] <= 1.0
|
||||
assert isinstance(prediction['change'], float)
|
||||
assert isinstance(prediction['value'], float)
|
||||
|
||||
avg_inference_time = np.mean(inference_times)
|
||||
max_inference_time = np.max(inference_times)
|
||||
|
||||
# Check if inference is fast enough (target: <50ms per inference)
|
||||
inference_target_ms = 50.0
|
||||
|
||||
self.test_results['test_inference_performance'] = {
|
||||
'status': 'PASSED' if avg_inference_time < inference_target_ms else 'WARNING',
|
||||
'average_inference_time_ms': float(avg_inference_time),
|
||||
'max_inference_time_ms': float(max_inference_time),
|
||||
'target_time_ms': inference_target_ms,
|
||||
'meets_target': avg_inference_time < inference_target_ms,
|
||||
'inferences_per_second': 1000.0 / avg_inference_time
|
||||
}
|
||||
|
||||
logger.info(f"✅ Average inference time: {avg_inference_time:.2f}ms")
|
||||
logger.info(f"✅ Max inference time: {max_inference_time:.2f}ms")
|
||||
logger.info(f"✅ Inferences per second: {1000.0/avg_inference_time:.1f}")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results['test_inference_performance'] = {'status': 'FAILED', 'error': str(e)}
|
||||
raise
|
||||
|
||||
async def test_signal_accumulation(self):
|
||||
"""Test signal accumulation and consensus logic"""
|
||||
logger.info("🎯 Testing Signal Accumulation...")
|
||||
|
||||
try:
|
||||
if not self.trader:
|
||||
self.trading_executor = TradingExecutor(simulation_mode=True)
|
||||
self.trader = RealtimeRLCOBTrader(
|
||||
symbols=['BTC/USDT'],
|
||||
trading_executor=self.trading_executor,
|
||||
required_confident_predictions=3
|
||||
)
|
||||
|
||||
symbol = 'BTC/USDT'
|
||||
accumulator = self.trader.signal_accumulators[symbol]
|
||||
|
||||
# Test adding signals
|
||||
test_predictions = []
|
||||
for i in range(5):
|
||||
prediction = PredictionResult(
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol,
|
||||
predicted_direction=2, # UP
|
||||
confidence=0.8,
|
||||
predicted_change=0.001,
|
||||
features=np.random.randn(2000).astype(np.float32)
|
||||
)
|
||||
test_predictions.append(prediction)
|
||||
self.trader._add_signal(symbol, prediction)
|
||||
|
||||
# Check accumulator state
|
||||
assert len(accumulator.signals) == 5
|
||||
assert accumulator.confidence_sum == 5 * 0.8
|
||||
assert accumulator.total_predictions == 5
|
||||
|
||||
# Test consensus logic (simulate processing)
|
||||
recent_signals = list(accumulator.signals)[-3:]
|
||||
directions = [signal.predicted_direction for signal in recent_signals]
|
||||
|
||||
# All should be direction 2 (UP)
|
||||
direction_counts = {0: 0, 1: 0, 2: 0}
|
||||
for direction in directions:
|
||||
direction_counts[direction] += 1
|
||||
|
||||
dominant_direction = max(direction_counts, key=direction_counts.get)
|
||||
consensus_count = direction_counts[dominant_direction]
|
||||
|
||||
assert dominant_direction == 2
|
||||
assert consensus_count == 3
|
||||
|
||||
self.test_results['test_signal_accumulation'] = {
|
||||
'status': 'PASSED',
|
||||
'signals_added': len(accumulator.signals),
|
||||
'confidence_sum': accumulator.confidence_sum,
|
||||
'consensus_direction': dominant_direction,
|
||||
'consensus_count': consensus_count
|
||||
}
|
||||
|
||||
logger.info("✅ Signal accumulation test PASSED")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results['test_signal_accumulation'] = {'status': 'FAILED', 'error': str(e)}
|
||||
raise
|
||||
|
||||
async def test_training_pipeline(self):
|
||||
"""Test training pipeline functionality"""
|
||||
logger.info("🧠 Testing Training Pipeline...")
|
||||
|
||||
try:
|
||||
if not self.trader:
|
||||
self.trading_executor = TradingExecutor(simulation_mode=True)
|
||||
self.trader = RealtimeRLCOBTrader(
|
||||
symbols=['BTC/USDT'],
|
||||
trading_executor=self.trading_executor
|
||||
)
|
||||
|
||||
symbol = 'BTC/USDT'
|
||||
|
||||
# Create mock training data
|
||||
test_predictions = []
|
||||
for i in range(10):
|
||||
prediction = PredictionResult(
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol,
|
||||
predicted_direction=np.random.randint(0, 3),
|
||||
confidence=np.random.uniform(0.5, 1.0),
|
||||
predicted_change=np.random.uniform(-0.001, 0.001),
|
||||
features=np.random.randn(2000).astype(np.float32),
|
||||
actual_direction=np.random.randint(0, 3),
|
||||
actual_change=np.random.uniform(-0.001, 0.001),
|
||||
reward=np.random.uniform(-1.0, 1.0)
|
||||
)
|
||||
test_predictions.append(prediction)
|
||||
|
||||
# Test training batch
|
||||
loss = await self.trader._train_batch(symbol, test_predictions)
|
||||
|
||||
assert isinstance(loss, float)
|
||||
assert not np.isnan(loss)
|
||||
assert not np.isinf(loss)
|
||||
assert loss >= 0.0 # Loss should be non-negative
|
||||
|
||||
self.test_results['test_training_pipeline'] = {
|
||||
'status': 'PASSED',
|
||||
'training_loss': float(loss),
|
||||
'batch_size': len(test_predictions),
|
||||
'training_successful': True
|
||||
}
|
||||
|
||||
logger.info(f"✅ Training pipeline test PASSED (loss: {loss:.6f})")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results['test_training_pipeline'] = {'status': 'FAILED', 'error': str(e)}
|
||||
raise
|
||||
|
||||
async def test_trading_integration(self):
|
||||
"""Test integration with trading executor"""
|
||||
logger.info("💰 Testing Trading Integration...")
|
||||
|
||||
try:
|
||||
# Initialize with simulation mode
|
||||
trading_executor = TradingExecutor(simulation_mode=True)
|
||||
|
||||
# Test signal execution
|
||||
success = trading_executor.execute_signal(
|
||||
symbol='BTC/USDT',
|
||||
action='BUY',
|
||||
confidence=0.8,
|
||||
current_price=50000.0
|
||||
)
|
||||
|
||||
# In simulation mode, this should always succeed
|
||||
assert success == True
|
||||
|
||||
# Check positions
|
||||
positions = trading_executor.get_positions()
|
||||
assert 'BTC/USDT' in positions
|
||||
|
||||
# Test sell signal
|
||||
success = trading_executor.execute_signal(
|
||||
symbol='BTC/USDT',
|
||||
action='SELL',
|
||||
confidence=0.8,
|
||||
current_price=50100.0
|
||||
)
|
||||
|
||||
assert success == True
|
||||
|
||||
# Check trade history
|
||||
trade_history = trading_executor.get_trade_history()
|
||||
assert len(trade_history) > 0
|
||||
|
||||
last_trade = trade_history[-1]
|
||||
assert last_trade.symbol == 'BTC/USDT'
|
||||
assert last_trade.pnl != 0 # Should have some P&L
|
||||
|
||||
self.test_results['test_trading_integration'] = {
|
||||
'status': 'PASSED',
|
||||
'simulation_mode': True,
|
||||
'trades_executed': len(trade_history),
|
||||
'last_trade_pnl': float(last_trade.pnl)
|
||||
}
|
||||
|
||||
logger.info("✅ Trading integration test PASSED")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results['test_trading_integration'] = {'status': 'FAILED', 'error': str(e)}
|
||||
raise
|
||||
|
||||
async def test_performance_monitoring(self):
|
||||
"""Test performance monitoring and statistics"""
|
||||
logger.info("📊 Testing Performance Monitoring...")
|
||||
|
||||
try:
|
||||
if not self.trader:
|
||||
self.trading_executor = TradingExecutor(simulation_mode=True)
|
||||
self.trader = RealtimeRLCOBTrader(
|
||||
symbols=['BTC/USDT', 'ETH/USDT'],
|
||||
trading_executor=self.trading_executor
|
||||
)
|
||||
|
||||
# Get performance stats
|
||||
stats = self.trader.get_performance_stats()
|
||||
|
||||
# Verify structure
|
||||
assert 'symbols' in stats
|
||||
assert 'training_stats' in stats
|
||||
assert 'inference_stats' in stats
|
||||
assert 'signal_stats' in stats
|
||||
assert 'model_info' in stats
|
||||
|
||||
# Check symbols
|
||||
assert 'BTC/USDT' in stats['symbols']
|
||||
assert 'ETH/USDT' in stats['symbols']
|
||||
|
||||
# Check model info
|
||||
for symbol in stats['symbols']:
|
||||
assert symbol in stats['model_info']
|
||||
model_info = stats['model_info'][symbol]
|
||||
assert 'total_parameters' in model_info
|
||||
assert 'trainable_parameters' in model_info
|
||||
assert model_info['total_parameters'] > 0
|
||||
|
||||
self.test_results['test_performance_monitoring'] = {
|
||||
'status': 'PASSED',
|
||||
'stats_structure_valid': True,
|
||||
'symbols_tracked': len(stats['symbols']),
|
||||
'model_info_available': len(stats['model_info'])
|
||||
}
|
||||
|
||||
logger.info("✅ Performance monitoring test PASSED")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results['test_performance_monitoring'] = {'status': 'FAILED', 'error': str(e)}
|
||||
raise
|
||||
|
||||
async def generate_test_report(self):
|
||||
"""Generate comprehensive test report"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("REAL-TIME RL COB TRADER TEST REPORT")
|
||||
logger.info("=" * 60)
|
||||
|
||||
total_tests = len(self.test_results)
|
||||
passed_tests = sum(1 for result in self.test_results.values() if result['status'] == 'PASSED')
|
||||
failed_tests = sum(1 for result in self.test_results.values() if result['status'] == 'FAILED')
|
||||
warning_tests = sum(1 for result in self.test_results.values() if result['status'] == 'WARNING')
|
||||
|
||||
logger.info(f"📊 Test Summary:")
|
||||
logger.info(f" Total Tests: {total_tests}")
|
||||
logger.info(f" ✅ Passed: {passed_tests}")
|
||||
logger.info(f" ⚠️ Warnings: {warning_tests}")
|
||||
logger.info(f" ❌ Failed: {failed_tests}")
|
||||
|
||||
success_rate = (passed_tests / total_tests) * 100 if total_tests > 0 else 0
|
||||
logger.info(f" Success Rate: {success_rate:.1f}%")
|
||||
|
||||
logger.info("\n📋 Detailed Results:")
|
||||
for test_name, result in self.test_results.items():
|
||||
status_icon = "✅" if result['status'] == 'PASSED' else "⚠️" if result['status'] == 'WARNING' else "❌"
|
||||
logger.info(f" {status_icon} {test_name}: {result['status']}")
|
||||
|
||||
if result['status'] == 'FAILED':
|
||||
logger.error(f" Error: {result.get('error', 'Unknown error')}")
|
||||
|
||||
# System readiness assessment
|
||||
logger.info("\n🎯 System Readiness Assessment:")
|
||||
if failed_tests == 0:
|
||||
if warning_tests == 0:
|
||||
logger.info(" 🟢 SYSTEM READY FOR DEPLOYMENT")
|
||||
logger.info(" All tests passed. The real-time RL COB trader is ready for live operation.")
|
||||
else:
|
||||
logger.info(" 🟡 SYSTEM READY WITH WARNINGS")
|
||||
logger.info(" System is functional but some performance warnings exist.")
|
||||
else:
|
||||
logger.info(" 🔴 SYSTEM NOT READY")
|
||||
logger.info(" Critical issues found. Fix errors before deployment.")
|
||||
|
||||
# Save detailed report
|
||||
report_data = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'test_summary': {
|
||||
'total_tests': total_tests,
|
||||
'passed_tests': passed_tests,
|
||||
'warning_tests': warning_tests,
|
||||
'failed_tests': failed_tests,
|
||||
'success_rate': success_rate
|
||||
},
|
||||
'test_results': self.test_results,
|
||||
'system_readiness': 'READY' if failed_tests == 0 else 'NOT_READY'
|
||||
}
|
||||
|
||||
report_file = f"test_reports/realtime_rl_test_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
|
||||
import os
|
||||
os.makedirs('test_reports', exist_ok=True)
|
||||
|
||||
with open(report_file, 'w') as f:
|
||||
json.dump(report_data, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"\n📄 Detailed report saved to: {report_file}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
async def main():
|
||||
"""Main test entry point"""
|
||||
logger.info("Starting Real-time RL COB Trader Test Suite...")
|
||||
|
||||
tester = RealtimeRLTester()
|
||||
await tester.run_all_tests()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set event loop policy for Windows compatibility
|
||||
if hasattr(asyncio, 'WindowsProactorEventLoopPolicy'):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||
|
||||
asyncio.run(main())
|
@ -1,279 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Real-Time Tick Processor
|
||||
|
||||
This script tests the Neural Network Real-Time Tick Processing Module
|
||||
to ensure it properly processes tick data with volume information and
|
||||
feeds processed features to models in real-time.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.realtime_tick_processor import RealTimeTickProcessor, ProcessedTickFeatures, create_realtime_tick_processor
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
from core.config import get_config
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def test_realtime_tick_processor():
|
||||
"""Test the real-time tick processor functionality"""
|
||||
logger.info("="*80)
|
||||
logger.info("🧪 TESTING REAL-TIME TICK PROCESSOR")
|
||||
logger.info("="*80)
|
||||
|
||||
try:
|
||||
# Test 1: Create tick processor
|
||||
logger.info("\n📊 TEST 1: Creating Real-Time Tick Processor")
|
||||
logger.info("-" * 40)
|
||||
|
||||
symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
tick_processor = create_realtime_tick_processor(symbols)
|
||||
|
||||
logger.info("✅ Tick processor created successfully")
|
||||
logger.info(f" Symbols: {tick_processor.symbols}")
|
||||
logger.info(f" Device: {tick_processor.device}")
|
||||
logger.info(f" Buffer size: {tick_processor.tick_buffer_size}")
|
||||
|
||||
# Test 2: Feature subscriber
|
||||
logger.info("\n📡 TEST 2: Feature Subscriber Integration")
|
||||
logger.info("-" * 40)
|
||||
|
||||
received_features = []
|
||||
|
||||
def test_callback(symbol: str, features: ProcessedTickFeatures):
|
||||
"""Test callback to receive processed features"""
|
||||
received_features.append((symbol, features))
|
||||
logger.info(f"Received features for {symbol}: confidence={features.confidence:.3f}")
|
||||
logger.info(f" Neural features shape: {features.neural_features.shape}")
|
||||
logger.info(f" Volume features shape: {features.volume_features.shape}")
|
||||
logger.info(f" Price features shape: {features.price_features.shape}")
|
||||
logger.info(f" Microstructure features shape: {features.microstructure_features.shape}")
|
||||
|
||||
tick_processor.add_feature_subscriber(test_callback)
|
||||
logger.info("✅ Feature subscriber added")
|
||||
|
||||
# Test 3: Start processing (short duration)
|
||||
logger.info("\n🚀 TEST 3: Start Real-Time Processing")
|
||||
logger.info("-" * 40)
|
||||
|
||||
logger.info("Starting tick processing for 30 seconds...")
|
||||
await tick_processor.start_processing()
|
||||
|
||||
# Let it run for 30 seconds to collect some data
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < 30:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check stats every 5 seconds
|
||||
if int(time.time() - start_time) % 5 == 0:
|
||||
stats = tick_processor.get_processing_stats()
|
||||
logger.info(f"Processing stats: {stats.get('tick_counts', {})}")
|
||||
|
||||
if stats.get('processing_performance'):
|
||||
perf = stats['processing_performance']
|
||||
logger.info(f"Performance: avg={perf['avg_time_ms']:.2f}ms, "
|
||||
f"min={perf['min_time_ms']:.2f}ms, max={perf['max_time_ms']:.2f}ms")
|
||||
|
||||
logger.info("✅ Real-time processing test completed")
|
||||
|
||||
# Test 4: Check received features
|
||||
logger.info("\n📈 TEST 4: Analyze Received Features")
|
||||
logger.info("-" * 40)
|
||||
|
||||
if received_features:
|
||||
logger.info(f"✅ Received {len(received_features)} feature sets")
|
||||
|
||||
# Analyze feature quality
|
||||
high_confidence_count = sum(1 for _, features in received_features if features.confidence > 0.7)
|
||||
avg_confidence = sum(features.confidence for _, features in received_features) / len(received_features)
|
||||
|
||||
logger.info(f" Average confidence: {avg_confidence:.3f}")
|
||||
logger.info(f" High confidence features (>0.7): {high_confidence_count}")
|
||||
|
||||
# Show latest features
|
||||
if received_features:
|
||||
symbol, latest_features = received_features[-1]
|
||||
logger.info(f" Latest features for {symbol}:")
|
||||
logger.info(f" Timestamp: {latest_features.timestamp}")
|
||||
logger.info(f" Confidence: {latest_features.confidence:.3f}")
|
||||
logger.info(f" Neural features sample: {latest_features.neural_features[:5]}")
|
||||
logger.info(f" Volume features sample: {latest_features.volume_features[:3]}")
|
||||
else:
|
||||
logger.warning("⚠️ No features received - this may be normal if markets are closed")
|
||||
|
||||
# Test 5: Integration with orchestrator
|
||||
logger.info("\n🎯 TEST 5: Integration with Enhanced Orchestrator")
|
||||
logger.info("-" * 40)
|
||||
|
||||
try:
|
||||
config = get_config()
|
||||
data_provider = DataProvider(config)
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
|
||||
# Check if tick processor is integrated
|
||||
if hasattr(orchestrator, 'tick_processor'):
|
||||
logger.info("✅ Tick processor integrated with orchestrator")
|
||||
logger.info(f" Orchestrator symbols: {orchestrator.symbols}")
|
||||
logger.info(f" Tick processor symbols: {orchestrator.tick_processor.symbols}")
|
||||
|
||||
# Test real-time processing start
|
||||
await orchestrator.start_realtime_processing()
|
||||
logger.info("✅ Orchestrator real-time processing started")
|
||||
|
||||
# Brief test
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Get stats
|
||||
tick_stats = orchestrator.get_realtime_tick_stats()
|
||||
logger.info(f" Orchestrator tick stats: {tick_stats}")
|
||||
|
||||
await orchestrator.stop_realtime_processing()
|
||||
logger.info("✅ Orchestrator real-time processing stopped")
|
||||
else:
|
||||
logger.error("❌ Tick processor not found in orchestrator")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Orchestrator integration test failed: {e}")
|
||||
|
||||
# Test 6: Stop processing
|
||||
logger.info("\n🛑 TEST 6: Stop Processing")
|
||||
logger.info("-" * 40)
|
||||
|
||||
await tick_processor.stop_processing()
|
||||
logger.info("✅ Tick processing stopped")
|
||||
|
||||
# Final stats
|
||||
final_stats = tick_processor.get_processing_stats()
|
||||
logger.info(f"Final stats: {final_stats}")
|
||||
|
||||
# Test 7: Neural Network Features
|
||||
logger.info("\n🧠 TEST 7: Neural Network Feature Quality")
|
||||
logger.info("-" * 40)
|
||||
|
||||
if received_features:
|
||||
# Analyze neural network output quality
|
||||
neural_feature_sizes = [len(features.neural_features) for _, features in received_features]
|
||||
confidence_scores = [features.confidence for _, features in received_features]
|
||||
|
||||
logger.info(f" Neural feature dimensions: {set(neural_feature_sizes)}")
|
||||
logger.info(f" Confidence range: {min(confidence_scores):.3f} - {max(confidence_scores):.3f}")
|
||||
logger.info(f" Average confidence: {sum(confidence_scores)/len(confidence_scores):.3f}")
|
||||
|
||||
# Check for feature consistency
|
||||
if len(set(neural_feature_sizes)) == 1:
|
||||
logger.info("✅ Neural features have consistent dimensions")
|
||||
else:
|
||||
logger.warning("⚠️ Neural feature dimensions are inconsistent")
|
||||
|
||||
# Summary
|
||||
logger.info("\n" + "="*80)
|
||||
logger.info("🎉 REAL-TIME TICK PROCESSOR TEST SUMMARY")
|
||||
logger.info("="*80)
|
||||
logger.info("✅ All core tests PASSED!")
|
||||
logger.info("")
|
||||
logger.info("📋 VERIFIED FUNCTIONALITY:")
|
||||
logger.info(" ✓ Real-time tick data ingestion")
|
||||
logger.info(" ✓ Neural network feature extraction")
|
||||
logger.info(" ✓ Volume and microstructure analysis")
|
||||
logger.info(" ✓ Ultra-low latency processing")
|
||||
logger.info(" ✓ Feature subscriber system")
|
||||
logger.info(" ✓ Integration with orchestrator")
|
||||
logger.info(" ✓ Performance monitoring")
|
||||
logger.info("")
|
||||
logger.info("🎯 NEURAL DPS ALTERNATIVE ACTIVE:")
|
||||
logger.info(" • Real-time tick processing ✓")
|
||||
logger.info(" • Volume-weighted analysis ✓")
|
||||
logger.info(" • Neural feature extraction ✓")
|
||||
logger.info(" • Sub-millisecond latency ✓")
|
||||
logger.info(" • Model integration ready ✓")
|
||||
logger.info("")
|
||||
logger.info("🚀 Your real-time tick processor is working as a Neural DPS alternative!")
|
||||
logger.info("="*80)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Real-time tick processor test failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def test_dqn_integration():
|
||||
"""Test DQN integration with real-time tick features"""
|
||||
logger.info("\n🤖 TESTING DQN INTEGRATION WITH TICK FEATURES")
|
||||
logger.info("-" * 50)
|
||||
|
||||
try:
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
import numpy as np
|
||||
|
||||
# Create DQN agent
|
||||
state_shape = (3, 5) # 3 timeframes, 5 features
|
||||
dqn = DQNAgent(state_shape=state_shape, n_actions=3)
|
||||
|
||||
logger.info("✅ DQN agent created")
|
||||
logger.info(f" Tick feature weight: {dqn.tick_feature_weight}")
|
||||
|
||||
# Test state enhancement
|
||||
test_state = np.random.rand(3, 5)
|
||||
|
||||
# Simulate tick features
|
||||
mock_tick_features = {
|
||||
'neural_features': np.random.rand(64),
|
||||
'volume_features': np.random.rand(8),
|
||||
'microstructure_features': np.random.rand(4),
|
||||
'confidence': 0.85
|
||||
}
|
||||
|
||||
# Update DQN with tick features
|
||||
dqn.update_realtime_tick_features(mock_tick_features)
|
||||
logger.info("✅ DQN updated with mock tick features")
|
||||
|
||||
# Test enhanced action selection
|
||||
action = dqn.act(test_state, explore=False)
|
||||
logger.info(f"✅ DQN action with tick features: {action}")
|
||||
|
||||
# Test without tick features
|
||||
dqn.realtime_tick_features = None
|
||||
action_without = dqn.act(test_state, explore=False)
|
||||
logger.info(f"✅ DQN action without tick features: {action_without}")
|
||||
|
||||
logger.info("✅ DQN integration test completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ DQN integration test failed: {e}")
|
||||
|
||||
async def main():
|
||||
"""Main test function"""
|
||||
logger.info("🚀 Starting Real-Time Tick Processor Tests...")
|
||||
|
||||
# Test the tick processor
|
||||
success = await test_realtime_tick_processor()
|
||||
|
||||
if success:
|
||||
# Test DQN integration
|
||||
await test_dqn_integration()
|
||||
|
||||
logger.info("\n🎉 All tests passed! Your Neural DPS alternative is ready.")
|
||||
logger.info("The real-time tick processor provides ultra-low latency processing")
|
||||
logger.info("with volume information and neural network feature extraction.")
|
||||
else:
|
||||
logger.error("\n💥 Tests failed! Please check the implementation.")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@ -1 +0,0 @@
|
||||
|
@ -1,372 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test DQN RL-based Sensitivity Learning and 300s Data Preloading
|
||||
|
||||
This script tests:
|
||||
1. DQN RL-based sensitivity learning from completed trades
|
||||
2. 300s data preloading on first load
|
||||
3. Dynamic threshold adjustment based on sensitivity levels
|
||||
4. Color-coded position display integration
|
||||
5. Enhanced model training status with sensitivity info
|
||||
|
||||
Usage:
|
||||
python test_sensitivity_learning.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator, TradingAction
|
||||
from web.old_archived.scalping_dashboard import RealTimeScalpingDashboard
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SensitivityLearningTester:
|
||||
"""Test class for sensitivity learning features"""
|
||||
|
||||
def __init__(self):
|
||||
self.data_provider = DataProvider()
|
||||
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
||||
self.dashboard = None
|
||||
|
||||
async def test_300s_data_preloading(self):
|
||||
"""Test 300s data preloading functionality"""
|
||||
logger.info("=== Testing 300s Data Preloading ===")
|
||||
|
||||
# Test preloading for all symbols and timeframes
|
||||
start_time = time.time()
|
||||
preload_results = self.data_provider.preload_all_symbols_data(['1s', '1m', '5m', '15m', '1h'])
|
||||
end_time = time.time()
|
||||
|
||||
logger.info(f"Preloading completed in {end_time - start_time:.2f} seconds")
|
||||
|
||||
# Verify results
|
||||
total_pairs = 0
|
||||
successful_pairs = 0
|
||||
|
||||
for symbol, timeframe_results in preload_results.items():
|
||||
for timeframe, success in timeframe_results.items():
|
||||
total_pairs += 1
|
||||
if success:
|
||||
successful_pairs += 1
|
||||
|
||||
# Verify data was actually loaded
|
||||
data = self.data_provider.get_historical_data(symbol, timeframe, limit=50)
|
||||
if data is not None and len(data) > 0:
|
||||
logger.info(f"✅ {symbol} {timeframe}: {len(data)} candles loaded")
|
||||
else:
|
||||
logger.warning(f"❌ {symbol} {timeframe}: No data despite success flag")
|
||||
else:
|
||||
logger.warning(f"❌ {symbol} {timeframe}: Failed to preload")
|
||||
|
||||
success_rate = (successful_pairs / total_pairs) * 100 if total_pairs > 0 else 0
|
||||
logger.info(f"Preloading success rate: {success_rate:.1f}% ({successful_pairs}/{total_pairs})")
|
||||
|
||||
return success_rate > 80 # Consider test passed if >80% success rate
|
||||
|
||||
def test_sensitivity_learning_initialization(self):
|
||||
"""Test sensitivity learning system initialization"""
|
||||
logger.info("=== Testing Sensitivity Learning Initialization ===")
|
||||
|
||||
# Check if sensitivity learning is enabled
|
||||
if hasattr(self.orchestrator, 'sensitivity_learning_enabled'):
|
||||
logger.info(f"✅ Sensitivity learning enabled: {self.orchestrator.sensitivity_learning_enabled}")
|
||||
else:
|
||||
logger.warning("❌ Sensitivity learning not found in orchestrator")
|
||||
return False
|
||||
|
||||
# Check sensitivity levels configuration
|
||||
if hasattr(self.orchestrator, 'sensitivity_levels'):
|
||||
levels = self.orchestrator.sensitivity_levels
|
||||
logger.info(f"✅ Sensitivity levels configured: {len(levels)} levels")
|
||||
for level, config in levels.items():
|
||||
logger.info(f" Level {level}: {config['name']} - Open: {config['open_threshold_multiplier']:.2f}, Close: {config['close_threshold_multiplier']:.2f}")
|
||||
else:
|
||||
logger.warning("❌ Sensitivity levels not configured")
|
||||
return False
|
||||
|
||||
# Check DQN agent initialization
|
||||
if hasattr(self.orchestrator, 'sensitivity_dqn_agent'):
|
||||
if self.orchestrator.sensitivity_dqn_agent is not None:
|
||||
logger.info("✅ DQN agent initialized")
|
||||
stats = self.orchestrator.sensitivity_dqn_agent.get_stats()
|
||||
logger.info(f" Device: {stats['device']}")
|
||||
logger.info(f" Memory size: {stats['memory_size']}")
|
||||
logger.info(f" Epsilon: {stats['epsilon']:.3f}")
|
||||
else:
|
||||
logger.info("⏳ DQN agent not yet initialized (will be created on first use)")
|
||||
|
||||
# Check learning queues
|
||||
if hasattr(self.orchestrator, 'sensitivity_learning_queue'):
|
||||
logger.info(f"✅ Sensitivity learning queue initialized: {len(self.orchestrator.sensitivity_learning_queue)} items")
|
||||
|
||||
if hasattr(self.orchestrator, 'completed_trades'):
|
||||
logger.info(f"✅ Completed trades tracking initialized: {len(self.orchestrator.completed_trades)} trades")
|
||||
|
||||
if hasattr(self.orchestrator, 'active_trades'):
|
||||
logger.info(f"✅ Active trades tracking initialized: {len(self.orchestrator.active_trades)} active")
|
||||
|
||||
return True
|
||||
|
||||
def simulate_trading_scenario(self):
|
||||
"""Simulate a trading scenario to test sensitivity learning"""
|
||||
logger.info("=== Simulating Trading Scenario ===")
|
||||
|
||||
# Simulate some trades to test the learning system
|
||||
test_trades = [
|
||||
{
|
||||
'symbol': 'ETH/USDT',
|
||||
'action': 'BUY',
|
||||
'price': 2500.0,
|
||||
'confidence': 0.7,
|
||||
'timestamp': datetime.now() - timedelta(minutes=10)
|
||||
},
|
||||
{
|
||||
'symbol': 'ETH/USDT',
|
||||
'action': 'SELL',
|
||||
'price': 2510.0,
|
||||
'confidence': 0.6,
|
||||
'timestamp': datetime.now() - timedelta(minutes=5)
|
||||
},
|
||||
{
|
||||
'symbol': 'ETH/USDT',
|
||||
'action': 'BUY',
|
||||
'price': 2505.0,
|
||||
'confidence': 0.8,
|
||||
'timestamp': datetime.now() - timedelta(minutes=3)
|
||||
},
|
||||
{
|
||||
'symbol': 'ETH/USDT',
|
||||
'action': 'SELL',
|
||||
'price': 2495.0,
|
||||
'confidence': 0.9,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
]
|
||||
|
||||
# Process each trade through the orchestrator
|
||||
for i, trade_data in enumerate(test_trades):
|
||||
action = TradingAction(
|
||||
symbol=trade_data['symbol'],
|
||||
action=trade_data['action'],
|
||||
quantity=0.1,
|
||||
confidence=trade_data['confidence'],
|
||||
price=trade_data['price'],
|
||||
timestamp=trade_data['timestamp'],
|
||||
reasoning={'test': f'simulated_trade_{i}'},
|
||||
timeframe_analysis=[]
|
||||
)
|
||||
|
||||
# Update position tracking (this should trigger sensitivity learning)
|
||||
self.orchestrator._update_position_tracking(trade_data['symbol'], action)
|
||||
|
||||
logger.info(f"Processed trade {i+1}: {trade_data['action']} @ ${trade_data['price']:.2f}")
|
||||
|
||||
# Check if learning cases were created
|
||||
if hasattr(self.orchestrator, 'sensitivity_learning_queue'):
|
||||
queue_size = len(self.orchestrator.sensitivity_learning_queue)
|
||||
logger.info(f"✅ Learning queue now has {queue_size} cases")
|
||||
|
||||
if hasattr(self.orchestrator, 'completed_trades'):
|
||||
completed_count = len(self.orchestrator.completed_trades)
|
||||
logger.info(f"✅ Completed trades: {completed_count}")
|
||||
|
||||
return True
|
||||
|
||||
def test_threshold_adjustment(self):
|
||||
"""Test dynamic threshold adjustment based on sensitivity"""
|
||||
logger.info("=== Testing Threshold Adjustment ===")
|
||||
|
||||
# Test different sensitivity levels
|
||||
for level in range(5): # 0-4 sensitivity levels
|
||||
if hasattr(self.orchestrator, 'current_sensitivity_level'):
|
||||
self.orchestrator.current_sensitivity_level = level
|
||||
|
||||
if hasattr(self.orchestrator, '_update_thresholds_from_sensitivity'):
|
||||
self.orchestrator._update_thresholds_from_sensitivity()
|
||||
|
||||
open_threshold = getattr(self.orchestrator, 'confidence_threshold_open', 0.6)
|
||||
close_threshold = getattr(self.orchestrator, 'confidence_threshold_close', 0.25)
|
||||
|
||||
logger.info(f"Level {level}: Open={open_threshold:.3f}, Close={close_threshold:.3f}")
|
||||
|
||||
return True
|
||||
|
||||
def test_dashboard_integration(self):
|
||||
"""Test dashboard integration with sensitivity learning"""
|
||||
logger.info("=== Testing Dashboard Integration ===")
|
||||
|
||||
try:
|
||||
# Create dashboard instance
|
||||
self.dashboard = RealTimeScalpingDashboard(
|
||||
data_provider=self.data_provider,
|
||||
orchestrator=self.orchestrator
|
||||
)
|
||||
|
||||
# Test sensitivity learning info retrieval
|
||||
sensitivity_info = self.dashboard._get_sensitivity_learning_info()
|
||||
|
||||
logger.info("✅ Dashboard sensitivity info:")
|
||||
logger.info(f" Level: {sensitivity_info['level_name']}")
|
||||
logger.info(f" Completed trades: {sensitivity_info['completed_trades']}")
|
||||
logger.info(f" Learning queue: {sensitivity_info['learning_queue_size']}")
|
||||
logger.info(f" Open threshold: {sensitivity_info['open_threshold']:.3f}")
|
||||
logger.info(f" Close threshold: {sensitivity_info['close_threshold']:.3f}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Dashboard integration test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_dqn_training_simulation(self):
|
||||
"""Test DQN training with simulated data"""
|
||||
logger.info("=== Testing DQN Training Simulation ===")
|
||||
|
||||
try:
|
||||
# Initialize DQN agent if not already done
|
||||
if not hasattr(self.orchestrator, 'sensitivity_dqn_agent') or self.orchestrator.sensitivity_dqn_agent is None:
|
||||
self.orchestrator._initialize_sensitivity_dqn()
|
||||
|
||||
if self.orchestrator.sensitivity_dqn_agent is None:
|
||||
logger.warning("❌ Could not initialize DQN agent")
|
||||
return False
|
||||
|
||||
# Create some mock learning cases
|
||||
for i in range(10):
|
||||
# Create mock market state
|
||||
mock_state = np.random.random(self.orchestrator.sensitivity_state_size)
|
||||
action = np.random.randint(0, self.orchestrator.sensitivity_action_space)
|
||||
reward = np.random.random() - 0.5 # Random reward between -0.5 and 0.5
|
||||
next_state = np.random.random(self.orchestrator.sensitivity_state_size)
|
||||
done = True
|
||||
|
||||
# Add to learning queue
|
||||
learning_case = {
|
||||
'state': mock_state,
|
||||
'action': action,
|
||||
'reward': reward,
|
||||
'next_state': next_state,
|
||||
'done': done,
|
||||
'optimal_action': action,
|
||||
'trade_outcome': reward * 0.02, # Convert to percentage
|
||||
'trade_duration': 300 + np.random.randint(-100, 100),
|
||||
'symbol': 'ETH/USDT'
|
||||
}
|
||||
|
||||
self.orchestrator.sensitivity_learning_queue.append(learning_case)
|
||||
|
||||
# Trigger training
|
||||
initial_queue_size = len(self.orchestrator.sensitivity_learning_queue)
|
||||
self.orchestrator._train_sensitivity_dqn()
|
||||
|
||||
logger.info(f"✅ DQN training completed")
|
||||
logger.info(f" Initial queue size: {initial_queue_size}")
|
||||
logger.info(f" Final queue size: {len(self.orchestrator.sensitivity_learning_queue)}")
|
||||
|
||||
# Check agent stats
|
||||
if self.orchestrator.sensitivity_dqn_agent:
|
||||
stats = self.orchestrator.sensitivity_dqn_agent.get_stats()
|
||||
logger.info(f" Training steps: {stats['training_step']}")
|
||||
logger.info(f" Memory size: {stats['memory_size']}")
|
||||
logger.info(f" Epsilon: {stats['epsilon']:.3f}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ DQN training simulation failed: {e}")
|
||||
return False
|
||||
|
||||
async def run_all_tests(self):
|
||||
"""Run all sensitivity learning tests"""
|
||||
logger.info("🚀 Starting Sensitivity Learning Test Suite")
|
||||
logger.info("=" * 60)
|
||||
|
||||
test_results = {}
|
||||
|
||||
# Test 1: 300s Data Preloading
|
||||
test_results['preloading'] = await self.test_300s_data_preloading()
|
||||
|
||||
# Test 2: Sensitivity Learning Initialization
|
||||
test_results['initialization'] = self.test_sensitivity_learning_initialization()
|
||||
|
||||
# Test 3: Trading Scenario Simulation
|
||||
test_results['trading_simulation'] = self.simulate_trading_scenario()
|
||||
|
||||
# Test 4: Threshold Adjustment
|
||||
test_results['threshold_adjustment'] = self.test_threshold_adjustment()
|
||||
|
||||
# Test 5: Dashboard Integration
|
||||
test_results['dashboard_integration'] = self.test_dashboard_integration()
|
||||
|
||||
# Test 6: DQN Training Simulation
|
||||
test_results['dqn_training'] = self.test_dqn_training_simulation()
|
||||
|
||||
# Summary
|
||||
logger.info("=" * 60)
|
||||
logger.info("🏁 Test Suite Results:")
|
||||
|
||||
passed_tests = 0
|
||||
total_tests = len(test_results)
|
||||
|
||||
for test_name, result in test_results.items():
|
||||
status = "✅ PASSED" if result else "❌ FAILED"
|
||||
logger.info(f" {test_name}: {status}")
|
||||
if result:
|
||||
passed_tests += 1
|
||||
|
||||
success_rate = (passed_tests / total_tests) * 100
|
||||
logger.info(f"Overall success rate: {success_rate:.1f}% ({passed_tests}/{total_tests})")
|
||||
|
||||
if success_rate >= 80:
|
||||
logger.info("🎉 Test suite PASSED! Sensitivity learning system is working correctly.")
|
||||
else:
|
||||
logger.warning("⚠️ Test suite FAILED! Some issues need to be addressed.")
|
||||
|
||||
return success_rate >= 80
|
||||
|
||||
async def main():
|
||||
"""Main test function"""
|
||||
tester = SensitivityLearningTester()
|
||||
|
||||
try:
|
||||
success = await tester.run_all_tests()
|
||||
|
||||
if success:
|
||||
logger.info("✅ All tests passed! The sensitivity learning system is ready for production.")
|
||||
else:
|
||||
logger.error("❌ Some tests failed. Please review the issues above.")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Test suite failed with exception: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the test suite
|
||||
result = asyncio.run(main())
|
||||
|
||||
if result:
|
||||
print("\n🎯 SENSITIVITY LEARNING SYSTEM READY!")
|
||||
print("Features verified:")
|
||||
print(" ✅ DQN RL-based sensitivity learning from completed trades")
|
||||
print(" ✅ 300s data preloading for faster initial performance")
|
||||
print(" ✅ Dynamic threshold adjustment (lower for closing positions)")
|
||||
print(" ✅ Color-coded position display ([LONG] green, [SHORT] red)")
|
||||
print(" ✅ Enhanced model training status with sensitivity info")
|
||||
print("\nYou can now run the dashboard with these enhanced features!")
|
||||
else:
|
||||
print("\n❌ SOME TESTS FAILED")
|
||||
print("Please review the test output above and fix any issues.")
|
||||
|
||||
exit(0 if result else 1)
|
@ -1 +0,0 @@
|
||||
|
@ -1,130 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify tick caching with timestamp serialization
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from dataprovider_realtime import TickStorage
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_tick_caching():
|
||||
"""Test tick caching with pandas Timestamps"""
|
||||
logger.info("Testing tick caching with timestamp serialization...")
|
||||
|
||||
try:
|
||||
# Create tick storage
|
||||
tick_storage = TickStorage("TEST/SYMBOL", ["1s", "1m"])
|
||||
|
||||
# Clear any existing cache
|
||||
if os.path.exists(tick_storage.cache_path):
|
||||
os.remove(tick_storage.cache_path)
|
||||
logger.info("Cleared existing cache file")
|
||||
|
||||
# Add some test ticks with different timestamp formats
|
||||
test_ticks = [
|
||||
{
|
||||
'price': 100.0,
|
||||
'quantity': 1.0,
|
||||
'timestamp': pd.Timestamp.now()
|
||||
},
|
||||
{
|
||||
'price': 101.0,
|
||||
'quantity': 1.5,
|
||||
'timestamp': datetime.now()
|
||||
},
|
||||
{
|
||||
'price': 102.0,
|
||||
'quantity': 2.0,
|
||||
'timestamp': int(datetime.now().timestamp() * 1000) # milliseconds
|
||||
}
|
||||
]
|
||||
|
||||
# Add ticks
|
||||
for i, tick in enumerate(test_ticks):
|
||||
logger.info(f"Adding tick {i+1}: price=${tick['price']}, timestamp type={type(tick['timestamp'])}")
|
||||
tick_storage.add_tick(tick)
|
||||
|
||||
logger.info(f"Total ticks in storage: {len(tick_storage.ticks)}")
|
||||
|
||||
# Force save to cache
|
||||
tick_storage._save_to_cache()
|
||||
logger.info("Saved ticks to cache")
|
||||
|
||||
# Verify cache file exists
|
||||
if os.path.exists(tick_storage.cache_path):
|
||||
logger.info(f"✅ Cache file created: {tick_storage.cache_path}")
|
||||
|
||||
# Check file content
|
||||
with open(tick_storage.cache_path, 'r') as f:
|
||||
import json
|
||||
cache_content = json.load(f)
|
||||
logger.info(f"Cache contains {len(cache_content)} ticks")
|
||||
|
||||
# Show first tick to verify format
|
||||
if cache_content:
|
||||
first_tick = cache_content[0]
|
||||
logger.info(f"First tick in cache: {first_tick}")
|
||||
logger.info(f"Timestamp type in cache: {type(first_tick['timestamp'])}")
|
||||
else:
|
||||
logger.error("❌ Cache file was not created")
|
||||
return False
|
||||
|
||||
# Create new tick storage instance to test loading
|
||||
logger.info("Creating new TickStorage instance to test loading...")
|
||||
new_tick_storage = TickStorage("TEST/SYMBOL", ["1s", "1m"])
|
||||
|
||||
# Load from cache
|
||||
cache_loaded = new_tick_storage._load_from_cache()
|
||||
|
||||
if cache_loaded:
|
||||
logger.info(f"✅ Successfully loaded {len(new_tick_storage.ticks)} ticks from cache")
|
||||
|
||||
# Verify timestamps are properly converted back to pandas Timestamps
|
||||
for i, tick in enumerate(new_tick_storage.ticks):
|
||||
logger.info(f"Loaded tick {i+1}: price=${tick['price']}, timestamp={tick['timestamp']}, type={type(tick['timestamp'])}")
|
||||
|
||||
if not isinstance(tick['timestamp'], pd.Timestamp):
|
||||
logger.error(f"❌ Timestamp not properly converted back to pandas.Timestamp: {type(tick['timestamp'])}")
|
||||
return False
|
||||
|
||||
logger.info("✅ All timestamps properly converted back to pandas.Timestamp")
|
||||
return True
|
||||
else:
|
||||
logger.error("❌ Failed to load ticks from cache")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in tick caching test: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run the test"""
|
||||
logger.info("🧪 Starting tick caching test...")
|
||||
logger.info("=" * 50)
|
||||
|
||||
success = test_tick_caching()
|
||||
|
||||
logger.info("\n" + "=" * 50)
|
||||
if success:
|
||||
logger.info("🎉 Tick caching test PASSED!")
|
||||
else:
|
||||
logger.error("❌ Tick caching test FAILED!")
|
||||
|
||||
return success
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
@ -1,310 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Final Real-Time Tick Processor Test
|
||||
|
||||
This script demonstrates that the Neural Network Real-Time Tick Processing Module
|
||||
is working correctly as a DPS alternative for processing tick data with volume information.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.realtime_tick_processor import (
|
||||
RealTimeTickProcessor,
|
||||
ProcessedTickFeatures,
|
||||
TickData,
|
||||
create_realtime_tick_processor
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def demonstrate_neural_dps_alternative():
|
||||
"""Demonstrate the Neural DPS alternative functionality"""
|
||||
logger.info("="*80)
|
||||
logger.info("🚀 NEURAL DPS ALTERNATIVE DEMONSTRATION")
|
||||
logger.info("="*80)
|
||||
|
||||
try:
|
||||
# Create tick processor
|
||||
logger.info("\n📊 STEP 1: Initialize Neural DPS Alternative")
|
||||
logger.info("-" * 50)
|
||||
|
||||
symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
tick_processor = create_realtime_tick_processor(symbols)
|
||||
|
||||
logger.info("✅ Neural DPS Alternative initialized successfully")
|
||||
logger.info(f" Symbols: {tick_processor.symbols}")
|
||||
logger.info(f" Processing device: {tick_processor.device}")
|
||||
logger.info(f" Neural network architecture: TickProcessingNN")
|
||||
logger.info(f" Input features per tick: 9")
|
||||
logger.info(f" Output neural features: 64")
|
||||
logger.info(f" Processing window: {tick_processor.processing_window} ticks")
|
||||
|
||||
# Generate realistic market tick data
|
||||
logger.info("\n📈 STEP 2: Generate Realistic Market Tick Data")
|
||||
logger.info("-" * 50)
|
||||
|
||||
def generate_realistic_ticks(symbol: str, count: int = 100):
|
||||
"""Generate realistic tick data with volume information"""
|
||||
ticks = []
|
||||
base_price = 3500.0 if 'ETH' in symbol else 65000.0
|
||||
base_time = datetime.now()
|
||||
|
||||
for i in range(count):
|
||||
# Simulate realistic price movement with micro-trends
|
||||
if i % 20 < 10: # Uptrend phase
|
||||
price_change = np.random.normal(0.0002, 0.0008)
|
||||
else: # Downtrend phase
|
||||
price_change = np.random.normal(-0.0002, 0.0008)
|
||||
|
||||
price = base_price * (1 + price_change)
|
||||
|
||||
# Simulate realistic volume distribution
|
||||
if abs(price_change) > 0.001: # Large price moves get more volume
|
||||
volume = np.random.exponential(0.5)
|
||||
else:
|
||||
volume = np.random.exponential(0.1)
|
||||
|
||||
# Market maker vs taker dynamics
|
||||
side = 'buy' if price_change > 0 else 'sell'
|
||||
if np.random.random() < 0.3: # 30% chance to flip
|
||||
side = 'sell' if side == 'buy' else 'buy'
|
||||
|
||||
tick = TickData(
|
||||
timestamp=base_time,
|
||||
price=price,
|
||||
volume=volume,
|
||||
side=side,
|
||||
trade_id=f"{symbol}_{i}"
|
||||
)
|
||||
|
||||
ticks.append(tick)
|
||||
base_price = price
|
||||
|
||||
return ticks
|
||||
|
||||
# Generate ticks for both symbols
|
||||
eth_ticks = generate_realistic_ticks('ETH/USDT', 100)
|
||||
btc_ticks = generate_realistic_ticks('BTC/USDT', 100)
|
||||
|
||||
logger.info(f"✅ Generated realistic market data:")
|
||||
logger.info(f" ETH/USDT: {len(eth_ticks)} ticks")
|
||||
logger.info(f" Price range: ${min(t.price for t in eth_ticks):.2f} - ${max(t.price for t in eth_ticks):.2f}")
|
||||
logger.info(f" Volume range: {min(t.volume for t in eth_ticks):.4f} - {max(t.volume for t in eth_ticks):.4f}")
|
||||
logger.info(f" BTC/USDT: {len(btc_ticks)} ticks")
|
||||
logger.info(f" Price range: ${min(t.price for t in btc_ticks):.2f} - ${max(t.price for t in btc_ticks):.2f}")
|
||||
|
||||
# Process ticks through Neural DPS
|
||||
logger.info("\n🧠 STEP 3: Neural Network Processing")
|
||||
logger.info("-" * 50)
|
||||
|
||||
# Add ticks to processor buffers
|
||||
with tick_processor.data_lock:
|
||||
for tick in eth_ticks:
|
||||
tick_processor.tick_buffers['ETH/USDT'].append(tick)
|
||||
for tick in btc_ticks:
|
||||
tick_processor.tick_buffers['BTC/USDT'].append(tick)
|
||||
|
||||
# Process through neural network
|
||||
eth_features = tick_processor._extract_neural_features('ETH/USDT')
|
||||
btc_features = tick_processor._extract_neural_features('BTC/USDT')
|
||||
|
||||
logger.info("✅ Neural network processing completed:")
|
||||
|
||||
if eth_features:
|
||||
logger.info(f" ETH/USDT processed features:")
|
||||
logger.info(f" Neural features: {eth_features.neural_features.shape} (confidence: {eth_features.confidence:.3f})")
|
||||
logger.info(f" Price features: {eth_features.price_features.shape}")
|
||||
logger.info(f" Volume features: {eth_features.volume_features.shape}")
|
||||
logger.info(f" Microstructure features: {eth_features.microstructure_features.shape}")
|
||||
|
||||
if btc_features:
|
||||
logger.info(f" BTC/USDT processed features:")
|
||||
logger.info(f" Neural features: {btc_features.neural_features.shape} (confidence: {btc_features.confidence:.3f})")
|
||||
logger.info(f" Price features: {btc_features.price_features.shape}")
|
||||
logger.info(f" Volume features: {btc_features.volume_features.shape}")
|
||||
logger.info(f" Microstructure features: {btc_features.microstructure_features.shape}")
|
||||
|
||||
# Demonstrate volume analysis
|
||||
logger.info("\n💰 STEP 4: Volume Analysis Capabilities")
|
||||
logger.info("-" * 50)
|
||||
|
||||
if eth_features:
|
||||
volume_features = eth_features.volume_features
|
||||
logger.info("✅ Volume analysis extracted:")
|
||||
logger.info(f" Total volume: {volume_features[0]:.4f}")
|
||||
logger.info(f" Average volume: {volume_features[1]:.4f}")
|
||||
logger.info(f" Volume volatility: {volume_features[2]:.4f}")
|
||||
logger.info(f" Buy volume: {volume_features[3]:.4f}")
|
||||
logger.info(f" Sell volume: {volume_features[4]:.4f}")
|
||||
logger.info(f" Volume imbalance: {volume_features[5]:.4f}")
|
||||
logger.info(f" VWAP deviation: {volume_features[6]:.4f}")
|
||||
|
||||
# Demonstrate microstructure analysis
|
||||
logger.info("\n🔬 STEP 5: Market Microstructure Analysis")
|
||||
logger.info("-" * 50)
|
||||
|
||||
if eth_features:
|
||||
micro_features = eth_features.microstructure_features
|
||||
logger.info("✅ Microstructure analysis extracted:")
|
||||
logger.info(f" Trade frequency: {micro_features[0]:.2f} trades/sec")
|
||||
logger.info(f" Price impact: {micro_features[1]:.6f}")
|
||||
logger.info(f" Bid-ask spread proxy: {micro_features[2]:.6f}")
|
||||
logger.info(f" Order flow imbalance: {micro_features[3]:.4f}")
|
||||
|
||||
# Demonstrate real-time feature streaming
|
||||
logger.info("\n📡 STEP 6: Real-Time Feature Streaming")
|
||||
logger.info("-" * 50)
|
||||
|
||||
received_features = []
|
||||
|
||||
def feature_callback(symbol: str, features: ProcessedTickFeatures):
|
||||
"""Callback to receive real-time features"""
|
||||
received_features.append((symbol, features))
|
||||
logger.info(f"📨 Received real-time features for {symbol}")
|
||||
logger.info(f" Confidence: {features.confidence:.3f}")
|
||||
logger.info(f" Neural features: {len(features.neural_features)} dimensions")
|
||||
logger.info(f" Timestamp: {features.timestamp}")
|
||||
|
||||
# Add subscriber and simulate feature streaming
|
||||
tick_processor.add_feature_subscriber(feature_callback)
|
||||
|
||||
# Manually trigger feature processing to simulate streaming
|
||||
tick_processor._notify_feature_subscribers('ETH/USDT', eth_features)
|
||||
tick_processor._notify_feature_subscribers('BTC/USDT', btc_features)
|
||||
|
||||
logger.info(f"✅ Feature streaming demonstrated: {len(received_features)} features received")
|
||||
|
||||
# Performance metrics
|
||||
logger.info("\n⚡ STEP 7: Performance Metrics")
|
||||
logger.info("-" * 50)
|
||||
|
||||
stats = tick_processor.get_processing_stats()
|
||||
logger.info("✅ Performance metrics:")
|
||||
logger.info(f" Symbols processed: {len(stats['symbols'])}")
|
||||
logger.info(f" Buffer utilization: {stats['buffer_sizes']}")
|
||||
logger.info(f" Feature subscribers: {stats['subscribers']}")
|
||||
logger.info(f" Neural network device: {tick_processor.device}")
|
||||
|
||||
# Demonstrate integration readiness
|
||||
logger.info("\n🔗 STEP 8: Model Integration Readiness")
|
||||
logger.info("-" * 50)
|
||||
|
||||
logger.info("✅ Integration capabilities verified:")
|
||||
logger.info(" ✓ Feature subscriber system for real-time streaming")
|
||||
logger.info(" ✓ Standardized ProcessedTickFeatures format")
|
||||
logger.info(" ✓ Neural network feature extraction (64 dimensions)")
|
||||
logger.info(" ✓ Volume-weighted analysis")
|
||||
logger.info(" ✓ Market microstructure detection")
|
||||
logger.info(" ✓ Confidence scoring for feature quality")
|
||||
logger.info(" ✓ Multi-symbol processing")
|
||||
logger.info(" ✓ Thread-safe data handling")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Neural DPS demonstration failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def demonstrate_dqn_compatibility():
|
||||
"""Demonstrate compatibility with DQN models"""
|
||||
logger.info("\n🤖 STEP 9: DQN Model Compatibility")
|
||||
logger.info("-" * 50)
|
||||
|
||||
try:
|
||||
# Create mock tick features in the format DQN expects
|
||||
mock_tick_features = {
|
||||
'neural_features': np.random.rand(64) * 0.1,
|
||||
'volume_features': np.array([1.2, 0.8, 0.15, 850.5, 720.3, 0.05, 0.02]),
|
||||
'microstructure_features': np.array([12.5, 0.3, 0.001, 0.1]),
|
||||
'confidence': 0.85
|
||||
}
|
||||
|
||||
logger.info("✅ DQN-compatible feature format created:")
|
||||
logger.info(f" Neural features: {len(mock_tick_features['neural_features'])} dimensions")
|
||||
logger.info(f" Volume features: {len(mock_tick_features['volume_features'])} dimensions")
|
||||
logger.info(f" Microstructure features: {len(mock_tick_features['microstructure_features'])} dimensions")
|
||||
logger.info(f" Confidence score: {mock_tick_features['confidence']}")
|
||||
|
||||
# Demonstrate feature integration
|
||||
logger.info("\n✅ Ready for DQN integration:")
|
||||
logger.info(" ✓ update_realtime_tick_features() method available")
|
||||
logger.info(" ✓ State enhancement with tick features")
|
||||
logger.info(" ✓ Weighted feature integration (configurable weight)")
|
||||
logger.info(" ✓ Real-time decision enhancement")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ DQN compatibility test failed: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main demonstration function"""
|
||||
logger.info("🚀 Starting Neural DPS Alternative Demonstration...")
|
||||
|
||||
# Demonstrate core functionality
|
||||
neural_success = demonstrate_neural_dps_alternative()
|
||||
|
||||
# Demonstrate DQN compatibility
|
||||
dqn_success = demonstrate_dqn_compatibility()
|
||||
|
||||
# Final summary
|
||||
logger.info("\n" + "="*80)
|
||||
logger.info("🎉 NEURAL DPS ALTERNATIVE DEMONSTRATION COMPLETE")
|
||||
logger.info("="*80)
|
||||
|
||||
if neural_success and dqn_success:
|
||||
logger.info("✅ ALL DEMONSTRATIONS SUCCESSFUL!")
|
||||
logger.info("")
|
||||
logger.info("🎯 NEURAL DPS ALTERNATIVE VERIFIED:")
|
||||
logger.info(" ✓ Real-time tick data processing with volume information")
|
||||
logger.info(" ✓ Neural network feature extraction (64-dimensional)")
|
||||
logger.info(" ✓ Volume-weighted price analysis")
|
||||
logger.info(" ✓ Market microstructure pattern detection")
|
||||
logger.info(" ✓ Ultra-low latency processing capability")
|
||||
logger.info(" ✓ Real-time feature streaming to models")
|
||||
logger.info(" ✓ Multi-symbol processing (ETH/USDT, BTC/USDT)")
|
||||
logger.info(" ✓ DQN model integration ready")
|
||||
logger.info("")
|
||||
logger.info("🚀 YOUR NEURAL DPS ALTERNATIVE IS FULLY OPERATIONAL!")
|
||||
logger.info("")
|
||||
logger.info("📋 WHAT THIS SYSTEM PROVIDES:")
|
||||
logger.info(" • Replaces traditional DPS with neural network processing")
|
||||
logger.info(" • Processes real-time tick streams with volume information")
|
||||
logger.info(" • Extracts sophisticated features for trading models")
|
||||
logger.info(" • Provides ultra-low latency for high-frequency trading")
|
||||
logger.info(" • Integrates seamlessly with your DQN agents")
|
||||
logger.info(" • Supports WebSocket streaming from exchanges")
|
||||
logger.info(" • Includes confidence scoring for feature quality")
|
||||
logger.info("")
|
||||
logger.info("🎯 NEXT STEPS:")
|
||||
logger.info(" 1. Connect to live WebSocket feeds (Binance, etc.)")
|
||||
logger.info(" 2. Start real-time processing with tick_processor.start_processing()")
|
||||
logger.info(" 3. Your DQN models will receive enhanced tick features automatically")
|
||||
logger.info(" 4. Monitor performance with get_processing_stats()")
|
||||
|
||||
else:
|
||||
logger.error("❌ SOME DEMONSTRATIONS FAILED!")
|
||||
logger.error(f" Neural DPS: {'✅' if neural_success else '❌'}")
|
||||
logger.error(f" DQN Compatibility: {'✅' if dqn_success else '❌'}")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("="*80)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,311 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple Real-Time Tick Processor Test
|
||||
|
||||
This script tests the core Neural Network functionality of the Real-Time Tick Processing Module
|
||||
without requiring live WebSocket connections.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.realtime_tick_processor import (
|
||||
RealTimeTickProcessor,
|
||||
ProcessedTickFeatures,
|
||||
TickData,
|
||||
create_realtime_tick_processor
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_neural_network_functionality():
|
||||
"""Test the neural network processing without WebSocket connections"""
|
||||
logger.info("="*80)
|
||||
logger.info("🧪 TESTING NEURAL NETWORK TICK PROCESSING")
|
||||
logger.info("="*80)
|
||||
|
||||
try:
|
||||
# Test 1: Create tick processor
|
||||
logger.info("\n📊 TEST 1: Creating Real-Time Tick Processor")
|
||||
logger.info("-" * 40)
|
||||
|
||||
symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
tick_processor = create_realtime_tick_processor(symbols)
|
||||
|
||||
logger.info("✅ Tick processor created successfully")
|
||||
logger.info(f" Symbols: {tick_processor.symbols}")
|
||||
logger.info(f" Device: {tick_processor.device}")
|
||||
logger.info(f" Neural network input size: 9")
|
||||
|
||||
# Test 2: Generate mock tick data
|
||||
logger.info("\n📈 TEST 2: Generating Mock Tick Data")
|
||||
logger.info("-" * 40)
|
||||
|
||||
# Create realistic mock tick data
|
||||
mock_ticks = []
|
||||
base_price = 3500.0 # ETH price
|
||||
base_time = datetime.now()
|
||||
|
||||
for i in range(50): # Generate 50 ticks
|
||||
# Simulate price movement
|
||||
price_change = np.random.normal(0, 0.001) # Small random changes
|
||||
price = base_price * (1 + price_change)
|
||||
|
||||
# Simulate volume
|
||||
volume = np.random.exponential(0.1) # Exponential distribution for volume
|
||||
|
||||
# Random buy/sell
|
||||
side = 'buy' if np.random.random() > 0.5 else 'sell'
|
||||
|
||||
tick = TickData(
|
||||
timestamp=base_time,
|
||||
price=price,
|
||||
volume=volume,
|
||||
side=side,
|
||||
trade_id=f"trade_{i}"
|
||||
)
|
||||
|
||||
mock_ticks.append(tick)
|
||||
base_price = price # Update base price for next tick
|
||||
|
||||
logger.info(f"✅ Generated {len(mock_ticks)} mock ticks")
|
||||
logger.info(f" Price range: {min(t.price for t in mock_ticks):.2f} - {max(t.price for t in mock_ticks):.2f}")
|
||||
logger.info(f" Volume range: {min(t.volume for t in mock_ticks):.4f} - {max(t.volume for t in mock_ticks):.4f}")
|
||||
|
||||
# Test 3: Add ticks to processor buffer
|
||||
logger.info("\n💾 TEST 3: Adding Ticks to Processor Buffer")
|
||||
logger.info("-" * 40)
|
||||
|
||||
symbol = 'ETH/USDT'
|
||||
with tick_processor.data_lock:
|
||||
for tick in mock_ticks:
|
||||
tick_processor.tick_buffers[symbol].append(tick)
|
||||
|
||||
buffer_size = len(tick_processor.tick_buffers[symbol])
|
||||
logger.info(f"✅ Added ticks to buffer: {buffer_size} ticks")
|
||||
|
||||
# Test 4: Extract neural features
|
||||
logger.info("\n🧠 TEST 4: Neural Network Feature Extraction")
|
||||
logger.info("-" * 40)
|
||||
|
||||
features = tick_processor._extract_neural_features(symbol)
|
||||
|
||||
if features is not None:
|
||||
logger.info("✅ Neural features extracted successfully")
|
||||
logger.info(f" Timestamp: {features.timestamp}")
|
||||
logger.info(f" Confidence: {features.confidence:.3f}")
|
||||
logger.info(f" Neural features shape: {features.neural_features.shape}")
|
||||
logger.info(f" Price features shape: {features.price_features.shape}")
|
||||
logger.info(f" Volume features shape: {features.volume_features.shape}")
|
||||
logger.info(f" Microstructure features shape: {features.microstructure_features.shape}")
|
||||
|
||||
# Show sample values
|
||||
logger.info(f" Neural features sample: {features.neural_features[:5]}")
|
||||
logger.info(f" Price features sample: {features.price_features[:3]}")
|
||||
logger.info(f" Volume features sample: {features.volume_features[:3]}")
|
||||
else:
|
||||
logger.error("❌ Failed to extract neural features")
|
||||
return False
|
||||
|
||||
# Test 5: Test feature conversion methods
|
||||
logger.info("\n🔧 TEST 5: Feature Conversion Methods")
|
||||
logger.info("-" * 40)
|
||||
|
||||
# Test tick-to-features conversion
|
||||
tick_features = tick_processor._ticks_to_features(mock_ticks)
|
||||
logger.info(f"✅ Tick features converted: shape {tick_features.shape}")
|
||||
logger.info(f" Expected shape: ({tick_processor.processing_window}, 9)")
|
||||
|
||||
# Test individual feature extraction
|
||||
price_features = tick_processor._extract_price_features(mock_ticks)
|
||||
volume_features = tick_processor._extract_volume_features(mock_ticks)
|
||||
microstructure_features = tick_processor._extract_microstructure_features(mock_ticks)
|
||||
|
||||
logger.info(f"✅ Price features: {len(price_features)} features")
|
||||
logger.info(f"✅ Volume features: {len(volume_features)} features")
|
||||
logger.info(f"✅ Microstructure features: {len(microstructure_features)} features")
|
||||
|
||||
# Test 6: Neural network forward pass
|
||||
logger.info("\n⚡ TEST 6: Neural Network Forward Pass")
|
||||
logger.info("-" * 40)
|
||||
|
||||
import torch
|
||||
|
||||
# Test direct neural network inference
|
||||
tick_tensor = torch.FloatTensor(tick_features).unsqueeze(0).to(tick_processor.device)
|
||||
|
||||
with torch.no_grad():
|
||||
neural_features, confidence = tick_processor.tick_nn(tick_tensor)
|
||||
|
||||
logger.info("✅ Neural network forward pass successful")
|
||||
logger.info(f" Input shape: {tick_tensor.shape}")
|
||||
logger.info(f" Output features shape: {neural_features.shape}")
|
||||
logger.info(f" Confidence shape: {confidence.shape}")
|
||||
logger.info(f" Confidence value: {confidence.item():.3f}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Neural network test failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def test_dqn_integration():
|
||||
"""Test DQN integration with real-time tick features"""
|
||||
logger.info("\n🤖 TESTING DQN INTEGRATION WITH TICK FEATURES")
|
||||
logger.info("-" * 50)
|
||||
|
||||
try:
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
import numpy as np
|
||||
|
||||
# Create DQN agent
|
||||
state_shape = (3, 5) # 3 timeframes, 5 features
|
||||
dqn = DQNAgent(state_shape=state_shape, n_actions=3)
|
||||
|
||||
logger.info("✅ DQN agent created")
|
||||
logger.info(f" State shape: {state_shape}")
|
||||
logger.info(f" Actions: {dqn.n_actions}")
|
||||
logger.info(f" Device: {dqn.device}")
|
||||
logger.info(f" Tick feature weight: {dqn.tick_feature_weight}")
|
||||
|
||||
# Test state enhancement
|
||||
test_state = np.random.rand(3, 5)
|
||||
logger.info(f" Test state shape: {test_state.shape}")
|
||||
|
||||
# Simulate realistic tick features
|
||||
mock_tick_features = {
|
||||
'neural_features': np.random.rand(64) * 0.1, # Small neural features
|
||||
'volume_features': np.array([1.2, 0.8, 0.15, 850.5, 720.3, 0.05, 0.02]), # Realistic volume features
|
||||
'microstructure_features': np.array([12.5, 0.3, 0.001, 0.1]), # Realistic microstructure
|
||||
'confidence': 0.85
|
||||
}
|
||||
|
||||
# Update DQN with tick features
|
||||
dqn.update_realtime_tick_features(mock_tick_features)
|
||||
logger.info("✅ DQN updated with mock tick features")
|
||||
|
||||
# Test enhanced action selection
|
||||
action_with_ticks = dqn.act(test_state, explore=False)
|
||||
logger.info(f"✅ DQN action with tick features: {action_with_ticks}")
|
||||
|
||||
# Test without tick features
|
||||
dqn.realtime_tick_features = None
|
||||
action_without_ticks = dqn.act(test_state, explore=False)
|
||||
logger.info(f"✅ DQN action without tick features: {action_without_ticks}")
|
||||
|
||||
# Test state enhancement method directly
|
||||
dqn.realtime_tick_features = mock_tick_features
|
||||
enhanced_state = dqn._enhance_state_with_tick_features(test_state)
|
||||
logger.info(f"✅ State enhancement test:")
|
||||
logger.info(f" Original state shape: {test_state.shape}")
|
||||
logger.info(f" Enhanced state shape: {enhanced_state.shape}")
|
||||
|
||||
logger.info("✅ DQN integration test completed successfully")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ DQN integration test failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def test_performance_metrics():
|
||||
"""Test performance and statistics functionality"""
|
||||
logger.info("\n📊 TESTING PERFORMANCE METRICS")
|
||||
logger.info("-" * 40)
|
||||
|
||||
try:
|
||||
tick_processor = create_realtime_tick_processor(['ETH/USDT'])
|
||||
|
||||
# Test stats without processing
|
||||
stats = tick_processor.get_processing_stats()
|
||||
logger.info("✅ Basic stats retrieved")
|
||||
logger.info(f" Symbols: {stats['symbols']}")
|
||||
logger.info(f" Streaming: {stats['streaming']}")
|
||||
logger.info(f" Tick counts: {stats['tick_counts']}")
|
||||
logger.info(f" Buffer sizes: {stats['buffer_sizes']}")
|
||||
logger.info(f" Subscribers: {stats['subscribers']}")
|
||||
|
||||
# Test feature subscriber
|
||||
received_features = []
|
||||
|
||||
def test_callback(symbol: str, features: ProcessedTickFeatures):
|
||||
received_features.append((symbol, features))
|
||||
|
||||
tick_processor.add_feature_subscriber(test_callback)
|
||||
logger.info("✅ Feature subscriber added")
|
||||
|
||||
# Test subscriber removal
|
||||
tick_processor.remove_feature_subscriber(test_callback)
|
||||
logger.info("✅ Feature subscriber removed")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Performance metrics test failed: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
logger.info("🚀 Starting Simple Real-Time Tick Processor Tests...")
|
||||
|
||||
# Test neural network functionality
|
||||
nn_success = test_neural_network_functionality()
|
||||
|
||||
# Test DQN integration
|
||||
dqn_success = test_dqn_integration()
|
||||
|
||||
# Test performance metrics
|
||||
perf_success = test_performance_metrics()
|
||||
|
||||
# Summary
|
||||
logger.info("\n" + "="*80)
|
||||
logger.info("🎉 SIMPLE TICK PROCESSOR TEST SUMMARY")
|
||||
logger.info("="*80)
|
||||
|
||||
if nn_success and dqn_success and perf_success:
|
||||
logger.info("✅ ALL TESTS PASSED!")
|
||||
logger.info("")
|
||||
logger.info("📋 VERIFIED FUNCTIONALITY:")
|
||||
logger.info(" ✓ Neural network tick processing")
|
||||
logger.info(" ✓ Feature extraction (price, volume, microstructure)")
|
||||
logger.info(" ✓ DQN integration with tick features")
|
||||
logger.info(" ✓ State enhancement for RL models")
|
||||
logger.info(" ✓ Performance monitoring")
|
||||
logger.info("")
|
||||
logger.info("🎯 NEURAL DPS ALTERNATIVE READY:")
|
||||
logger.info(" • Real-time tick processing ✓")
|
||||
logger.info(" • Volume-weighted analysis ✓")
|
||||
logger.info(" • Neural feature extraction ✓")
|
||||
logger.info(" • Model integration ready ✓")
|
||||
logger.info("")
|
||||
logger.info("🚀 Your Neural DPS alternative is working correctly!")
|
||||
logger.info(" The system can now process real-time tick data with volume")
|
||||
logger.info(" information and feed enhanced features to your DQN models.")
|
||||
|
||||
else:
|
||||
logger.error("❌ SOME TESTS FAILED!")
|
||||
logger.error(f" Neural Network: {'✅' if nn_success else '❌'}")
|
||||
logger.error(f" DQN Integration: {'✅' if dqn_success else '❌'}")
|
||||
logger.error(f" Performance: {'✅' if perf_success else '❌'}")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("="*80)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user