#!/usr/bin/env python """ Training Configuration for GOGO2 Trading System This module provides a central configuration for all training scripts, ensuring they use real market data and follow consistent practices. Usage: import train_config config = train_config.get_config('supervised') # or 'reinforcement' or 'hybrid' """ import os import logging import json from datetime import datetime from pathlib import Path # Ensure consistent logging across all training scripts log_dir = Path("logs") log_dir.mkdir(exist_ok=True) log_file = log_dir / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file), logging.StreamHandler() ] ) logger = logging.getLogger('training') # Define available training types TRAINING_TYPES = { 'supervised': { 'description': 'Supervised learning using CNN model', 'script': 'train_with_realtime.py', 'model_class': 'CNNModelPyTorch', 'data_interface': 'MultiTimeframeDataInterface' }, 'reinforcement': { 'description': 'Reinforcement learning using DQN agent', 'script': 'train_rl_with_realtime.py', 'model_class': 'DQNAgent', 'data_interface': 'MultiTimeframeDataInterface' }, 'hybrid': { 'description': 'Combined supervised and reinforcement learning', 'script': 'train_hybrid.py', # To be implemented 'model_class': 'HybridModel', # To be implemented 'data_interface': 'MultiTimeframeDataInterface' } } # Default configuration DEFAULT_CONFIG = { # Market data configuration 'market_data': { 'use_real_data_only': True, # IMPORTANT: Only use real market data, never synthetic 'symbol': 'BTC/USDT', 'timeframes': ['1m', '5m', '15m'], 'window_size': 24, 'data_refresh_interval': 300, # seconds 'use_indicators': True }, # Training parameters 'training': { 'max_training_time': 12 * 3600, # seconds (12 hours) 'checkpoint_interval': 3600, # seconds (1 hour) 'batch_size': 64, 'learning_rate': 0.0001, 'optimizer': 'adam', 'loss_function': 'custom_pnl' # Focus on profitability }, # Model paths 'paths': { 'models_dir': 'NN/models/saved', 'logs_dir': 'logs', 'tensorboard_dir': 'runs' }, # GPU configuration 'hardware': { 'use_gpu': True, 'mixed_precision': True, 'device': 'cuda' if os.environ.get('CUDA_VISIBLE_DEVICES') is not None else 'cpu' } } def get_config(training_type='supervised', custom_config=None): """ Get configuration for a specific training type Args: training_type (str): Type of training ('supervised', 'reinforcement', or 'hybrid') custom_config (dict): Optional custom configuration to merge Returns: dict: Complete configuration """ if training_type not in TRAINING_TYPES: raise ValueError(f"Invalid training type: {training_type}. Must be one of {list(TRAINING_TYPES.keys())}") # Start with default configuration config = DEFAULT_CONFIG.copy() # Add training type-specific configuration config['training_type'] = training_type config['training_info'] = TRAINING_TYPES[training_type] # Override with custom configuration if provided if custom_config: _deep_update(config, custom_config) # Validate configuration _validate_config(config) return config def save_config(config, filepath=None): """ Save configuration to a JSON file Args: config (dict): Configuration to save filepath (str): Path to save to (default: based on training type and timestamp) Returns: str: Path where configuration was saved """ if filepath is None: timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') training_type = config.get('training_type', 'unknown') filepath = f"configs/training_{training_type}_{timestamp}.json" os.makedirs(os.path.dirname(filepath), exist_ok=True) with open(filepath, 'w') as f: json.dump(config, f, indent=2) logger.info(f"Configuration saved to {filepath}") return filepath def load_config(filepath): """ Load configuration from a JSON file Args: filepath (str): Path to load from Returns: dict: Loaded configuration """ with open(filepath, 'r') as f: config = json.load(f) # Validate the loaded configuration _validate_config(config) logger.info(f"Configuration loaded from {filepath}") return config def _deep_update(target, source): """ Deep update a nested dictionary Args: target (dict): Target dictionary to update source (dict): Source dictionary with updates Returns: dict: Updated target dictionary """ for key, value in source.items(): if key in target and isinstance(target[key], dict) and isinstance(value, dict): _deep_update(target[key], value) else: target[key] = value return target def _validate_config(config): """ Validate configuration to ensure it follows required guidelines Args: config (dict): Configuration to validate Returns: bool: True if valid, raises exception otherwise """ # Enforce real data policy if config.get('use_real_data_only', True) is not True: logger.error("POLICY VIOLATION: Real market data policy requires only using real data") raise ValueError("Configuration violates policy: Must use only real market data, never synthetic") # Add explicit check at the beginning of the validation function if 'allow_synthetic_data' in config and config['allow_synthetic_data'] is True: logger.error("POLICY VIOLATION: Synthetic data is not allowed under any circumstances") raise ValueError("Configuration violates policy: Synthetic data is explicitly forbidden") # Validate symbol if not config['market_data']['symbol'] or '/' not in config['market_data']['symbol']: raise ValueError(f"Invalid symbol format: {config['market_data']['symbol']}") # Validate timeframes if not config['market_data']['timeframes']: raise ValueError("At least one timeframe must be specified") # Ensure window size is reasonable if config['market_data']['window_size'] < 10 or config['market_data']['window_size'] > 500: raise ValueError(f"Window size out of reasonable range: {config['market_data']['window_size']}") return True if __name__ == "__main__": # Show available training configurations print("Available Training Configurations:") print("=" * 40) for training_type, info in TRAINING_TYPES.items(): print(f"{training_type.upper()}: {info['description']}") # Example of getting and saving a configuration config = get_config('supervised') save_config(config) print("\nDefault configuration generated and saved.") print(f"Log file: {log_file}")