84 lines
3.3 KiB
Python
84 lines
3.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Debug Training Methods
|
|
|
|
This script checks what training methods are available on each model.
|
|
"""
|
|
|
|
import asyncio
|
|
from core.orchestrator import TradingOrchestrator
|
|
from core.data_provider import DataProvider
|
|
|
|
async def debug_training_methods():
|
|
"""Debug the available training methods on each model"""
|
|
print("=== Debugging Training Methods ===")
|
|
|
|
# Initialize orchestrator
|
|
print("1. Initializing orchestrator...")
|
|
data_provider = DataProvider()
|
|
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
|
|
|
# Wait for initialization
|
|
await asyncio.sleep(2)
|
|
|
|
print("\n2. Checking available training methods on each model:")
|
|
|
|
for model_name, model_interface in orchestrator.model_registry.models.items():
|
|
print(f"\n--- {model_name} ---")
|
|
print(f"Interface type: {type(model_interface).__name__}")
|
|
|
|
# Get underlying model
|
|
underlying_model = getattr(model_interface, 'model', None)
|
|
if underlying_model:
|
|
print(f"Underlying model type: {type(underlying_model).__name__}")
|
|
else:
|
|
print("No underlying model found")
|
|
continue
|
|
|
|
# Check for training methods
|
|
training_methods = []
|
|
for method in ['train_on_outcome', 'add_experience', 'remember', 'replay', 'add_training_sample', 'train', 'train_with_reward', 'update_loss']:
|
|
if hasattr(underlying_model, method):
|
|
training_methods.append(method)
|
|
|
|
print(f"Available training methods: {training_methods}")
|
|
|
|
# Check for specific attributes
|
|
attributes = []
|
|
for attr in ['memory', 'batch_size', 'training_data']:
|
|
if hasattr(underlying_model, attr):
|
|
attr_value = getattr(underlying_model, attr)
|
|
if attr == 'memory' and hasattr(attr_value, '__len__'):
|
|
attributes.append(f"{attr}(len={len(attr_value)})")
|
|
elif attr == 'training_data' and hasattr(attr_value, '__len__'):
|
|
attributes.append(f"{attr}(len={len(attr_value)})")
|
|
else:
|
|
attributes.append(f"{attr}={attr_value}")
|
|
|
|
print(f"Relevant attributes: {attributes}")
|
|
|
|
# Check if it's an RL agent
|
|
if hasattr(underlying_model, 'act') and hasattr(underlying_model, 'remember'):
|
|
print("✅ Detected as RL Agent")
|
|
elif hasattr(underlying_model, 'predict') and hasattr(underlying_model, 'add_training_sample'):
|
|
print("✅ Detected as CNN Model")
|
|
else:
|
|
print("❓ Unknown model type")
|
|
|
|
print("\n3. Testing a simple training attempt:")
|
|
|
|
# Get a prediction first
|
|
predictions = await orchestrator._get_all_predictions('ETH/USDT')
|
|
print(f"Got {len(predictions)} predictions")
|
|
|
|
# Try to trigger training for each model
|
|
for model_name in orchestrator.model_registry.models.keys():
|
|
print(f"\nTesting training for {model_name}...")
|
|
try:
|
|
await orchestrator._trigger_immediate_training_for_model(model_name, 'ETH/USDT')
|
|
print(f"✅ Training attempt completed for {model_name}")
|
|
except Exception as e:
|
|
print(f"❌ Training failed for {model_name}: {e}")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(debug_training_methods()) |