155 lines
5.8 KiB
Python
155 lines
5.8 KiB
Python
"""
|
|
Test Continuous CNN Training
|
|
|
|
This script demonstrates how the CNN model can be trained with each new inference result
|
|
using collected data, implementing a continuous learning loop.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from datetime import datetime
|
|
import random
|
|
import os
|
|
|
|
from core.standardized_data_provider import StandardizedDataProvider
|
|
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
|
from core.data_models import create_model_output
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def simulate_market_feedback(action, symbol):
|
|
"""
|
|
Simulate market feedback for a given action
|
|
|
|
In a real system, this would be replaced with actual market performance data
|
|
|
|
Args:
|
|
action: Trading action ('BUY', 'SELL', 'HOLD')
|
|
symbol: Trading symbol
|
|
|
|
Returns:
|
|
tuple: (actual_action, reward)
|
|
"""
|
|
# Simulate market movement (random for demonstration)
|
|
market_direction = random.choice(['up', 'down', 'sideways'])
|
|
|
|
# Determine actual best action based on market direction
|
|
if market_direction == 'up':
|
|
best_action = 'BUY'
|
|
elif market_direction == 'down':
|
|
best_action = 'SELL'
|
|
else:
|
|
best_action = 'HOLD'
|
|
|
|
# Calculate reward based on whether the action matched the best action
|
|
if action == best_action:
|
|
reward = random.uniform(0.01, 0.1) # Positive reward for correct action
|
|
else:
|
|
reward = random.uniform(-0.1, -0.01) # Negative reward for incorrect action
|
|
|
|
logger.info(f"Market went {market_direction}, best action was {best_action}, model chose {action}, reward: {reward:.4f}")
|
|
|
|
return best_action, reward
|
|
|
|
def test_continuous_training():
|
|
"""Test continuous training of the CNN model with new inference results"""
|
|
try:
|
|
# Initialize data provider
|
|
symbols = ['ETH/USDT', 'BTC/USDT']
|
|
timeframes = ['1s', '1m', '1h', '1d']
|
|
data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
|
|
|
|
# Initialize CNN adapter
|
|
checkpoint_dir = "models/enhanced_cnn"
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=checkpoint_dir)
|
|
|
|
# Load best checkpoint if available
|
|
cnn_adapter.load_best_checkpoint()
|
|
|
|
# Continuous learning loop
|
|
num_iterations = 10
|
|
training_frequency = 3 # Train every N iterations
|
|
samples_collected = 0
|
|
|
|
logger.info(f"Starting continuous learning loop with {num_iterations} iterations")
|
|
|
|
for i in range(num_iterations):
|
|
logger.info(f"\nIteration {i+1}/{num_iterations}")
|
|
|
|
# Get standardized input data
|
|
symbol = random.choice(symbols)
|
|
logger.info(f"Getting data for {symbol}...")
|
|
base_data = data_provider.get_base_data_input(symbol)
|
|
|
|
if base_data is None:
|
|
logger.warning(f"Failed to get base data input for {symbol}, skipping iteration")
|
|
continue
|
|
|
|
# Make prediction
|
|
logger.info(f"Making prediction for {symbol}...")
|
|
model_output = cnn_adapter.predict(base_data)
|
|
|
|
# Log prediction
|
|
action = model_output.predictions['action']
|
|
confidence = model_output.confidence
|
|
logger.info(f"Prediction: {action} with confidence {confidence:.4f}")
|
|
|
|
# Store model output
|
|
data_provider.store_model_output(model_output)
|
|
|
|
# Simulate market feedback
|
|
best_action, reward = simulate_market_feedback(action, symbol)
|
|
|
|
# Add training sample
|
|
logger.info(f"Adding training sample: action={best_action}, reward={reward:.4f}")
|
|
cnn_adapter.add_training_sample(base_data, best_action, reward)
|
|
samples_collected += 1
|
|
|
|
# Train model periodically
|
|
if (i + 1) % training_frequency == 0 and samples_collected >= 3:
|
|
logger.info(f"Training model with {samples_collected} samples...")
|
|
metrics = cnn_adapter.train(epochs=1)
|
|
|
|
# Log training metrics
|
|
logger.info(f"Training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}")
|
|
|
|
# Simulate time passing
|
|
time.sleep(1)
|
|
|
|
logger.info("\nContinuous learning loop completed")
|
|
|
|
# Final evaluation
|
|
logger.info("Performing final evaluation...")
|
|
|
|
# Get data for evaluation
|
|
symbol = 'ETH/USDT'
|
|
base_data = data_provider.get_base_data_input(symbol)
|
|
|
|
if base_data is not None:
|
|
# Make prediction
|
|
model_output = cnn_adapter.predict(base_data)
|
|
|
|
# Log prediction
|
|
action = model_output.predictions['action']
|
|
confidence = model_output.confidence
|
|
logger.info(f"Final prediction for {symbol}: {action} with confidence {confidence:.4f}")
|
|
|
|
# Get model output manager
|
|
output_manager = data_provider.get_model_output_manager()
|
|
|
|
# Evaluate model performance
|
|
metrics = output_manager.evaluate_model_performance(symbol, cnn_adapter.model_name)
|
|
logger.info(f"Performance metrics: {metrics}")
|
|
else:
|
|
logger.warning(f"Failed to get base data input for final evaluation")
|
|
|
|
logger.info("Test completed successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in test: {e}", exc_info=True)
|
|
|
|
if __name__ == "__main__":
|
|
test_continuous_training() |