references
This commit is contained in:
@ -2503,3 +2503,157 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking signal confirmation for {symbol}: {e}")
|
||||
return None
|
||||
def _initialize_checkpoint_manager(self):
|
||||
"""Initialize the checkpoint manager for model persistence"""
|
||||
try:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
|
||||
# Initialize model states dictionary to track performance
|
||||
self.model_states = {
|
||||
'dqn': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False},
|
||||
'cnn': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False},
|
||||
'cob_rl': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False},
|
||||
'extrema': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False}
|
||||
}
|
||||
|
||||
logger.info("Checkpoint manager initialized for model persistence")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing checkpoint manager: {e}")
|
||||
self.checkpoint_manager = None de
|
||||
f _save_training_checkpoints(self, models_trained: List[str], performance_score: float):
|
||||
"""Save checkpoints for trained models if performance improved
|
||||
|
||||
This is CRITICAL for preserving training progress across restarts.
|
||||
"""
|
||||
try:
|
||||
if not self.checkpoint_manager:
|
||||
return
|
||||
|
||||
# Increment training counter
|
||||
self.training_iterations += 1
|
||||
|
||||
# Save checkpoints for each trained model
|
||||
for model_name in models_trained:
|
||||
try:
|
||||
model_obj = None
|
||||
current_loss = None
|
||||
model_type = model_name
|
||||
|
||||
# Get model object and calculate current performance
|
||||
if model_name == 'dqn' and self.rl_agent:
|
||||
model_obj = self.rl_agent
|
||||
# Use current loss from model state or estimate from performance
|
||||
current_loss = self.model_states['dqn'].get('current_loss')
|
||||
if current_loss is None:
|
||||
# Estimate loss from performance score (inverse relationship)
|
||||
current_loss = max(0.001, 1.0 - performance_score)
|
||||
|
||||
# Update model state tracking
|
||||
self.model_states['dqn']['current_loss'] = current_loss
|
||||
|
||||
# If this is the first loss value, set it as initial and best
|
||||
if self.model_states['dqn']['initial_loss'] is None:
|
||||
self.model_states['dqn']['initial_loss'] = current_loss
|
||||
if self.model_states['dqn']['best_loss'] is None or current_loss < self.model_states['dqn']['best_loss']:
|
||||
self.model_states['dqn']['best_loss'] = current_loss
|
||||
|
||||
elif model_name == 'cnn' and self.cnn_model:
|
||||
model_obj = self.cnn_model
|
||||
# Use current loss from model state or estimate from performance
|
||||
current_loss = self.model_states['cnn'].get('current_loss')
|
||||
if current_loss is None:
|
||||
# Estimate loss from performance score (inverse relationship)
|
||||
current_loss = max(0.001, 1.0 - performance_score)
|
||||
|
||||
# Update model state tracking
|
||||
self.model_states['cnn']['current_loss'] = current_loss
|
||||
|
||||
# If this is the first loss value, set it as initial and best
|
||||
if self.model_states['cnn']['initial_loss'] is None:
|
||||
self.model_states['cnn']['initial_loss'] = current_loss
|
||||
if self.model_states['cnn']['best_loss'] is None or current_loss < self.model_states['cnn']['best_loss']:
|
||||
self.model_states['cnn']['best_loss'] = current_loss
|
||||
|
||||
elif model_name == 'cob_rl' and self.cob_rl_agent:
|
||||
model_obj = self.cob_rl_agent
|
||||
# Use current loss from model state or estimate from performance
|
||||
current_loss = self.model_states['cob_rl'].get('current_loss')
|
||||
if current_loss is None:
|
||||
# Estimate loss from performance score (inverse relationship)
|
||||
current_loss = max(0.001, 1.0 - performance_score)
|
||||
|
||||
# Update model state tracking
|
||||
self.model_states['cob_rl']['current_loss'] = current_loss
|
||||
|
||||
# If this is the first loss value, set it as initial and best
|
||||
if self.model_states['cob_rl']['initial_loss'] is None:
|
||||
self.model_states['cob_rl']['initial_loss'] = current_loss
|
||||
if self.model_states['cob_rl']['best_loss'] is None or current_loss < self.model_states['cob_rl']['best_loss']:
|
||||
self.model_states['cob_rl']['best_loss'] = current_loss
|
||||
|
||||
elif model_name == 'extrema' and hasattr(self, 'extrema_trainer') and self.extrema_trainer:
|
||||
model_obj = self.extrema_trainer
|
||||
# Use current loss from model state or estimate from performance
|
||||
current_loss = self.model_states['extrema'].get('current_loss')
|
||||
if current_loss is None:
|
||||
# Estimate loss from performance score (inverse relationship)
|
||||
current_loss = max(0.001, 1.0 - performance_score)
|
||||
|
||||
# Update model state tracking
|
||||
self.model_states['extrema']['current_loss'] = current_loss
|
||||
|
||||
# If this is the first loss value, set it as initial and best
|
||||
if self.model_states['extrema']['initial_loss'] is None:
|
||||
self.model_states['extrema']['initial_loss'] = current_loss
|
||||
if self.model_states['extrema']['best_loss'] is None or current_loss < self.model_states['extrema']['best_loss']:
|
||||
self.model_states['extrema']['best_loss'] = current_loss
|
||||
|
||||
# Skip if we couldn't get a model object
|
||||
if model_obj is None:
|
||||
continue
|
||||
|
||||
# Prepare performance metrics for checkpoint
|
||||
performance_metrics = {
|
||||
'loss': current_loss,
|
||||
'accuracy': performance_score, # Use confidence as a proxy for accuracy
|
||||
}
|
||||
|
||||
# Prepare training metadata
|
||||
training_metadata = {
|
||||
'training_iteration': self.training_iterations,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Save checkpoint using checkpoint manager
|
||||
from utils.checkpoint_manager import save_checkpoint
|
||||
checkpoint_metadata = save_checkpoint(
|
||||
model=model_obj,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
if checkpoint_metadata:
|
||||
logger.info(f"Saved checkpoint for {model_name}: {checkpoint_metadata.checkpoint_id} (loss={current_loss:.4f})")
|
||||
|
||||
# Also save periodically based on training iterations
|
||||
if self.training_iterations % 100 == 0:
|
||||
# Force save every 100 training iterations regardless of performance
|
||||
checkpoint_metadata = save_checkpoint(
|
||||
model=model_obj,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata,
|
||||
force_save=True
|
||||
)
|
||||
if checkpoint_metadata:
|
||||
logger.info(f"Periodic checkpoint saved for {model_name}: {checkpoint_metadata.checkpoint_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _save_training_checkpoints: {e}")
|
@ -25,7 +25,7 @@ except ImportError:
|
||||
key, value = line.strip().split('=', 1)
|
||||
os.environ[key] = value
|
||||
|
||||
from NN.exchanges.bybit_interface import BybitInterface
|
||||
from core.exchanges.bybit_interface import BybitInterface
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@ -104,7 +104,7 @@ class BybitEthFuturesTest:
|
||||
# First test simple connectivity without auth
|
||||
print("Testing basic API connectivity...")
|
||||
try:
|
||||
from NN.exchanges.bybit_rest_client import BybitRestClient
|
||||
from core.exchanges.bybit_rest_client import BybitRestClient
|
||||
client = BybitRestClient(
|
||||
api_key="dummy",
|
||||
api_secret="dummy",
|
||||
|
@ -24,7 +24,7 @@ except ImportError:
|
||||
key, value = line.strip().split('=', 1)
|
||||
os.environ[key] = value
|
||||
|
||||
from NN.exchanges.bybit_interface import BybitInterface
|
||||
from core.exchanges.bybit_interface import BybitInterface
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
|
@ -23,7 +23,7 @@ except ImportError:
|
||||
key, value = line.strip().split('=', 1)
|
||||
os.environ[key] = value
|
||||
|
||||
from NN.exchanges.bybit_interface import BybitInterface
|
||||
from core.exchanges.bybit_interface import BybitInterface
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
|
@ -11,7 +11,7 @@ import logging
|
||||
# Add the project root to the path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from NN.exchanges.bybit_rest_client import BybitRestClient
|
||||
from core.exchanges.bybit_rest_client import BybitRestClient
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
|
@ -15,8 +15,8 @@ load_dotenv()
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), 'NN'))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), 'core'))
|
||||
|
||||
from NN.exchanges.exchange_factory import ExchangeFactory
|
||||
from NN.exchanges.deribit_interface import DeribitInterface
|
||||
from core.exchanges.exchange_factory import ExchangeFactory
|
||||
from core.exchanges.deribit_interface import DeribitInterface
|
||||
from core.config import get_config
|
||||
|
||||
# Setup logging
|
||||
|
Reference in New Issue
Block a user