#!/usr/bin/env python3 """ Comprehensive Training Integration Tests This module consolidates and improves test functionality from multiple test files: - CNN training tests (from test_cnn_only.py, test_training.py) - Model testing (from test_model.py) - Chart data testing (from test_chart_data.py) - Integration testing between components """ import sys import os import logging import time import unittest import tempfile from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) from core.config import setup_logging, get_config from core.data_provider import DataProvider from training.cnn_trainer import CNNTrainer from training.rl_trainer import RLTrainer from dataprovider_realtime import RealTimeChart, TickStorage, BinanceHistoricalData logger = logging.getLogger(__name__) class TestDataProviders(unittest.TestCase): """Test suite for data provider functionality""" def test_binance_historical_data(self): """Test Binance historical data fetching""" logger.info("Testing Binance historical data fetch...") try: binance_data = BinanceHistoricalData() df = binance_data.get_historical_candles("ETH/USDT", 60, 100) self.assertIsNotNone(df, "Should fetch data successfully") self.assertFalse(df.empty, "Data should not be empty") self.assertGreater(len(df), 0, "Should have candles") # Verify data structure required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] for col in required_columns: self.assertIn(col, df.columns, f"Should have {col} column") logger.info(f"✅ Successfully fetched {len(df)} candles") return True except Exception as e: logger.warning(f"Binance API test failed: {e}") self.skipTest("Binance API not available") def test_tick_storage(self): """Test TickStorage functionality""" logger.info("Testing TickStorage data loading...") try: tick_storage = TickStorage("ETH/USDT", ["1m", "5m", "1h"]) success = tick_storage.load_historical_data("ETH/USDT", limit=100) if success: # Check timeframes for tf in ["1m", "5m", "1h"]: candles = tick_storage.get_candles(tf) logger.info(f" {tf}: {len(candles)} candles") logger.info("✅ TickStorage working correctly") return True else: self.skipTest("Could not load tick storage data") except Exception as e: logger.warning(f"TickStorage test failed: {e}") self.skipTest("TickStorage not available") def test_chart_initialization(self): """Test RealTimeChart initialization""" logger.info("Testing RealTimeChart initialization...") try: chart = RealTimeChart(app=None, symbol="ETH/USDT", standalone=False) # Test getting candles candles_1m = chart.get_candles(60) self.assertIsInstance(candles_1m, list, "Should return list of candles") logger.info(f"✅ Chart initialized with {len(candles_1m)} 1m candles") except Exception as e: logger.warning(f"Chart initialization failed: {e}") self.skipTest("Chart initialization not available") class TestCNNTraining(unittest.TestCase): """Test suite for CNN training functionality""" def setUp(self): """Set up test fixtures""" self.temp_dir = tempfile.mkdtemp() setup_logging() def tearDown(self): """Clean up test fixtures""" import shutil shutil.rmtree(self.temp_dir, ignore_errors=True) def test_cnn_quick_training(self): """Test quick CNN training with small dataset""" logger.info("Testing CNN quick training...") try: config = get_config() # Test configuration symbols = ['ETH/USDT'] timeframes = ['1m', '5m'] num_samples = 100 # Very small for testing epochs = 1 batch_size = 16 # Override config for quick test config._config['timeframes'] = timeframes trainer = CNNTrainer(config) trainer.batch_size = batch_size trainer.epochs = epochs # Train model save_path = os.path.join(self.temp_dir, 'test_cnn.pt') results = trainer.train(symbols, save_path=save_path, num_samples=num_samples) # Verify results self.assertIsInstance(results, dict, "Should return results dict") self.assertIn('best_val_accuracy', results, "Should have accuracy metric") self.assertIn('total_epochs', results, "Should have epoch count") self.assertIn('training_time', results, "Should have training time") # Verify model was saved self.assertTrue(os.path.exists(save_path), "Model should be saved") logger.info(f"✅ CNN training completed successfully") logger.info(f" Best accuracy: {results['best_val_accuracy']:.4f}") logger.info(f" Training time: {results['training_time']:.2f}s") except Exception as e: logger.error(f"CNN training test failed: {e}") raise finally: if hasattr(trainer, 'close_tensorboard'): trainer.close_tensorboard() class TestRLTraining(unittest.TestCase): """Test suite for RL training functionality""" def setUp(self): """Set up test fixtures""" self.temp_dir = tempfile.mkdtemp() setup_logging() def tearDown(self): """Clean up test fixtures""" import shutil shutil.rmtree(self.temp_dir, ignore_errors=True) def test_rl_quick_training(self): """Test quick RL training with small dataset""" logger.info("Testing RL quick training...") try: # Setup minimal configuration data_provider = DataProvider(['ETH/USDT'], ['1m', '5m']) trainer = RLTrainer(data_provider) # Configure for very quick test trainer.num_episodes = 5 trainer.max_steps_per_episode = 50 trainer.evaluation_frequency = 3 trainer.save_frequency = 10 # Don't save during test # Train save_path = os.path.join(self.temp_dir, 'test_rl.pt') results = trainer.train(save_path=save_path) # Verify results self.assertIsInstance(results, dict, "Should return results dict") self.assertIn('total_episodes', results, "Should have episode count") self.assertIn('best_reward', results, "Should have best reward") self.assertIn('final_evaluation', results, "Should have final evaluation") logger.info(f"✅ RL training completed successfully") logger.info(f" Total episodes: {results['total_episodes']}") logger.info(f" Best reward: {results['best_reward']:.4f}") except Exception as e: logger.error(f"RL training test failed: {e}") raise class TestExtendedTraining(unittest.TestCase): """Test suite for extended training functionality (from test_model.py)""" def test_metrics_tracking(self): """Test comprehensive metrics tracking functionality""" logger.info("Testing extended metrics tracking...") # Test metrics history structure metrics_history = { "epoch": [], "train_loss": [], "val_loss": [], "train_acc": [], "val_acc": [], "train_pnl": [], "val_pnl": [], "train_win_rate": [], "val_win_rate": [], "signal_distribution": [] } # Simulate adding metrics for epoch in range(3): metrics_history["epoch"].append(epoch + 1) metrics_history["train_loss"].append(0.5 - epoch * 0.1) metrics_history["val_loss"].append(0.6 - epoch * 0.1) metrics_history["train_acc"].append(0.6 + epoch * 0.05) metrics_history["val_acc"].append(0.55 + epoch * 0.05) metrics_history["train_pnl"].append(epoch * 0.1) metrics_history["val_pnl"].append(epoch * 0.08) metrics_history["train_win_rate"].append(0.5 + epoch * 0.1) metrics_history["val_win_rate"].append(0.45 + epoch * 0.1) metrics_history["signal_distribution"].append({ "BUY": 0.3, "SELL": 0.3, "HOLD": 0.4 }) # Verify structure self.assertEqual(len(metrics_history["epoch"]), 3) self.assertEqual(len(metrics_history["train_loss"]), 3) self.assertEqual(len(metrics_history["signal_distribution"]), 3) # Verify improvement self.assertLess(metrics_history["train_loss"][-1], metrics_history["train_loss"][0]) self.assertGreater(metrics_history["train_acc"][-1], metrics_history["train_acc"][0]) logger.info("✅ Metrics tracking test passed") def test_signal_distribution_calculation(self): """Test signal distribution calculation""" import numpy as np # Mock predictions (SELL=0, HOLD=1, BUY=2) predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0]) buy_count = np.sum(predictions == 2) sell_count = np.sum(predictions == 0) hold_count = np.sum(predictions == 1) total = len(predictions) distribution = { "BUY": buy_count / total, "SELL": sell_count / total, "HOLD": hold_count / total } # Verify calculations self.assertAlmostEqual(distribution["BUY"], 0.3, places=2) self.assertAlmostEqual(distribution["SELL"], 0.3, places=2) self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2) self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2) logger.info("✅ Signal distribution calculation test passed") class TestIntegration(unittest.TestCase): """Integration tests between components""" def test_training_pipeline_integration(self): """Test that CNN and RL training can work together""" logger.info("Testing training pipeline integration...") with tempfile.TemporaryDirectory() as temp_dir: try: # Quick CNN training config = get_config() config._config['timeframes'] = ['1m'] cnn_trainer = CNNTrainer(config) cnn_trainer.epochs = 1 cnn_trainer.batch_size = 8 cnn_path = os.path.join(temp_dir, 'test_cnn.pt') cnn_results = cnn_trainer.train(['ETH/USDT'], save_path=cnn_path, num_samples=50) # Quick RL training data_provider = DataProvider(['ETH/USDT'], ['1m']) rl_trainer = RLTrainer(data_provider) rl_trainer.num_episodes = 3 rl_trainer.max_steps_per_episode = 25 rl_path = os.path.join(temp_dir, 'test_rl.pt') rl_results = rl_trainer.train(save_path=rl_path) # Verify both trained successfully self.assertIsInstance(cnn_results, dict) self.assertIsInstance(rl_results, dict) self.assertTrue(os.path.exists(cnn_path)) self.assertTrue(os.path.exists(rl_path)) logger.info("✅ Training pipeline integration test passed") except Exception as e: logger.error(f"Integration test failed: {e}") raise finally: if 'cnn_trainer' in locals(): cnn_trainer.close_tensorboard() def run_quick_tests(): """Run only the quickest tests for fast validation""" test_suites = [ unittest.TestLoader().loadTestsFromTestCase(TestExtendedTraining), ] combined_suite = unittest.TestSuite(test_suites) runner = unittest.TextTestRunner(verbosity=2) result = runner.run(combined_suite) return result.wasSuccessful() def run_data_tests(): """Run data provider tests""" test_suites = [ unittest.TestLoader().loadTestsFromTestCase(TestDataProviders), ] combined_suite = unittest.TestSuite(test_suites) runner = unittest.TextTestRunner(verbosity=2) result = runner.run(combined_suite) return result.wasSuccessful() def run_training_tests(): """Run training tests (slower)""" test_suites = [ unittest.TestLoader().loadTestsFromTestCase(TestCNNTraining), unittest.TestLoader().loadTestsFromTestCase(TestRLTraining), ] combined_suite = unittest.TestSuite(test_suites) runner = unittest.TextTestRunner(verbosity=2) result = runner.run(combined_suite) return result.wasSuccessful() def run_all_tests(): """Run all test suites""" test_suites = [ unittest.TestLoader().loadTestsFromTestCase(TestDataProviders), unittest.TestLoader().loadTestsFromTestCase(TestCNNTraining), unittest.TestLoader().loadTestsFromTestCase(TestRLTraining), unittest.TestLoader().loadTestsFromTestCase(TestExtendedTraining), unittest.TestLoader().loadTestsFromTestCase(TestIntegration), ] combined_suite = unittest.TestSuite(test_suites) runner = unittest.TextTestRunner(verbosity=2) result = runner.run(combined_suite) return result.wasSuccessful() if __name__ == "__main__": setup_logging() logger.info("Running comprehensive training integration tests...") if len(sys.argv) > 1: test_type = sys.argv[1] if test_type == "quick": success = run_quick_tests() elif test_type == "data": success = run_data_tests() elif test_type == "training": success = run_training_tests() else: success = run_all_tests() else: success = run_all_tests() if success: logger.info("✅ All tests passed!") sys.exit(0) else: logger.error("❌ Some tests failed!") sys.exit(1)