references
This commit is contained in:
@ -2503,3 +2503,157 @@ class TradingOrchestrator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error checking signal confirmation for {symbol}: {e}")
|
logger.error(f"Error checking signal confirmation for {symbol}: {e}")
|
||||||
return None
|
return None
|
||||||
|
def _initialize_checkpoint_manager(self):
|
||||||
|
"""Initialize the checkpoint manager for model persistence"""
|
||||||
|
try:
|
||||||
|
from utils.checkpoint_manager import get_checkpoint_manager
|
||||||
|
self.checkpoint_manager = get_checkpoint_manager()
|
||||||
|
|
||||||
|
# Initialize model states dictionary to track performance
|
||||||
|
self.model_states = {
|
||||||
|
'dqn': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False},
|
||||||
|
'cnn': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False},
|
||||||
|
'cob_rl': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False},
|
||||||
|
'extrema': {'initial_loss': None, 'current_loss': None, 'best_loss': float('inf'), 'checkpoint_loaded': False}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("Checkpoint manager initialized for model persistence")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error initializing checkpoint manager: {e}")
|
||||||
|
self.checkpoint_manager = None de
|
||||||
|
f _save_training_checkpoints(self, models_trained: List[str], performance_score: float):
|
||||||
|
"""Save checkpoints for trained models if performance improved
|
||||||
|
|
||||||
|
This is CRITICAL for preserving training progress across restarts.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self.checkpoint_manager:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Increment training counter
|
||||||
|
self.training_iterations += 1
|
||||||
|
|
||||||
|
# Save checkpoints for each trained model
|
||||||
|
for model_name in models_trained:
|
||||||
|
try:
|
||||||
|
model_obj = None
|
||||||
|
current_loss = None
|
||||||
|
model_type = model_name
|
||||||
|
|
||||||
|
# Get model object and calculate current performance
|
||||||
|
if model_name == 'dqn' and self.rl_agent:
|
||||||
|
model_obj = self.rl_agent
|
||||||
|
# Use current loss from model state or estimate from performance
|
||||||
|
current_loss = self.model_states['dqn'].get('current_loss')
|
||||||
|
if current_loss is None:
|
||||||
|
# Estimate loss from performance score (inverse relationship)
|
||||||
|
current_loss = max(0.001, 1.0 - performance_score)
|
||||||
|
|
||||||
|
# Update model state tracking
|
||||||
|
self.model_states['dqn']['current_loss'] = current_loss
|
||||||
|
|
||||||
|
# If this is the first loss value, set it as initial and best
|
||||||
|
if self.model_states['dqn']['initial_loss'] is None:
|
||||||
|
self.model_states['dqn']['initial_loss'] = current_loss
|
||||||
|
if self.model_states['dqn']['best_loss'] is None or current_loss < self.model_states['dqn']['best_loss']:
|
||||||
|
self.model_states['dqn']['best_loss'] = current_loss
|
||||||
|
|
||||||
|
elif model_name == 'cnn' and self.cnn_model:
|
||||||
|
model_obj = self.cnn_model
|
||||||
|
# Use current loss from model state or estimate from performance
|
||||||
|
current_loss = self.model_states['cnn'].get('current_loss')
|
||||||
|
if current_loss is None:
|
||||||
|
# Estimate loss from performance score (inverse relationship)
|
||||||
|
current_loss = max(0.001, 1.0 - performance_score)
|
||||||
|
|
||||||
|
# Update model state tracking
|
||||||
|
self.model_states['cnn']['current_loss'] = current_loss
|
||||||
|
|
||||||
|
# If this is the first loss value, set it as initial and best
|
||||||
|
if self.model_states['cnn']['initial_loss'] is None:
|
||||||
|
self.model_states['cnn']['initial_loss'] = current_loss
|
||||||
|
if self.model_states['cnn']['best_loss'] is None or current_loss < self.model_states['cnn']['best_loss']:
|
||||||
|
self.model_states['cnn']['best_loss'] = current_loss
|
||||||
|
|
||||||
|
elif model_name == 'cob_rl' and self.cob_rl_agent:
|
||||||
|
model_obj = self.cob_rl_agent
|
||||||
|
# Use current loss from model state or estimate from performance
|
||||||
|
current_loss = self.model_states['cob_rl'].get('current_loss')
|
||||||
|
if current_loss is None:
|
||||||
|
# Estimate loss from performance score (inverse relationship)
|
||||||
|
current_loss = max(0.001, 1.0 - performance_score)
|
||||||
|
|
||||||
|
# Update model state tracking
|
||||||
|
self.model_states['cob_rl']['current_loss'] = current_loss
|
||||||
|
|
||||||
|
# If this is the first loss value, set it as initial and best
|
||||||
|
if self.model_states['cob_rl']['initial_loss'] is None:
|
||||||
|
self.model_states['cob_rl']['initial_loss'] = current_loss
|
||||||
|
if self.model_states['cob_rl']['best_loss'] is None or current_loss < self.model_states['cob_rl']['best_loss']:
|
||||||
|
self.model_states['cob_rl']['best_loss'] = current_loss
|
||||||
|
|
||||||
|
elif model_name == 'extrema' and hasattr(self, 'extrema_trainer') and self.extrema_trainer:
|
||||||
|
model_obj = self.extrema_trainer
|
||||||
|
# Use current loss from model state or estimate from performance
|
||||||
|
current_loss = self.model_states['extrema'].get('current_loss')
|
||||||
|
if current_loss is None:
|
||||||
|
# Estimate loss from performance score (inverse relationship)
|
||||||
|
current_loss = max(0.001, 1.0 - performance_score)
|
||||||
|
|
||||||
|
# Update model state tracking
|
||||||
|
self.model_states['extrema']['current_loss'] = current_loss
|
||||||
|
|
||||||
|
# If this is the first loss value, set it as initial and best
|
||||||
|
if self.model_states['extrema']['initial_loss'] is None:
|
||||||
|
self.model_states['extrema']['initial_loss'] = current_loss
|
||||||
|
if self.model_states['extrema']['best_loss'] is None or current_loss < self.model_states['extrema']['best_loss']:
|
||||||
|
self.model_states['extrema']['best_loss'] = current_loss
|
||||||
|
|
||||||
|
# Skip if we couldn't get a model object
|
||||||
|
if model_obj is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prepare performance metrics for checkpoint
|
||||||
|
performance_metrics = {
|
||||||
|
'loss': current_loss,
|
||||||
|
'accuracy': performance_score, # Use confidence as a proxy for accuracy
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prepare training metadata
|
||||||
|
training_metadata = {
|
||||||
|
'training_iteration': self.training_iterations,
|
||||||
|
'timestamp': datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save checkpoint using checkpoint manager
|
||||||
|
from utils.checkpoint_manager import save_checkpoint
|
||||||
|
checkpoint_metadata = save_checkpoint(
|
||||||
|
model=model_obj,
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=model_type,
|
||||||
|
performance_metrics=performance_metrics,
|
||||||
|
training_metadata=training_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
if checkpoint_metadata:
|
||||||
|
logger.info(f"Saved checkpoint for {model_name}: {checkpoint_metadata.checkpoint_id} (loss={current_loss:.4f})")
|
||||||
|
|
||||||
|
# Also save periodically based on training iterations
|
||||||
|
if self.training_iterations % 100 == 0:
|
||||||
|
# Force save every 100 training iterations regardless of performance
|
||||||
|
checkpoint_metadata = save_checkpoint(
|
||||||
|
model=model_obj,
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=model_type,
|
||||||
|
performance_metrics=performance_metrics,
|
||||||
|
training_metadata=training_metadata,
|
||||||
|
force_save=True
|
||||||
|
)
|
||||||
|
if checkpoint_metadata:
|
||||||
|
logger.info(f"Periodic checkpoint saved for {model_name}: {checkpoint_metadata.checkpoint_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in _save_training_checkpoints: {e}")
|
@ -25,7 +25,7 @@ except ImportError:
|
|||||||
key, value = line.strip().split('=', 1)
|
key, value = line.strip().split('=', 1)
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
|
|
||||||
from NN.exchanges.bybit_interface import BybitInterface
|
from core.exchanges.bybit_interface import BybitInterface
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -104,7 +104,7 @@ class BybitEthFuturesTest:
|
|||||||
# First test simple connectivity without auth
|
# First test simple connectivity without auth
|
||||||
print("Testing basic API connectivity...")
|
print("Testing basic API connectivity...")
|
||||||
try:
|
try:
|
||||||
from NN.exchanges.bybit_rest_client import BybitRestClient
|
from core.exchanges.bybit_rest_client import BybitRestClient
|
||||||
client = BybitRestClient(
|
client = BybitRestClient(
|
||||||
api_key="dummy",
|
api_key="dummy",
|
||||||
api_secret="dummy",
|
api_secret="dummy",
|
||||||
|
@ -24,7 +24,7 @@ except ImportError:
|
|||||||
key, value = line.strip().split('=', 1)
|
key, value = line.strip().split('=', 1)
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
|
|
||||||
from NN.exchanges.bybit_interface import BybitInterface
|
from core.exchanges.bybit_interface import BybitInterface
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
@ -23,7 +23,7 @@ except ImportError:
|
|||||||
key, value = line.strip().split('=', 1)
|
key, value = line.strip().split('=', 1)
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
|
|
||||||
from NN.exchanges.bybit_interface import BybitInterface
|
from core.exchanges.bybit_interface import BybitInterface
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
@ -11,7 +11,7 @@ import logging
|
|||||||
# Add the project root to the path
|
# Add the project root to the path
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
from NN.exchanges.bybit_rest_client import BybitRestClient
|
from core.exchanges.bybit_rest_client import BybitRestClient
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
@ -15,8 +15,8 @@ load_dotenv()
|
|||||||
sys.path.append(os.path.join(os.path.dirname(__file__), 'NN'))
|
sys.path.append(os.path.join(os.path.dirname(__file__), 'NN'))
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), 'core'))
|
sys.path.append(os.path.join(os.path.dirname(__file__), 'core'))
|
||||||
|
|
||||||
from NN.exchanges.exchange_factory import ExchangeFactory
|
from core.exchanges.exchange_factory import ExchangeFactory
|
||||||
from NN.exchanges.deribit_interface import DeribitInterface
|
from core.exchanges.deribit_interface import DeribitInterface
|
||||||
from core.config import get_config
|
from core.config import get_config
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
|
Reference in New Issue
Block a user