training
This commit is contained in:
84
debug_training_methods.py
Normal file
84
debug_training_methods.py
Normal file
@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Debug Training Methods
|
||||
|
||||
This script checks what training methods are available on each model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
async def debug_training_methods():
|
||||
"""Debug the available training methods on each model"""
|
||||
print("=== Debugging Training Methods ===")
|
||||
|
||||
# Initialize orchestrator
|
||||
print("1. Initializing orchestrator...")
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
||||
|
||||
# Wait for initialization
|
||||
await asyncio.sleep(2)
|
||||
|
||||
print("\n2. Checking available training methods on each model:")
|
||||
|
||||
for model_name, model_interface in orchestrator.model_registry.models.items():
|
||||
print(f"\n--- {model_name} ---")
|
||||
print(f"Interface type: {type(model_interface).__name__}")
|
||||
|
||||
# Get underlying model
|
||||
underlying_model = getattr(model_interface, 'model', None)
|
||||
if underlying_model:
|
||||
print(f"Underlying model type: {type(underlying_model).__name__}")
|
||||
else:
|
||||
print("No underlying model found")
|
||||
continue
|
||||
|
||||
# Check for training methods
|
||||
training_methods = []
|
||||
for method in ['train_on_outcome', 'add_experience', 'remember', 'replay', 'add_training_sample', 'train', 'train_with_reward', 'update_loss']:
|
||||
if hasattr(underlying_model, method):
|
||||
training_methods.append(method)
|
||||
|
||||
print(f"Available training methods: {training_methods}")
|
||||
|
||||
# Check for specific attributes
|
||||
attributes = []
|
||||
for attr in ['memory', 'batch_size', 'training_data']:
|
||||
if hasattr(underlying_model, attr):
|
||||
attr_value = getattr(underlying_model, attr)
|
||||
if attr == 'memory' and hasattr(attr_value, '__len__'):
|
||||
attributes.append(f"{attr}(len={len(attr_value)})")
|
||||
elif attr == 'training_data' and hasattr(attr_value, '__len__'):
|
||||
attributes.append(f"{attr}(len={len(attr_value)})")
|
||||
else:
|
||||
attributes.append(f"{attr}={attr_value}")
|
||||
|
||||
print(f"Relevant attributes: {attributes}")
|
||||
|
||||
# Check if it's an RL agent
|
||||
if hasattr(underlying_model, 'act') and hasattr(underlying_model, 'remember'):
|
||||
print("✅ Detected as RL Agent")
|
||||
elif hasattr(underlying_model, 'predict') and hasattr(underlying_model, 'add_training_sample'):
|
||||
print("✅ Detected as CNN Model")
|
||||
else:
|
||||
print("❓ Unknown model type")
|
||||
|
||||
print("\n3. Testing a simple training attempt:")
|
||||
|
||||
# Get a prediction first
|
||||
predictions = await orchestrator._get_all_predictions('ETH/USDT')
|
||||
print(f"Got {len(predictions)} predictions")
|
||||
|
||||
# Try to trigger training for each model
|
||||
for model_name in orchestrator.model_registry.models.keys():
|
||||
print(f"\nTesting training for {model_name}...")
|
||||
try:
|
||||
await orchestrator._trigger_immediate_training_for_model(model_name, 'ETH/USDT')
|
||||
print(f"✅ Training attempt completed for {model_name}")
|
||||
except Exception as e:
|
||||
print(f"❌ Training failed for {model_name}: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(debug_training_methods())
|
Reference in New Issue
Block a user