references

This commit is contained in:
Dobromir Popov
2025-07-22 16:53:36 +03:00
parent 1d224e5b8c
commit 9b72b18eb7
6 changed files with 162 additions and 8 deletions

View File

@ -2503,3 +2503,157 @@ class TradingOrchestrator:
except Exception as e: except Exception as e:
logger.error(f"Error checking signal confirmation for {symbol}: {e}") logger.error(f"Error checking signal confirmation for {symbol}: {e}")
return None return None
def _initialize_checkpoint_manager(self):
"""Initialize the checkpoint manager for model persistence"""
try:
from utils.checkpoint_manager import get_checkpoint_manager
self.checkpoint_manager = get_checkpoint_manager()
# Initialize model states dictionary to track performance
self.model_states = {
'dqn': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False},
'cnn': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False},
'cob_rl': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False},
'extrema': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False}
}
logger.info("Checkpoint manager initialized for model persistence")
except Exception as e:
logger.error(f"Error initializing checkpoint manager: {e}")
self.checkpoint_manager = None de
f _save_training_checkpoints(self, models_trained: List[str], performance_score: float):
"""Save checkpoints for trained models if performance improved
This is CRITICAL for preserving training progress across restarts.
"""
try:
if not self.checkpoint_manager:
return
# Increment training counter
self.training_iterations += 1
# Save checkpoints for each trained model
for model_name in models_trained:
try:
model_obj = None
current_loss = None
model_type = model_name
# Get model object and calculate current performance
if model_name == 'dqn' and self.rl_agent:
model_obj = self.rl_agent
# Use current loss from model state or estimate from performance
current_loss = self.model_states['dqn'].get('current_loss')
if current_loss is None:
# Estimate loss from performance score (inverse relationship)
current_loss = max(0.001, 1.0 - performance_score)
# Update model state tracking
self.model_states['dqn']['current_loss'] = current_loss
# If this is the first loss value, set it as initial and best
if self.model_states['dqn']['initial_loss'] is None:
self.model_states['dqn']['initial_loss'] = current_loss
if self.model_states['dqn']['best_loss'] is None or current_loss < self.model_states['dqn']['best_loss']:
self.model_states['dqn']['best_loss'] = current_loss
elif model_name == 'cnn' and self.cnn_model:
model_obj = self.cnn_model
# Use current loss from model state or estimate from performance
current_loss = self.model_states['cnn'].get('current_loss')
if current_loss is None:
# Estimate loss from performance score (inverse relationship)
current_loss = max(0.001, 1.0 - performance_score)
# Update model state tracking
self.model_states['cnn']['current_loss'] = current_loss
# If this is the first loss value, set it as initial and best
if self.model_states['cnn']['initial_loss'] is None:
self.model_states['cnn']['initial_loss'] = current_loss
if self.model_states['cnn']['best_loss'] is None or current_loss < self.model_states['cnn']['best_loss']:
self.model_states['cnn']['best_loss'] = current_loss
elif model_name == 'cob_rl' and self.cob_rl_agent:
model_obj = self.cob_rl_agent
# Use current loss from model state or estimate from performance
current_loss = self.model_states['cob_rl'].get('current_loss')
if current_loss is None:
# Estimate loss from performance score (inverse relationship)
current_loss = max(0.001, 1.0 - performance_score)
# Update model state tracking
self.model_states['cob_rl']['current_loss'] = current_loss
# If this is the first loss value, set it as initial and best
if self.model_states['cob_rl']['initial_loss'] is None:
self.model_states['cob_rl']['initial_loss'] = current_loss
if self.model_states['cob_rl']['best_loss'] is None or current_loss < self.model_states['cob_rl']['best_loss']:
self.model_states['cob_rl']['best_loss'] = current_loss
elif model_name == 'extrema' and hasattr(self, 'extrema_trainer') and self.extrema_trainer:
model_obj = self.extrema_trainer
# Use current loss from model state or estimate from performance
current_loss = self.model_states['extrema'].get('current_loss')
if current_loss is None:
# Estimate loss from performance score (inverse relationship)
current_loss = max(0.001, 1.0 - performance_score)
# Update model state tracking
self.model_states['extrema']['current_loss'] = current_loss
# If this is the first loss value, set it as initial and best
if self.model_states['extrema']['initial_loss'] is None:
self.model_states['extrema']['initial_loss'] = current_loss
if self.model_states['extrema']['best_loss'] is None or current_loss < self.model_states['extrema']['best_loss']:
self.model_states['extrema']['best_loss'] = current_loss
# Skip if we couldn't get a model object
if model_obj is None:
continue
# Prepare performance metrics for checkpoint
performance_metrics = {
'loss': current_loss,
'accuracy': performance_score, # Use confidence as a proxy for accuracy
}
# Prepare training metadata
training_metadata = {
'training_iteration': self.training_iterations,
'timestamp': datetime.now().isoformat()
}
# Save checkpoint using checkpoint manager
from utils.checkpoint_manager import save_checkpoint
checkpoint_metadata = save_checkpoint(
model=model_obj,
model_name=model_name,
model_type=model_type,
performance_metrics=performance_metrics,
training_metadata=training_metadata
)
if checkpoint_metadata:
logger.info(f"Saved checkpoint for {model_name}: {checkpoint_metadata.checkpoint_id} (loss={current_loss:.4f})")
# Also save periodically based on training iterations
if self.training_iterations % 100 == 0:
# Force save every 100 training iterations regardless of performance
checkpoint_metadata = save_checkpoint(
model=model_obj,
model_name=model_name,
model_type=model_type,
performance_metrics=performance_metrics,
training_metadata=training_metadata,
force_save=True
)
if checkpoint_metadata:
logger.info(f"Periodic checkpoint saved for {model_name}: {checkpoint_metadata.checkpoint_id}")
except Exception as e:
logger.error(f"Error saving checkpoint for {model_name}: {e}")
except Exception as e:
logger.error(f"Error in _save_training_checkpoints: {e}")

View File

@ -25,7 +25,7 @@ except ImportError:
key, value = line.strip().split('=', 1) key, value = line.strip().split('=', 1)
os.environ[key] = value os.environ[key] = value
from NN.exchanges.bybit_interface import BybitInterface from core.exchanges.bybit_interface import BybitInterface
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
@ -104,7 +104,7 @@ class BybitEthFuturesTest:
# First test simple connectivity without auth # First test simple connectivity without auth
print("Testing basic API connectivity...") print("Testing basic API connectivity...")
try: try:
from NN.exchanges.bybit_rest_client import BybitRestClient from core.exchanges.bybit_rest_client import BybitRestClient
client = BybitRestClient( client = BybitRestClient(
api_key="dummy", api_key="dummy",
api_secret="dummy", api_secret="dummy",

View File

@ -24,7 +24,7 @@ except ImportError:
key, value = line.strip().split('=', 1) key, value = line.strip().split('=', 1)
os.environ[key] = value os.environ[key] = value
from NN.exchanges.bybit_interface import BybitInterface from core.exchanges.bybit_interface import BybitInterface
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(

View File

@ -23,7 +23,7 @@ except ImportError:
key, value = line.strip().split('=', 1) key, value = line.strip().split('=', 1)
os.environ[key] = value os.environ[key] = value
from NN.exchanges.bybit_interface import BybitInterface from core.exchanges.bybit_interface import BybitInterface
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(

View File

@ -11,7 +11,7 @@ import logging
# Add the project root to the path # Add the project root to the path
sys.path.append(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from NN.exchanges.bybit_rest_client import BybitRestClient from core.exchanges.bybit_rest_client import BybitRestClient
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(

View File

@ -15,8 +15,8 @@ load_dotenv()
sys.path.append(os.path.join(os.path.dirname(__file__), 'NN')) sys.path.append(os.path.join(os.path.dirname(__file__), 'NN'))
sys.path.append(os.path.join(os.path.dirname(__file__), 'core')) sys.path.append(os.path.join(os.path.dirname(__file__), 'core'))
from NN.exchanges.exchange_factory import ExchangeFactory from core.exchanges.exchange_factory import ExchangeFactory
from NN.exchanges.deribit_interface import DeribitInterface from core.exchanges.deribit_interface import DeribitInterface
from core.config import get_config from core.config import get_config
# Setup logging # Setup logging