231 lines
7.2 KiB
Python
231 lines
7.2 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Training Configuration for GOGO2 Trading System
|
|
|
|
This module provides a central configuration for all training scripts,
|
|
ensuring they use real market data and follow consistent practices.
|
|
|
|
Usage:
|
|
import train_config
|
|
config = train_config.get_config('supervised') # or 'reinforcement' or 'hybrid'
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
import json
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
# Ensure consistent logging across all training scripts
|
|
log_dir = Path("logs")
|
|
log_dir.mkdir(exist_ok=True)
|
|
log_file = log_dir / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler(log_file),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger('training')
|
|
|
|
# Define available training types
|
|
TRAINING_TYPES = {
|
|
'supervised': {
|
|
'description': 'Supervised learning using CNN model',
|
|
'script': 'train_with_realtime.py',
|
|
'model_class': 'CNNModelPyTorch',
|
|
'data_interface': 'MultiTimeframeDataInterface'
|
|
},
|
|
'reinforcement': {
|
|
'description': 'Reinforcement learning using DQN agent',
|
|
'script': 'train_rl_with_realtime.py',
|
|
'model_class': 'DQNAgent',
|
|
'data_interface': 'MultiTimeframeDataInterface'
|
|
},
|
|
'hybrid': {
|
|
'description': 'Combined supervised and reinforcement learning',
|
|
'script': 'train_hybrid.py', # To be implemented
|
|
'model_class': 'HybridModel', # To be implemented
|
|
'data_interface': 'MultiTimeframeDataInterface'
|
|
}
|
|
}
|
|
|
|
# Default configuration
|
|
DEFAULT_CONFIG = {
|
|
# Market data configuration
|
|
'market_data': {
|
|
'use_real_data_only': True, # IMPORTANT: Only use real market data, never synthetic
|
|
'symbol': 'BTC/USDT',
|
|
'timeframes': ['1m', '5m', '15m'],
|
|
'window_size': 24,
|
|
'data_refresh_interval': 300, # seconds
|
|
'use_indicators': True
|
|
},
|
|
|
|
# Training parameters
|
|
'training': {
|
|
'max_training_time': 12 * 3600, # seconds (12 hours)
|
|
'checkpoint_interval': 3600, # seconds (1 hour)
|
|
'batch_size': 64,
|
|
'learning_rate': 0.0001,
|
|
'optimizer': 'adam',
|
|
'loss_function': 'custom_pnl' # Focus on profitability
|
|
},
|
|
|
|
# Model paths
|
|
'paths': {
|
|
'models_dir': 'NN/models/saved',
|
|
'logs_dir': 'logs',
|
|
'tensorboard_dir': 'runs'
|
|
},
|
|
|
|
# GPU configuration
|
|
'hardware': {
|
|
'use_gpu': True,
|
|
'mixed_precision': True,
|
|
'device': 'cuda' if os.environ.get('CUDA_VISIBLE_DEVICES') is not None else 'cpu'
|
|
}
|
|
}
|
|
|
|
def get_config(training_type='supervised', custom_config=None):
|
|
"""
|
|
Get configuration for a specific training type
|
|
|
|
Args:
|
|
training_type (str): Type of training ('supervised', 'reinforcement', or 'hybrid')
|
|
custom_config (dict): Optional custom configuration to merge
|
|
|
|
Returns:
|
|
dict: Complete configuration
|
|
"""
|
|
if training_type not in TRAINING_TYPES:
|
|
raise ValueError(f"Invalid training type: {training_type}. Must be one of {list(TRAINING_TYPES.keys())}")
|
|
|
|
# Start with default configuration
|
|
config = DEFAULT_CONFIG.copy()
|
|
|
|
# Add training type-specific configuration
|
|
config['training_type'] = training_type
|
|
config['training_info'] = TRAINING_TYPES[training_type]
|
|
|
|
# Override with custom configuration if provided
|
|
if custom_config:
|
|
_deep_update(config, custom_config)
|
|
|
|
# Validate configuration
|
|
_validate_config(config)
|
|
|
|
return config
|
|
|
|
def save_config(config, filepath=None):
|
|
"""
|
|
Save configuration to a JSON file
|
|
|
|
Args:
|
|
config (dict): Configuration to save
|
|
filepath (str): Path to save to (default: based on training type and timestamp)
|
|
|
|
Returns:
|
|
str: Path where configuration was saved
|
|
"""
|
|
if filepath is None:
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
training_type = config.get('training_type', 'unknown')
|
|
filepath = f"configs/training_{training_type}_{timestamp}.json"
|
|
|
|
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
|
|
|
with open(filepath, 'w') as f:
|
|
json.dump(config, f, indent=2)
|
|
|
|
logger.info(f"Configuration saved to {filepath}")
|
|
return filepath
|
|
|
|
def load_config(filepath):
|
|
"""
|
|
Load configuration from a JSON file
|
|
|
|
Args:
|
|
filepath (str): Path to load from
|
|
|
|
Returns:
|
|
dict: Loaded configuration
|
|
"""
|
|
with open(filepath, 'r') as f:
|
|
config = json.load(f)
|
|
|
|
# Validate the loaded configuration
|
|
_validate_config(config)
|
|
|
|
logger.info(f"Configuration loaded from {filepath}")
|
|
return config
|
|
|
|
def _deep_update(target, source):
|
|
"""
|
|
Deep update a nested dictionary
|
|
|
|
Args:
|
|
target (dict): Target dictionary to update
|
|
source (dict): Source dictionary with updates
|
|
|
|
Returns:
|
|
dict: Updated target dictionary
|
|
"""
|
|
for key, value in source.items():
|
|
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
|
|
_deep_update(target[key], value)
|
|
else:
|
|
target[key] = value
|
|
return target
|
|
|
|
def _validate_config(config):
|
|
"""
|
|
Validate configuration to ensure it follows required guidelines
|
|
|
|
Args:
|
|
config (dict): Configuration to validate
|
|
|
|
Returns:
|
|
bool: True if valid, raises exception otherwise
|
|
"""
|
|
# Enforce real data policy
|
|
if config.get('use_real_data_only', True) is not True:
|
|
logger.error("POLICY VIOLATION: Real market data policy requires only using real data")
|
|
raise ValueError("Configuration violates policy: Must use only real market data, never synthetic")
|
|
|
|
# Add explicit check at the beginning of the validation function
|
|
if 'allow_synthetic_data' in config and config['allow_synthetic_data'] is True:
|
|
logger.error("POLICY VIOLATION: Synthetic data is not allowed under any circumstances")
|
|
raise ValueError("Configuration violates policy: Synthetic data is explicitly forbidden")
|
|
|
|
# Validate symbol
|
|
if not config['market_data']['symbol'] or '/' not in config['market_data']['symbol']:
|
|
raise ValueError(f"Invalid symbol format: {config['market_data']['symbol']}")
|
|
|
|
# Validate timeframes
|
|
if not config['market_data']['timeframes']:
|
|
raise ValueError("At least one timeframe must be specified")
|
|
|
|
# Ensure window size is reasonable
|
|
if config['market_data']['window_size'] < 10 or config['market_data']['window_size'] > 500:
|
|
raise ValueError(f"Window size out of reasonable range: {config['market_data']['window_size']}")
|
|
|
|
return True
|
|
|
|
if __name__ == "__main__":
|
|
# Show available training configurations
|
|
print("Available Training Configurations:")
|
|
print("=" * 40)
|
|
for training_type, info in TRAINING_TYPES.items():
|
|
print(f"{training_type.upper()}: {info['description']}")
|
|
|
|
# Example of getting and saving a configuration
|
|
config = get_config('supervised')
|
|
save_config(config)
|
|
|
|
print("\nDefault configuration generated and saved.")
|
|
print(f"Log file: {log_file}") |