""" Test Continuous CNN Training This script demonstrates how the CNN model can be trained with each new inference result using collected data, implementing a continuous learning loop. """ import logging import time from datetime import datetime import random import os from core.standardized_data_provider import StandardizedDataProvider from core.enhanced_cnn_adapter import EnhancedCNNAdapter from core.data_models import create_model_output # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def simulate_market_feedback(action, symbol): """ Simulate market feedback for a given action In a real system, this would be replaced with actual market performance data Args: action: Trading action ('BUY', 'SELL', 'HOLD') symbol: Trading symbol Returns: tuple: (actual_action, reward) """ # Simulate market movement (random for demonstration) market_direction = random.choice(['up', 'down', 'sideways']) # Determine actual best action based on market direction if market_direction == 'up': best_action = 'BUY' elif market_direction == 'down': best_action = 'SELL' else: best_action = 'HOLD' # Calculate reward based on whether the action matched the best action if action == best_action: reward = random.uniform(0.01, 0.1) # Positive reward for correct action else: reward = random.uniform(-0.1, -0.01) # Negative reward for incorrect action logger.info(f"Market went {market_direction}, best action was {best_action}, model chose {action}, reward: {reward:.4f}") return best_action, reward def test_continuous_training(): """Test continuous training of the CNN model with new inference results""" try: # Initialize data provider symbols = ['ETH/USDT', 'BTC/USDT'] timeframes = ['1s', '1m', '1h', '1d'] data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes) # Initialize CNN adapter checkpoint_dir = "models/enhanced_cnn" os.makedirs(checkpoint_dir, exist_ok=True) cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=checkpoint_dir) # Load best checkpoint if available cnn_adapter.load_best_checkpoint() # Continuous learning loop num_iterations = 10 training_frequency = 3 # Train every N iterations samples_collected = 0 logger.info(f"Starting continuous learning loop with {num_iterations} iterations") for i in range(num_iterations): logger.info(f"\nIteration {i+1}/{num_iterations}") # Get standardized input data symbol = random.choice(symbols) logger.info(f"Getting data for {symbol}...") base_data = data_provider.get_base_data_input(symbol) if base_data is None: logger.warning(f"Failed to get base data input for {symbol}, skipping iteration") continue # Make prediction logger.info(f"Making prediction for {symbol}...") model_output = cnn_adapter.predict(base_data) # Log prediction action = model_output.predictions['action'] confidence = model_output.confidence logger.info(f"Prediction: {action} with confidence {confidence:.4f}") # Store model output data_provider.store_model_output(model_output) # Simulate market feedback best_action, reward = simulate_market_feedback(action, symbol) # Add training sample logger.info(f"Adding training sample: action={best_action}, reward={reward:.4f}") cnn_adapter.add_training_sample(base_data, best_action, reward) samples_collected += 1 # Train model periodically if (i + 1) % training_frequency == 0 and samples_collected >= 3: logger.info(f"Training model with {samples_collected} samples...") metrics = cnn_adapter.train(epochs=1) # Log training metrics logger.info(f"Training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}") # Simulate time passing time.sleep(1) logger.info("\nContinuous learning loop completed") # Final evaluation logger.info("Performing final evaluation...") # Get data for evaluation symbol = 'ETH/USDT' base_data = data_provider.get_base_data_input(symbol) if base_data is not None: # Make prediction model_output = cnn_adapter.predict(base_data) # Log prediction action = model_output.predictions['action'] confidence = model_output.confidence logger.info(f"Final prediction for {symbol}: {action} with confidence {confidence:.4f}") # Get model output manager output_manager = data_provider.get_model_output_manager() # Evaluate model performance metrics = output_manager.evaluate_model_performance(symbol, cnn_adapter.model_name) logger.info(f"Performance metrics: {metrics}") else: logger.warning(f"Failed to get base data input for final evaluation") logger.info("Test completed successfully") except Exception as e: logger.error(f"Error in test: {e}", exc_info=True) if __name__ == "__main__": test_continuous_training()