wip cnn training and cob
This commit is contained in:
@ -2726,6 +2726,7 @@ class CleanTradingDashboard:
|
||||
"""Update CNN model panel with real-time data and performance metrics"""
|
||||
try:
|
||||
if not self.cnn_adapter:
|
||||
logger.debug("CNN adapter not available for model panel update")
|
||||
return {
|
||||
'status': 'NOT_AVAILABLE',
|
||||
'parameters': '0M',
|
||||
@ -2744,8 +2745,15 @@ class CleanTradingDashboard:
|
||||
'last_training_loss': 0.0
|
||||
}
|
||||
|
||||
logger.debug(f"CNN adapter available: {type(self.cnn_adapter)}")
|
||||
|
||||
# Get CNN prediction for ETH/USDT
|
||||
prediction = self._get_cnn_prediction('ETH/USDT')
|
||||
logger.debug(f"CNN prediction result: {prediction}")
|
||||
|
||||
# Debug: Check CNN adapter attributes
|
||||
logger.debug(f"CNN adapter attributes: inference_count={getattr(self.cnn_adapter, 'inference_count', 'MISSING')}, training_count={getattr(self.cnn_adapter, 'training_count', 'MISSING')}")
|
||||
logger.debug(f"CNN adapter training data length: {len(getattr(self.cnn_adapter, 'training_data', []))}")
|
||||
|
||||
# Get model performance metrics
|
||||
model_info = self.cnn_adapter.get_model_info() if hasattr(self.cnn_adapter, 'get_model_info') else {}
|
||||
@ -2804,9 +2812,15 @@ class CleanTradingDashboard:
|
||||
pivot_price_str = "N/A"
|
||||
last_prediction = "No prediction"
|
||||
|
||||
# Get model status
|
||||
# Get model status - enhanced for cold start mode
|
||||
if hasattr(self.cnn_adapter, 'model') and self.cnn_adapter.model:
|
||||
if training_samples > 100:
|
||||
# Check if model is actively training (cold start mode)
|
||||
if training_count > 0 and training_samples > 0:
|
||||
if training_samples > 100:
|
||||
status = 'TRAINED'
|
||||
else:
|
||||
status = 'TRAINING' # Cold start training mode
|
||||
elif training_samples > 100:
|
||||
status = 'TRAINED'
|
||||
elif training_samples > 0:
|
||||
status = 'TRAINING'
|
||||
@ -5730,14 +5744,17 @@ class CleanTradingDashboard:
|
||||
"""Get CNN prediction using standardized input format"""
|
||||
try:
|
||||
if not self.cnn_adapter:
|
||||
logger.debug(f"CNN adapter not available for prediction")
|
||||
return None
|
||||
|
||||
# Get standardized input data from data provider
|
||||
base_data_input = self._get_base_data_input(symbol)
|
||||
if not base_data_input:
|
||||
logger.debug(f"No base data input available for {symbol}")
|
||||
logger.warning(f"No base data input available for {symbol} - this will prevent CNN predictions")
|
||||
return None
|
||||
|
||||
logger.debug(f"Base data input created successfully for {symbol}")
|
||||
|
||||
# Make prediction using CNN adapter
|
||||
model_output = self.cnn_adapter.predict(base_data_input)
|
||||
|
||||
@ -5770,15 +5787,48 @@ class CleanTradingDashboard:
|
||||
# Fallback: create BaseDataInput from available data
|
||||
from core.data_models import BaseDataInput, OHLCVBar, COBData
|
||||
|
||||
# Get OHLCV data for different timeframes
|
||||
# Get OHLCV data for different timeframes - ensure we have enough data
|
||||
ohlcv_1s = self._get_ohlcv_bars(symbol, '1s', 300)
|
||||
ohlcv_1m = self._get_ohlcv_bars(symbol, '1m', 300)
|
||||
ohlcv_1m = self._get_ohlcv_bars(symbol, '1m', 300)
|
||||
ohlcv_1h = self._get_ohlcv_bars(symbol, '1h', 300)
|
||||
ohlcv_1d = self._get_ohlcv_bars(symbol, '1d', 300)
|
||||
|
||||
# Get BTC reference data
|
||||
btc_ohlcv_1s = self._get_ohlcv_bars('BTC/USDT', '1s', 300)
|
||||
|
||||
# Ensure we have minimum required data (pad if necessary)
|
||||
def pad_ohlcv_data(bars, target_count=300):
|
||||
if len(bars) < target_count:
|
||||
# Pad with the last bar repeated
|
||||
if len(bars) > 0:
|
||||
last_bar = bars[-1]
|
||||
while len(bars) < target_count:
|
||||
bars.append(last_bar)
|
||||
else:
|
||||
# Create dummy bars if no data
|
||||
from core.data_models import OHLCVBar
|
||||
dummy_bar = OHLCVBar(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
open=3500.0,
|
||||
high=3510.0,
|
||||
low=3490.0,
|
||||
close=3505.0,
|
||||
volume=1000.0,
|
||||
timeframe="1s"
|
||||
)
|
||||
bars = [dummy_bar] * target_count
|
||||
return bars[:target_count] # Ensure exactly target_count
|
||||
|
||||
# Pad all data to required length
|
||||
ohlcv_1s = pad_ohlcv_data(ohlcv_1s, 300)
|
||||
ohlcv_1m = pad_ohlcv_data(ohlcv_1m, 300)
|
||||
ohlcv_1h = pad_ohlcv_data(ohlcv_1h, 300)
|
||||
ohlcv_1d = pad_ohlcv_data(ohlcv_1d, 300)
|
||||
btc_ohlcv_1s = pad_ohlcv_data(btc_ohlcv_1s, 300)
|
||||
|
||||
logger.debug(f"OHLCV data lengths: 1s={len(ohlcv_1s)}, 1m={len(ohlcv_1m)}, 1h={len(ohlcv_1h)}, 1d={len(ohlcv_1d)}, BTC={len(btc_ohlcv_1s)}")
|
||||
|
||||
# Get COB data if available
|
||||
cob_data = self._get_cob_data(symbol)
|
||||
|
||||
@ -5942,41 +5992,65 @@ class CleanTradingDashboard:
|
||||
}
|
||||
|
||||
def _start_cnn_prediction_loop(self):
|
||||
"""Start CNN real-time prediction loop"""
|
||||
"""Start CNN real-time prediction loop with cold start training mode"""
|
||||
try:
|
||||
if not self.cnn_adapter:
|
||||
logger.warning("CNN adapter not available, skipping prediction loop")
|
||||
return
|
||||
|
||||
def cnn_prediction_worker():
|
||||
"""Worker thread for CNN predictions"""
|
||||
logger.info("CNN prediction worker started")
|
||||
"""Worker thread for CNN predictions with cold start training"""
|
||||
logger.info("CNN prediction worker started in COLD START mode")
|
||||
logger.info("Mode: Inference every 10s + Training after each inference")
|
||||
|
||||
previous_predictions = {} # Store previous predictions for training
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Make predictions for primary symbols
|
||||
for symbol in ['ETH/USDT', 'BTC/USDT']:
|
||||
prediction = self._get_cnn_prediction(symbol)
|
||||
# Get current prediction
|
||||
current_prediction = self._get_cnn_prediction(symbol)
|
||||
|
||||
if prediction:
|
||||
if current_prediction:
|
||||
# Store prediction for dashboard display
|
||||
if not hasattr(self, 'cnn_predictions'):
|
||||
self.cnn_predictions = {}
|
||||
|
||||
self.cnn_predictions[symbol] = prediction
|
||||
self.cnn_predictions[symbol] = current_prediction
|
||||
|
||||
# Add to training data if confidence is high enough
|
||||
if prediction['confidence'] > 0.7:
|
||||
self._add_cnn_training_sample(symbol, prediction)
|
||||
logger.info(f"CNN prediction for {symbol}: {current_prediction['action']} ({current_prediction['confidence']:.3f}) @ {current_prediction.get('pivot_price', 'N/A')}")
|
||||
|
||||
logger.debug(f"CNN prediction for {symbol}: {prediction['action']} ({prediction['confidence']:.3f})")
|
||||
# COLD START TRAINING: Train with previous prediction if available
|
||||
if symbol in previous_predictions:
|
||||
prev_prediction = previous_predictions[symbol]
|
||||
|
||||
# Calculate reward based on price movement since last prediction
|
||||
reward = self._calculate_prediction_reward(symbol, prev_prediction, current_prediction)
|
||||
|
||||
# Add training sample with previous prediction and calculated reward
|
||||
self._add_cnn_training_sample_with_reward(symbol, prev_prediction, reward)
|
||||
|
||||
# Train the model immediately (cold start mode)
|
||||
if len(self.cnn_adapter.training_data) >= 2: # Need at least 2 samples
|
||||
training_result = self.cnn_adapter.train(epochs=1)
|
||||
logger.info(f"CNN trained for {symbol}: loss={training_result.get('loss', 0.0):.6f}, samples={training_result.get('samples', 0)}")
|
||||
|
||||
# Store current prediction for next iteration
|
||||
previous_predictions[symbol] = {
|
||||
'action': current_prediction['action'],
|
||||
'confidence': current_prediction['confidence'],
|
||||
'pivot_price': current_prediction.get('pivot_price'),
|
||||
'timestamp': current_prediction['timestamp'],
|
||||
'price_at_prediction': self._get_current_price(symbol)
|
||||
}
|
||||
|
||||
# Sleep for 1 second (1Hz prediction rate)
|
||||
time.sleep(1.0)
|
||||
# Sleep for 10 seconds (0.1Hz prediction rate for cold start)
|
||||
time.sleep(10.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN prediction worker: {e}")
|
||||
time.sleep(5.0) # Wait longer on error
|
||||
time.sleep(10.0) # Wait same interval on error
|
||||
|
||||
# Start the worker thread
|
||||
import threading
|
||||
@ -5984,7 +6058,7 @@ class CleanTradingDashboard:
|
||||
prediction_thread = threading.Thread(target=cnn_prediction_worker, daemon=True)
|
||||
prediction_thread.start()
|
||||
|
||||
logger.info("CNN real-time prediction loop started")
|
||||
logger.info("CNN real-time prediction loop started in COLD START mode (10s intervals)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting CNN prediction loop: {e}")
|
||||
@ -6041,6 +6115,67 @@ class CleanTradingDashboard:
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting price history for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def _calculate_prediction_reward(self, symbol: str, prev_prediction: Dict[str, Any], current_prediction: Dict[str, Any]) -> float:
|
||||
"""Calculate reward based on prediction accuracy for cold start training"""
|
||||
try:
|
||||
# Get price at previous prediction and current price
|
||||
prev_price = prev_prediction.get('price_at_prediction', 0.0)
|
||||
current_price = self._get_current_price(symbol)
|
||||
|
||||
if not prev_price or not current_price or prev_price <= 0 or current_price <= 0:
|
||||
return 0.0 # No reward if prices are invalid
|
||||
|
||||
# Calculate actual price movement
|
||||
price_change_pct = (current_price - prev_price) / prev_price
|
||||
|
||||
# Get previous prediction details
|
||||
prev_action = prev_prediction.get('action', 'HOLD')
|
||||
prev_confidence = prev_prediction.get('confidence', 0.0)
|
||||
|
||||
# Calculate base reward based on prediction accuracy
|
||||
base_reward = 0.0
|
||||
|
||||
if prev_action == 'BUY' and price_change_pct > 0.001: # Price went up (>0.1%)
|
||||
base_reward = price_change_pct * prev_confidence * 10.0 # Reward for correct BUY
|
||||
elif prev_action == 'SELL' and price_change_pct < -0.001: # Price went down (<-0.1%)
|
||||
base_reward = abs(price_change_pct) * prev_confidence * 10.0 # Reward for correct SELL
|
||||
elif prev_action == 'HOLD' and abs(price_change_pct) < 0.001: # Price stayed stable
|
||||
base_reward = prev_confidence * 0.5 # Small reward for correct HOLD
|
||||
else:
|
||||
# Wrong prediction - negative reward
|
||||
base_reward = -abs(price_change_pct) * prev_confidence * 5.0
|
||||
|
||||
# Bonus for high confidence correct predictions
|
||||
if base_reward > 0 and prev_confidence > 0.8:
|
||||
base_reward *= 1.5
|
||||
|
||||
# Clamp reward to reasonable range
|
||||
reward = max(-1.0, min(1.0, base_reward))
|
||||
|
||||
logger.debug(f"Reward calculation for {symbol}: {prev_action} @ {prev_price:.2f} -> {current_price:.2f} ({price_change_pct:.3%}) = {reward:.4f}")
|
||||
|
||||
return reward
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating prediction reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def _add_cnn_training_sample_with_reward(self, symbol: str, prediction: Dict[str, Any], reward: float):
|
||||
"""Add CNN training sample with calculated reward for cold start training"""
|
||||
try:
|
||||
if not self.cnn_adapter or not hasattr(self.cnn_adapter, 'add_training_sample'):
|
||||
return
|
||||
|
||||
action = prediction.get('action', 'HOLD')
|
||||
|
||||
# Add training sample with calculated reward
|
||||
self.cnn_adapter.add_training_sample(symbol, action, reward)
|
||||
|
||||
logger.debug(f"Added CNN training sample with reward: {symbol} {action} (reward: {reward:.4f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding CNN training sample with reward: {e}")
|
||||
|
||||
def _initialize_enhanced_position_sync(self):
|
||||
"""Initialize enhanced position synchronization system"""
|
||||
|
Reference in New Issue
Block a user