wip
This commit is contained in:
@ -1,400 +1,118 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Training Data Collection System
|
||||
Test Training Data Collection and Checkpoint Storage
|
||||
|
||||
This script demonstrates and tests the comprehensive training data collection
|
||||
system with data validation, rapid change detection, and profitable setup replay.
|
||||
This script tests if the training system is working correctly and storing checkpoints.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import get_config, setup_logging
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import our training system components
|
||||
from core.training_data_collector import (
|
||||
TrainingDataCollector,
|
||||
RapidChangeDetector,
|
||||
ModelInputPackage,
|
||||
TrainingOutcome,
|
||||
TrainingEpisode
|
||||
)
|
||||
from core.cnn_training_pipeline import (
|
||||
CNNPivotPredictor,
|
||||
CNNTrainer
|
||||
)
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
def create_sample_ohlcv_data() -> Dict[str, pd.DataFrame]:
|
||||
"""Create sample OHLCV data for testing"""
|
||||
timeframes = ['1s', '1m', '5m', '15m', '1h']
|
||||
ohlcv_data = {}
|
||||
async def test_training_system():
|
||||
"""Test if the training system is working and storing checkpoints"""
|
||||
logger.info("Testing training system and checkpoint storage...")
|
||||
|
||||
for timeframe in timeframes:
|
||||
# Create sample data
|
||||
dates = pd.date_range(start='2024-01-01', periods=300, freq='1min')
|
||||
# Initialize components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
|
||||
|
||||
# Get checkpoint manager
|
||||
checkpoint_manager = get_checkpoint_manager()
|
||||
|
||||
# Check if checkpoint directory exists
|
||||
checkpoint_dir = Path("models/saved")
|
||||
if not checkpoint_dir.exists():
|
||||
logger.warning(f"Checkpoint directory {checkpoint_dir} does not exist. Creating...")
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check for existing checkpoints
|
||||
checkpoint_stats = checkpoint_manager.get_checkpoint_stats()
|
||||
logger.info(f"Found {checkpoint_stats['total_checkpoints']} existing checkpoints.")
|
||||
logger.info(f"Total checkpoint size: {checkpoint_stats['total_size_mb']:.2f} MB")
|
||||
|
||||
# List checkpoint files
|
||||
checkpoint_files = list(checkpoint_dir.glob("*.pt"))
|
||||
if checkpoint_files:
|
||||
logger.info("Recent checkpoint files:")
|
||||
for i, file in enumerate(sorted(checkpoint_files, key=lambda f: f.stat().st_mtime, reverse=True)[:5]):
|
||||
file_size = file.stat().st_size / (1024 * 1024) # Convert to MB
|
||||
modified_time = datetime.fromtimestamp(file.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info(f" {i+1}. {file.name} ({file_size:.2f} MB, modified: {modified_time})")
|
||||
else:
|
||||
logger.warning("No checkpoint files found.")
|
||||
|
||||
# Test training by making trading decisions
|
||||
logger.info("\nTesting training by making trading decisions...")
|
||||
symbols = orchestrator.symbols
|
||||
|
||||
for symbol in symbols:
|
||||
logger.info(f"Making trading decision for {symbol}...")
|
||||
decision = await orchestrator.make_trading_decision(symbol)
|
||||
|
||||
# Generate realistic price data
|
||||
base_price = 3000.0 # ETH price
|
||||
price_data = []
|
||||
current_price = base_price
|
||||
|
||||
for i in range(300):
|
||||
# Add some randomness
|
||||
change = np.random.normal(0, 0.002) # 0.2% std dev
|
||||
current_price *= (1 + change)
|
||||
|
||||
# OHLCV for this period
|
||||
open_price = current_price
|
||||
high_price = current_price * (1 + abs(np.random.normal(0, 0.001)))
|
||||
low_price = current_price * (1 - abs(np.random.normal(0, 0.001)))
|
||||
close_price = current_price * (1 + np.random.normal(0, 0.0005))
|
||||
volume = np.random.uniform(100, 1000)
|
||||
|
||||
price_data.append({
|
||||
'timestamp': dates[i],
|
||||
'open': open_price,
|
||||
'high': high_price,
|
||||
'low': low_price,
|
||||
'close': close_price,
|
||||
'volume': volume
|
||||
})
|
||||
|
||||
current_price = close_price
|
||||
|
||||
df = pd.DataFrame(price_data)
|
||||
df.set_index('timestamp', inplace=True)
|
||||
ohlcv_data[timeframe] = df
|
||||
|
||||
return ohlcv_data
|
||||
|
||||
def create_sample_tick_data() -> List[Dict[str, Any]]:
|
||||
"""Create sample tick data for testing"""
|
||||
tick_data = []
|
||||
base_price = 3000.0
|
||||
|
||||
for i in range(100):
|
||||
tick_data.append({
|
||||
'timestamp': datetime.now() - timedelta(seconds=100-i),
|
||||
'price': base_price + np.random.normal(0, 5),
|
||||
'volume': np.random.uniform(0.1, 10.0),
|
||||
'side': 'buy' if np.random.random() > 0.5 else 'sell',
|
||||
'trade_id': f'trade_{i}',
|
||||
'quantity': np.random.uniform(0.1, 5.0)
|
||||
})
|
||||
|
||||
return tick_data
|
||||
|
||||
def create_sample_cob_data() -> Dict[str, Any]:
|
||||
"""Create sample COB data for testing"""
|
||||
return {
|
||||
'timestamp': datetime.now(),
|
||||
'bid_levels': [3000 - i for i in range(10)],
|
||||
'ask_levels': [3000 + i for i in range(10)],
|
||||
'bid_volumes': [np.random.uniform(1, 10) for _ in range(10)],
|
||||
'ask_volumes': [np.random.uniform(1, 10) for _ in range(10)],
|
||||
'spread': 1.0,
|
||||
'depth': 100.0
|
||||
}
|
||||
|
||||
def test_rapid_change_detector():
|
||||
"""Test the rapid change detection system"""
|
||||
logger.info("=== Testing Rapid Change Detector ===")
|
||||
|
||||
detector = RapidChangeDetector(
|
||||
velocity_threshold=0.5,
|
||||
volatility_multiplier=3.0,
|
||||
lookback_minutes=5
|
||||
)
|
||||
|
||||
symbol = 'ETHUSDT'
|
||||
base_price = 3000.0
|
||||
|
||||
# Add normal price points
|
||||
for i in range(120): # 2 minutes of data
|
||||
timestamp = datetime.now() - timedelta(seconds=120-i)
|
||||
price = base_price + np.random.normal(0, 1) # Small changes
|
||||
detector.add_price_point(symbol, timestamp, price)
|
||||
|
||||
# Check for rapid change (should be False)
|
||||
is_rapid, velocity, volatility_spike = detector.detect_rapid_change(symbol)
|
||||
logger.info(f"Normal conditions - Rapid change: {is_rapid}, Velocity: {velocity:.3f}")
|
||||
|
||||
# Add rapid price change
|
||||
for i in range(60): # 1 minute of rapid changes
|
||||
timestamp = datetime.now() - timedelta(seconds=60-i)
|
||||
price = base_price + 50 + i * 0.5 # Rapid increase
|
||||
detector.add_price_point(symbol, timestamp, price)
|
||||
|
||||
# Check for rapid change (should be True)
|
||||
is_rapid, velocity, volatility_spike = detector.detect_rapid_change(symbol)
|
||||
logger.info(f"Rapid change conditions - Rapid change: {is_rapid}, Velocity: {velocity:.3f}")
|
||||
|
||||
return detector
|
||||
|
||||
def test_training_data_collector():
|
||||
"""Test the training data collection system"""
|
||||
logger.info("=== Testing Training Data Collector ===")
|
||||
|
||||
# Initialize collector
|
||||
collector = TrainingDataCollector(
|
||||
storage_dir="test_training_data",
|
||||
max_episodes_per_symbol=100
|
||||
)
|
||||
|
||||
collector.start_collection()
|
||||
|
||||
symbol = 'ETHUSDT'
|
||||
|
||||
# Create sample data
|
||||
ohlcv_data = create_sample_ohlcv_data()
|
||||
tick_data = create_sample_tick_data()
|
||||
cob_data = create_sample_cob_data()
|
||||
technical_indicators = {
|
||||
'rsi_14': 65.5,
|
||||
'macd': 0.5,
|
||||
'sma_20': 3000.0,
|
||||
'ema_12': 3005.0,
|
||||
'bollinger_upper': 3050.0,
|
||||
'bollinger_lower': 2950.0
|
||||
}
|
||||
pivot_points = [
|
||||
{'timestamp': datetime.now(), 'price': 3020.0, 'type': 'high'},
|
||||
{'timestamp': datetime.now() - timedelta(minutes=30), 'price': 2980.0, 'type': 'low'}
|
||||
]
|
||||
|
||||
# Create CNN and RL features
|
||||
cnn_features = np.random.randn(2000).astype(np.float32)
|
||||
rl_state = np.random.randn(2000).astype(np.float32)
|
||||
orchestrator_context = {
|
||||
'market_session': 'european',
|
||||
'volatility_regime': 'medium',
|
||||
'trend_direction': 'uptrend'
|
||||
}
|
||||
|
||||
# Collect training data
|
||||
episode_id = collector.collect_training_data(
|
||||
symbol=symbol,
|
||||
ohlcv_data=ohlcv_data,
|
||||
tick_data=tick_data,
|
||||
cob_data=cob_data,
|
||||
technical_indicators=technical_indicators,
|
||||
pivot_points=pivot_points,
|
||||
cnn_features=cnn_features,
|
||||
rl_state=rl_state,
|
||||
orchestrator_context=orchestrator_context
|
||||
)
|
||||
|
||||
logger.info(f"Created training episode: {episode_id}")
|
||||
|
||||
# Test data validation
|
||||
validation_results = collector.validate_data_integrity()
|
||||
logger.info(f"Data integrity validation: {validation_results}")
|
||||
|
||||
# Get statistics
|
||||
stats = collector.get_collection_statistics()
|
||||
logger.info(f"Collection statistics: {stats}")
|
||||
|
||||
collector.stop_collection()
|
||||
|
||||
return collector
|
||||
|
||||
def test_cnn_training_pipeline():
|
||||
"""Test the CNN training pipeline"""
|
||||
logger.info("=== Testing CNN Training Pipeline ===")
|
||||
|
||||
# Initialize CNN model and trainer
|
||||
model = CNNPivotPredictor(
|
||||
input_channels=10,
|
||||
sequence_length=300,
|
||||
hidden_dim=128, # Smaller for testing
|
||||
num_pivot_classes=3
|
||||
)
|
||||
|
||||
trainer = CNNTrainer(
|
||||
model=model,
|
||||
device='cpu', # Use CPU for testing
|
||||
learning_rate=0.001,
|
||||
storage_dir="test_cnn_training"
|
||||
)
|
||||
|
||||
# Create sample training episodes
|
||||
episodes = []
|
||||
for i in range(50): # Create 50 sample episodes
|
||||
# Create sample input package
|
||||
input_package = ModelInputPackage(
|
||||
timestamp=datetime.now() - timedelta(minutes=i),
|
||||
symbol='ETHUSDT',
|
||||
ohlcv_data=create_sample_ohlcv_data(),
|
||||
tick_data=create_sample_tick_data(),
|
||||
cob_data=create_sample_cob_data(),
|
||||
technical_indicators={'rsi': 50.0, 'macd': 0.0},
|
||||
pivot_points=[],
|
||||
cnn_features=np.random.randn(2000).astype(np.float32),
|
||||
rl_state=np.random.randn(2000).astype(np.float32),
|
||||
orchestrator_context={}
|
||||
)
|
||||
|
||||
# Create sample outcome
|
||||
outcome = TrainingOutcome(
|
||||
input_package_hash=input_package.data_hash,
|
||||
timestamp=input_package.timestamp,
|
||||
symbol='ETHUSDT',
|
||||
price_change_1m=np.random.normal(0, 0.01),
|
||||
price_change_5m=np.random.normal(0, 0.02),
|
||||
price_change_15m=np.random.normal(0, 0.03),
|
||||
price_change_1h=np.random.normal(0, 0.05),
|
||||
max_profit_potential=abs(np.random.normal(0, 0.02)),
|
||||
max_loss_potential=abs(np.random.normal(0, 0.015)),
|
||||
optimal_entry_price=3000.0,
|
||||
optimal_exit_price=3000.0 + np.random.normal(0, 10),
|
||||
optimal_holding_time=timedelta(minutes=np.random.randint(5, 60)),
|
||||
is_profitable=np.random.random() > 0.4, # 60% profitable
|
||||
profitability_score=np.random.uniform(0.3, 1.0),
|
||||
risk_reward_ratio=np.random.uniform(1.0, 3.0),
|
||||
is_rapid_change=np.random.random() > 0.8, # 20% rapid changes
|
||||
change_velocity=np.random.uniform(0.1, 2.0),
|
||||
volatility_spike=np.random.random() > 0.9,
|
||||
outcome_validated=True
|
||||
)
|
||||
|
||||
# Create training episode
|
||||
episode = TrainingEpisode(
|
||||
episode_id=f"test_episode_{i}",
|
||||
input_package=input_package,
|
||||
model_predictions={},
|
||||
actual_outcome=outcome,
|
||||
episode_type='normal'
|
||||
)
|
||||
|
||||
episodes.append(episode)
|
||||
|
||||
# Test training on episodes
|
||||
results = trainer._train_on_episodes(episodes, training_mode='test_batch')
|
||||
logger.info(f"Training results: {results}")
|
||||
|
||||
# Test profitable episode training
|
||||
profitable_results = trainer.train_on_profitable_episodes(
|
||||
symbol='ETHUSDT',
|
||||
min_profitability=0.7,
|
||||
max_episodes=20
|
||||
)
|
||||
logger.info(f"Profitable training results: {profitable_results}")
|
||||
|
||||
# Get training statistics
|
||||
stats = trainer.get_training_statistics()
|
||||
logger.info(f"Training statistics: {stats}")
|
||||
|
||||
return trainer
|
||||
|
||||
def test_integration():
|
||||
"""Test the complete integration"""
|
||||
logger.info("=== Testing Complete Integration ===")
|
||||
|
||||
try:
|
||||
# Test individual components
|
||||
detector = test_rapid_change_detector()
|
||||
collector = test_training_data_collector()
|
||||
trainer = test_cnn_training_pipeline()
|
||||
|
||||
logger.info("✅ All components tested successfully!")
|
||||
|
||||
# Test data flow
|
||||
logger.info("Testing data flow integration...")
|
||||
|
||||
# Simulate real-time data collection and training
|
||||
symbol = 'ETHUSDT'
|
||||
|
||||
# Collect multiple data points
|
||||
for i in range(10):
|
||||
ohlcv_data = create_sample_ohlcv_data()
|
||||
tick_data = create_sample_tick_data()
|
||||
cob_data = create_sample_cob_data()
|
||||
|
||||
episode_id = collector.collect_training_data(
|
||||
symbol=symbol,
|
||||
ohlcv_data=ohlcv_data,
|
||||
tick_data=tick_data,
|
||||
cob_data=cob_data,
|
||||
technical_indicators={'rsi': 50.0 + i},
|
||||
pivot_points=[],
|
||||
cnn_features=np.random.randn(2000).astype(np.float32),
|
||||
rl_state=np.random.randn(2000).astype(np.float32),
|
||||
orchestrator_context={}
|
||||
)
|
||||
|
||||
logger.info(f"Collected episode {i+1}: {episode_id}")
|
||||
time.sleep(0.1) # Small delay
|
||||
|
||||
# Get final statistics
|
||||
final_stats = collector.get_collection_statistics()
|
||||
logger.info(f"Final collection statistics: {final_stats}")
|
||||
|
||||
logger.info("✅ Integration test completed successfully!")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Integration test failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
logger.info("=" * 80)
|
||||
logger.info("COMPREHENSIVE TRAINING DATA COLLECTION SYSTEM TEST")
|
||||
logger.info("=" * 80)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Run integration test
|
||||
success = test_integration()
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
|
||||
logger.info("=" * 80)
|
||||
if success:
|
||||
logger.info("✅ ALL TESTS PASSED!")
|
||||
if decision:
|
||||
logger.info(f"Decision for {symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
|
||||
else:
|
||||
logger.info("❌ SOME TESTS FAILED!")
|
||||
logger.warning(f"No decision made for {symbol}.")
|
||||
|
||||
# Check if new checkpoints were created
|
||||
new_checkpoint_stats = checkpoint_manager.get_checkpoint_stats()
|
||||
new_checkpoints = new_checkpoint_stats['total_checkpoints'] - checkpoint_stats['total_checkpoints']
|
||||
|
||||
if new_checkpoints > 0:
|
||||
logger.info(f"\nSuccess! {new_checkpoints} new checkpoints were created.")
|
||||
logger.info("Training system is working correctly.")
|
||||
else:
|
||||
logger.warning("\nNo new checkpoints were created.")
|
||||
logger.warning("This could be normal if the training threshold wasn't met.")
|
||||
logger.warning("Check the orchestrator's checkpoint saving logic.")
|
||||
|
||||
# Check model states
|
||||
model_states = orchestrator.get_model_states()
|
||||
logger.info("\nModel states:")
|
||||
for model_name, state in model_states.items():
|
||||
checkpoint_loaded = state.get('checkpoint_loaded', False)
|
||||
checkpoint_filename = state.get('checkpoint_filename', 'none')
|
||||
current_loss = state.get('current_loss', None)
|
||||
|
||||
logger.info(f"Test duration: {duration:.2f} seconds")
|
||||
logger.info("=" * 80)
|
||||
status = "LOADED" if checkpoint_loaded else "FRESH"
|
||||
loss_str = f"{current_loss:.4f}" if current_loss is not None else "N/A"
|
||||
|
||||
# Display summary
|
||||
logger.info("\n📊 SYSTEM CAPABILITIES DEMONSTRATED:")
|
||||
logger.info("✓ Comprehensive training data collection with validation")
|
||||
logger.info("✓ Rapid price change detection for premium training examples")
|
||||
logger.info("✓ Data integrity validation and completeness checking")
|
||||
logger.info("✓ CNN training pipeline with backpropagation data storage")
|
||||
logger.info("✓ Profitable episode prioritization and replay")
|
||||
logger.info("✓ Training session value calculation and ranking")
|
||||
logger.info("✓ Real-time data integration capabilities")
|
||||
|
||||
logger.info("\n🎯 NEXT STEPS:")
|
||||
logger.info("1. Integrate with existing DataProvider for real market data")
|
||||
logger.info("2. Connect with actual CNN and RL models")
|
||||
logger.info("3. Implement outcome validation with real price data")
|
||||
logger.info("4. Add dashboard integration for monitoring")
|
||||
logger.info("5. Scale up for production deployment")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Test execution failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
logger.info(f" {model_name}: {status}, Loss: {loss_str}, Checkpoint: {checkpoint_filename}")
|
||||
|
||||
return new_checkpoints > 0
|
||||
|
||||
async def main():
|
||||
"""Main function"""
|
||||
logger.info("=" * 70)
|
||||
logger.info("TRAINING SYSTEM TEST")
|
||||
logger.info("=" * 70)
|
||||
|
||||
success = await test_training_system()
|
||||
|
||||
if success:
|
||||
logger.info("\nTraining system test passed!")
|
||||
return 0
|
||||
else:
|
||||
logger.warning("\nTraining system test completed with warnings.")
|
||||
logger.info("Check the logs for details.")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
sys.exit(asyncio.run(main()))
|
Reference in New Issue
Block a user