gogo2/train_config.py
Dobromir Popov c0872248ab misc
2025-05-13 17:19:52 +03:00

231 lines
7.2 KiB
Python

#!/usr/bin/env python
"""
Training Configuration for GOGO2 Trading System
This module provides a central configuration for all training scripts,
ensuring they use real market data and follow consistent practices.
Usage:
import train_config
config = train_config.get_config('supervised') # or 'reinforcement' or 'hybrid'
"""
import os
import logging
import json
from datetime import datetime
from pathlib import Path
# Ensure consistent logging across all training scripts
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
log_file = log_dir / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
logger = logging.getLogger('training')
# Define available training types
TRAINING_TYPES = {
'supervised': {
'description': 'Supervised learning using CNN model',
'script': 'train_with_realtime.py',
'model_class': 'CNNModelPyTorch',
'data_interface': 'MultiTimeframeDataInterface'
},
'reinforcement': {
'description': 'Reinforcement learning using DQN agent',
'script': 'train_rl_with_realtime.py',
'model_class': 'DQNAgent',
'data_interface': 'MultiTimeframeDataInterface'
},
'hybrid': {
'description': 'Combined supervised and reinforcement learning',
'script': 'train_hybrid.py', # To be implemented
'model_class': 'HybridModel', # To be implemented
'data_interface': 'MultiTimeframeDataInterface'
}
}
# Default configuration
DEFAULT_CONFIG = {
# Market data configuration
'market_data': {
'use_real_data_only': True, # IMPORTANT: Only use real market data, never synthetic
'symbol': 'BTC/USDT',
'timeframes': ['1m', '5m', '15m'],
'window_size': 24,
'data_refresh_interval': 300, # seconds
'use_indicators': True
},
# Training parameters
'training': {
'max_training_time': 12 * 3600, # seconds (12 hours)
'checkpoint_interval': 3600, # seconds (1 hour)
'batch_size': 64,
'learning_rate': 0.0001,
'optimizer': 'adam',
'loss_function': 'custom_pnl' # Focus on profitability
},
# Model paths
'paths': {
'models_dir': 'NN/models/saved',
'logs_dir': 'logs',
'tensorboard_dir': 'runs'
},
# GPU configuration
'hardware': {
'use_gpu': True,
'mixed_precision': True,
'device': 'cuda' if os.environ.get('CUDA_VISIBLE_DEVICES') is not None else 'cpu'
}
}
def get_config(training_type='supervised', custom_config=None):
"""
Get configuration for a specific training type
Args:
training_type (str): Type of training ('supervised', 'reinforcement', or 'hybrid')
custom_config (dict): Optional custom configuration to merge
Returns:
dict: Complete configuration
"""
if training_type not in TRAINING_TYPES:
raise ValueError(f"Invalid training type: {training_type}. Must be one of {list(TRAINING_TYPES.keys())}")
# Start with default configuration
config = DEFAULT_CONFIG.copy()
# Add training type-specific configuration
config['training_type'] = training_type
config['training_info'] = TRAINING_TYPES[training_type]
# Override with custom configuration if provided
if custom_config:
_deep_update(config, custom_config)
# Validate configuration
_validate_config(config)
return config
def save_config(config, filepath=None):
"""
Save configuration to a JSON file
Args:
config (dict): Configuration to save
filepath (str): Path to save to (default: based on training type and timestamp)
Returns:
str: Path where configuration was saved
"""
if filepath is None:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
training_type = config.get('training_type', 'unknown')
filepath = f"configs/training_{training_type}_{timestamp}.json"
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, 'w') as f:
json.dump(config, f, indent=2)
logger.info(f"Configuration saved to {filepath}")
return filepath
def load_config(filepath):
"""
Load configuration from a JSON file
Args:
filepath (str): Path to load from
Returns:
dict: Loaded configuration
"""
with open(filepath, 'r') as f:
config = json.load(f)
# Validate the loaded configuration
_validate_config(config)
logger.info(f"Configuration loaded from {filepath}")
return config
def _deep_update(target, source):
"""
Deep update a nested dictionary
Args:
target (dict): Target dictionary to update
source (dict): Source dictionary with updates
Returns:
dict: Updated target dictionary
"""
for key, value in source.items():
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
_deep_update(target[key], value)
else:
target[key] = value
return target
def _validate_config(config):
"""
Validate configuration to ensure it follows required guidelines
Args:
config (dict): Configuration to validate
Returns:
bool: True if valid, raises exception otherwise
"""
# Enforce real data policy
if config.get('use_real_data_only', True) is not True:
logger.error("POLICY VIOLATION: Real market data policy requires only using real data")
raise ValueError("Configuration violates policy: Must use only real market data, never synthetic")
# Add explicit check at the beginning of the validation function
if 'allow_synthetic_data' in config and config['allow_synthetic_data'] is True:
logger.error("POLICY VIOLATION: Synthetic data is not allowed under any circumstances")
raise ValueError("Configuration violates policy: Synthetic data is explicitly forbidden")
# Validate symbol
if not config['market_data']['symbol'] or '/' not in config['market_data']['symbol']:
raise ValueError(f"Invalid symbol format: {config['market_data']['symbol']}")
# Validate timeframes
if not config['market_data']['timeframes']:
raise ValueError("At least one timeframe must be specified")
# Ensure window size is reasonable
if config['market_data']['window_size'] < 10 or config['market_data']['window_size'] > 500:
raise ValueError(f"Window size out of reasonable range: {config['market_data']['window_size']}")
return True
if __name__ == "__main__":
# Show available training configurations
print("Available Training Configurations:")
print("=" * 40)
for training_type, info in TRAINING_TYPES.items():
print(f"{training_type.upper()}: {info['description']}")
# Example of getting and saving a configuration
config = get_config('supervised')
save_config(config)
print("\nDefault configuration generated and saved.")
print(f"Log file: {log_file}")