gogo2/tests/test_training_integration.py
2025-05-24 10:32:00 +03:00

395 lines
15 KiB
Python

#!/usr/bin/env python3
"""
Comprehensive Training Integration Tests
This module consolidates and improves test functionality from multiple test files:
- CNN training tests (from test_cnn_only.py, test_training.py)
- Model testing (from test_model.py)
- Chart data testing (from test_chart_data.py)
- Integration testing between components
"""
import sys
import os
import logging
import time
import unittest
import tempfile
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from core.config import setup_logging, get_config
from core.data_provider import DataProvider
from training.cnn_trainer import CNNTrainer
from training.rl_trainer import RLTrainer
from dataprovider_realtime import RealTimeChart, TickStorage, BinanceHistoricalData
logger = logging.getLogger(__name__)
class TestDataProviders(unittest.TestCase):
"""Test suite for data provider functionality"""
def test_binance_historical_data(self):
"""Test Binance historical data fetching"""
logger.info("Testing Binance historical data fetch...")
try:
binance_data = BinanceHistoricalData()
df = binance_data.get_historical_candles("ETH/USDT", 60, 100)
self.assertIsNotNone(df, "Should fetch data successfully")
self.assertFalse(df.empty, "Data should not be empty")
self.assertGreater(len(df), 0, "Should have candles")
# Verify data structure
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
for col in required_columns:
self.assertIn(col, df.columns, f"Should have {col} column")
logger.info(f"✅ Successfully fetched {len(df)} candles")
return True
except Exception as e:
logger.warning(f"Binance API test failed: {e}")
self.skipTest("Binance API not available")
def test_tick_storage(self):
"""Test TickStorage functionality"""
logger.info("Testing TickStorage data loading...")
try:
tick_storage = TickStorage("ETH/USDT", ["1m", "5m", "1h"])
success = tick_storage.load_historical_data("ETH/USDT", limit=100)
if success:
# Check timeframes
for tf in ["1m", "5m", "1h"]:
candles = tick_storage.get_candles(tf)
logger.info(f" {tf}: {len(candles)} candles")
logger.info("✅ TickStorage working correctly")
return True
else:
self.skipTest("Could not load tick storage data")
except Exception as e:
logger.warning(f"TickStorage test failed: {e}")
self.skipTest("TickStorage not available")
def test_chart_initialization(self):
"""Test RealTimeChart initialization"""
logger.info("Testing RealTimeChart initialization...")
try:
chart = RealTimeChart(app=None, symbol="ETH/USDT", standalone=False)
# Test getting candles
candles_1m = chart.get_candles(60)
self.assertIsInstance(candles_1m, list, "Should return list of candles")
logger.info(f"✅ Chart initialized with {len(candles_1m)} 1m candles")
except Exception as e:
logger.warning(f"Chart initialization failed: {e}")
self.skipTest("Chart initialization not available")
class TestCNNTraining(unittest.TestCase):
"""Test suite for CNN training functionality"""
def setUp(self):
"""Set up test fixtures"""
self.temp_dir = tempfile.mkdtemp()
setup_logging()
def tearDown(self):
"""Clean up test fixtures"""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_cnn_quick_training(self):
"""Test quick CNN training with small dataset"""
logger.info("Testing CNN quick training...")
try:
config = get_config()
# Test configuration
symbols = ['ETH/USDT']
timeframes = ['1m', '5m']
num_samples = 100 # Very small for testing
epochs = 1
batch_size = 16
# Override config for quick test
config._config['timeframes'] = timeframes
trainer = CNNTrainer(config)
trainer.batch_size = batch_size
trainer.epochs = epochs
# Train model
save_path = os.path.join(self.temp_dir, 'test_cnn.pt')
results = trainer.train(symbols, save_path=save_path, num_samples=num_samples)
# Verify results
self.assertIsInstance(results, dict, "Should return results dict")
self.assertIn('best_val_accuracy', results, "Should have accuracy metric")
self.assertIn('total_epochs', results, "Should have epoch count")
self.assertIn('training_time', results, "Should have training time")
# Verify model was saved
self.assertTrue(os.path.exists(save_path), "Model should be saved")
logger.info(f"✅ CNN training completed successfully")
logger.info(f" Best accuracy: {results['best_val_accuracy']:.4f}")
logger.info(f" Training time: {results['training_time']:.2f}s")
except Exception as e:
logger.error(f"CNN training test failed: {e}")
raise
finally:
if hasattr(trainer, 'close_tensorboard'):
trainer.close_tensorboard()
class TestRLTraining(unittest.TestCase):
"""Test suite for RL training functionality"""
def setUp(self):
"""Set up test fixtures"""
self.temp_dir = tempfile.mkdtemp()
setup_logging()
def tearDown(self):
"""Clean up test fixtures"""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_rl_quick_training(self):
"""Test quick RL training with small dataset"""
logger.info("Testing RL quick training...")
try:
# Setup minimal configuration
data_provider = DataProvider(['ETH/USDT'], ['1m', '5m'])
trainer = RLTrainer(data_provider)
# Configure for very quick test
trainer.num_episodes = 5
trainer.max_steps_per_episode = 50
trainer.evaluation_frequency = 3
trainer.save_frequency = 10 # Don't save during test
# Train
save_path = os.path.join(self.temp_dir, 'test_rl.pt')
results = trainer.train(save_path=save_path)
# Verify results
self.assertIsInstance(results, dict, "Should return results dict")
self.assertIn('total_episodes', results, "Should have episode count")
self.assertIn('best_reward', results, "Should have best reward")
self.assertIn('final_evaluation', results, "Should have final evaluation")
logger.info(f"✅ RL training completed successfully")
logger.info(f" Total episodes: {results['total_episodes']}")
logger.info(f" Best reward: {results['best_reward']:.4f}")
except Exception as e:
logger.error(f"RL training test failed: {e}")
raise
class TestExtendedTraining(unittest.TestCase):
"""Test suite for extended training functionality (from test_model.py)"""
def test_metrics_tracking(self):
"""Test comprehensive metrics tracking functionality"""
logger.info("Testing extended metrics tracking...")
# Test metrics history structure
metrics_history = {
"epoch": [],
"train_loss": [],
"val_loss": [],
"train_acc": [],
"val_acc": [],
"train_pnl": [],
"val_pnl": [],
"train_win_rate": [],
"val_win_rate": [],
"signal_distribution": []
}
# Simulate adding metrics
for epoch in range(3):
metrics_history["epoch"].append(epoch + 1)
metrics_history["train_loss"].append(0.5 - epoch * 0.1)
metrics_history["val_loss"].append(0.6 - epoch * 0.1)
metrics_history["train_acc"].append(0.6 + epoch * 0.05)
metrics_history["val_acc"].append(0.55 + epoch * 0.05)
metrics_history["train_pnl"].append(epoch * 0.1)
metrics_history["val_pnl"].append(epoch * 0.08)
metrics_history["train_win_rate"].append(0.5 + epoch * 0.1)
metrics_history["val_win_rate"].append(0.45 + epoch * 0.1)
metrics_history["signal_distribution"].append({
"BUY": 0.3, "SELL": 0.3, "HOLD": 0.4
})
# Verify structure
self.assertEqual(len(metrics_history["epoch"]), 3)
self.assertEqual(len(metrics_history["train_loss"]), 3)
self.assertEqual(len(metrics_history["signal_distribution"]), 3)
# Verify improvement
self.assertLess(metrics_history["train_loss"][-1], metrics_history["train_loss"][0])
self.assertGreater(metrics_history["train_acc"][-1], metrics_history["train_acc"][0])
logger.info("✅ Metrics tracking test passed")
def test_signal_distribution_calculation(self):
"""Test signal distribution calculation"""
import numpy as np
# Mock predictions (SELL=0, HOLD=1, BUY=2)
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0])
buy_count = np.sum(predictions == 2)
sell_count = np.sum(predictions == 0)
hold_count = np.sum(predictions == 1)
total = len(predictions)
distribution = {
"BUY": buy_count / total,
"SELL": sell_count / total,
"HOLD": hold_count / total
}
# Verify calculations
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
logger.info("✅ Signal distribution calculation test passed")
class TestIntegration(unittest.TestCase):
"""Integration tests between components"""
def test_training_pipeline_integration(self):
"""Test that CNN and RL training can work together"""
logger.info("Testing training pipeline integration...")
with tempfile.TemporaryDirectory() as temp_dir:
try:
# Quick CNN training
config = get_config()
config._config['timeframes'] = ['1m']
cnn_trainer = CNNTrainer(config)
cnn_trainer.epochs = 1
cnn_trainer.batch_size = 8
cnn_path = os.path.join(temp_dir, 'test_cnn.pt')
cnn_results = cnn_trainer.train(['ETH/USDT'], save_path=cnn_path, num_samples=50)
# Quick RL training
data_provider = DataProvider(['ETH/USDT'], ['1m'])
rl_trainer = RLTrainer(data_provider)
rl_trainer.num_episodes = 3
rl_trainer.max_steps_per_episode = 25
rl_path = os.path.join(temp_dir, 'test_rl.pt')
rl_results = rl_trainer.train(save_path=rl_path)
# Verify both trained successfully
self.assertIsInstance(cnn_results, dict)
self.assertIsInstance(rl_results, dict)
self.assertTrue(os.path.exists(cnn_path))
self.assertTrue(os.path.exists(rl_path))
logger.info("✅ Training pipeline integration test passed")
except Exception as e:
logger.error(f"Integration test failed: {e}")
raise
finally:
if 'cnn_trainer' in locals():
cnn_trainer.close_tensorboard()
def run_quick_tests():
"""Run only the quickest tests for fast validation"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestExtendedTraining),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
def run_data_tests():
"""Run data provider tests"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestDataProviders),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
def run_training_tests():
"""Run training tests (slower)"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestCNNTraining),
unittest.TestLoader().loadTestsFromTestCase(TestRLTraining),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
def run_all_tests():
"""Run all test suites"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestDataProviders),
unittest.TestLoader().loadTestsFromTestCase(TestCNNTraining),
unittest.TestLoader().loadTestsFromTestCase(TestRLTraining),
unittest.TestLoader().loadTestsFromTestCase(TestExtendedTraining),
unittest.TestLoader().loadTestsFromTestCase(TestIntegration),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
if __name__ == "__main__":
setup_logging()
logger.info("Running comprehensive training integration tests...")
if len(sys.argv) > 1:
test_type = sys.argv[1]
if test_type == "quick":
success = run_quick_tests()
elif test_type == "data":
success = run_data_tests()
elif test_type == "training":
success = run_training_tests()
else:
success = run_all_tests()
else:
success = run_all_tests()
if success:
logger.info("✅ All tests passed!")
sys.exit(0)
else:
logger.error("❌ Some tests failed!")
sys.exit(1)