checkpoint manager
This commit is contained in:
155
test_continuous_cnn_training.py
Normal file
155
test_continuous_cnn_training.py
Normal file
@ -0,0 +1,155 @@
|
||||
"""
|
||||
Test Continuous CNN Training
|
||||
|
||||
This script demonstrates how the CNN model can be trained with each new inference result
|
||||
using collected data, implementing a continuous learning loop.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
import random
|
||||
import os
|
||||
|
||||
from core.standardized_data_provider import StandardizedDataProvider
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
from core.data_models import create_model_output
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def simulate_market_feedback(action, symbol):
|
||||
"""
|
||||
Simulate market feedback for a given action
|
||||
|
||||
In a real system, this would be replaced with actual market performance data
|
||||
|
||||
Args:
|
||||
action: Trading action ('BUY', 'SELL', 'HOLD')
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
tuple: (actual_action, reward)
|
||||
"""
|
||||
# Simulate market movement (random for demonstration)
|
||||
market_direction = random.choice(['up', 'down', 'sideways'])
|
||||
|
||||
# Determine actual best action based on market direction
|
||||
if market_direction == 'up':
|
||||
best_action = 'BUY'
|
||||
elif market_direction == 'down':
|
||||
best_action = 'SELL'
|
||||
else:
|
||||
best_action = 'HOLD'
|
||||
|
||||
# Calculate reward based on whether the action matched the best action
|
||||
if action == best_action:
|
||||
reward = random.uniform(0.01, 0.1) # Positive reward for correct action
|
||||
else:
|
||||
reward = random.uniform(-0.1, -0.01) # Negative reward for incorrect action
|
||||
|
||||
logger.info(f"Market went {market_direction}, best action was {best_action}, model chose {action}, reward: {reward:.4f}")
|
||||
|
||||
return best_action, reward
|
||||
|
||||
def test_continuous_training():
|
||||
"""Test continuous training of the CNN model with new inference results"""
|
||||
try:
|
||||
# Initialize data provider
|
||||
symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
|
||||
|
||||
# Initialize CNN adapter
|
||||
checkpoint_dir = "models/enhanced_cnn"
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=checkpoint_dir)
|
||||
|
||||
# Load best checkpoint if available
|
||||
cnn_adapter.load_best_checkpoint()
|
||||
|
||||
# Continuous learning loop
|
||||
num_iterations = 10
|
||||
training_frequency = 3 # Train every N iterations
|
||||
samples_collected = 0
|
||||
|
||||
logger.info(f"Starting continuous learning loop with {num_iterations} iterations")
|
||||
|
||||
for i in range(num_iterations):
|
||||
logger.info(f"\nIteration {i+1}/{num_iterations}")
|
||||
|
||||
# Get standardized input data
|
||||
symbol = random.choice(symbols)
|
||||
logger.info(f"Getting data for {symbol}...")
|
||||
base_data = data_provider.get_base_data_input(symbol)
|
||||
|
||||
if base_data is None:
|
||||
logger.warning(f"Failed to get base data input for {symbol}, skipping iteration")
|
||||
continue
|
||||
|
||||
# Make prediction
|
||||
logger.info(f"Making prediction for {symbol}...")
|
||||
model_output = cnn_adapter.predict(base_data)
|
||||
|
||||
# Log prediction
|
||||
action = model_output.predictions['action']
|
||||
confidence = model_output.confidence
|
||||
logger.info(f"Prediction: {action} with confidence {confidence:.4f}")
|
||||
|
||||
# Store model output
|
||||
data_provider.store_model_output(model_output)
|
||||
|
||||
# Simulate market feedback
|
||||
best_action, reward = simulate_market_feedback(action, symbol)
|
||||
|
||||
# Add training sample
|
||||
logger.info(f"Adding training sample: action={best_action}, reward={reward:.4f}")
|
||||
cnn_adapter.add_training_sample(base_data, best_action, reward)
|
||||
samples_collected += 1
|
||||
|
||||
# Train model periodically
|
||||
if (i + 1) % training_frequency == 0 and samples_collected >= 3:
|
||||
logger.info(f"Training model with {samples_collected} samples...")
|
||||
metrics = cnn_adapter.train(epochs=1)
|
||||
|
||||
# Log training metrics
|
||||
logger.info(f"Training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}")
|
||||
|
||||
# Simulate time passing
|
||||
time.sleep(1)
|
||||
|
||||
logger.info("\nContinuous learning loop completed")
|
||||
|
||||
# Final evaluation
|
||||
logger.info("Performing final evaluation...")
|
||||
|
||||
# Get data for evaluation
|
||||
symbol = 'ETH/USDT'
|
||||
base_data = data_provider.get_base_data_input(symbol)
|
||||
|
||||
if base_data is not None:
|
||||
# Make prediction
|
||||
model_output = cnn_adapter.predict(base_data)
|
||||
|
||||
# Log prediction
|
||||
action = model_output.predictions['action']
|
||||
confidence = model_output.confidence
|
||||
logger.info(f"Final prediction for {symbol}: {action} with confidence {confidence:.4f}")
|
||||
|
||||
# Get model output manager
|
||||
output_manager = data_provider.get_model_output_manager()
|
||||
|
||||
# Evaluate model performance
|
||||
metrics = output_manager.evaluate_model_performance(symbol, cnn_adapter.model_name)
|
||||
logger.info(f"Performance metrics: {metrics}")
|
||||
else:
|
||||
logger.warning(f"Failed to get base data input for final evaluation")
|
||||
|
||||
logger.info("Test completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in test: {e}", exc_info=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_continuous_training()
|
Reference in New Issue
Block a user