This commit is contained in:
Dobromir Popov
2025-07-23 13:39:41 +03:00
parent 944a7b79e6
commit df17a99247
13 changed files with 663 additions and 695 deletions

View File

@ -1,400 +1,118 @@
#!/usr/bin/env python3
"""
Test Training Data Collection System
Test Training Data Collection and Checkpoint Storage
This script demonstrates and tests the comprehensive training data collection
system with data validation, rapid change detection, and profitable setup replay.
This script tests if the training system is working correctly and storing checkpoints.
"""
import asyncio
import os
import sys
import logging
import numpy as np
import pandas as pd
import time
from datetime import datetime, timedelta
import asyncio
from pathlib import Path
from datetime import datetime
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import get_config, setup_logging
from core.orchestrator import TradingOrchestrator
from core.data_provider import DataProvider
from utils.checkpoint_manager import get_checkpoint_manager
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
setup_logging()
logger = logging.getLogger(__name__)
# Import our training system components
from core.training_data_collector import (
TrainingDataCollector,
RapidChangeDetector,
ModelInputPackage,
TrainingOutcome,
TrainingEpisode
)
from core.cnn_training_pipeline import (
CNNPivotPredictor,
CNNTrainer
)
from core.data_provider import DataProvider
def create_sample_ohlcv_data() -> Dict[str, pd.DataFrame]:
"""Create sample OHLCV data for testing"""
timeframes = ['1s', '1m', '5m', '15m', '1h']
ohlcv_data = {}
async def test_training_system():
"""Test if the training system is working and storing checkpoints"""
logger.info("Testing training system and checkpoint storage...")
for timeframe in timeframes:
# Create sample data
dates = pd.date_range(start='2024-01-01', periods=300, freq='1min')
# Initialize components
data_provider = DataProvider()
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
# Get checkpoint manager
checkpoint_manager = get_checkpoint_manager()
# Check if checkpoint directory exists
checkpoint_dir = Path("models/saved")
if not checkpoint_dir.exists():
logger.warning(f"Checkpoint directory {checkpoint_dir} does not exist. Creating...")
checkpoint_dir.mkdir(parents=True, exist_ok=True)
# Check for existing checkpoints
checkpoint_stats = checkpoint_manager.get_checkpoint_stats()
logger.info(f"Found {checkpoint_stats['total_checkpoints']} existing checkpoints.")
logger.info(f"Total checkpoint size: {checkpoint_stats['total_size_mb']:.2f} MB")
# List checkpoint files
checkpoint_files = list(checkpoint_dir.glob("*.pt"))
if checkpoint_files:
logger.info("Recent checkpoint files:")
for i, file in enumerate(sorted(checkpoint_files, key=lambda f: f.stat().st_mtime, reverse=True)[:5]):
file_size = file.stat().st_size / (1024 * 1024) # Convert to MB
modified_time = datetime.fromtimestamp(file.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S")
logger.info(f" {i+1}. {file.name} ({file_size:.2f} MB, modified: {modified_time})")
else:
logger.warning("No checkpoint files found.")
# Test training by making trading decisions
logger.info("\nTesting training by making trading decisions...")
symbols = orchestrator.symbols
for symbol in symbols:
logger.info(f"Making trading decision for {symbol}...")
decision = await orchestrator.make_trading_decision(symbol)
# Generate realistic price data
base_price = 3000.0 # ETH price
price_data = []
current_price = base_price
for i in range(300):
# Add some randomness
change = np.random.normal(0, 0.002) # 0.2% std dev
current_price *= (1 + change)
# OHLCV for this period
open_price = current_price
high_price = current_price * (1 + abs(np.random.normal(0, 0.001)))
low_price = current_price * (1 - abs(np.random.normal(0, 0.001)))
close_price = current_price * (1 + np.random.normal(0, 0.0005))
volume = np.random.uniform(100, 1000)
price_data.append({
'timestamp': dates[i],
'open': open_price,
'high': high_price,
'low': low_price,
'close': close_price,
'volume': volume
})
current_price = close_price
df = pd.DataFrame(price_data)
df.set_index('timestamp', inplace=True)
ohlcv_data[timeframe] = df
return ohlcv_data
def create_sample_tick_data() -> List[Dict[str, Any]]:
"""Create sample tick data for testing"""
tick_data = []
base_price = 3000.0
for i in range(100):
tick_data.append({
'timestamp': datetime.now() - timedelta(seconds=100-i),
'price': base_price + np.random.normal(0, 5),
'volume': np.random.uniform(0.1, 10.0),
'side': 'buy' if np.random.random() > 0.5 else 'sell',
'trade_id': f'trade_{i}',
'quantity': np.random.uniform(0.1, 5.0)
})
return tick_data
def create_sample_cob_data() -> Dict[str, Any]:
"""Create sample COB data for testing"""
return {
'timestamp': datetime.now(),
'bid_levels': [3000 - i for i in range(10)],
'ask_levels': [3000 + i for i in range(10)],
'bid_volumes': [np.random.uniform(1, 10) for _ in range(10)],
'ask_volumes': [np.random.uniform(1, 10) for _ in range(10)],
'spread': 1.0,
'depth': 100.0
}
def test_rapid_change_detector():
"""Test the rapid change detection system"""
logger.info("=== Testing Rapid Change Detector ===")
detector = RapidChangeDetector(
velocity_threshold=0.5,
volatility_multiplier=3.0,
lookback_minutes=5
)
symbol = 'ETHUSDT'
base_price = 3000.0
# Add normal price points
for i in range(120): # 2 minutes of data
timestamp = datetime.now() - timedelta(seconds=120-i)
price = base_price + np.random.normal(0, 1) # Small changes
detector.add_price_point(symbol, timestamp, price)
# Check for rapid change (should be False)
is_rapid, velocity, volatility_spike = detector.detect_rapid_change(symbol)
logger.info(f"Normal conditions - Rapid change: {is_rapid}, Velocity: {velocity:.3f}")
# Add rapid price change
for i in range(60): # 1 minute of rapid changes
timestamp = datetime.now() - timedelta(seconds=60-i)
price = base_price + 50 + i * 0.5 # Rapid increase
detector.add_price_point(symbol, timestamp, price)
# Check for rapid change (should be True)
is_rapid, velocity, volatility_spike = detector.detect_rapid_change(symbol)
logger.info(f"Rapid change conditions - Rapid change: {is_rapid}, Velocity: {velocity:.3f}")
return detector
def test_training_data_collector():
"""Test the training data collection system"""
logger.info("=== Testing Training Data Collector ===")
# Initialize collector
collector = TrainingDataCollector(
storage_dir="test_training_data",
max_episodes_per_symbol=100
)
collector.start_collection()
symbol = 'ETHUSDT'
# Create sample data
ohlcv_data = create_sample_ohlcv_data()
tick_data = create_sample_tick_data()
cob_data = create_sample_cob_data()
technical_indicators = {
'rsi_14': 65.5,
'macd': 0.5,
'sma_20': 3000.0,
'ema_12': 3005.0,
'bollinger_upper': 3050.0,
'bollinger_lower': 2950.0
}
pivot_points = [
{'timestamp': datetime.now(), 'price': 3020.0, 'type': 'high'},
{'timestamp': datetime.now() - timedelta(minutes=30), 'price': 2980.0, 'type': 'low'}
]
# Create CNN and RL features
cnn_features = np.random.randn(2000).astype(np.float32)
rl_state = np.random.randn(2000).astype(np.float32)
orchestrator_context = {
'market_session': 'european',
'volatility_regime': 'medium',
'trend_direction': 'uptrend'
}
# Collect training data
episode_id = collector.collect_training_data(
symbol=symbol,
ohlcv_data=ohlcv_data,
tick_data=tick_data,
cob_data=cob_data,
technical_indicators=technical_indicators,
pivot_points=pivot_points,
cnn_features=cnn_features,
rl_state=rl_state,
orchestrator_context=orchestrator_context
)
logger.info(f"Created training episode: {episode_id}")
# Test data validation
validation_results = collector.validate_data_integrity()
logger.info(f"Data integrity validation: {validation_results}")
# Get statistics
stats = collector.get_collection_statistics()
logger.info(f"Collection statistics: {stats}")
collector.stop_collection()
return collector
def test_cnn_training_pipeline():
"""Test the CNN training pipeline"""
logger.info("=== Testing CNN Training Pipeline ===")
# Initialize CNN model and trainer
model = CNNPivotPredictor(
input_channels=10,
sequence_length=300,
hidden_dim=128, # Smaller for testing
num_pivot_classes=3
)
trainer = CNNTrainer(
model=model,
device='cpu', # Use CPU for testing
learning_rate=0.001,
storage_dir="test_cnn_training"
)
# Create sample training episodes
episodes = []
for i in range(50): # Create 50 sample episodes
# Create sample input package
input_package = ModelInputPackage(
timestamp=datetime.now() - timedelta(minutes=i),
symbol='ETHUSDT',
ohlcv_data=create_sample_ohlcv_data(),
tick_data=create_sample_tick_data(),
cob_data=create_sample_cob_data(),
technical_indicators={'rsi': 50.0, 'macd': 0.0},
pivot_points=[],
cnn_features=np.random.randn(2000).astype(np.float32),
rl_state=np.random.randn(2000).astype(np.float32),
orchestrator_context={}
)
# Create sample outcome
outcome = TrainingOutcome(
input_package_hash=input_package.data_hash,
timestamp=input_package.timestamp,
symbol='ETHUSDT',
price_change_1m=np.random.normal(0, 0.01),
price_change_5m=np.random.normal(0, 0.02),
price_change_15m=np.random.normal(0, 0.03),
price_change_1h=np.random.normal(0, 0.05),
max_profit_potential=abs(np.random.normal(0, 0.02)),
max_loss_potential=abs(np.random.normal(0, 0.015)),
optimal_entry_price=3000.0,
optimal_exit_price=3000.0 + np.random.normal(0, 10),
optimal_holding_time=timedelta(minutes=np.random.randint(5, 60)),
is_profitable=np.random.random() > 0.4, # 60% profitable
profitability_score=np.random.uniform(0.3, 1.0),
risk_reward_ratio=np.random.uniform(1.0, 3.0),
is_rapid_change=np.random.random() > 0.8, # 20% rapid changes
change_velocity=np.random.uniform(0.1, 2.0),
volatility_spike=np.random.random() > 0.9,
outcome_validated=True
)
# Create training episode
episode = TrainingEpisode(
episode_id=f"test_episode_{i}",
input_package=input_package,
model_predictions={},
actual_outcome=outcome,
episode_type='normal'
)
episodes.append(episode)
# Test training on episodes
results = trainer._train_on_episodes(episodes, training_mode='test_batch')
logger.info(f"Training results: {results}")
# Test profitable episode training
profitable_results = trainer.train_on_profitable_episodes(
symbol='ETHUSDT',
min_profitability=0.7,
max_episodes=20
)
logger.info(f"Profitable training results: {profitable_results}")
# Get training statistics
stats = trainer.get_training_statistics()
logger.info(f"Training statistics: {stats}")
return trainer
def test_integration():
"""Test the complete integration"""
logger.info("=== Testing Complete Integration ===")
try:
# Test individual components
detector = test_rapid_change_detector()
collector = test_training_data_collector()
trainer = test_cnn_training_pipeline()
logger.info("✅ All components tested successfully!")
# Test data flow
logger.info("Testing data flow integration...")
# Simulate real-time data collection and training
symbol = 'ETHUSDT'
# Collect multiple data points
for i in range(10):
ohlcv_data = create_sample_ohlcv_data()
tick_data = create_sample_tick_data()
cob_data = create_sample_cob_data()
episode_id = collector.collect_training_data(
symbol=symbol,
ohlcv_data=ohlcv_data,
tick_data=tick_data,
cob_data=cob_data,
technical_indicators={'rsi': 50.0 + i},
pivot_points=[],
cnn_features=np.random.randn(2000).astype(np.float32),
rl_state=np.random.randn(2000).astype(np.float32),
orchestrator_context={}
)
logger.info(f"Collected episode {i+1}: {episode_id}")
time.sleep(0.1) # Small delay
# Get final statistics
final_stats = collector.get_collection_statistics()
logger.info(f"Final collection statistics: {final_stats}")
logger.info("✅ Integration test completed successfully!")
return True
except Exception as e:
logger.error(f"❌ Integration test failed: {e}")
import traceback
logger.error(traceback.format_exc())
return False
def main():
"""Main test function"""
logger.info("=" * 80)
logger.info("COMPREHENSIVE TRAINING DATA COLLECTION SYSTEM TEST")
logger.info("=" * 80)
start_time = time.time()
try:
# Run integration test
success = test_integration()
end_time = time.time()
duration = end_time - start_time
logger.info("=" * 80)
if success:
logger.info("✅ ALL TESTS PASSED!")
if decision:
logger.info(f"Decision for {symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
else:
logger.info("❌ SOME TESTS FAILED!")
logger.warning(f"No decision made for {symbol}.")
# Check if new checkpoints were created
new_checkpoint_stats = checkpoint_manager.get_checkpoint_stats()
new_checkpoints = new_checkpoint_stats['total_checkpoints'] - checkpoint_stats['total_checkpoints']
if new_checkpoints > 0:
logger.info(f"\nSuccess! {new_checkpoints} new checkpoints were created.")
logger.info("Training system is working correctly.")
else:
logger.warning("\nNo new checkpoints were created.")
logger.warning("This could be normal if the training threshold wasn't met.")
logger.warning("Check the orchestrator's checkpoint saving logic.")
# Check model states
model_states = orchestrator.get_model_states()
logger.info("\nModel states:")
for model_name, state in model_states.items():
checkpoint_loaded = state.get('checkpoint_loaded', False)
checkpoint_filename = state.get('checkpoint_filename', 'none')
current_loss = state.get('current_loss', None)
logger.info(f"Test duration: {duration:.2f} seconds")
logger.info("=" * 80)
status = "LOADED" if checkpoint_loaded else "FRESH"
loss_str = f"{current_loss:.4f}" if current_loss is not None else "N/A"
# Display summary
logger.info("\n📊 SYSTEM CAPABILITIES DEMONSTRATED:")
logger.info("✓ Comprehensive training data collection with validation")
logger.info("✓ Rapid price change detection for premium training examples")
logger.info("✓ Data integrity validation and completeness checking")
logger.info("✓ CNN training pipeline with backpropagation data storage")
logger.info("✓ Profitable episode prioritization and replay")
logger.info("✓ Training session value calculation and ranking")
logger.info("✓ Real-time data integration capabilities")
logger.info("\n🎯 NEXT STEPS:")
logger.info("1. Integrate with existing DataProvider for real market data")
logger.info("2. Connect with actual CNN and RL models")
logger.info("3. Implement outcome validation with real price data")
logger.info("4. Add dashboard integration for monitoring")
logger.info("5. Scale up for production deployment")
except Exception as e:
logger.error(f"❌ Test execution failed: {e}")
import traceback
logger.error(traceback.format_exc())
logger.info(f" {model_name}: {status}, Loss: {loss_str}, Checkpoint: {checkpoint_filename}")
return new_checkpoints > 0
async def main():
"""Main function"""
logger.info("=" * 70)
logger.info("TRAINING SYSTEM TEST")
logger.info("=" * 70)
success = await test_training_system()
if success:
logger.info("\nTraining system test passed!")
return 0
else:
logger.warning("\nTraining system test completed with warnings.")
logger.info("Check the logs for details.")
return 1
if __name__ == "__main__":
main()
sys.exit(asyncio.run(main()))