checkpoint manager

This commit is contained in:
Dobromir Popov
2025-07-23 21:40:04 +03:00
parent bab39fa68f
commit 45a62443a0
9 changed files with 1587 additions and 709 deletions

View File

@ -0,0 +1,155 @@
"""
Test Continuous CNN Training
This script demonstrates how the CNN model can be trained with each new inference result
using collected data, implementing a continuous learning loop.
"""
import logging
import time
from datetime import datetime
import random
import os
from core.standardized_data_provider import StandardizedDataProvider
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
from core.data_models import create_model_output
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def simulate_market_feedback(action, symbol):
"""
Simulate market feedback for a given action
In a real system, this would be replaced with actual market performance data
Args:
action: Trading action ('BUY', 'SELL', 'HOLD')
symbol: Trading symbol
Returns:
tuple: (actual_action, reward)
"""
# Simulate market movement (random for demonstration)
market_direction = random.choice(['up', 'down', 'sideways'])
# Determine actual best action based on market direction
if market_direction == 'up':
best_action = 'BUY'
elif market_direction == 'down':
best_action = 'SELL'
else:
best_action = 'HOLD'
# Calculate reward based on whether the action matched the best action
if action == best_action:
reward = random.uniform(0.01, 0.1) # Positive reward for correct action
else:
reward = random.uniform(-0.1, -0.01) # Negative reward for incorrect action
logger.info(f"Market went {market_direction}, best action was {best_action}, model chose {action}, reward: {reward:.4f}")
return best_action, reward
def test_continuous_training():
"""Test continuous training of the CNN model with new inference results"""
try:
# Initialize data provider
symbols = ['ETH/USDT', 'BTC/USDT']
timeframes = ['1s', '1m', '1h', '1d']
data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
# Initialize CNN adapter
checkpoint_dir = "models/enhanced_cnn"
os.makedirs(checkpoint_dir, exist_ok=True)
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=checkpoint_dir)
# Load best checkpoint if available
cnn_adapter.load_best_checkpoint()
# Continuous learning loop
num_iterations = 10
training_frequency = 3 # Train every N iterations
samples_collected = 0
logger.info(f"Starting continuous learning loop with {num_iterations} iterations")
for i in range(num_iterations):
logger.info(f"\nIteration {i+1}/{num_iterations}")
# Get standardized input data
symbol = random.choice(symbols)
logger.info(f"Getting data for {symbol}...")
base_data = data_provider.get_base_data_input(symbol)
if base_data is None:
logger.warning(f"Failed to get base data input for {symbol}, skipping iteration")
continue
# Make prediction
logger.info(f"Making prediction for {symbol}...")
model_output = cnn_adapter.predict(base_data)
# Log prediction
action = model_output.predictions['action']
confidence = model_output.confidence
logger.info(f"Prediction: {action} with confidence {confidence:.4f}")
# Store model output
data_provider.store_model_output(model_output)
# Simulate market feedback
best_action, reward = simulate_market_feedback(action, symbol)
# Add training sample
logger.info(f"Adding training sample: action={best_action}, reward={reward:.4f}")
cnn_adapter.add_training_sample(base_data, best_action, reward)
samples_collected += 1
# Train model periodically
if (i + 1) % training_frequency == 0 and samples_collected >= 3:
logger.info(f"Training model with {samples_collected} samples...")
metrics = cnn_adapter.train(epochs=1)
# Log training metrics
logger.info(f"Training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}")
# Simulate time passing
time.sleep(1)
logger.info("\nContinuous learning loop completed")
# Final evaluation
logger.info("Performing final evaluation...")
# Get data for evaluation
symbol = 'ETH/USDT'
base_data = data_provider.get_base_data_input(symbol)
if base_data is not None:
# Make prediction
model_output = cnn_adapter.predict(base_data)
# Log prediction
action = model_output.predictions['action']
confidence = model_output.confidence
logger.info(f"Final prediction for {symbol}: {action} with confidence {confidence:.4f}")
# Get model output manager
output_manager = data_provider.get_model_output_manager()
# Evaluate model performance
metrics = output_manager.evaluate_model_performance(symbol, cnn_adapter.model_name)
logger.info(f"Performance metrics: {metrics}")
else:
logger.warning(f"Failed to get base data input for final evaluation")
logger.info("Test completed successfully")
except Exception as e:
logger.error(f"Error in test: {e}", exc_info=True)
if __name__ == "__main__":
test_continuous_training()