fix emojies
This commit is contained in:
19
.cursorrules
19
.cursorrules
@@ -5,6 +5,9 @@
|
||||
|
||||
## Unicode and Encoding Rules
|
||||
- **NEVER use Unicode characters that may not be supported by Windows console (cp1252)**
|
||||
- **ABSOLUTELY NO EMOJIS** in any code, logs, or console output (e.g., ✅, ✓, ❌, ⚠️, 🚀, 📊, 💾, 🔄, ⏳, 🎯, 📈, 📉, 🔍, ⚡, 💡, 🛠️, 🔧, 🎉, ⭐, 📁, 📋)
|
||||
- Use plain ASCII text for all log messages, print statements, and console output
|
||||
- Replace emojis with descriptive text (e.g., "OK", "ERROR", "WARNING", "SUCCESS")
|
||||
|
||||
|
||||
## Code Structure and Versioning Rules
|
||||
@@ -24,14 +27,18 @@
|
||||
|
||||
## Logging Best Practices
|
||||
- Use structured logging with clear, ASCII-only messages
|
||||
- **NEVER use emojis or Unicode symbols in log messages**
|
||||
- Include relevant context in log messages without Unicode characters
|
||||
- Use logger.info(), logger.error(), etc. with plain text
|
||||
- Use logger.info(), logger.error(), etc. with plain text only
|
||||
- Use descriptive prefixes instead of emojis (e.g., "SUCCESS:", "ERROR:", "WARNING:")
|
||||
- Example: `logger.info("TRADING: Starting Live Scalping Dashboard at http://127.0.0.1:8051")`
|
||||
- Example: `logger.info("SUCCESS: Model checkpoint loaded successfully")`
|
||||
|
||||
## Error Handling
|
||||
- Always include proper exception handling
|
||||
- Log errors with ASCII-only characters
|
||||
- Provide meaningful error messages without emojis
|
||||
- Log errors with ASCII-only characters - **NO EMOJIS**
|
||||
- Provide meaningful error messages using plain text descriptors
|
||||
- Use text prefixes like "ERROR:", "FAILED:", "WARNING:" instead of emoji symbols
|
||||
- Include stack traces for debugging when appropriate
|
||||
|
||||
## File Naming Conventions
|
||||
@@ -59,8 +66,10 @@
|
||||
|
||||
## Code Review Checklist
|
||||
Before submitting code changes, verify:
|
||||
- [ ] No Unicode/emoji characters in logging or console output
|
||||
- [ ] **ABSOLUTELY NO EMOJIS OR UNICODE SYMBOLS** in any code, logs, or output
|
||||
- [ ] All log messages use plain ASCII text only (logger.info, logger.error, print, etc.)
|
||||
- [ ] No duplicate implementations of existing functionality
|
||||
- [ ] Proper error handling with ASCII-only messages
|
||||
- [ ] Windows compatibility maintained
|
||||
- [ ] Windows compatibility maintained (PowerShell console safe)
|
||||
- [ ] Existing code structure preserved and enhanced rather than replaced
|
||||
- [ ] Use descriptive text instead of symbols: "OK" not "✓", "ERROR" not "❌", "SUCCESS" not "✅"
|
||||
@@ -83,7 +83,7 @@ class HistoricalDataLoader:
|
||||
# For 1s/1m, we want to return immediately if valid
|
||||
if timeframe not in ['1s', '1m']:
|
||||
elapsed_ms = (time.time() - start_time_ms) * 1000
|
||||
logger.debug(f"⚡ Memory cache hit for {symbol} {timeframe} ({elapsed_ms:.1f}ms)")
|
||||
logger.debug(f"Memory cache hit for {symbol} {timeframe} ({elapsed_ms:.1f}ms)")
|
||||
return cached_data
|
||||
|
||||
try:
|
||||
@@ -221,7 +221,7 @@ class HistoricalDataLoader:
|
||||
timeframe=timeframe,
|
||||
df=df
|
||||
)
|
||||
logger.info(f"💾 Stored {stored_count} new candles in DuckDB")
|
||||
logger.info(f"Stored {stored_count} new candles in DuckDB")
|
||||
|
||||
# Cache in memory
|
||||
self.memory_cache[cache_key] = (df.copy(), datetime.now())
|
||||
|
||||
@@ -22,14 +22,14 @@ def test_data_loader():
|
||||
# Initialize DataProvider
|
||||
print("\n1. Initializing DataProvider...")
|
||||
data_provider = DataProvider()
|
||||
print(f" ✓ DataProvider initialized")
|
||||
print(f" DataProvider initialized")
|
||||
print(f" - Symbols: {data_provider.symbols}")
|
||||
print(f" - Timeframes: {data_provider.timeframes}")
|
||||
|
||||
# Initialize HistoricalDataLoader
|
||||
print("\n2. Initializing HistoricalDataLoader...")
|
||||
data_loader = HistoricalDataLoader(data_provider)
|
||||
print(f" ✓ HistoricalDataLoader initialized")
|
||||
print(f" HistoricalDataLoader initialized")
|
||||
|
||||
# Test loading data for ETH/USDT
|
||||
print("\n3. Testing data loading for ETH/USDT...")
|
||||
@@ -39,7 +39,7 @@ def test_data_loader():
|
||||
for timeframe in timeframes:
|
||||
df = data_loader.get_data(symbol, timeframe, limit=100)
|
||||
if df is not None and not df.empty:
|
||||
print(f" ✓ {timeframe}: Loaded {len(df)} candles")
|
||||
print(f" {timeframe}: Loaded {len(df)} candles")
|
||||
print(f" Latest: {df.index[-1]} - Close: ${df['close'].iloc[-1]:.2f}")
|
||||
else:
|
||||
print(f" ✗ {timeframe}: No data available")
|
||||
@@ -47,7 +47,7 @@ def test_data_loader():
|
||||
# Test multi-timeframe loading
|
||||
print("\n4. Testing multi-timeframe loading...")
|
||||
multi_data = data_loader.get_multi_timeframe_data(symbol, timeframes, limit=50)
|
||||
print(f" ✓ Loaded data for {len(multi_data)} timeframes")
|
||||
print(f" Loaded data for {len(multi_data)} timeframes")
|
||||
for tf, df in multi_data.items():
|
||||
print(f" {tf}: {len(df)} candles")
|
||||
|
||||
@@ -59,24 +59,24 @@ def test_data_loader():
|
||||
range_preset = '1d'
|
||||
start_time, end_time = time_manager.calculate_time_range(center_time, range_preset)
|
||||
|
||||
print(f" ✓ Time range calculated for '{range_preset}':")
|
||||
print(f" Time range calculated for '{range_preset}':")
|
||||
print(f" Start: {start_time}")
|
||||
print(f" End: {end_time}")
|
||||
|
||||
increment = time_manager.get_navigation_increment(range_preset)
|
||||
print(f" ✓ Navigation increment: {increment}")
|
||||
print(f" Navigation increment: {increment}")
|
||||
|
||||
# Test data boundaries
|
||||
print("\n6. Testing data boundaries...")
|
||||
earliest, latest = data_loader.get_data_boundaries(symbol, '1m')
|
||||
if earliest and latest:
|
||||
print(f" ✓ Data available from {earliest} to {latest}")
|
||||
print(f" Data available from {earliest} to {latest}")
|
||||
print(f" Total span: {latest - earliest}")
|
||||
else:
|
||||
print(f" ✗ Could not determine data boundaries")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ All tests completed successfully!")
|
||||
print("All tests completed successfully!")
|
||||
print("=" * 60)
|
||||
print("\nThe data loader is ready to use with the annotation UI.")
|
||||
print("It uses the same DataProvider as training/inference systems.")
|
||||
|
||||
@@ -36,7 +36,7 @@ class BinanceExample:
|
||||
if isinstance(data, OrderBookSnapshot):
|
||||
self.orderbook_count += 1
|
||||
logger.info(
|
||||
f"📊 Order Book {self.orderbook_count}: {data.symbol} - "
|
||||
f"Order Book {self.orderbook_count}: {data.symbol} - "
|
||||
f"Mid: ${data.mid_price:.2f}, Spread: ${data.spread:.2f}, "
|
||||
f"Bids: {len(data.bids)}, Asks: {len(data.asks)}"
|
||||
)
|
||||
@@ -68,22 +68,22 @@ class BinanceExample:
|
||||
logger.info(" Connected to Binance successfully")
|
||||
|
||||
# Get available symbols
|
||||
logger.info("📋 Getting available symbols...")
|
||||
logger.info("Getting available symbols...")
|
||||
symbols = await self.connector.get_symbols()
|
||||
logger.info(f"📋 Found {len(symbols)} trading symbols")
|
||||
logger.info(f"Found {len(symbols)} trading symbols")
|
||||
|
||||
# Show some popular symbols
|
||||
popular_symbols = ['BTCUSDT', 'ETHUSDT', 'ADAUSDT', 'BNBUSDT']
|
||||
available_popular = [s for s in popular_symbols if s in symbols]
|
||||
logger.info(f"📋 Popular symbols available: {available_popular}")
|
||||
logger.info(f"Popular symbols available: {available_popular}")
|
||||
|
||||
# Get order book snapshot
|
||||
if 'BTCUSDT' in symbols:
|
||||
logger.info("📊 Getting BTC order book snapshot...")
|
||||
logger.info("Getting BTC order book snapshot...")
|
||||
orderbook = await self.connector.get_orderbook_snapshot('BTCUSDT', depth=10)
|
||||
if orderbook:
|
||||
logger.info(
|
||||
f"📊 BTC Order Book: Mid=${orderbook.mid_price:.2f}, "
|
||||
f"BTC Order Book: Mid=${orderbook.mid_price:.2f}, "
|
||||
f"Spread=${orderbook.spread:.2f}"
|
||||
)
|
||||
|
||||
@@ -102,18 +102,18 @@ class BinanceExample:
|
||||
logger.info(" Subscribed to ETHUSDT order book")
|
||||
|
||||
# Let it run for a while
|
||||
logger.info("⏳ Collecting data for 30 seconds...")
|
||||
logger.info("Collecting data for 30 seconds...")
|
||||
await asyncio.sleep(30)
|
||||
|
||||
# Show statistics
|
||||
stats = self.connector.get_binance_stats()
|
||||
logger.info("📈 Final Statistics:")
|
||||
logger.info(f" 📊 Order books received: {self.orderbook_count}")
|
||||
logger.info("Final Statistics:")
|
||||
logger.info(f" Order books received: {self.orderbook_count}")
|
||||
logger.info(f" 💰 Trades received: {self.trade_count}")
|
||||
logger.info(f" 📡 Total messages: {stats['message_count']}")
|
||||
logger.info(f" Errors: {stats['error_count']}")
|
||||
logger.info(f" 🔗 Active streams: {stats['active_streams']}")
|
||||
logger.info(f" 📋 Subscriptions: {list(stats['subscriptions'].keys())}")
|
||||
logger.info(f" Subscriptions: {list(stats['subscriptions'].keys())}")
|
||||
|
||||
# Unsubscribe and disconnect
|
||||
logger.info("🔌 Cleaning up...")
|
||||
@@ -129,7 +129,7 @@ class BinanceExample:
|
||||
logger.info(" Disconnected successfully")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("⏹️ Interrupted by user")
|
||||
logger.info("Interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f" Example failed: {e}")
|
||||
finally:
|
||||
|
||||
@@ -45,7 +45,7 @@ class MultiExchangeManager:
|
||||
try:
|
||||
if isinstance(data, OrderBookSnapshot):
|
||||
self.data_received[exchange]['orderbooks'] += 1
|
||||
logger.info(f"📊 {exchange.upper()}: Order book for {data.symbol} - "
|
||||
logger.info(f"{exchange.upper()}: Order book for {data.symbol} - "
|
||||
f"Bids: {len(data.bids)}, Asks: {len(data.asks)}")
|
||||
|
||||
# Show best bid/ask if available
|
||||
@@ -102,7 +102,7 @@ class MultiExchangeManager:
|
||||
if connector.is_connected:
|
||||
# Subscribe to order book
|
||||
await connector.subscribe_orderbook(symbol)
|
||||
logger.info(f"📈 Subscribed to {symbol} order book on {name}")
|
||||
logger.info(f"Subscribed to {symbol} order book on {name}")
|
||||
|
||||
# Subscribe to trades
|
||||
await connector.subscribe_trades(symbol)
|
||||
@@ -131,7 +131,7 @@ class MultiExchangeManager:
|
||||
|
||||
def _print_statistics(self):
|
||||
"""Print current data statistics."""
|
||||
logger.info("📊 Current Statistics:")
|
||||
logger.info("Current Statistics:")
|
||||
total_orderbooks = 0
|
||||
total_trades = 0
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ def check_health():
|
||||
all_healthy = False
|
||||
|
||||
if all_healthy:
|
||||
print("\n🎉 Overall Health: HEALTHY")
|
||||
print("\nOverall Health: HEALTHY")
|
||||
return 0
|
||||
else:
|
||||
print("\n Overall Health: DEGRADED")
|
||||
@@ -91,7 +91,7 @@ def main():
|
||||
print("=" * 50)
|
||||
|
||||
if api_healthy and ws_healthy:
|
||||
print("🎉 COBY System: FULLY HEALTHY")
|
||||
print("COBY System: FULLY HEALTHY")
|
||||
return 0
|
||||
elif api_healthy:
|
||||
print(" COBY System: PARTIALLY HEALTHY (API only)")
|
||||
|
||||
@@ -14,7 +14,7 @@ try:
|
||||
from api.rest_api import create_app
|
||||
from caching.redis_manager import redis_manager
|
||||
from utils.logging import get_logger, setup_logging
|
||||
print("✓ All imports successful")
|
||||
print("All imports successful")
|
||||
except ImportError as e:
|
||||
print(f"✗ Import error: {e}")
|
||||
sys.exit(1)
|
||||
@@ -29,11 +29,11 @@ async def test_health_endpoints():
|
||||
# Test Redis manager
|
||||
await redis_manager.initialize()
|
||||
ping_result = await redis_manager.ping()
|
||||
print(f"✓ Redis ping: {ping_result}")
|
||||
print(f"Redis ping: {ping_result}")
|
||||
|
||||
# Test app creation
|
||||
app = create_app()
|
||||
print("✓ FastAPI app created successfully")
|
||||
print("FastAPI app created successfully")
|
||||
|
||||
# Test health endpoint logic
|
||||
from api.response_formatter import ResponseFormatter
|
||||
@@ -46,7 +46,7 @@ async def test_health_endpoints():
|
||||
}
|
||||
|
||||
response = formatter.status_response(health_data)
|
||||
print(f"✓ Health response format: {type(response)}")
|
||||
print(f"Health response format: {type(response)}")
|
||||
|
||||
return True
|
||||
|
||||
@@ -63,19 +63,19 @@ async def test_static_files():
|
||||
index_path = os.path.join(static_path, "index.html")
|
||||
|
||||
if os.path.exists(static_path):
|
||||
print(f"✓ Static directory exists: {static_path}")
|
||||
print(f"Static directory exists: {static_path}")
|
||||
else:
|
||||
print(f"✗ Static directory missing: {static_path}")
|
||||
return False
|
||||
|
||||
if os.path.exists(index_path):
|
||||
print(f"✓ Index.html exists: {index_path}")
|
||||
print(f"Index.html exists: {index_path}")
|
||||
|
||||
# Test reading the file
|
||||
with open(index_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
if "COBY" in content:
|
||||
print("✓ Index.html contains COBY content")
|
||||
print("Index.html contains COBY content")
|
||||
else:
|
||||
print("✗ Index.html missing COBY content")
|
||||
return False
|
||||
@@ -101,7 +101,7 @@ async def test_websocket_config():
|
||||
host=config.api.host,
|
||||
port=config.api.websocket_port
|
||||
)
|
||||
print(f"✓ WebSocket server configured: {config.api.host}:{config.api.websocket_port}")
|
||||
print(f"WebSocket server configured: {config.api.host}:{config.api.websocket_port}")
|
||||
|
||||
return True
|
||||
|
||||
@@ -139,13 +139,13 @@ async def main():
|
||||
total = len(results)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
status = "✓ PASS" if result else "✗ FAIL"
|
||||
status = "PASS" if result else "FAIL"
|
||||
print(f" Test {i+1}: {status}")
|
||||
|
||||
print(f"\nOverall: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("🎉 All tests passed! COBY system should work correctly.")
|
||||
print("All tests passed! COBY system should work correctly.")
|
||||
return 0
|
||||
else:
|
||||
print(" Some tests failed. Please check the issues above.")
|
||||
|
||||
@@ -40,10 +40,10 @@ async def test_database_connection():
|
||||
|
||||
# Test storage stats
|
||||
stats = await manager.get_storage_stats()
|
||||
logger.info(f"📊 Found {len(stats.get('table_sizes', []))} tables")
|
||||
logger.info(f"Found {len(stats.get('table_sizes', []))} tables")
|
||||
|
||||
for table_info in stats.get('table_sizes', []):
|
||||
logger.info(f" 📋 {table_info['table']}: {table_info['size']}")
|
||||
logger.info(f" {table_info['table']}: {table_info['size']}")
|
||||
|
||||
await manager.close()
|
||||
return True
|
||||
@@ -55,7 +55,7 @@ async def test_database_connection():
|
||||
|
||||
async def test_data_storage():
|
||||
"""Test storing and retrieving data"""
|
||||
logger.info("💾 Testing data storage operations...")
|
||||
logger.info("Testing data storage operations...")
|
||||
|
||||
try:
|
||||
manager = TimescaleManager()
|
||||
@@ -181,7 +181,7 @@ async def test_batch_operations():
|
||||
|
||||
async def test_configuration():
|
||||
"""Test configuration system"""
|
||||
logger.info("⚙️ Testing configuration system...")
|
||||
logger.info("Testing configuration system...")
|
||||
|
||||
try:
|
||||
# Test database configuration
|
||||
@@ -237,7 +237,7 @@ async def run_all_tests():
|
||||
|
||||
# Summary
|
||||
logger.info("\n" + "=" * 50)
|
||||
logger.info("📋 TEST SUMMARY")
|
||||
logger.info("TEST SUMMARY")
|
||||
logger.info("=" * 50)
|
||||
|
||||
passed = sum(1 for _, result in results if result)
|
||||
@@ -250,7 +250,7 @@ async def run_all_tests():
|
||||
logger.info(f"\nOverall: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
logger.info("🎉 All tests passed! System is ready.")
|
||||
logger.info("All tests passed! System is ready.")
|
||||
return True
|
||||
else:
|
||||
logger.error(" Some tests failed. Check configuration and database connection.")
|
||||
@@ -265,7 +265,7 @@ if __name__ == "__main__":
|
||||
success = asyncio.run(run_all_tests())
|
||||
|
||||
if success:
|
||||
print("\n🎉 Integration tests completed successfully!")
|
||||
print("\nIntegration tests completed successfully!")
|
||||
print("The system is ready for the next development phase.")
|
||||
sys.exit(0)
|
||||
else:
|
||||
|
||||
@@ -104,7 +104,7 @@ class TestAllConnectors:
|
||||
await connector.unsubscribe_orderbook('BTCUSDT')
|
||||
await connector.unsubscribe_trades('ETHUSDT')
|
||||
|
||||
print(f"✓ {name} subscription interface works")
|
||||
print(f"{name} subscription interface works")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ {name} subscription interface failed: {e}")
|
||||
@@ -140,7 +140,7 @@ class TestAllConnectors:
|
||||
assert 'exchange' in stats
|
||||
assert stats['exchange'] == name
|
||||
assert 'connection_status' in stats
|
||||
print(f"✓ {name} statistics interface works")
|
||||
print(f"{name} statistics interface works")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ {name} statistics interface failed: {e}")
|
||||
@@ -160,7 +160,7 @@ class TestAllConnectors:
|
||||
connector.remove_data_callback(test_callback)
|
||||
assert test_callback not in connector.data_callbacks
|
||||
|
||||
print(f"✓ {name} callback system works")
|
||||
print(f"{name} callback system works")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ {name} callback system failed: {e}")
|
||||
@@ -176,7 +176,7 @@ class TestAllConnectors:
|
||||
is_connected = connector.is_connected
|
||||
assert isinstance(is_connected, bool)
|
||||
|
||||
print(f"✓ {name} connection status interface works")
|
||||
print(f"{name} connection status interface works")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ {name} connection status interface failed: {e}")
|
||||
@@ -206,25 +206,25 @@ async def test_connector_compatibility():
|
||||
|
||||
# Test initialization
|
||||
assert connector.exchange_name == name
|
||||
print(f" ✓ Initialization: {connector.exchange_name}")
|
||||
print(f" Initialization: {connector.exchange_name}")
|
||||
|
||||
# Test symbol normalization
|
||||
btc_symbol = connector.normalize_symbol('BTCUSDT')
|
||||
eth_symbol = connector.normalize_symbol('ETHUSDT')
|
||||
print(f" ✓ Symbol normalization: BTCUSDT -> {btc_symbol}, ETHUSDT -> {eth_symbol}")
|
||||
print(f" Symbol normalization: BTCUSDT -> {btc_symbol}, ETHUSDT -> {eth_symbol}")
|
||||
|
||||
# Test message type detection
|
||||
test_msg = {'type': 'test'} if name != 'kraken' else [1, {}, 'test', 'symbol']
|
||||
msg_type = connector._get_message_type(test_msg)
|
||||
print(f" ✓ Message type detection: {msg_type}")
|
||||
print(f" Message type detection: {msg_type}")
|
||||
|
||||
# Test statistics
|
||||
stats = connector.get_stats()
|
||||
print(f" ✓ Statistics: {len(stats)} fields")
|
||||
print(f" Statistics: {len(stats)} fields")
|
||||
|
||||
# Test connection status
|
||||
status = connector.get_connection_status()
|
||||
print(f" ✓ Connection status: {status.value}")
|
||||
print(f" Connection status: {status.value}")
|
||||
|
||||
print(f" {name.upper()} connector passed all tests")
|
||||
|
||||
@@ -265,7 +265,7 @@ async def test_multi_connector_data_flow():
|
||||
try:
|
||||
await connector.subscribe_orderbook(symbol)
|
||||
await connector.subscribe_trades(symbol)
|
||||
print(f"✓ Subscribed to {symbol} on {name}")
|
||||
print(f"Subscribed to {symbol} on {name}")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to subscribe to {symbol} on {name}: {e}")
|
||||
|
||||
|
||||
@@ -306,7 +306,7 @@ async def test_bybit_integration():
|
||||
test_message = {'topic': 'orderbook.50.BTCUSDT', 'data': {}}
|
||||
assert connector._get_message_type(test_message) == 'orderbook'
|
||||
|
||||
print("✓ Bybit connector integration test passed")
|
||||
print("Bybit connector integration test passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -349,7 +349,7 @@ async def test_coinbase_integration():
|
||||
test_message = {'type': 'l2update', 'product_id': 'BTC-USD'}
|
||||
assert connector._get_message_type(test_message) == 'l2update'
|
||||
|
||||
print("✓ Coinbase connector integration test passed")
|
||||
print("Coinbase connector integration test passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -383,7 +383,7 @@ async def test_kraken_integration():
|
||||
status_message = {'event': 'systemStatus', 'status': 'online'}
|
||||
assert connector._get_message_type(status_message) == 'systemStatus'
|
||||
|
||||
print("✓ Kraken connector integration test passed")
|
||||
print("Kraken connector integration test passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -344,7 +344,7 @@ async def test_integration_suite():
|
||||
|
||||
price = adapter.get_current_price('BTCUSDT')
|
||||
assert price == 50000.0
|
||||
logger.info(f"✓ Current price retrieval: {price}")
|
||||
logger.info(f"Current price retrieval: {price}")
|
||||
|
||||
# Test subscription
|
||||
callback_called = False
|
||||
@@ -355,17 +355,17 @@ async def test_integration_suite():
|
||||
|
||||
subscriber_id = adapter.subscribe_to_ticks(test_callback, ['BTCUSDT'])
|
||||
assert subscriber_id != ""
|
||||
logger.info(f"✓ Subscription created: {subscriber_id}")
|
||||
logger.info(f"Subscription created: {subscriber_id}")
|
||||
|
||||
# Test data quality
|
||||
quality = adapter.get_data_quality_indicators('BTCUSDT')
|
||||
assert quality['symbol'] == 'BTCUSDT'
|
||||
logger.info(f"✓ Data quality check: {quality['quality_score']}")
|
||||
logger.info(f"Data quality check: {quality['quality_score']}")
|
||||
|
||||
# Test system metadata
|
||||
metadata = adapter.get_system_metadata()
|
||||
assert metadata['system'] == 'COBY'
|
||||
logger.info(f"✓ System metadata: {metadata['mode']}")
|
||||
logger.info(f"System metadata: {metadata['mode']}")
|
||||
|
||||
logger.info("All integration tests passed successfully!")
|
||||
return True
|
||||
@@ -379,7 +379,7 @@ if __name__ == "__main__":
|
||||
# Run the integration tests
|
||||
success = asyncio.run(test_integration_suite())
|
||||
if success:
|
||||
print("✓ COBY orchestrator integration tests completed successfully")
|
||||
print("COBY orchestrator integration tests completed successfully")
|
||||
else:
|
||||
print("✗ COBY orchestrator integration tests failed")
|
||||
exit(1)
|
||||
@@ -1052,7 +1052,7 @@ class TradingTransformerTrainer:
|
||||
|
||||
# Move model to device
|
||||
self.model.to(self.device)
|
||||
logger.info(f"✅ Model moved to device: {self.device}")
|
||||
logger.info(f"Model moved to device: {self.device}")
|
||||
|
||||
# Log GPU info if available
|
||||
if torch.cuda.is_available():
|
||||
|
||||
@@ -2794,7 +2794,7 @@ class EnhancedRealtimeTrainingSystem:
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
logger.info(f"💾 Saved checkpoint for {model_name}: {checkpoint_path} (loss: {loss:.4f})")
|
||||
logger.info(f"Saved checkpoint for {model_name}: {checkpoint_path} (loss: {loss:.4f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||
@@ -118,7 +118,7 @@ class COBIntegration:
|
||||
async def _on_enhanced_cob_update(self, symbol: str, cob_data: Dict):
|
||||
"""Handle COB updates from Enhanced WebSocket"""
|
||||
try:
|
||||
logger.debug(f"📊 Enhanced WebSocket COB update for {symbol}")
|
||||
logger.debug(f"Enhanced WebSocket COB update for {symbol}")
|
||||
|
||||
# Convert enhanced WebSocket data to COB format for existing callbacks
|
||||
# Notify CNN callbacks
|
||||
|
||||
@@ -29,7 +29,7 @@ async def test_bybit_balance():
|
||||
print("ERROR: Failed to connect to Bybit")
|
||||
return
|
||||
|
||||
print("✓ Connected to Bybit successfully")
|
||||
print("Connected to Bybit successfully")
|
||||
|
||||
# Test get_balance for USDT
|
||||
print("\nTesting get_balance('USDT')...")
|
||||
|
||||
@@ -97,7 +97,7 @@ def debug_interface():
|
||||
print(f"Manual signature: {manual_signature}")
|
||||
|
||||
# Compare parameters
|
||||
print(f"\n📊 COMPARISON:")
|
||||
print(f"\nCOMPARISON:")
|
||||
print(f"symbol: Interface='{interface_params['symbol']}', Manual='{manual_params['symbol']}' {'' if interface_params['symbol'] == manual_params['symbol'] else ''}")
|
||||
print(f"side: Interface='{interface_params['side']}', Manual='{manual_params['side']}' {'' if interface_params['side'] == manual_params['side'] else ''}")
|
||||
print(f"type: Interface='{interface_params['type']}', Manual='{manual_params['type']}' {'' if interface_params['type'] == manual_params['type'] else ''}")
|
||||
@@ -111,7 +111,7 @@ def debug_interface():
|
||||
print(f"timeInForce: Interface='{interface_params['timeInForce']}', Manual=None (EXTRA PARAMETER)")
|
||||
|
||||
# Test without timeInForce
|
||||
print(f"\n🔧 TESTING WITHOUT timeInForce:")
|
||||
print(f"\nTESTING WITHOUT timeInForce:")
|
||||
interface_params_minimal = interface_params.copy()
|
||||
del interface_params_minimal['timeInForce']
|
||||
|
||||
|
||||
@@ -817,7 +817,7 @@ class TradingOrchestrator:
|
||||
'status': 'loaded'
|
||||
}
|
||||
|
||||
logger.info(f"✅ Loaded transformer checkpoint: {os.path.basename(checkpoint_path)}")
|
||||
logger.info(f"Loaded transformer checkpoint: {os.path.basename(checkpoint_path)}")
|
||||
logger.info(f" Epoch: {epoch}, Loss: {loss:.6f}, Accuracy: {accuracy:.2%}, LR: {learning_rate:.6f}")
|
||||
checkpoint_loaded = True
|
||||
else:
|
||||
@@ -1154,7 +1154,7 @@ class TradingOrchestrator:
|
||||
|
||||
logger.info("Orchestrator session data cleared")
|
||||
logger.info("🧠 Model states preserved for continued training")
|
||||
logger.info("📊 Prediction history cleared")
|
||||
logger.info("Prediction history cleared")
|
||||
logger.info("💼 Position tracking reset")
|
||||
|
||||
except Exception as e:
|
||||
@@ -1711,10 +1711,10 @@ class TradingOrchestrator:
|
||||
self.dashboard, "update_cob_data_from_orchestrator"
|
||||
):
|
||||
self.dashboard.update_cob_data_from_orchestrator(symbol, cob_data)
|
||||
logger.debug(f"📊 Sent COB data for {symbol} to dashboard")
|
||||
logger.debug(f"Sent COB data for {symbol} to dashboard")
|
||||
else:
|
||||
logger.debug(
|
||||
f"📊 No dashboard connected to receive COB data for {symbol}"
|
||||
f"No dashboard connected to receive COB data for {symbol}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -2811,6 +2811,25 @@ class TradingOrchestrator:
|
||||
self.trading_executor = trading_executor
|
||||
logger.info("Trading executor set for position tracking and P&L feedback")
|
||||
|
||||
def get_latest_transformer_prediction(self, symbol: str = 'ETH/USDT') -> Optional[Dict]:
|
||||
"""
|
||||
Get latest transformer prediction with next_candles data for ghost candle display
|
||||
Returns dict with predicted OHLCV for each timeframe
|
||||
"""
|
||||
try:
|
||||
if not self.primary_transformer:
|
||||
return None
|
||||
|
||||
# Get recent predictions from storage
|
||||
if symbol in self.recent_transformer_predictions and self.recent_transformer_predictions[symbol]:
|
||||
return dict(self.recent_transformer_predictions[symbol][-1])
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting latest transformer prediction: {e}")
|
||||
return None
|
||||
|
||||
def store_transformer_prediction(self, symbol: str, prediction: Dict):
|
||||
"""Store a transformer prediction for visualization and tracking"""
|
||||
try:
|
||||
|
||||
@@ -661,7 +661,7 @@ class OvernightTrainingCoordinator:
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Overall statistics
|
||||
logger.info(f"📊 OVERALL STATISTICS:")
|
||||
logger.info(f"OVERALL STATISTICS:")
|
||||
logger.info(f" Total Signals Processed: {self.performance_stats['total_signals']}")
|
||||
logger.info(f" Total Trades Executed: {self.performance_stats['total_trades']}")
|
||||
logger.info(f" Successful Trades: {self.performance_stats['successful_trades']}")
|
||||
@@ -679,7 +679,7 @@ class OvernightTrainingCoordinator:
|
||||
executed_trades = [r for r in recent_records if r.executed]
|
||||
successful_trades = [r for r in executed_trades if r.trade_pnl and r.trade_pnl > 0]
|
||||
|
||||
logger.info(f"📈 RECENT PERFORMANCE (Last 20 signals):")
|
||||
logger.info(f"RECENT PERFORMANCE (Last 20 signals):")
|
||||
logger.info(f" Signals: {len(recent_records)}")
|
||||
logger.info(f" Executed: {len(executed_trades)}")
|
||||
logger.info(f" Successful: {len(successful_trades)}")
|
||||
|
||||
@@ -75,7 +75,7 @@ class RealtimePredictionLoop:
|
||||
new_candle_detected, timeframe = await self._detect_new_candle(symbol)
|
||||
|
||||
if new_candle_detected:
|
||||
logger.info(f"📊 New {timeframe} candle detected for {symbol} - running predictions")
|
||||
logger.info(f"New {timeframe} candle detected for {symbol} - running predictions")
|
||||
await self._run_all_model_predictions(symbol, trigger=f"new_{timeframe}_candle")
|
||||
|
||||
# 2. Check for pivot point
|
||||
|
||||
@@ -803,7 +803,7 @@ class TradingExecutor:
|
||||
self.max_profitability_multiplier,
|
||||
self.profitability_reward_multiplier + self.profitability_adjustment_step
|
||||
)
|
||||
logger.info(f"🎯 SUCCESS RATE HIGH ({success_rate:.1%}) - Increased profitability multiplier: {old_multiplier:.1f} → {self.profitability_reward_multiplier:.1f}")
|
||||
logger.info(f"SUCCESS RATE HIGH ({success_rate:.1%}) - Increased profitability multiplier: {old_multiplier:.1f} -> {self.profitability_reward_multiplier:.1f}")
|
||||
|
||||
# Decrease multiplier if success rate < 51%
|
||||
elif success_rate < self.success_rate_decrease_threshold:
|
||||
@@ -811,7 +811,7 @@ class TradingExecutor:
|
||||
self.min_profitability_multiplier,
|
||||
self.profitability_reward_multiplier - self.profitability_adjustment_step
|
||||
)
|
||||
logger.info(f" SUCCESS RATE LOW ({success_rate:.1%}) - Decreased profitability multiplier: {old_multiplier:.1f} → {self.profitability_reward_multiplier:.1f}")
|
||||
logger.info(f" SUCCESS RATE LOW ({success_rate:.1%}) - Decreased profitability multiplier: {old_multiplier:.1f} -> {self.profitability_reward_multiplier:.1f}")
|
||||
|
||||
else:
|
||||
logger.debug(f"Success rate {success_rate:.1%} in acceptable range - keeping multiplier at {self.profitability_reward_multiplier:.1f}")
|
||||
@@ -2168,9 +2168,9 @@ class TradingExecutor:
|
||||
f.write(f"Export Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"Data File: {filename}\n")
|
||||
|
||||
logger.info(f"📊 Trade history exported to: {filepath}")
|
||||
logger.info(f"📈 Trade summary saved to: {summary_filepath}")
|
||||
logger.info(f"📊 Total Trades: {total_trades} | Win Rate: {win_rate:.1f}% | Total P&L: ${total_pnl:.2f}")
|
||||
logger.info(f"Trade history exported to: {filepath}")
|
||||
logger.info(f"Trade summary saved to: {summary_filepath}")
|
||||
logger.info(f"Total Trades: {total_trades} | Win Rate: {win_rate:.1f}% | Total P&L: ${total_pnl:.2f}")
|
||||
|
||||
return str(filepath)
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ class UnifiedDataProviderExtension:
|
||||
|
||||
# Initialize cache manager
|
||||
self.cache_manager = DataCacheManager(cache_duration_seconds=300)
|
||||
logger.info("✓ Cache manager initialized")
|
||||
logger.info("Cache manager initialized")
|
||||
|
||||
# Initialize database connection
|
||||
self.db_connection = DatabaseConnectionManager(self.config)
|
||||
@@ -77,11 +77,11 @@ class UnifiedDataProviderExtension:
|
||||
logger.error("Failed to initialize database connection")
|
||||
return False
|
||||
|
||||
logger.info("✓ Database connection initialized")
|
||||
logger.info("Database connection initialized")
|
||||
|
||||
# Initialize query manager
|
||||
self.db_query_manager = UnifiedDatabaseQueryManager(self.db_connection)
|
||||
logger.info("✓ Query manager initialized")
|
||||
logger.info("Query manager initialized")
|
||||
|
||||
# Initialize ingestion pipeline
|
||||
self.ingestion_pipeline = DataIngestionPipeline(
|
||||
@@ -93,7 +93,7 @@ class UnifiedDataProviderExtension:
|
||||
|
||||
# Start ingestion pipeline
|
||||
self.ingestion_pipeline.start()
|
||||
logger.info("✓ Ingestion pipeline started")
|
||||
logger.info("Ingestion pipeline started")
|
||||
|
||||
self._initialized = True
|
||||
logger.info(" Unified storage system initialized successfully")
|
||||
@@ -112,11 +112,11 @@ class UnifiedDataProviderExtension:
|
||||
|
||||
if self.ingestion_pipeline:
|
||||
await self.ingestion_pipeline.stop()
|
||||
logger.info("✓ Ingestion pipeline stopped")
|
||||
logger.info("Ingestion pipeline stopped")
|
||||
|
||||
if self.db_connection:
|
||||
await self.db_connection.close()
|
||||
logger.info("✓ Database connection closed")
|
||||
logger.info("Database connection closed")
|
||||
|
||||
self._initialized = False
|
||||
logger.info(" Unified storage system shutdown complete")
|
||||
|
||||
@@ -94,7 +94,7 @@ def main():
|
||||
|
||||
if success:
|
||||
logger.info("Example completed successfully!")
|
||||
print("\n🎉 Report Data Crawler is ready for use!")
|
||||
print("\nReport Data Crawler is ready for use!")
|
||||
print("\nUsage:")
|
||||
print("1. data_provider.generate_trading_report('BTC/USDT') - Get formatted report")
|
||||
print("2. data_provider.crawl_comprehensive_report('BTC/USDT') - Get raw data")
|
||||
|
||||
@@ -39,18 +39,18 @@ async def example_realtime_data():
|
||||
print("\n2. Getting latest real-time data...")
|
||||
inference_data = await data_provider.get_inference_data_unified('ETH/USDT')
|
||||
|
||||
print(f"\n📊 Inference Data:")
|
||||
print(f"\nInference Data:")
|
||||
print(f" Symbol: {inference_data.symbol}")
|
||||
print(f" Timestamp: {inference_data.timestamp}")
|
||||
print(f" Data Source: {inference_data.data_source}")
|
||||
print(f" Query Latency: {inference_data.query_latency_ms:.2f}ms")
|
||||
|
||||
# Check data completeness
|
||||
print(f"\n✓ Complete Data: {inference_data.has_complete_data()}")
|
||||
print(f"\nComplete Data: {inference_data.has_complete_data()}")
|
||||
|
||||
# Get data summary
|
||||
summary = inference_data.get_data_summary()
|
||||
print(f"\n📈 Data Summary:")
|
||||
print(f"\nData Summary:")
|
||||
print(f" OHLCV 1s rows: {summary['ohlcv_1s_rows']}")
|
||||
print(f" OHLCV 1m rows: {summary['ohlcv_1m_rows']}")
|
||||
print(f" OHLCV 1h rows: {summary['ohlcv_1h_rows']}")
|
||||
@@ -64,7 +64,7 @@ async def example_realtime_data():
|
||||
|
||||
# Get technical indicators
|
||||
if inference_data.indicators:
|
||||
print(f"\n📉 Technical Indicators:")
|
||||
print(f"\nTechnical Indicators:")
|
||||
for indicator, value in inference_data.indicators.items():
|
||||
print(f" {indicator}: {value:.4f}")
|
||||
|
||||
@@ -94,21 +94,21 @@ async def example_historical_data():
|
||||
context_window_minutes=5
|
||||
)
|
||||
|
||||
print(f"\n📊 Inference Data:")
|
||||
print(f"\nInference Data:")
|
||||
print(f" Symbol: {inference_data.symbol}")
|
||||
print(f" Timestamp: {inference_data.timestamp}")
|
||||
print(f" Data Source: {inference_data.data_source}")
|
||||
print(f" Query Latency: {inference_data.query_latency_ms:.2f}ms")
|
||||
|
||||
# Show multi-timeframe data
|
||||
print(f"\n📈 Multi-Timeframe Data:")
|
||||
print(f"\nMulti-Timeframe Data:")
|
||||
for tf in ['1s', '1m', '5m', '15m', '1h', '1d']:
|
||||
df = inference_data.get_timeframe_data(tf)
|
||||
print(f" {tf}: {len(df)} candles")
|
||||
|
||||
# Show context data
|
||||
if inference_data.context_data is not None:
|
||||
print(f"\n🔍 Context Data: {len(inference_data.context_data)} rows")
|
||||
print(f"\nContext Data: {len(inference_data.context_data)} rows")
|
||||
|
||||
# Cleanup
|
||||
await data_provider.disable_unified_storage()
|
||||
@@ -134,7 +134,7 @@ async def example_multi_timeframe():
|
||||
limit=100
|
||||
)
|
||||
|
||||
print(f"\n📊 Multi-Timeframe Data:")
|
||||
print(f"\nMulti-Timeframe Data:")
|
||||
for timeframe, df in multi_tf.items():
|
||||
print(f"\n {timeframe}:")
|
||||
print(f" Rows: {len(df)}")
|
||||
@@ -163,7 +163,7 @@ async def example_orderbook():
|
||||
print("\n2. Getting order book data...")
|
||||
orderbook = await data_provider.get_order_book_data_unified('ETH/USDT')
|
||||
|
||||
print(f"\n📊 Order Book:")
|
||||
print(f"\nOrder Book:")
|
||||
print(f" Symbol: {orderbook.symbol}")
|
||||
print(f" Timestamp: {orderbook.timestamp}")
|
||||
print(f" Mid Price: ${orderbook.mid_price:.2f}")
|
||||
@@ -181,7 +181,7 @@ async def example_orderbook():
|
||||
|
||||
# Show imbalances
|
||||
imbalances = orderbook.get_imbalance_summary()
|
||||
print(f"\n📉 Imbalances:")
|
||||
print(f"\nImbalances:")
|
||||
for key, value in imbalances.items():
|
||||
print(f" {key}: {value:.4f}")
|
||||
|
||||
@@ -211,7 +211,7 @@ async def example_statistics():
|
||||
stats = data_provider.get_unified_storage_stats()
|
||||
|
||||
if stats.get('cache'):
|
||||
print(f"\n📊 Cache Statistics:")
|
||||
print(f"\nCache Statistics:")
|
||||
cache_stats = stats['cache']
|
||||
print(f" Hit Rate: {cache_stats.get('hit_rate_percent', 0):.2f}%")
|
||||
print(f" Total Entries: {cache_stats.get('total_entries', 0)}")
|
||||
@@ -219,7 +219,7 @@ async def example_statistics():
|
||||
print(f" Cache Misses: {cache_stats.get('cache_misses', 0)}")
|
||||
|
||||
if stats.get('database'):
|
||||
print(f"\n💾 Database Statistics:")
|
||||
print(f"\nDatabase Statistics:")
|
||||
db_stats = stats['database']
|
||||
print(f" Total Queries: {db_stats.get('total_queries', 0)}")
|
||||
print(f" Failed Queries: {db_stats.get('failed_queries', 0)}")
|
||||
|
||||
@@ -183,9 +183,9 @@ def main():
|
||||
for proc in remaining:
|
||||
kill_process(proc, force=True)
|
||||
|
||||
print(f"\n✓ Killed {killed_count} dashboard process(es)")
|
||||
print(f"\nKilled {killed_count} dashboard process(es)")
|
||||
else:
|
||||
print("\n✓ No processes to kill")
|
||||
print("\nNo processes to kill")
|
||||
|
||||
print("\nPort status:")
|
||||
for port in DASHBOARD_PORTS:
|
||||
|
||||
@@ -159,7 +159,7 @@ async def test_basic_operations(pool):
|
||||
VALUES (NOW(), 'ETH/USDT', '1s', 2000.0, 2001.0, 1999.0, 2000.5, 100.0)
|
||||
ON CONFLICT (timestamp, symbol, timeframe) DO NOTHING
|
||||
""")
|
||||
logger.info("✓ OHLCV insert successful")
|
||||
logger.info("OHLCV insert successful")
|
||||
|
||||
# Test query
|
||||
logger.info("Testing OHLCV query...")
|
||||
@@ -170,7 +170,7 @@ async def test_basic_operations(pool):
|
||||
LIMIT 1
|
||||
""")
|
||||
if result:
|
||||
logger.info(f"✓ OHLCV query successful: {dict(result)}")
|
||||
logger.info(f"OHLCV query successful: {dict(result)}")
|
||||
|
||||
# Test order book insert
|
||||
logger.info("Testing order book insert...")
|
||||
@@ -180,7 +180,7 @@ async def test_basic_operations(pool):
|
||||
VALUES (NOW(), 'ETH/USDT', 'binance', '[]'::jsonb, '[]'::jsonb, 2000.0, 0.1)
|
||||
ON CONFLICT (timestamp, symbol, exchange) DO NOTHING
|
||||
""")
|
||||
logger.info("✓ Order book insert successful")
|
||||
logger.info("Order book insert successful")
|
||||
|
||||
# Test imbalances insert
|
||||
logger.info("Testing imbalances insert...")
|
||||
@@ -190,9 +190,9 @@ async def test_basic_operations(pool):
|
||||
VALUES (NOW(), 'ETH/USDT', 0.5, 0.4, 0.3, 0.2)
|
||||
ON CONFLICT (timestamp, symbol) DO NOTHING
|
||||
""")
|
||||
logger.info("✓ Imbalances insert successful")
|
||||
logger.info("Imbalances insert successful")
|
||||
|
||||
logger.info("\n✓ All basic operations successful")
|
||||
logger.info("\nAll basic operations successful")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -25,7 +25,7 @@ if torch.cuda.is_available():
|
||||
print("=" * 80)
|
||||
try:
|
||||
x = torch.tensor([1.0, 2.0, 3.0]).cuda()
|
||||
print("✓ PASSED: Simple tensor creation on GPU")
|
||||
print("PASSED: Simple tensor creation on GPU")
|
||||
except Exception as e:
|
||||
print(f"✗ FAILED: {e}")
|
||||
sys.exit(1)
|
||||
@@ -38,7 +38,7 @@ if torch.cuda.is_available():
|
||||
a = torch.randn(100, 100).cuda()
|
||||
b = torch.randn(100, 100).cuda()
|
||||
c = torch.matmul(a, b)
|
||||
print("✓ PASSED: Matrix multiplication on GPU")
|
||||
print("PASSED: Matrix multiplication on GPU")
|
||||
except Exception as e:
|
||||
print(f"✗ FAILED: {e}")
|
||||
sys.exit(1)
|
||||
@@ -51,8 +51,8 @@ if torch.cuda.is_available():
|
||||
x = torch.randn(10, 20).cuda()
|
||||
linear = torch.nn.Linear(20, 10).cuda()
|
||||
y = linear(x)
|
||||
print("✓ PASSED: Linear layer on GPU")
|
||||
print("✓ Your GPU is fully compatible!")
|
||||
print("PASSED: Linear layer on GPU")
|
||||
print("Your GPU is fully compatible!")
|
||||
except RuntimeError as e:
|
||||
if "invalid device function" in str(e):
|
||||
print(f"✗ FAILED: {e}")
|
||||
@@ -77,7 +77,7 @@ if torch.cuda.is_available():
|
||||
x = torch.randn(1, 3, 32, 32).cuda()
|
||||
conv = torch.nn.Conv2d(3, 16, 3).cuda()
|
||||
y = conv(x)
|
||||
print("✓ PASSED: Convolutional layer on GPU")
|
||||
print("PASSED: Convolutional layer on GPU")
|
||||
except Exception as e:
|
||||
print(f"✗ FAILED: {e}")
|
||||
|
||||
@@ -89,7 +89,7 @@ if torch.cuda.is_available():
|
||||
x = torch.randn(1, 10, 512).cuda()
|
||||
transformer = torch.nn.TransformerEncoderLayer(d_model=512, nhead=8).cuda()
|
||||
y = transformer(x)
|
||||
print("✓ PASSED: Transformer layer on GPU")
|
||||
print("PASSED: Transformer layer on GPU")
|
||||
except Exception as e:
|
||||
print(f"✗ FAILED: {e}")
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ def test_checkpoint_reading():
|
||||
print(f" Exists: {os.path.exists(checkpoint_dir)}")
|
||||
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
print(f" ❌ Directory not found")
|
||||
print(f" Directory not found")
|
||||
continue
|
||||
|
||||
# Find checkpoint files
|
||||
@@ -28,7 +28,7 @@ def test_checkpoint_reading():
|
||||
print(f" Checkpoints found: {len(checkpoint_files)}")
|
||||
|
||||
if not checkpoint_files:
|
||||
print(f" ❌ No checkpoint files")
|
||||
print(f" No checkpoint files")
|
||||
continue
|
||||
|
||||
# Try to load best checkpoint
|
||||
@@ -55,13 +55,13 @@ def test_checkpoint_reading():
|
||||
'accuracy': accuracy
|
||||
}
|
||||
except Exception as e:
|
||||
print(f" ❌ Error: {e}")
|
||||
print(f" Error: {e}")
|
||||
|
||||
if best_checkpoint:
|
||||
print(f" ✅ Best: {best_checkpoint['filename']}")
|
||||
print(f" Best: {best_checkpoint['filename']}")
|
||||
print(f" Epoch: {best_checkpoint['epoch']}, Loss: {best_checkpoint['loss']:.6f}, Accuracy: {best_checkpoint['accuracy']:.2%}")
|
||||
else:
|
||||
print(f" ❌ No valid checkpoint found")
|
||||
print(f" No valid checkpoint found")
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_checkpoint_reading()
|
||||
|
||||
@@ -203,14 +203,14 @@ print(" - Data retrieval: WORKING")
|
||||
print(" - SQL queries: WORKING")
|
||||
print(" - Annotation manager: WORKING")
|
||||
|
||||
print("\n📊 Performance:")
|
||||
print("\nPerformance:")
|
||||
print(f" - Initialization: {init_time:.2f}s")
|
||||
if 'fetch_time' in locals():
|
||||
print(f" - Data fetch: {fetch_time:.2f}s")
|
||||
if 'query_time' in locals():
|
||||
print(f" - DuckDB query: {query_time*1000:.1f}ms")
|
||||
|
||||
print("\n💡 Benefits:")
|
||||
print("\nBenefits:")
|
||||
print(" - Single storage system (no dual cache)")
|
||||
print(" - Native Parquet support (fast queries)")
|
||||
print(" - Full SQL capabilities (complex queries)")
|
||||
|
||||
@@ -30,14 +30,14 @@ print()
|
||||
|
||||
# Check GPU availability
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU Available: ✅ CUDA")
|
||||
print(f"GPU Available: CUDA")
|
||||
print(f" Device: {torch.cuda.get_device_name(0)}")
|
||||
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
|
||||
|
||||
# Move model to GPU
|
||||
device = torch.device('cuda')
|
||||
model = model.to(device)
|
||||
print(f" Model moved to GPU ✅")
|
||||
print(f" Model moved to GPU")
|
||||
|
||||
# Test forward pass
|
||||
batch_size = 1
|
||||
@@ -50,18 +50,18 @@ if torch.cuda.is_available():
|
||||
with torch.no_grad():
|
||||
outputs = model(price_data_1m=price_data_1m)
|
||||
|
||||
print(f" Forward pass successful ✅")
|
||||
print(f" Forward pass successful")
|
||||
print(f" GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
||||
print(f" GPU memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
|
||||
|
||||
elif hasattr(torch.version, 'hip') and torch.version.hip:
|
||||
print(f"GPU Available: ✅ ROCm/HIP")
|
||||
print(f"GPU Available: ROCm/HIP")
|
||||
device = torch.device('cuda') # ROCm uses 'cuda' device name
|
||||
model = model.to(device)
|
||||
print(f" Model moved to GPU ✅")
|
||||
print(f" Model moved to GPU")
|
||||
else:
|
||||
print(f"GPU Available: ❌ CPU only")
|
||||
print(f"GPU Available: CPU only")
|
||||
print(f" Training will use CPU (slower)")
|
||||
|
||||
print()
|
||||
print("Model ready for training! 🚀")
|
||||
print("Model ready for training!")
|
||||
|
||||
@@ -109,9 +109,9 @@ def test_model_outputs():
|
||||
print(f" Mean value: {candle_pred.mean().item():.6f}")
|
||||
|
||||
if candle_pred.min() >= 0.0 and candle_pred.max() <= 1.0:
|
||||
print(" ✅ PASS: Values in [0, 1] range (Sigmoid working!)")
|
||||
print(" PASS: Values in [0, 1] range (Sigmoid working!)")
|
||||
else:
|
||||
print(" ❌ FAIL: Values outside [0, 1] range!")
|
||||
print(" FAIL: Values outside [0, 1] range!")
|
||||
|
||||
# Check price prediction is in [-1, 1] range (thanks to Tanh)
|
||||
if 'price_prediction' in outputs:
|
||||
@@ -121,9 +121,9 @@ def test_model_outputs():
|
||||
print(f" Value: {price_pred.item():.6f}")
|
||||
|
||||
if price_pred.min() >= -1.0 and price_pred.max() <= 1.0:
|
||||
print(" ✅ PASS: Values in [-1, 1] range (Tanh working!)")
|
||||
print(" PASS: Values in [-1, 1] range (Tanh working!)")
|
||||
else:
|
||||
print(" ❌ FAIL: Values outside [-1, 1] range!")
|
||||
print(" FAIL: Values outside [-1, 1] range!")
|
||||
|
||||
# Check action probabilities sum to 1
|
||||
if 'action_probs' in outputs:
|
||||
@@ -135,9 +135,9 @@ def test_model_outputs():
|
||||
print(f" Sum: {action_probs[0].sum().item():.6f}")
|
||||
|
||||
if abs(action_probs[0].sum().item() - 1.0) < 0.001:
|
||||
print(" ✅ PASS: Probabilities sum to 1.0")
|
||||
print(" PASS: Probabilities sum to 1.0")
|
||||
else:
|
||||
print(" ❌ FAIL: Probabilities don't sum to 1.0!")
|
||||
print(" FAIL: Probabilities don't sum to 1.0!")
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -180,18 +180,18 @@ def test_denormalization():
|
||||
for i, name in enumerate(['Open', 'High', 'Low', 'Close']):
|
||||
value = denorm_candle[0, i].item()
|
||||
if value < expected_min_price or value > expected_max_price:
|
||||
print(f" ❌ FAIL: {name} price ${value:.2f} outside expected range!")
|
||||
print(f" FAIL: {name} price ${value:.2f} outside expected range!")
|
||||
prices_ok = False
|
||||
|
||||
if prices_ok:
|
||||
print(f" ✅ PASS: All prices in expected range [${expected_min_price}, ${expected_max_price}]")
|
||||
print(f" PASS: All prices in expected range [${expected_min_price}, ${expected_max_price}]")
|
||||
|
||||
# Verify volume
|
||||
volume = denorm_candle[0, 4].item()
|
||||
if norm_params['volume_min'] <= volume <= norm_params['volume_max']:
|
||||
print(f" ✅ PASS: Volume {volume:.2f} in expected range [{norm_params['volume_min']}, {norm_params['volume_max']}]")
|
||||
print(f" PASS: Volume {volume:.2f} in expected range [{norm_params['volume_min']}, {norm_params['volume_max']}]")
|
||||
else:
|
||||
print(f" ❌ FAIL: Volume {volume:.2f} outside expected range!")
|
||||
print(f" FAIL: Volume {volume:.2f} outside expected range!")
|
||||
|
||||
def test_loss_magnitude():
|
||||
"""Test that losses are in reasonable ranges"""
|
||||
@@ -227,15 +227,15 @@ def test_loss_magnitude():
|
||||
all_ok = True
|
||||
|
||||
if result['total_loss'] < 100.0:
|
||||
print(f" ✅ PASS: Total loss < 100 (was {result['total_loss']:.6f})")
|
||||
print(f" PASS: Total loss < 100 (was {result['total_loss']:.6f})")
|
||||
else:
|
||||
print(f" ❌ FAIL: Total loss too high! ({result['total_loss']:.6f})")
|
||||
print(f" FAIL: Total loss too high! ({result['total_loss']:.6f})")
|
||||
all_ok = False
|
||||
|
||||
if result['candle_loss'] < 10.0:
|
||||
print(f" ✅ PASS: Candle loss < 10 (was {result['candle_loss']:.6f})")
|
||||
print(f" PASS: Candle loss < 10 (was {result['candle_loss']:.6f})")
|
||||
else:
|
||||
print(f" ❌ FAIL: Candle loss too high! ({result['candle_loss']:.6f})")
|
||||
print(f" FAIL: Candle loss too high! ({result['candle_loss']:.6f})")
|
||||
all_ok = False
|
||||
|
||||
# Check denormalized losses if available
|
||||
@@ -244,15 +244,15 @@ def test_loss_magnitude():
|
||||
for tf, loss in result['candle_loss_denorm'].items():
|
||||
print(f" {tf}: ${loss:.2f}")
|
||||
if loss < 1000.0:
|
||||
print(f" ✅ PASS: Real price error < $1000")
|
||||
print(f" PASS: Real price error < $1000")
|
||||
else:
|
||||
print(f" ❌ FAIL: Real price error too high!")
|
||||
print(f" FAIL: Real price error too high!")
|
||||
all_ok = False
|
||||
|
||||
if all_ok:
|
||||
print("\n ✅ ALL TESTS PASSED: Losses in reasonable ranges!")
|
||||
print("\n ALL TESTS PASSED: Losses in reasonable ranges!")
|
||||
else:
|
||||
print("\n ❌ SOME TESTS FAILED: Check model/normalization!")
|
||||
print("\n SOME TESTS FAILED: Check model/normalization!")
|
||||
|
||||
return result
|
||||
|
||||
@@ -275,7 +275,7 @@ def main():
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST SUMMARY")
|
||||
print("=" * 80)
|
||||
print("\nIf all tests passed (✅), the normalization fix is working correctly!")
|
||||
print("\nIf all tests passed, the normalization fix is working correctly!")
|
||||
print("You should now see reasonable losses in training logs:")
|
||||
print(" - Total loss: ~0.5-1.0 (not billions!)")
|
||||
print(" - Candle loss: ~0.1-0.3")
|
||||
|
||||
@@ -77,7 +77,7 @@ def test_pivot_levels():
|
||||
if __name__ == "__main__":
|
||||
success = test_pivot_levels()
|
||||
if success:
|
||||
print("\n🎉 Pivot levels test completed!")
|
||||
print("\nPivot levels test completed!")
|
||||
else:
|
||||
print("\n Pivot levels test failed!")
|
||||
sys.exit(1)
|
||||
|
||||
@@ -55,7 +55,7 @@ def test_training():
|
||||
|
||||
# List available methods
|
||||
methods = [m for m in dir(trainer) if not m.startswith('_') and callable(getattr(trainer, m))]
|
||||
print(f" 📋 Trainer methods: {', '.join(methods[:10])}...")
|
||||
print(f" Trainer methods: {', '.join(methods[:10])}...")
|
||||
|
||||
if not available_models:
|
||||
print(" No models available!")
|
||||
|
||||
@@ -1962,6 +1962,10 @@ class CleanTradingDashboard:
|
||||
def update_price_chart(n, pivots_value, relayout_data):
|
||||
"""Update price chart every second, persisting user zoom/pan"""
|
||||
try:
|
||||
# Validate and train on predictions every update (once per second)
|
||||
# This checks if any predictions can be validated against real candles
|
||||
self._validate_and_train_on_predictions('ETH/USDT')
|
||||
|
||||
show_pivots = bool(pivots_value and 'enabled' in pivots_value)
|
||||
fig, legend_children = self._create_price_chart('ETH/USDT', show_pivots=show_pivots, return_legend=True)
|
||||
|
||||
@@ -4061,6 +4065,107 @@ class CleanTradingDashboard:
|
||||
logger.debug(f"Error getting CNN predictions: {e}")
|
||||
return []
|
||||
|
||||
def _get_live_transformer_prediction_with_next_candles(self, symbol: str = 'ETH/USDT') -> Optional[Dict]:
|
||||
"""
|
||||
Get live transformer prediction including next_candles for ghost candle display
|
||||
This makes a real-time prediction with the transformer model
|
||||
"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer:
|
||||
return None
|
||||
|
||||
transformer = self.orchestrator.primary_transformer
|
||||
transformer.eval()
|
||||
|
||||
# Get recent market data for all timeframes
|
||||
price_data_1s = self.data_provider.get_ohlcv(symbol, '1s', limit=200)
|
||||
price_data_1m = self.data_provider.get_ohlcv(symbol, '1m', limit=150)
|
||||
price_data_1h = self.data_provider.get_ohlcv(symbol, '1h', limit=24)
|
||||
price_data_1d = self.data_provider.get_ohlcv(symbol, '1d', limit=14)
|
||||
btc_data_1m = self.data_provider.get_ohlcv('BTC/USDT', '1m', limit=150)
|
||||
|
||||
if not price_data_1m or len(price_data_1m) < 10:
|
||||
return None
|
||||
|
||||
# Convert to tensors (simplified - you may need proper normalization)
|
||||
import torch
|
||||
device = next(transformer.parameters()).device
|
||||
|
||||
def ohlcv_to_tensor(data, limit=None):
|
||||
if not data:
|
||||
return None
|
||||
data = data[-limit:] if limit and len(data) > limit else data
|
||||
arr = np.array([[d['open'], d['high'], d['low'], d['close'], d['volume']] for d in data], dtype=np.float32)
|
||||
return torch.from_numpy(arr).unsqueeze(0).to(device) # Add batch dim
|
||||
|
||||
# Create input tensors
|
||||
inputs = {
|
||||
'price_data_1s': ohlcv_to_tensor(price_data_1s, 200),
|
||||
'price_data_1m': ohlcv_to_tensor(price_data_1m, 150),
|
||||
'price_data_1h': ohlcv_to_tensor(price_data_1h, 24),
|
||||
'price_data_1d': ohlcv_to_tensor(price_data_1d, 14),
|
||||
'btc_data_1m': ohlcv_to_tensor(btc_data_1m, 150)
|
||||
}
|
||||
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
outputs = transformer(**inputs)
|
||||
|
||||
# Extract next_candles predictions
|
||||
next_candles = outputs.get('next_candles', {})
|
||||
if not next_candles:
|
||||
return None
|
||||
|
||||
# Convert tensors to lists for JSON serialization
|
||||
predicted_candles = {}
|
||||
for tf, candle_tensor in next_candles.items():
|
||||
if candle_tensor is not None:
|
||||
# candle_tensor shape: [batch, 5] where 5 is [O, H, L, C, V]
|
||||
candle_values = candle_tensor.squeeze(0).cpu().numpy().tolist()
|
||||
predicted_candles[tf] = candle_values
|
||||
|
||||
# Get current price for action determination
|
||||
current_price = price_data_1m[-1]['close']
|
||||
predicted_1m_close = predicted_candles.get('1m', [0,0,0,current_price,0])[3]
|
||||
|
||||
# Determine action based on price change
|
||||
price_change = (predicted_1m_close - current_price) / current_price
|
||||
if price_change > 0.001:
|
||||
action = 'BUY'
|
||||
elif price_change < -0.001:
|
||||
action = 'SELL'
|
||||
else:
|
||||
action = 'HOLD'
|
||||
|
||||
# Get confidence from outputs if available
|
||||
confidence = 0.7 # Default confidence
|
||||
if 'confidence' in outputs:
|
||||
conf_tensor = outputs['confidence']
|
||||
confidence = float(conf_tensor.squeeze(0).cpu().numpy()[0])
|
||||
|
||||
prediction = {
|
||||
'timestamp': datetime.now(),
|
||||
'symbol': symbol,
|
||||
'action': action,
|
||||
'confidence': confidence,
|
||||
'predicted_price': predicted_1m_close,
|
||||
'current_price': current_price,
|
||||
'price_change': price_change,
|
||||
'next_candles': predicted_candles, # This is what the frontend needs!
|
||||
'type': 'transformer_prediction'
|
||||
}
|
||||
|
||||
# Store prediction for tracking
|
||||
self.orchestrator.store_transformer_prediction(symbol, prediction)
|
||||
|
||||
logger.debug(f"Generated transformer prediction with next_candles for {len(predicted_candles)} timeframes")
|
||||
|
||||
return prediction
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting live transformer prediction: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _get_recent_transformer_predictions(self, symbol: str) -> List[Dict]:
|
||||
"""Get recent Transformer predictions from orchestrator"""
|
||||
try:
|
||||
@@ -4109,6 +4214,217 @@ class CleanTradingDashboard:
|
||||
logger.debug(f"Error getting prediction accuracy history: {e}")
|
||||
return []
|
||||
|
||||
def _validate_and_train_on_predictions(self, symbol: str = 'ETH/USDT'):
|
||||
"""
|
||||
Validate pending predictions against real candles and train on them
|
||||
This is called periodically to check if predictions can be validated
|
||||
"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'recent_transformer_predictions'):
|
||||
return
|
||||
|
||||
# Get recent predictions for this symbol
|
||||
predictions = self.orchestrator.recent_transformer_predictions.get(symbol, [])
|
||||
if not predictions:
|
||||
return
|
||||
|
||||
for prediction in list(predictions):
|
||||
# Skip if already validated
|
||||
if prediction.get('validated', False):
|
||||
continue
|
||||
|
||||
# Check if prediction has next_candles
|
||||
next_candles = prediction.get('next_candles', {})
|
||||
if not next_candles:
|
||||
continue
|
||||
|
||||
pred_timestamp = prediction.get('timestamp')
|
||||
if not pred_timestamp:
|
||||
continue
|
||||
|
||||
# Check each timeframe
|
||||
for timeframe, predicted_ohlcv in next_candles.items():
|
||||
try:
|
||||
# Calculate when this prediction should be validated
|
||||
# For '1s' prediction, validate after 1 second
|
||||
# For '1m' prediction, validate after 60 seconds
|
||||
validation_delay = {'1s': 1, '1m': 60, '1h': 3600, '1d': 86400}.get(timeframe, 60)
|
||||
|
||||
# Check if enough time has passed
|
||||
current_time = datetime.now()
|
||||
if not isinstance(pred_timestamp, datetime):
|
||||
pred_timestamp = pd.to_datetime(pred_timestamp)
|
||||
|
||||
time_elapsed = (current_time - pred_timestamp).total_seconds()
|
||||
if time_elapsed < validation_delay:
|
||||
continue # Not ready to validate yet
|
||||
|
||||
# Get the actual candle at the predicted time
|
||||
target_time = pred_timestamp + timedelta(seconds=validation_delay)
|
||||
actual_candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=10)
|
||||
|
||||
if not actual_candles:
|
||||
continue
|
||||
|
||||
# Find the candle closest to target_time
|
||||
actual_candle = None
|
||||
for candle in actual_candles:
|
||||
candle_time = candle.get('timestamp') or candle.get('time')
|
||||
if not candle_time:
|
||||
continue
|
||||
if not isinstance(candle_time, datetime):
|
||||
candle_time = pd.to_datetime(candle_time)
|
||||
|
||||
# Check if this is the target candle (within 1 second tolerance)
|
||||
if abs((candle_time - target_time).total_seconds()) < 1:
|
||||
actual_candle = candle
|
||||
break
|
||||
|
||||
if not actual_candle:
|
||||
continue # Actual candle not available yet
|
||||
|
||||
# Extract actual OHLCV
|
||||
actual_ohlcv = [
|
||||
actual_candle['open'],
|
||||
actual_candle['high'],
|
||||
actual_candle['low'],
|
||||
actual_candle['close'],
|
||||
actual_candle.get('volume', 0)
|
||||
]
|
||||
|
||||
# Calculate accuracy
|
||||
errors = {
|
||||
'open': abs(predicted_ohlcv[0] - actual_ohlcv[0]),
|
||||
'high': abs(predicted_ohlcv[1] - actual_ohlcv[1]),
|
||||
'low': abs(predicted_ohlcv[2] - actual_ohlcv[2]),
|
||||
'close': abs(predicted_ohlcv[3] - actual_ohlcv[3]),
|
||||
'volume': abs(predicted_ohlcv[4] - actual_ohlcv[4])
|
||||
}
|
||||
|
||||
pct_errors = {
|
||||
'open': (errors['open'] / actual_ohlcv[0]) * 100 if actual_ohlcv[0] > 0 else 0,
|
||||
'high': (errors['high'] / actual_ohlcv[1]) * 100 if actual_ohlcv[1] > 0 else 0,
|
||||
'low': (errors['low'] / actual_ohlcv[2]) * 100 if actual_ohlcv[2] > 0 else 0,
|
||||
'close': (errors['close'] / actual_ohlcv[3]) * 100 if actual_ohlcv[3] > 0 else 0,
|
||||
}
|
||||
|
||||
avg_pct_error = (pct_errors['open'] + pct_errors['high'] + pct_errors['low'] + pct_errors['close']) / 4
|
||||
accuracy = max(0, 100 - avg_pct_error)
|
||||
|
||||
# Check direction correctness
|
||||
pred_direction = 'up' if predicted_ohlcv[3] >= predicted_ohlcv[0] else 'down'
|
||||
actual_direction = 'up' if actual_ohlcv[3] >= actual_ohlcv[0] else 'down'
|
||||
direction_correct = pred_direction == actual_direction
|
||||
|
||||
logger.info(f"Validated {timeframe} prediction: accuracy={accuracy:.1f}%, direction_correct={direction_correct}")
|
||||
|
||||
# Train on this validated prediction
|
||||
self._train_transformer_on_validated_prediction(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
predicted_ohlcv=predicted_ohlcv,
|
||||
actual_ohlcv=actual_ohlcv,
|
||||
accuracy=accuracy,
|
||||
direction_correct=direction_correct
|
||||
)
|
||||
|
||||
# Mark prediction as validated
|
||||
prediction['validated'] = True
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error validating {timeframe} prediction: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _validate_and_train_on_predictions: {e}", exc_info=True)
|
||||
|
||||
def _train_transformer_on_validated_prediction(self, symbol: str, timeframe: str,
|
||||
predicted_ohlcv: list, actual_ohlcv: list,
|
||||
accuracy: float, direction_correct: bool):
|
||||
"""
|
||||
Train transformer on validated prediction using backpropagation
|
||||
This implements online learning from prediction errors
|
||||
"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
|
||||
return
|
||||
|
||||
transformer = self.orchestrator.primary_transformer
|
||||
if not HAS_TORCH:
|
||||
return
|
||||
|
||||
# Calculate sample weight based on accuracy
|
||||
# Low accuracy = higher weight (learn more from mistakes)
|
||||
if accuracy < 50:
|
||||
sample_weight = 3.0
|
||||
elif accuracy < 70:
|
||||
sample_weight = 2.0
|
||||
elif accuracy < 85:
|
||||
sample_weight = 1.0
|
||||
else:
|
||||
sample_weight = 0.5
|
||||
|
||||
if not direction_correct:
|
||||
sample_weight *= 1.5 # Wrong direction is critical
|
||||
|
||||
logger.info(f"[{timeframe}] Training on validated prediction: accuracy={accuracy:.1f}%, weight={sample_weight:.1f}x")
|
||||
|
||||
# Get market state for training
|
||||
market_data_1m = self.data_provider.get_ohlcv(symbol, '1m', limit=150)
|
||||
if not market_data_1m or len(market_data_1m) < 10:
|
||||
return
|
||||
|
||||
# Prepare training batch
|
||||
device = next(transformer.parameters()).device
|
||||
transformer.train()
|
||||
|
||||
def ohlcv_to_tensor(data, limit=None):
|
||||
if not data:
|
||||
return None
|
||||
data = data[-limit:] if limit and len(data) > limit else data
|
||||
arr = np.array([[d['open'], d['high'], d['low'], d['close'], d['volume']] for d in data], dtype=np.float32)
|
||||
return torch.from_numpy(arr).unsqueeze(0).to(device)
|
||||
|
||||
# Create input tensors
|
||||
price_data_1s = self.data_provider.get_ohlcv(symbol, '1s', limit=200)
|
||||
price_data_1h = self.data_provider.get_ohlcv(symbol, '1h', limit=24)
|
||||
price_data_1d = self.data_provider.get_ohlcv(symbol, '1d', limit=14)
|
||||
btc_data_1m = self.data_provider.get_ohlcv('BTC/USDT', '1m', limit=150)
|
||||
|
||||
inputs = {
|
||||
'price_data_1s': ohlcv_to_tensor(price_data_1s, 200),
|
||||
'price_data_1m': ohlcv_to_tensor(market_data_1m, 150),
|
||||
'price_data_1h': ohlcv_to_tensor(price_data_1h, 24),
|
||||
'price_data_1d': ohlcv_to_tensor(price_data_1d, 14),
|
||||
'btc_data_1m': ohlcv_to_tensor(btc_data_1m, 150)
|
||||
}
|
||||
|
||||
# Forward pass
|
||||
outputs = transformer(**inputs)
|
||||
|
||||
# Get predicted candle for this timeframe
|
||||
next_candles = outputs.get('next_candles', {})
|
||||
if timeframe not in next_candles:
|
||||
return
|
||||
|
||||
pred_tensor = next_candles[timeframe] # [batch, 5]
|
||||
actual_tensor = torch.tensor([actual_ohlcv], dtype=torch.float32, device=device) # [batch, 5]
|
||||
|
||||
# Calculate loss
|
||||
criterion = torch.nn.MSELoss()
|
||||
loss = criterion(pred_tensor, actual_tensor) * sample_weight
|
||||
|
||||
# Backpropagation
|
||||
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
logger.info(f"[{timeframe}] Backpropagation complete: loss={loss.item():.6f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training on validated prediction: {e}", exc_info=True)
|
||||
|
||||
def _add_signals_to_mini_chart(self, fig: go.Figure, symbol: str, ws_data_1s: pd.DataFrame, row: int = 2):
|
||||
"""Add signals to the 1s mini chart - LIMITED TO PRICE DATA TIME RANGE"""
|
||||
try:
|
||||
@@ -6247,11 +6563,11 @@ class CleanTradingDashboard:
|
||||
'type': 'cnn_pivot'
|
||||
}
|
||||
|
||||
# Get latest Transformer prediction
|
||||
# Get latest Transformer prediction with next_candles for ghost candle display
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer'):
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'get_latest_transformer_prediction'):
|
||||
transformer_pred = self.orchestrator.get_latest_transformer_prediction()
|
||||
# Get live prediction with next_candles
|
||||
transformer_pred = self._get_live_transformer_prediction_with_next_candles('ETH/USDT')
|
||||
if transformer_pred:
|
||||
latest_predictions['transformer'] = {
|
||||
'timestamp': transformer_pred.get('timestamp', datetime.now()),
|
||||
@@ -6259,8 +6575,11 @@ class CleanTradingDashboard:
|
||||
'confidence': transformer_pred.get('confidence', 0),
|
||||
'predicted_price': transformer_pred.get('predicted_price', 0),
|
||||
'price_change': transformer_pred.get('price_change', 0),
|
||||
'type': 'transformer_prediction'
|
||||
'type': 'transformer_prediction',
|
||||
# Add predicted_candle data for ghost candle display
|
||||
'predicted_candle': transformer_pred.get('next_candles', {})
|
||||
}
|
||||
logger.debug(f"Sent transformer prediction with {len(transformer_pred.get('next_candles', {}))} timeframe candles to frontend")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting transformer prediction: {e}")
|
||||
|
||||
@@ -8036,13 +8355,13 @@ class CleanTradingDashboard:
|
||||
logger.info("=" * 60)
|
||||
logger.info(" SESSION CLEAR COMPLETED SUCCESSFULLY")
|
||||
logger.info("=" * 60)
|
||||
logger.info("📊 Session P&L reset to $0.00")
|
||||
logger.info("📈 All positions closed")
|
||||
logger.info("📋 Trade history cleared")
|
||||
logger.info("🎯 Success rate calculations reset")
|
||||
logger.info("📈 Model performance metrics reset")
|
||||
logger.info("Session P&L reset to $0.00")
|
||||
logger.info("All positions closed")
|
||||
logger.info("Trade history cleared")
|
||||
logger.info("Success rate calculations reset")
|
||||
logger.info("Model performance metrics reset")
|
||||
logger.info(" All caches cleared")
|
||||
logger.info("📁 Trade log files cleared")
|
||||
logger.info("Trade log files cleared")
|
||||
logger.info("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
@@ -8277,7 +8596,7 @@ class CleanTradingDashboard:
|
||||
self.trading_executor._last_stats_update = None
|
||||
|
||||
logger.info(" Trading executor state cleared completely")
|
||||
logger.info("📊 Success rate calculations will start fresh")
|
||||
logger.info("Success rate calculations will start fresh")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing trading executor state: {e}")
|
||||
@@ -8319,13 +8638,13 @@ class CleanTradingDashboard:
|
||||
try:
|
||||
# Store Decision Fusion model
|
||||
if hasattr(self.orchestrator, 'decision_fusion_network') and self.orchestrator.decision_fusion_network:
|
||||
logger.info("💾 Storing Decision Fusion model...")
|
||||
logger.info("Storing Decision Fusion model...")
|
||||
# Add storage logic here
|
||||
except Exception as e:
|
||||
logger.warning(f" Failed to store Decision Fusion model: {e}")
|
||||
|
||||
# 5. Verification Step - Try to load checkpoints to verify they work
|
||||
logger.info("🔍 Verifying stored checkpoints...")
|
||||
logger.info("Verifying stored checkpoints...")
|
||||
|
||||
for model_name, checkpoint_path in stored_models:
|
||||
try:
|
||||
@@ -8379,7 +8698,7 @@ class CleanTradingDashboard:
|
||||
with open(metadata_path, 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
logger.info(f"📋 Stored session metadata: {metadata_path}")
|
||||
logger.info(f"Stored session metadata: {metadata_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store metadata: {e}")
|
||||
@@ -8388,7 +8707,7 @@ class CleanTradingDashboard:
|
||||
if hasattr(self.orchestrator, '_save_ui_state'):
|
||||
try:
|
||||
self.orchestrator._save_ui_state()
|
||||
logger.info("💾 Saved orchestrator UI state")
|
||||
logger.info("Saved orchestrator UI state")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save UI state: {e}")
|
||||
|
||||
@@ -8397,10 +8716,10 @@ class CleanTradingDashboard:
|
||||
successful_verifications = len([r for r in verification_results if r[1]])
|
||||
|
||||
if stored_models:
|
||||
logger.info(f"📊 STORAGE SUMMARY:")
|
||||
logger.info(f"STORAGE SUMMARY:")
|
||||
logger.info(f" Models stored: {successful_stores}")
|
||||
logger.info(f" Verifications passed: {successful_verifications}/{len(verification_results)}")
|
||||
logger.info(f" 📋 Models: {[name for name, _ in stored_models]}")
|
||||
logger.info(f" Models: {[name for name, _ in stored_models]}")
|
||||
|
||||
# Update button display with success info
|
||||
return True
|
||||
@@ -9290,7 +9609,7 @@ class CleanTradingDashboard:
|
||||
self.cob_cache[symbol]['websocket_status'] = websocket_status
|
||||
self.cob_cache[symbol]['source'] = source
|
||||
|
||||
logger.debug(f"📊 Enhanced COB update for {symbol}: {websocket_status} via {source}")
|
||||
logger.debug(f"Enhanced COB update for {symbol}: {websocket_status} via {source}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" Error handling enhanced COB update for {symbol}: {e}")
|
||||
|
||||
@@ -52,7 +52,7 @@ class DashboardComponentManager:
|
||||
# Determine signal style
|
||||
if executed:
|
||||
badge_class = "bg-success"
|
||||
status = "✓"
|
||||
status = "OK"
|
||||
elif blocked:
|
||||
badge_class = "bg-danger"
|
||||
status = "✗"
|
||||
|
||||
@@ -284,7 +284,7 @@ class PredictionChartComponent:
|
||||
model_stats = prediction_stats.get('models', [])
|
||||
|
||||
return html.Div([
|
||||
html.H4("📊 Prediction Tracking & Performance", className="mb-3"),
|
||||
html.H4("Prediction Tracking & Performance", className="mb-3"),
|
||||
|
||||
# Summary cards
|
||||
html.Div([
|
||||
@@ -337,7 +337,7 @@ class PredictionChartComponent:
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating prediction panel: {e}")
|
||||
return html.Div([
|
||||
html.H4("📊 Prediction Tracking & Performance"),
|
||||
html.H4("Prediction Tracking & Performance"),
|
||||
html.P(f"Error loading prediction data: {str(e)}", className="text-danger")
|
||||
])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user