long term CNN training

This commit is contained in:
Dobromir Popov
2025-07-30 09:35:53 +03:00
parent ec24d55e00
commit 6ca19f4536
2 changed files with 146 additions and 3 deletions

View File

@ -1567,12 +1567,22 @@ class TradingOrchestrator:
async def _trading_decision_loop(self):
"""Main trading decision loop"""
logger.info("Trading decision loop started")
long_term_training_counter = 0
while self.running:
try:
# Only make decisions for the primary trading symbol
await self.make_trading_decision(self.symbol)
await asyncio.sleep(1)
# Trigger long-term training every 60 seconds (60 iterations)
long_term_training_counter += 1
if long_term_training_counter >= 60:
try:
await self.trigger_cnn_long_term_training()
long_term_training_counter = 0
except Exception as e:
logger.debug(f"Error in periodic long-term training: {e}")
await asyncio.sleep(self.decision_frequency)
except Exception as e:
logger.error(f"Error in trading decision loop: {e}")
@ -3413,6 +3423,49 @@ class TradingOrchestrator:
except Exception as e:
logger.error(f"Error in immediate training for previous inference {model_name}: {e}")
async def trigger_cnn_long_term_training(self):
"""Trigger long-term training on CNN stored inference records"""
try:
if hasattr(self, "cnn_model") and self.cnn_model and hasattr(self, "cnn_optimizer"):
if hasattr(self.cnn_model, "train_on_stored_records"):
# Get current price for all symbols
symbols = ["ETH/USDT"] # Add more symbols as needed
for symbol in symbols:
current_price = self._get_current_price(symbol)
if current_price and hasattr(self.cnn_model, "inference_records"):
# Update all stored records with current price information
for record in self.cnn_model.inference_records:
if "metadata" in record:
record_metadata = record["metadata"]
record_price = record_metadata.get("current_price")
if record_price and current_price:
price_change_pct = ((current_price - record_price) / record_price) * 100
time_diff = (datetime.now() - datetime.fromisoformat(record_metadata.get("timestamp", ""))).total_seconds() / 60
# Update with actual price changes and time differences
record_metadata["actual_price_changes"] = {
"short_term": price_change_pct,
"mid_term": price_change_pct * 0.8,
"long_term": price_change_pct * 0.6
}
record_metadata["time_diffs"] = {
"short_term": min(time_diff, 1.0),
"mid_term": min(time_diff, 5.0),
"long_term": min(time_diff, 15.0)
}
# Train on all stored records
long_term_loss = self.cnn_model.train_on_stored_records(self.cnn_optimizer, min_records=3)
if long_term_loss > 0:
logger.info(f"CNN long-term training completed: loss={long_term_loss:.4f}, records={len(self.cnn_model.inference_records)}")
else:
logger.debug(f"CNN long-term training skipped: insufficient records ({len(self.cnn_model.inference_records)})")
except Exception as e:
logger.error(f"Error in CNN long-term training: {e}")
async def _evaluate_and_train_on_record(self, record: Dict, current_price: float):
"""Evaluate prediction outcome and train model"""
try:
@ -4625,6 +4678,40 @@ class TradingOrchestrator:
logger.debug(
f"CNN direct training completed: loss={current_loss:.4f}, time={training_duration_ms:.1f}ms"
)
# Trigger long-term training on stored inference records
if hasattr(self.cnn_model, "train_on_stored_records") and hasattr(self, "cnn_optimizer"):
try:
# Update metadata in stored records with actual price changes
symbol = record.get("symbol", "ETH/USDT")
current_price = self._get_current_price(symbol)
inference_price = record.get("inference_price")
if inference_price and current_price:
price_change_pct = ((current_price - inference_price) / inference_price) * 100
# Update the most recent inference record with actual price changes
if hasattr(self.cnn_model, "inference_records") and self.cnn_model.inference_records:
latest_record = self.cnn_model.inference_records[-1]
if "metadata" in latest_record:
latest_record["metadata"]["actual_price_changes"] = {
"short_term": price_change_pct,
"mid_term": price_change_pct * 0.8, # Slight decay for longer timeframes
"long_term": price_change_pct * 0.6
}
latest_record["metadata"]["time_diffs"] = {
"short_term": 1.0, # 1 minute
"mid_term": 5.0, # 5 minutes
"long_term": 15.0 # 15 minutes
}
# Train on stored records
long_term_loss = self.cnn_model.train_on_stored_records(self.cnn_optimizer, min_records=5)
if long_term_loss > 0:
logger.debug(f"CNN long-term training completed: loss={long_term_loss:.4f}")
except Exception as e:
logger.debug(f"Error in CNN long-term training: {e}")
return True
else:
logger.warning(f"No model input available for CNN training")
@ -4947,6 +5034,34 @@ class TradingOrchestrator:
)
predictions.append(prediction)
# Store inference record in CNN model for long-term training
if hasattr(self.cnn_model, "store_inference_record"):
try:
# Get current price for metadata
current_price = self._get_current_price(symbol)
# Create metadata with price information for long-term training
metadata = {
"symbol": symbol,
"current_price": current_price,
"timestamp": datetime.now().isoformat(),
"prediction_action": action,
"prediction_confidence": confidence,
"actual_price_changes": {}, # Will be populated during training
"time_diffs": {} # Will be populated during training
}
# Store the inference record in the CNN model
self.cnn_model.store_inference_record(
input_data=features_tensor,
prediction_output=(q_values, extrema_pred, price_pred, features_refined, advanced_pred, multi_timeframe_pred),
metadata=metadata
)
logger.debug(f"Stored CNN inference record for long-term training")
except Exception as e:
logger.debug(f"Error storing CNN inference record: {e}")
logger.debug(
f"Added CNN prediction: {action} ({confidence:.3f})"
)
@ -5029,6 +5144,34 @@ class TradingOrchestrator:
)
predictions.append(pred)
# Store inference record in CNN model for long-term training (fallback method)
if hasattr(model.model, "store_inference_record"):
try:
# Get current price for metadata
current_price = self._get_current_price(symbol)
# Create metadata with price information for long-term training
metadata = {
"symbol": symbol,
"current_price": current_price,
"timestamp": datetime.now().isoformat(),
"prediction_action": best_action,
"prediction_confidence": float(confidence),
"actual_price_changes": {}, # Will be populated during training
"time_diffs": {} # Will be populated during training
}
# Store the inference record in the CNN model
model.model.store_inference_record(
input_data=features_tensor,
prediction_output=None, # Not available in fallback method
metadata=metadata
)
logger.debug(f"Stored CNN inference record for long-term training (fallback)")
except Exception as e:
logger.debug(f"Error storing CNN inference record (fallback): {e}")
# Note: Inference data will be stored in main prediction loop to avoid duplication
# Capture for dashboard

View File

@ -21,9 +21,9 @@
"training_enabled": true
},
"dqn_agent": {
"inference_enabled": true,
"training_enabled": true
"inference_enabled": "inference_enabled",
"training_enabled": false
}
},
"timestamp": "2025-07-30T00:41:19.241862"
"timestamp": "2025-07-30T09:19:11.731827"
}