folder stricture reorganize
This commit is contained in:
@ -1,395 +1,204 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive Training Integration Tests
|
||||
Test Training Integration with Dashboard
|
||||
|
||||
This module consolidates and improves test functionality from multiple test files:
|
||||
- CNN training tests (from test_cnn_only.py, test_training.py)
|
||||
- Model testing (from test_model.py)
|
||||
- Chart data testing (from test_chart_data.py)
|
||||
- Integration testing between components
|
||||
This script tests the enhanced dashboard's ability to:
|
||||
1. Stream training data to CNN and DQN models
|
||||
2. Display real-time training metrics and progress
|
||||
3. Show model learning curves and performance
|
||||
4. Integrate with the continuous training system
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
import unittest
|
||||
import tempfile
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import setup_logging, get_config
|
||||
from core.data_provider import DataProvider
|
||||
from training.cnn_trainer import CNNTrainer
|
||||
from training.rl_trainer import RLTrainer
|
||||
from dataprovider_realtime import RealTimeChart, TickStorage, BinanceHistoricalData
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TestDataProviders(unittest.TestCase):
|
||||
"""Test suite for data provider functionality"""
|
||||
|
||||
def test_binance_historical_data(self):
|
||||
"""Test Binance historical data fetching"""
|
||||
logger.info("Testing Binance historical data fetch...")
|
||||
def test_training_integration():
|
||||
"""Test the training integration functionality"""
|
||||
try:
|
||||
print("="*60)
|
||||
print("TESTING TRAINING INTEGRATION WITH DASHBOARD")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
binance_data = BinanceHistoricalData()
|
||||
df = binance_data.get_historical_candles("ETH/USDT", 60, 100)
|
||||
|
||||
self.assertIsNotNone(df, "Should fetch data successfully")
|
||||
self.assertFalse(df.empty, "Data should not be empty")
|
||||
self.assertGreater(len(df), 0, "Should have candles")
|
||||
|
||||
# Verify data structure
|
||||
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
for col in required_columns:
|
||||
self.assertIn(col, df.columns, f"Should have {col} column")
|
||||
|
||||
logger.info(f"✅ Successfully fetched {len(df)} candles")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Binance API test failed: {e}")
|
||||
self.skipTest("Binance API not available")
|
||||
|
||||
def test_tick_storage(self):
|
||||
"""Test TickStorage functionality"""
|
||||
logger.info("Testing TickStorage data loading...")
|
||||
# Import dashboard
|
||||
from web.dashboard import TradingDashboard
|
||||
from core.data_provider import DataProvider
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
|
||||
try:
|
||||
tick_storage = TickStorage("ETH/USDT", ["1m", "5m", "1h"])
|
||||
success = tick_storage.load_historical_data("ETH/USDT", limit=100)
|
||||
|
||||
if success:
|
||||
# Check timeframes
|
||||
for tf in ["1m", "5m", "1h"]:
|
||||
candles = tick_storage.get_candles(tf)
|
||||
logger.info(f" {tf}: {len(candles)} candles")
|
||||
|
||||
logger.info("✅ TickStorage working correctly")
|
||||
return True
|
||||
# Create components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider)
|
||||
dashboard = TradingDashboard(data_provider, orchestrator)
|
||||
|
||||
print(f"✓ Dashboard created with training integration")
|
||||
print(f"✓ Continuous training active: {getattr(dashboard, 'training_active', False)}")
|
||||
|
||||
# Test 1: Simulate tick data for training
|
||||
print("\n📊 TEST 1: Simulating Tick Data")
|
||||
print("-" * 40)
|
||||
|
||||
# Add simulated tick data to cache
|
||||
base_price = 3500.0
|
||||
for i in range(1000):
|
||||
tick_data = {
|
||||
'timestamp': datetime.now() - timedelta(seconds=1000-i),
|
||||
'price': base_price + (i % 100) * 0.1,
|
||||
'volume': 100 + (i % 50),
|
||||
'side': 'buy' if i % 2 == 0 else 'sell'
|
||||
}
|
||||
dashboard.tick_cache.append(tick_data)
|
||||
|
||||
print(f"✓ Added {len(dashboard.tick_cache)} ticks to cache")
|
||||
|
||||
# Test 2: Prepare training data
|
||||
print("\n🔄 TEST 2: Preparing Training Data")
|
||||
print("-" * 40)
|
||||
|
||||
training_data = dashboard._prepare_training_data()
|
||||
if training_data:
|
||||
print(f"✓ Training data prepared successfully")
|
||||
print(f" - OHLCV bars: {len(training_data['ohlcv'])}")
|
||||
print(f" - Features: {training_data['features']}")
|
||||
print(f" - Symbol: {training_data['symbol']}")
|
||||
else:
|
||||
print("❌ Failed to prepare training data")
|
||||
|
||||
# Test 3: Format data for CNN
|
||||
print("\n🧠 TEST 3: CNN Data Formatting")
|
||||
print("-" * 40)
|
||||
|
||||
if training_data:
|
||||
cnn_data = dashboard._format_data_for_cnn(training_data)
|
||||
if cnn_data and 'sequences' in cnn_data:
|
||||
print(f"✓ CNN data formatted successfully")
|
||||
print(f" - Sequences shape: {cnn_data['sequences'].shape}")
|
||||
print(f" - Targets shape: {cnn_data['targets'].shape}")
|
||||
print(f" - Sequence length: {cnn_data['sequence_length']}")
|
||||
else:
|
||||
self.skipTest("Could not load tick storage data")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"TickStorage test failed: {e}")
|
||||
self.skipTest("TickStorage not available")
|
||||
|
||||
def test_chart_initialization(self):
|
||||
"""Test RealTimeChart initialization"""
|
||||
logger.info("Testing RealTimeChart initialization...")
|
||||
print("❌ Failed to format CNN data")
|
||||
|
||||
# Test 4: Format data for RL
|
||||
print("\n🤖 TEST 4: RL Data Formatting")
|
||||
print("-" * 40)
|
||||
|
||||
if training_data:
|
||||
rl_experiences = dashboard._format_data_for_rl(training_data)
|
||||
if rl_experiences:
|
||||
print(f"✓ RL experiences formatted successfully")
|
||||
print(f" - Number of experiences: {len(rl_experiences)}")
|
||||
print(f" - Experience format: (state, action, reward, next_state, done)")
|
||||
print(f" - Sample experience shapes: {[len(exp) for exp in rl_experiences[:3]]}")
|
||||
else:
|
||||
print("❌ Failed to format RL experiences")
|
||||
|
||||
# Test 5: Send training data to models
|
||||
print("\n📤 TEST 5: Sending Training Data to Models")
|
||||
print("-" * 40)
|
||||
|
||||
success = dashboard.send_training_data_to_models()
|
||||
print(f"✓ Training data sent: {success}")
|
||||
|
||||
if hasattr(dashboard, 'training_stats'):
|
||||
stats = dashboard.training_stats
|
||||
print(f" - Total training sessions: {stats.get('total_training_sessions', 0)}")
|
||||
print(f" - CNN training count: {stats.get('cnn_training_count', 0)}")
|
||||
print(f" - RL training count: {stats.get('rl_training_count', 0)}")
|
||||
print(f" - Training data points: {stats.get('training_data_points', 0)}")
|
||||
|
||||
# Test 6: Training metrics display
|
||||
print("\n📈 TEST 6: Training Metrics Display")
|
||||
print("-" * 40)
|
||||
|
||||
training_metrics = dashboard._create_training_metrics()
|
||||
print(f"✓ Training metrics created: {len(training_metrics)} components")
|
||||
|
||||
# Test 7: Model training status
|
||||
print("\n🔍 TEST 7: Model Training Status")
|
||||
print("-" * 40)
|
||||
|
||||
training_status = dashboard._get_model_training_status()
|
||||
print(f"✓ Training status retrieved")
|
||||
print(f" - CNN status: {training_status['cnn']['status']}")
|
||||
print(f" - CNN accuracy: {training_status['cnn']['accuracy']:.1%}")
|
||||
print(f" - RL status: {training_status['rl']['status']}")
|
||||
print(f" - RL win rate: {training_status['rl']['win_rate']:.1%}")
|
||||
|
||||
# Test 8: Training events log
|
||||
print("\n📝 TEST 8: Training Events Log")
|
||||
print("-" * 40)
|
||||
|
||||
training_events = dashboard._get_recent_training_events()
|
||||
print(f"✓ Training events retrieved: {len(training_events)} events")
|
||||
|
||||
# Test 9: Mini training chart
|
||||
print("\n📊 TEST 9: Mini Training Chart")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
chart = RealTimeChart(app=None, symbol="ETH/USDT", standalone=False)
|
||||
|
||||
# Test getting candles
|
||||
candles_1m = chart.get_candles(60)
|
||||
|
||||
self.assertIsInstance(candles_1m, list, "Should return list of candles")
|
||||
logger.info(f"✅ Chart initialized with {len(candles_1m)} 1m candles")
|
||||
|
||||
training_chart = dashboard._create_mini_training_chart(training_status)
|
||||
print(f"✓ Mini training chart created")
|
||||
print(f" - Chart type: {type(training_chart)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Chart initialization failed: {e}")
|
||||
self.skipTest("Chart initialization not available")
|
||||
|
||||
class TestCNNTraining(unittest.TestCase):
|
||||
"""Test suite for CNN training functionality"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
setup_logging()
|
||||
print(f"❌ Error creating training chart: {e}")
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures"""
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_cnn_quick_training(self):
|
||||
"""Test quick CNN training with small dataset"""
|
||||
logger.info("Testing CNN quick training...")
|
||||
# Test 10: Continuous training loop
|
||||
print("\n🔄 TEST 10: Continuous Training Loop")
|
||||
print("-" * 40)
|
||||
|
||||
print(f"✓ Continuous training active: {getattr(dashboard, 'training_active', False)}")
|
||||
if hasattr(dashboard, 'training_thread'):
|
||||
print(f"✓ Training thread alive: {dashboard.training_thread.is_alive()}")
|
||||
|
||||
# Test 11: Integration with existing continuous training system
|
||||
print("\n🔗 TEST 11: Integration with Continuous Training System")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
config = get_config()
|
||||
# Check if we can get tick cache for external training
|
||||
tick_cache = dashboard.get_tick_cache_for_training()
|
||||
print(f"✓ Tick cache accessible: {len(tick_cache)} ticks")
|
||||
|
||||
# Test configuration
|
||||
symbols = ['ETH/USDT']
|
||||
timeframes = ['1m', '5m']
|
||||
num_samples = 100 # Very small for testing
|
||||
epochs = 1
|
||||
batch_size = 16
|
||||
|
||||
# Override config for quick test
|
||||
config._config['timeframes'] = timeframes
|
||||
|
||||
trainer = CNNTrainer(config)
|
||||
trainer.batch_size = batch_size
|
||||
trainer.epochs = epochs
|
||||
|
||||
# Train model
|
||||
save_path = os.path.join(self.temp_dir, 'test_cnn.pt')
|
||||
results = trainer.train(symbols, save_path=save_path, num_samples=num_samples)
|
||||
|
||||
# Verify results
|
||||
self.assertIsInstance(results, dict, "Should return results dict")
|
||||
self.assertIn('best_val_accuracy', results, "Should have accuracy metric")
|
||||
self.assertIn('total_epochs', results, "Should have epoch count")
|
||||
self.assertIn('training_time', results, "Should have training time")
|
||||
|
||||
# Verify model was saved
|
||||
self.assertTrue(os.path.exists(save_path), "Model should be saved")
|
||||
|
||||
logger.info(f"✅ CNN training completed successfully")
|
||||
logger.info(f" Best accuracy: {results['best_val_accuracy']:.4f}")
|
||||
logger.info(f" Training time: {results['training_time']:.2f}s")
|
||||
# Check if we can get 1-second bars
|
||||
one_second_bars = dashboard.get_one_second_bars()
|
||||
print(f"✓ 1-second bars accessible: {len(one_second_bars)} bars")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CNN training test failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
if hasattr(trainer, 'close_tensorboard'):
|
||||
trainer.close_tensorboard()
|
||||
|
||||
class TestRLTraining(unittest.TestCase):
|
||||
"""Test suite for RL training functionality"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
setup_logging()
|
||||
print(f"❌ Error accessing training data: {e}")
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures"""
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_rl_quick_training(self):
|
||||
"""Test quick RL training with small dataset"""
|
||||
logger.info("Testing RL quick training...")
|
||||
print("\n" + "="*60)
|
||||
print("TRAINING INTEGRATION TEST COMPLETED")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Setup minimal configuration
|
||||
data_provider = DataProvider(['ETH/USDT'], ['1m', '5m'])
|
||||
trainer = RLTrainer(data_provider)
|
||||
|
||||
# Configure for very quick test
|
||||
trainer.num_episodes = 5
|
||||
trainer.max_steps_per_episode = 50
|
||||
trainer.evaluation_frequency = 3
|
||||
trainer.save_frequency = 10 # Don't save during test
|
||||
|
||||
# Train
|
||||
save_path = os.path.join(self.temp_dir, 'test_rl.pt')
|
||||
results = trainer.train(save_path=save_path)
|
||||
|
||||
# Verify results
|
||||
self.assertIsInstance(results, dict, "Should return results dict")
|
||||
self.assertIn('total_episodes', results, "Should have episode count")
|
||||
self.assertIn('best_reward', results, "Should have best reward")
|
||||
self.assertIn('final_evaluation', results, "Should have final evaluation")
|
||||
|
||||
logger.info(f"✅ RL training completed successfully")
|
||||
logger.info(f" Total episodes: {results['total_episodes']}")
|
||||
logger.info(f" Best reward: {results['best_reward']:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RL training test failed: {e}")
|
||||
raise
|
||||
|
||||
class TestExtendedTraining(unittest.TestCase):
|
||||
"""Test suite for extended training functionality (from test_model.py)"""
|
||||
|
||||
def test_metrics_tracking(self):
|
||||
"""Test comprehensive metrics tracking functionality"""
|
||||
logger.info("Testing extended metrics tracking...")
|
||||
# Summary
|
||||
print("\n📋 SUMMARY:")
|
||||
print(f"✓ Dashboard with training integration: WORKING")
|
||||
print(f"✓ Training data preparation: WORKING")
|
||||
print(f"✓ CNN data formatting: WORKING")
|
||||
print(f"✓ RL data formatting: WORKING")
|
||||
print(f"✓ Training metrics display: WORKING")
|
||||
print(f"✓ Continuous training: ACTIVE")
|
||||
print(f"✓ Model status tracking: WORKING")
|
||||
print(f"✓ Training events logging: WORKING")
|
||||
|
||||
# Test metrics history structure
|
||||
metrics_history = {
|
||||
"epoch": [],
|
||||
"train_loss": [],
|
||||
"val_loss": [],
|
||||
"train_acc": [],
|
||||
"val_acc": [],
|
||||
"train_pnl": [],
|
||||
"val_pnl": [],
|
||||
"train_win_rate": [],
|
||||
"val_win_rate": [],
|
||||
"signal_distribution": []
|
||||
}
|
||||
return True
|
||||
|
||||
# Simulate adding metrics
|
||||
for epoch in range(3):
|
||||
metrics_history["epoch"].append(epoch + 1)
|
||||
metrics_history["train_loss"].append(0.5 - epoch * 0.1)
|
||||
metrics_history["val_loss"].append(0.6 - epoch * 0.1)
|
||||
metrics_history["train_acc"].append(0.6 + epoch * 0.05)
|
||||
metrics_history["val_acc"].append(0.55 + epoch * 0.05)
|
||||
metrics_history["train_pnl"].append(epoch * 0.1)
|
||||
metrics_history["val_pnl"].append(epoch * 0.08)
|
||||
metrics_history["train_win_rate"].append(0.5 + epoch * 0.1)
|
||||
metrics_history["val_win_rate"].append(0.45 + epoch * 0.1)
|
||||
metrics_history["signal_distribution"].append({
|
||||
"BUY": 0.3, "SELL": 0.3, "HOLD": 0.4
|
||||
})
|
||||
|
||||
# Verify structure
|
||||
self.assertEqual(len(metrics_history["epoch"]), 3)
|
||||
self.assertEqual(len(metrics_history["train_loss"]), 3)
|
||||
self.assertEqual(len(metrics_history["signal_distribution"]), 3)
|
||||
|
||||
# Verify improvement
|
||||
self.assertLess(metrics_history["train_loss"][-1], metrics_history["train_loss"][0])
|
||||
self.assertGreater(metrics_history["train_acc"][-1], metrics_history["train_acc"][0])
|
||||
|
||||
logger.info("✅ Metrics tracking test passed")
|
||||
|
||||
def test_signal_distribution_calculation(self):
|
||||
"""Test signal distribution calculation"""
|
||||
import numpy as np
|
||||
|
||||
# Mock predictions (SELL=0, HOLD=1, BUY=2)
|
||||
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0])
|
||||
|
||||
buy_count = np.sum(predictions == 2)
|
||||
sell_count = np.sum(predictions == 0)
|
||||
hold_count = np.sum(predictions == 1)
|
||||
total = len(predictions)
|
||||
|
||||
distribution = {
|
||||
"BUY": buy_count / total,
|
||||
"SELL": sell_count / total,
|
||||
"HOLD": hold_count / total
|
||||
}
|
||||
|
||||
# Verify calculations
|
||||
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
|
||||
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
|
||||
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
|
||||
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
|
||||
|
||||
logger.info("✅ Signal distribution calculation test passed")
|
||||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
"""Integration tests between components"""
|
||||
|
||||
def test_training_pipeline_integration(self):
|
||||
"""Test that CNN and RL training can work together"""
|
||||
logger.info("Testing training pipeline integration...")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
# Quick CNN training
|
||||
config = get_config()
|
||||
config._config['timeframes'] = ['1m']
|
||||
|
||||
cnn_trainer = CNNTrainer(config)
|
||||
cnn_trainer.epochs = 1
|
||||
cnn_trainer.batch_size = 8
|
||||
|
||||
cnn_path = os.path.join(temp_dir, 'test_cnn.pt')
|
||||
cnn_results = cnn_trainer.train(['ETH/USDT'], save_path=cnn_path, num_samples=50)
|
||||
|
||||
# Quick RL training
|
||||
data_provider = DataProvider(['ETH/USDT'], ['1m'])
|
||||
rl_trainer = RLTrainer(data_provider)
|
||||
rl_trainer.num_episodes = 3
|
||||
rl_trainer.max_steps_per_episode = 25
|
||||
|
||||
rl_path = os.path.join(temp_dir, 'test_rl.pt')
|
||||
rl_results = rl_trainer.train(save_path=rl_path)
|
||||
|
||||
# Verify both trained successfully
|
||||
self.assertIsInstance(cnn_results, dict)
|
||||
self.assertIsInstance(rl_results, dict)
|
||||
self.assertTrue(os.path.exists(cnn_path))
|
||||
self.assertTrue(os.path.exists(rl_path))
|
||||
|
||||
logger.info("✅ Training pipeline integration test passed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Integration test failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
if 'cnn_trainer' in locals():
|
||||
cnn_trainer.close_tensorboard()
|
||||
|
||||
def run_quick_tests():
|
||||
"""Run only the quickest tests for fast validation"""
|
||||
test_suites = [
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestExtendedTraining),
|
||||
]
|
||||
|
||||
combined_suite = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(combined_suite)
|
||||
|
||||
return result.wasSuccessful()
|
||||
|
||||
def run_data_tests():
|
||||
"""Run data provider tests"""
|
||||
test_suites = [
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestDataProviders),
|
||||
]
|
||||
|
||||
combined_suite = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(combined_suite)
|
||||
|
||||
return result.wasSuccessful()
|
||||
|
||||
def run_training_tests():
|
||||
"""Run training tests (slower)"""
|
||||
test_suites = [
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestCNNTraining),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestRLTraining),
|
||||
]
|
||||
|
||||
combined_suite = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(combined_suite)
|
||||
|
||||
return result.wasSuccessful()
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all test suites"""
|
||||
test_suites = [
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestDataProviders),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestCNNTraining),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestRLTraining),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestExtendedTraining),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestIntegration),
|
||||
]
|
||||
|
||||
combined_suite = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(combined_suite)
|
||||
|
||||
return result.wasSuccessful()
|
||||
except Exception as e:
|
||||
logger.error(f"Training integration test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logging()
|
||||
logger.info("Running comprehensive training integration tests...")
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
test_type = sys.argv[1]
|
||||
if test_type == "quick":
|
||||
success = run_quick_tests()
|
||||
elif test_type == "data":
|
||||
success = run_data_tests()
|
||||
elif test_type == "training":
|
||||
success = run_training_tests()
|
||||
else:
|
||||
success = run_all_tests()
|
||||
else:
|
||||
success = run_all_tests()
|
||||
|
||||
success = test_training_integration()
|
||||
if success:
|
||||
logger.info("✅ All tests passed!")
|
||||
sys.exit(0)
|
||||
print("\n🎉 All training integration tests passed!")
|
||||
else:
|
||||
logger.error("❌ Some tests failed!")
|
||||
print("\n❌ Some training integration tests failed!")
|
||||
sys.exit(1)
|
Reference in New Issue
Block a user