wip cnn training and cob

This commit is contained in:
Dobromir Popov
2025-07-23 23:33:36 +03:00
parent 8677c4c01c
commit 5437495003
4 changed files with 599 additions and 210 deletions

175
test_cnn_integration.py Normal file
View File

@ -0,0 +1,175 @@
#!/usr/bin/env python3
"""
Test CNN Integration
This script tests if the CNN adapter is working properly and identifies issues.
"""
import logging
import sys
import os
from datetime import datetime
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_cnn_adapter():
"""Test CNN adapter initialization and basic functionality"""
try:
logger.info("Testing CNN adapter initialization...")
# Test 1: Import CNN adapter
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
logger.info("✅ CNN adapter import successful")
# Test 2: Initialize CNN adapter
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
logger.info("✅ CNN adapter initialization successful")
# Test 3: Check adapter attributes
logger.info(f"CNN adapter model: {cnn_adapter.model}")
logger.info(f"CNN adapter device: {cnn_adapter.device}")
logger.info(f"CNN adapter model_name: {cnn_adapter.model_name}")
# Test 4: Check metrics tracking
logger.info(f"Inference count: {cnn_adapter.inference_count}")
logger.info(f"Training count: {cnn_adapter.training_count}")
logger.info(f"Training data length: {len(cnn_adapter.training_data)}")
# Test 5: Test simple training sample addition
cnn_adapter.add_training_sample("ETH/USDT", "BUY", 0.1)
logger.info(f"✅ Training sample added, new length: {len(cnn_adapter.training_data)}")
# Test 6: Test training if we have enough samples
if len(cnn_adapter.training_data) >= 2:
# Add another sample to have minimum for training
cnn_adapter.add_training_sample("ETH/USDT", "SELL", -0.05)
# Try training
training_result = cnn_adapter.train(epochs=1)
logger.info(f"✅ Training successful: {training_result}")
# Check if metrics were updated
logger.info(f"Last training time: {cnn_adapter.last_training_time}")
logger.info(f"Last training loss: {cnn_adapter.last_training_loss}")
logger.info(f"Training count: {cnn_adapter.training_count}")
else:
logger.info("⚠️ Not enough training samples for training test")
return True
except Exception as e:
logger.error(f"❌ CNN adapter test failed: {e}")
import traceback
traceback.print_exc()
return False
def test_base_data_input():
"""Test BaseDataInput creation"""
try:
logger.info("Testing BaseDataInput creation...")
# Test 1: Import BaseDataInput
from core.data_models import BaseDataInput, OHLCVBar, COBData
logger.info("✅ BaseDataInput import successful")
# Test 2: Create sample OHLCV bars
sample_bars = []
for i in range(10): # Create 10 sample bars
bar = OHLCVBar(
symbol="ETH/USDT",
timestamp=datetime.now(),
open=3500.0 + i,
high=3510.0 + i,
low=3490.0 + i,
close=3505.0 + i,
volume=1000.0,
timeframe="1s"
)
sample_bars.append(bar)
logger.info(f"✅ Created {len(sample_bars)} sample OHLCV bars")
# Test 3: Create BaseDataInput
base_data = BaseDataInput(
symbol="ETH/USDT",
timestamp=datetime.now(),
ohlcv_1s=sample_bars,
ohlcv_1m=sample_bars,
ohlcv_1h=sample_bars,
ohlcv_1d=sample_bars,
btc_ohlcv_1s=sample_bars
)
logger.info("✅ BaseDataInput created successfully")
# Test 4: Validate BaseDataInput
is_valid = base_data.validate()
logger.info(f"BaseDataInput validation: {is_valid}")
# Test 5: Get feature vector
feature_vector = base_data.get_feature_vector()
logger.info(f"✅ Feature vector created, shape: {feature_vector.shape}")
return base_data
except Exception as e:
logger.error(f"❌ BaseDataInput test failed: {e}")
import traceback
traceback.print_exc()
return None
def test_cnn_prediction():
"""Test CNN prediction with BaseDataInput"""
try:
logger.info("Testing CNN prediction...")
# Get CNN adapter and base data
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
base_data = test_base_data_input()
if not base_data:
logger.error("❌ Cannot test prediction without valid BaseDataInput")
return False
# Test prediction
model_output = cnn_adapter.predict(base_data)
logger.info(f"✅ Prediction successful: {model_output.predictions['action']} ({model_output.confidence:.3f})")
# Check if metrics were updated
logger.info(f"Inference count after prediction: {cnn_adapter.inference_count}")
logger.info(f"Last inference time: {cnn_adapter.last_inference_time}")
logger.info(f"Last prediction output: {cnn_adapter.last_prediction_output}")
return True
except Exception as e:
logger.error(f"❌ CNN prediction test failed: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""Run all tests"""
logger.info("🧪 Starting CNN Integration Tests")
# Test 1: CNN Adapter
if not test_cnn_adapter():
logger.error("❌ CNN adapter test failed - stopping")
return False
# Test 2: CNN Prediction
if not test_cnn_prediction():
logger.error("❌ CNN prediction test failed - stopping")
return False
logger.info("✅ All CNN integration tests passed!")
logger.info("🎯 The CNN adapter should now work properly in the dashboard")
return True
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)