From 6ca19f45360d8d322df61f7c84941a841e605cc3 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Wed, 30 Jul 2025 09:35:53 +0300 Subject: [PATCH] long term CNN training --- core/orchestrator.py | 143 +++++++++++++++++++++++++++++++++++++++++++ data/ui_state.json | 6 +- 2 files changed, 146 insertions(+), 3 deletions(-) diff --git a/core/orchestrator.py b/core/orchestrator.py index 493fd58..28e6bf9 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -1567,12 +1567,22 @@ class TradingOrchestrator: async def _trading_decision_loop(self): """Main trading decision loop""" logger.info("Trading decision loop started") + long_term_training_counter = 0 while self.running: try: # Only make decisions for the primary trading symbol await self.make_trading_decision(self.symbol) await asyncio.sleep(1) + # Trigger long-term training every 60 seconds (60 iterations) + long_term_training_counter += 1 + if long_term_training_counter >= 60: + try: + await self.trigger_cnn_long_term_training() + long_term_training_counter = 0 + except Exception as e: + logger.debug(f"Error in periodic long-term training: {e}") + await asyncio.sleep(self.decision_frequency) except Exception as e: logger.error(f"Error in trading decision loop: {e}") @@ -3413,6 +3423,49 @@ class TradingOrchestrator: except Exception as e: logger.error(f"Error in immediate training for previous inference {model_name}: {e}") + async def trigger_cnn_long_term_training(self): + """Trigger long-term training on CNN stored inference records""" + try: + if hasattr(self, "cnn_model") and self.cnn_model and hasattr(self, "cnn_optimizer"): + if hasattr(self.cnn_model, "train_on_stored_records"): + # Get current price for all symbols + symbols = ["ETH/USDT"] # Add more symbols as needed + + for symbol in symbols: + current_price = self._get_current_price(symbol) + if current_price and hasattr(self.cnn_model, "inference_records"): + # Update all stored records with current price information + for record in self.cnn_model.inference_records: + if "metadata" in record: + record_metadata = record["metadata"] + record_price = record_metadata.get("current_price") + + if record_price and current_price: + price_change_pct = ((current_price - record_price) / record_price) * 100 + time_diff = (datetime.now() - datetime.fromisoformat(record_metadata.get("timestamp", ""))).total_seconds() / 60 + + # Update with actual price changes and time differences + record_metadata["actual_price_changes"] = { + "short_term": price_change_pct, + "mid_term": price_change_pct * 0.8, + "long_term": price_change_pct * 0.6 + } + record_metadata["time_diffs"] = { + "short_term": min(time_diff, 1.0), + "mid_term": min(time_diff, 5.0), + "long_term": min(time_diff, 15.0) + } + + # Train on all stored records + long_term_loss = self.cnn_model.train_on_stored_records(self.cnn_optimizer, min_records=3) + if long_term_loss > 0: + logger.info(f"CNN long-term training completed: loss={long_term_loss:.4f}, records={len(self.cnn_model.inference_records)}") + else: + logger.debug(f"CNN long-term training skipped: insufficient records ({len(self.cnn_model.inference_records)})") + + except Exception as e: + logger.error(f"Error in CNN long-term training: {e}") + async def _evaluate_and_train_on_record(self, record: Dict, current_price: float): """Evaluate prediction outcome and train model""" try: @@ -4625,6 +4678,40 @@ class TradingOrchestrator: logger.debug( f"CNN direct training completed: loss={current_loss:.4f}, time={training_duration_ms:.1f}ms" ) + + # Trigger long-term training on stored inference records + if hasattr(self.cnn_model, "train_on_stored_records") and hasattr(self, "cnn_optimizer"): + try: + # Update metadata in stored records with actual price changes + symbol = record.get("symbol", "ETH/USDT") + current_price = self._get_current_price(symbol) + inference_price = record.get("inference_price") + + if inference_price and current_price: + price_change_pct = ((current_price - inference_price) / inference_price) * 100 + + # Update the most recent inference record with actual price changes + if hasattr(self.cnn_model, "inference_records") and self.cnn_model.inference_records: + latest_record = self.cnn_model.inference_records[-1] + if "metadata" in latest_record: + latest_record["metadata"]["actual_price_changes"] = { + "short_term": price_change_pct, + "mid_term": price_change_pct * 0.8, # Slight decay for longer timeframes + "long_term": price_change_pct * 0.6 + } + latest_record["metadata"]["time_diffs"] = { + "short_term": 1.0, # 1 minute + "mid_term": 5.0, # 5 minutes + "long_term": 15.0 # 15 minutes + } + + # Train on stored records + long_term_loss = self.cnn_model.train_on_stored_records(self.cnn_optimizer, min_records=5) + if long_term_loss > 0: + logger.debug(f"CNN long-term training completed: loss={long_term_loss:.4f}") + except Exception as e: + logger.debug(f"Error in CNN long-term training: {e}") + return True else: logger.warning(f"No model input available for CNN training") @@ -4947,6 +5034,34 @@ class TradingOrchestrator: ) predictions.append(prediction) + # Store inference record in CNN model for long-term training + if hasattr(self.cnn_model, "store_inference_record"): + try: + # Get current price for metadata + current_price = self._get_current_price(symbol) + + # Create metadata with price information for long-term training + metadata = { + "symbol": symbol, + "current_price": current_price, + "timestamp": datetime.now().isoformat(), + "prediction_action": action, + "prediction_confidence": confidence, + "actual_price_changes": {}, # Will be populated during training + "time_diffs": {} # Will be populated during training + } + + # Store the inference record in the CNN model + self.cnn_model.store_inference_record( + input_data=features_tensor, + prediction_output=(q_values, extrema_pred, price_pred, features_refined, advanced_pred, multi_timeframe_pred), + metadata=metadata + ) + + logger.debug(f"Stored CNN inference record for long-term training") + except Exception as e: + logger.debug(f"Error storing CNN inference record: {e}") + logger.debug( f"Added CNN prediction: {action} ({confidence:.3f})" ) @@ -5029,6 +5144,34 @@ class TradingOrchestrator: ) predictions.append(pred) + # Store inference record in CNN model for long-term training (fallback method) + if hasattr(model.model, "store_inference_record"): + try: + # Get current price for metadata + current_price = self._get_current_price(symbol) + + # Create metadata with price information for long-term training + metadata = { + "symbol": symbol, + "current_price": current_price, + "timestamp": datetime.now().isoformat(), + "prediction_action": best_action, + "prediction_confidence": float(confidence), + "actual_price_changes": {}, # Will be populated during training + "time_diffs": {} # Will be populated during training + } + + # Store the inference record in the CNN model + model.model.store_inference_record( + input_data=features_tensor, + prediction_output=None, # Not available in fallback method + metadata=metadata + ) + + logger.debug(f"Stored CNN inference record for long-term training (fallback)") + except Exception as e: + logger.debug(f"Error storing CNN inference record (fallback): {e}") + # Note: Inference data will be stored in main prediction loop to avoid duplication # Capture for dashboard diff --git a/data/ui_state.json b/data/ui_state.json index a468778..7498d2c 100644 --- a/data/ui_state.json +++ b/data/ui_state.json @@ -21,9 +21,9 @@ "training_enabled": true }, "dqn_agent": { - "inference_enabled": true, - "training_enabled": true + "inference_enabled": "inference_enabled", + "training_enabled": false } }, - "timestamp": "2025-07-30T00:41:19.241862" + "timestamp": "2025-07-30T09:19:11.731827" } \ No newline at end of file