long term CNN training
This commit is contained in:
@ -1567,12 +1567,22 @@ class TradingOrchestrator:
|
||||
async def _trading_decision_loop(self):
|
||||
"""Main trading decision loop"""
|
||||
logger.info("Trading decision loop started")
|
||||
long_term_training_counter = 0
|
||||
while self.running:
|
||||
try:
|
||||
# Only make decisions for the primary trading symbol
|
||||
await self.make_trading_decision(self.symbol)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Trigger long-term training every 60 seconds (60 iterations)
|
||||
long_term_training_counter += 1
|
||||
if long_term_training_counter >= 60:
|
||||
try:
|
||||
await self.trigger_cnn_long_term_training()
|
||||
long_term_training_counter = 0
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in periodic long-term training: {e}")
|
||||
|
||||
await asyncio.sleep(self.decision_frequency)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in trading decision loop: {e}")
|
||||
@ -3413,6 +3423,49 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in immediate training for previous inference {model_name}: {e}")
|
||||
|
||||
async def trigger_cnn_long_term_training(self):
|
||||
"""Trigger long-term training on CNN stored inference records"""
|
||||
try:
|
||||
if hasattr(self, "cnn_model") and self.cnn_model and hasattr(self, "cnn_optimizer"):
|
||||
if hasattr(self.cnn_model, "train_on_stored_records"):
|
||||
# Get current price for all symbols
|
||||
symbols = ["ETH/USDT"] # Add more symbols as needed
|
||||
|
||||
for symbol in symbols:
|
||||
current_price = self._get_current_price(symbol)
|
||||
if current_price and hasattr(self.cnn_model, "inference_records"):
|
||||
# Update all stored records with current price information
|
||||
for record in self.cnn_model.inference_records:
|
||||
if "metadata" in record:
|
||||
record_metadata = record["metadata"]
|
||||
record_price = record_metadata.get("current_price")
|
||||
|
||||
if record_price and current_price:
|
||||
price_change_pct = ((current_price - record_price) / record_price) * 100
|
||||
time_diff = (datetime.now() - datetime.fromisoformat(record_metadata.get("timestamp", ""))).total_seconds() / 60
|
||||
|
||||
# Update with actual price changes and time differences
|
||||
record_metadata["actual_price_changes"] = {
|
||||
"short_term": price_change_pct,
|
||||
"mid_term": price_change_pct * 0.8,
|
||||
"long_term": price_change_pct * 0.6
|
||||
}
|
||||
record_metadata["time_diffs"] = {
|
||||
"short_term": min(time_diff, 1.0),
|
||||
"mid_term": min(time_diff, 5.0),
|
||||
"long_term": min(time_diff, 15.0)
|
||||
}
|
||||
|
||||
# Train on all stored records
|
||||
long_term_loss = self.cnn_model.train_on_stored_records(self.cnn_optimizer, min_records=3)
|
||||
if long_term_loss > 0:
|
||||
logger.info(f"CNN long-term training completed: loss={long_term_loss:.4f}, records={len(self.cnn_model.inference_records)}")
|
||||
else:
|
||||
logger.debug(f"CNN long-term training skipped: insufficient records ({len(self.cnn_model.inference_records)})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN long-term training: {e}")
|
||||
|
||||
async def _evaluate_and_train_on_record(self, record: Dict, current_price: float):
|
||||
"""Evaluate prediction outcome and train model"""
|
||||
try:
|
||||
@ -4625,6 +4678,40 @@ class TradingOrchestrator:
|
||||
logger.debug(
|
||||
f"CNN direct training completed: loss={current_loss:.4f}, time={training_duration_ms:.1f}ms"
|
||||
)
|
||||
|
||||
# Trigger long-term training on stored inference records
|
||||
if hasattr(self.cnn_model, "train_on_stored_records") and hasattr(self, "cnn_optimizer"):
|
||||
try:
|
||||
# Update metadata in stored records with actual price changes
|
||||
symbol = record.get("symbol", "ETH/USDT")
|
||||
current_price = self._get_current_price(symbol)
|
||||
inference_price = record.get("inference_price")
|
||||
|
||||
if inference_price and current_price:
|
||||
price_change_pct = ((current_price - inference_price) / inference_price) * 100
|
||||
|
||||
# Update the most recent inference record with actual price changes
|
||||
if hasattr(self.cnn_model, "inference_records") and self.cnn_model.inference_records:
|
||||
latest_record = self.cnn_model.inference_records[-1]
|
||||
if "metadata" in latest_record:
|
||||
latest_record["metadata"]["actual_price_changes"] = {
|
||||
"short_term": price_change_pct,
|
||||
"mid_term": price_change_pct * 0.8, # Slight decay for longer timeframes
|
||||
"long_term": price_change_pct * 0.6
|
||||
}
|
||||
latest_record["metadata"]["time_diffs"] = {
|
||||
"short_term": 1.0, # 1 minute
|
||||
"mid_term": 5.0, # 5 minutes
|
||||
"long_term": 15.0 # 15 minutes
|
||||
}
|
||||
|
||||
# Train on stored records
|
||||
long_term_loss = self.cnn_model.train_on_stored_records(self.cnn_optimizer, min_records=5)
|
||||
if long_term_loss > 0:
|
||||
logger.debug(f"CNN long-term training completed: loss={long_term_loss:.4f}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in CNN long-term training: {e}")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"No model input available for CNN training")
|
||||
@ -4947,6 +5034,34 @@ class TradingOrchestrator:
|
||||
)
|
||||
predictions.append(prediction)
|
||||
|
||||
# Store inference record in CNN model for long-term training
|
||||
if hasattr(self.cnn_model, "store_inference_record"):
|
||||
try:
|
||||
# Get current price for metadata
|
||||
current_price = self._get_current_price(symbol)
|
||||
|
||||
# Create metadata with price information for long-term training
|
||||
metadata = {
|
||||
"symbol": symbol,
|
||||
"current_price": current_price,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"prediction_action": action,
|
||||
"prediction_confidence": confidence,
|
||||
"actual_price_changes": {}, # Will be populated during training
|
||||
"time_diffs": {} # Will be populated during training
|
||||
}
|
||||
|
||||
# Store the inference record in the CNN model
|
||||
self.cnn_model.store_inference_record(
|
||||
input_data=features_tensor,
|
||||
prediction_output=(q_values, extrema_pred, price_pred, features_refined, advanced_pred, multi_timeframe_pred),
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
logger.debug(f"Stored CNN inference record for long-term training")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error storing CNN inference record: {e}")
|
||||
|
||||
logger.debug(
|
||||
f"Added CNN prediction: {action} ({confidence:.3f})"
|
||||
)
|
||||
@ -5029,6 +5144,34 @@ class TradingOrchestrator:
|
||||
)
|
||||
predictions.append(pred)
|
||||
|
||||
# Store inference record in CNN model for long-term training (fallback method)
|
||||
if hasattr(model.model, "store_inference_record"):
|
||||
try:
|
||||
# Get current price for metadata
|
||||
current_price = self._get_current_price(symbol)
|
||||
|
||||
# Create metadata with price information for long-term training
|
||||
metadata = {
|
||||
"symbol": symbol,
|
||||
"current_price": current_price,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"prediction_action": best_action,
|
||||
"prediction_confidence": float(confidence),
|
||||
"actual_price_changes": {}, # Will be populated during training
|
||||
"time_diffs": {} # Will be populated during training
|
||||
}
|
||||
|
||||
# Store the inference record in the CNN model
|
||||
model.model.store_inference_record(
|
||||
input_data=features_tensor,
|
||||
prediction_output=None, # Not available in fallback method
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
logger.debug(f"Stored CNN inference record for long-term training (fallback)")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error storing CNN inference record (fallback): {e}")
|
||||
|
||||
# Note: Inference data will be stored in main prediction loop to avoid duplication
|
||||
|
||||
# Capture for dashboard
|
||||
|
@ -21,9 +21,9 @@
|
||||
"training_enabled": true
|
||||
},
|
||||
"dqn_agent": {
|
||||
"inference_enabled": true,
|
||||
"training_enabled": true
|
||||
"inference_enabled": "inference_enabled",
|
||||
"training_enabled": false
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-07-30T00:41:19.241862"
|
||||
"timestamp": "2025-07-30T09:19:11.731827"
|
||||
}
|
Reference in New Issue
Block a user