CNN training first working
This commit is contained in:
@ -51,10 +51,68 @@ class EnhancedCNNAdapter:
|
|||||||
# Create checkpoint directory if it doesn't exist
|
# Create checkpoint directory if it doesn't exist
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
# Initialize model
|
# Initialize the model
|
||||||
self._initialize_model()
|
self._initialize_model()
|
||||||
|
|
||||||
logger.info(f"EnhancedCNNAdapter initialized with device: {self.device}")
|
# Load checkpoint if available
|
||||||
|
if model_path and os.path.exists(model_path):
|
||||||
|
self._load_checkpoint(model_path)
|
||||||
|
else:
|
||||||
|
self._load_best_checkpoint()
|
||||||
|
|
||||||
|
logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
|
||||||
|
|
||||||
|
def _load_checkpoint(self, checkpoint_path: str) -> bool:
|
||||||
|
"""Load model from checkpoint path"""
|
||||||
|
try:
|
||||||
|
if self.model and os.path.exists(checkpoint_path):
|
||||||
|
success = self.model.load(checkpoint_path)
|
||||||
|
if success:
|
||||||
|
logger.info(f"Loaded model from {checkpoint_path}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to load model from {checkpoint_path}")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.warning(f"Checkpoint path does not exist: {checkpoint_path}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading checkpoint: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _load_best_checkpoint(self) -> bool:
|
||||||
|
"""Load the best available checkpoint"""
|
||||||
|
try:
|
||||||
|
return self.load_best_checkpoint()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading best checkpoint: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _create_default_output(self, symbol: str) -> ModelOutput:
|
||||||
|
"""Create default output when prediction fails"""
|
||||||
|
return create_model_output(
|
||||||
|
model_type='cnn',
|
||||||
|
model_name=self.model_name,
|
||||||
|
symbol=symbol,
|
||||||
|
action='HOLD',
|
||||||
|
confidence=0.0,
|
||||||
|
metadata={'error': 'Prediction failed, using default output'}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_hidden_states(self, hidden_states: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Process hidden states for cross-model feeding"""
|
||||||
|
processed_states = {}
|
||||||
|
|
||||||
|
for key, value in hidden_states.items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
# Convert tensor to numpy array
|
||||||
|
processed_states[key] = value.cpu().numpy().tolist()
|
||||||
|
else:
|
||||||
|
processed_states[key] = value
|
||||||
|
|
||||||
|
return processed_states
|
||||||
|
|
||||||
def _initialize_model(self):
|
def _initialize_model(self):
|
||||||
"""Initialize the EnhancedCNN model"""
|
"""Initialize the EnhancedCNN model"""
|
||||||
|
@ -259,6 +259,10 @@ class CleanTradingDashboard:
|
|||||||
self.data_provider.start_cob_collection()
|
self.data_provider.start_cob_collection()
|
||||||
logger.info("Started COB collection in data provider")
|
logger.info("Started COB collection in data provider")
|
||||||
|
|
||||||
|
# Start CNN real-time prediction loop
|
||||||
|
self._start_cnn_prediction_loop()
|
||||||
|
logger.info("Started CNN real-time prediction loop")
|
||||||
|
|
||||||
# Then subscribe to updates
|
# Then subscribe to updates
|
||||||
self.data_provider.subscribe_to_cob(self._on_cob_data_update)
|
self.data_provider.subscribe_to_cob(self._on_cob_data_update)
|
||||||
logger.info("Subscribed to COB data updates from data provider")
|
logger.info("Subscribed to COB data updates from data provider")
|
||||||
@ -2718,6 +2722,82 @@ class CleanTradingDashboard:
|
|||||||
logger.debug(f"Error getting enhanced training stats: {e}")
|
logger.debug(f"Error getting enhanced training stats: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def _update_cnn_model_panel(self) -> Dict[str, Any]:
|
||||||
|
"""Update CNN model panel with real-time data and performance metrics"""
|
||||||
|
try:
|
||||||
|
if not self.cnn_adapter:
|
||||||
|
return {
|
||||||
|
'status': 'NOT_AVAILABLE',
|
||||||
|
'parameters': '0M',
|
||||||
|
'current_loss': 0.0,
|
||||||
|
'accuracy': 0.0,
|
||||||
|
'confidence': 0.0,
|
||||||
|
'last_prediction': 'N/A',
|
||||||
|
'training_samples': 0,
|
||||||
|
'inference_rate': '0.00/s'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get CNN prediction for ETH/USDT
|
||||||
|
prediction = self._get_cnn_prediction('ETH/USDT')
|
||||||
|
|
||||||
|
# Get model performance metrics
|
||||||
|
model_info = self.cnn_adapter.get_model_info() if hasattr(self.cnn_adapter, 'get_model_info') else {}
|
||||||
|
|
||||||
|
# Calculate inference rate
|
||||||
|
inference_times = getattr(self.cnn_adapter, 'inference_times', [])
|
||||||
|
if len(inference_times) > 0:
|
||||||
|
avg_inference_time = sum(inference_times[-10:]) / min(len(inference_times), 10)
|
||||||
|
inference_rate = f"{1.0/avg_inference_time:.2f}/s" if avg_inference_time > 0 else "0.00/s"
|
||||||
|
else:
|
||||||
|
inference_rate = "0.00/s"
|
||||||
|
|
||||||
|
# Get training data count
|
||||||
|
training_samples = len(getattr(self.cnn_adapter, 'training_data', []))
|
||||||
|
|
||||||
|
# Format last prediction
|
||||||
|
if prediction:
|
||||||
|
last_prediction = f"{prediction['action']} ({prediction['confidence']:.1%})"
|
||||||
|
current_confidence = prediction['confidence']
|
||||||
|
else:
|
||||||
|
last_prediction = "No prediction"
|
||||||
|
current_confidence = 0.0
|
||||||
|
|
||||||
|
# Get model status
|
||||||
|
if hasattr(self.cnn_adapter, 'model') and self.cnn_adapter.model:
|
||||||
|
if training_samples > 100:
|
||||||
|
status = 'TRAINED'
|
||||||
|
elif training_samples > 0:
|
||||||
|
status = 'TRAINING'
|
||||||
|
else:
|
||||||
|
status = 'FRESH'
|
||||||
|
else:
|
||||||
|
status = 'NOT_LOADED'
|
||||||
|
|
||||||
|
return {
|
||||||
|
'status': status,
|
||||||
|
'parameters': model_info.get('parameters', '50.0M'),
|
||||||
|
'current_loss': model_info.get('current_loss', 0.0),
|
||||||
|
'accuracy': model_info.get('accuracy', 0.0),
|
||||||
|
'confidence': current_confidence,
|
||||||
|
'last_prediction': last_prediction,
|
||||||
|
'training_samples': training_samples,
|
||||||
|
'inference_rate': inference_rate,
|
||||||
|
'last_update': datetime.now().strftime('%H:%M:%S')
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating CNN model panel: {e}")
|
||||||
|
return {
|
||||||
|
'status': 'ERROR',
|
||||||
|
'parameters': '0M',
|
||||||
|
'current_loss': 0.0,
|
||||||
|
'accuracy': 0.0,
|
||||||
|
'confidence': 0.0,
|
||||||
|
'last_prediction': f'Error: {str(e)}',
|
||||||
|
'training_samples': 0,
|
||||||
|
'inference_rate': '0.00/s'
|
||||||
|
}
|
||||||
|
|
||||||
def _get_training_metrics(self) -> Dict:
|
def _get_training_metrics(self) -> Dict:
|
||||||
"""Get training metrics from unified orchestrator - using orchestrator as SSOT"""
|
"""Get training metrics from unified orchestrator - using orchestrator as SSOT"""
|
||||||
try:
|
try:
|
||||||
@ -2751,6 +2831,19 @@ class CleanTradingDashboard:
|
|||||||
latest_predictions = self._get_latest_model_predictions()
|
latest_predictions = self._get_latest_model_predictions()
|
||||||
cnn_prediction = self._get_cnn_pivot_prediction()
|
cnn_prediction = self._get_cnn_pivot_prediction()
|
||||||
|
|
||||||
|
# Get enhanced CNN model panel data
|
||||||
|
cnn_panel_data = self._update_cnn_model_panel()
|
||||||
|
|
||||||
|
# Update CNN model in loaded_models with real-time data
|
||||||
|
if cnn_panel_data:
|
||||||
|
model_states['cnn'].update({
|
||||||
|
'status': cnn_panel_data.get('status', 'FRESH'),
|
||||||
|
'confidence': cnn_panel_data.get('confidence', 0.0),
|
||||||
|
'last_prediction': cnn_panel_data.get('last_prediction', 'No prediction'),
|
||||||
|
'training_samples': cnn_panel_data.get('training_samples', 0),
|
||||||
|
'inference_rate': cnn_panel_data.get('inference_rate', '0.00/s')
|
||||||
|
})
|
||||||
|
|
||||||
# Get enhanced training statistics if available
|
# Get enhanced training statistics if available
|
||||||
enhanced_training_stats = self._get_enhanced_training_stats()
|
enhanced_training_stats = self._get_enhanced_training_stats()
|
||||||
|
|
||||||
@ -5534,14 +5627,311 @@ class CleanTradingDashboard:
|
|||||||
self.training_system = None
|
self.training_system = None
|
||||||
|
|
||||||
def _initialize_standardized_cnn(self):
|
def _initialize_standardized_cnn(self):
|
||||||
"""Initialize StandardizedCNN model for the dashboard"""
|
"""Initialize Enhanced CNN model with standardized input format for the dashboard"""
|
||||||
try:
|
try:
|
||||||
from NN.models.standardized_cnn import StandardizedCNN
|
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||||
self.standardized_cnn = StandardizedCNN(model_name="dashboard_standardized_cnn")
|
|
||||||
logger.info("StandardizedCNN model initialized for dashboard")
|
# Initialize the enhanced CNN adapter
|
||||||
|
self.cnn_adapter = EnhancedCNNAdapter(
|
||||||
|
checkpoint_dir="models/enhanced_cnn"
|
||||||
|
)
|
||||||
|
|
||||||
|
# For backward compatibility
|
||||||
|
self.standardized_cnn = self.cnn_adapter
|
||||||
|
|
||||||
|
logger.info("Enhanced CNN adapter initialized for dashboard with standardized input format")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"StandardizedCNN initialization failed: {e}")
|
logger.warning(f"Enhanced CNN adapter initialization failed: {e}")
|
||||||
self.standardized_cnn = None
|
|
||||||
|
# Fallback to original StandardizedCNN
|
||||||
|
try:
|
||||||
|
from NN.models.standardized_cnn import StandardizedCNN
|
||||||
|
self.standardized_cnn = StandardizedCNN(model_name="dashboard_standardized_cnn")
|
||||||
|
self.cnn_adapter = None
|
||||||
|
logger.info("Fallback to StandardizedCNN model initialized for dashboard")
|
||||||
|
except Exception as e2:
|
||||||
|
logger.warning(f"StandardizedCNN fallback initialization failed: {e2}")
|
||||||
|
self.standardized_cnn = None
|
||||||
|
self.cnn_adapter = None
|
||||||
|
|
||||||
|
def _get_cnn_prediction(self, symbol: str = 'ETH/USDT') -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get CNN prediction using standardized input format"""
|
||||||
|
try:
|
||||||
|
if not self.cnn_adapter:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get standardized input data from data provider
|
||||||
|
base_data_input = self._get_base_data_input(symbol)
|
||||||
|
if not base_data_input:
|
||||||
|
logger.debug(f"No base data input available for {symbol}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Make prediction using CNN adapter
|
||||||
|
model_output = self.cnn_adapter.predict(base_data_input)
|
||||||
|
|
||||||
|
# Convert to dictionary for dashboard use
|
||||||
|
prediction = {
|
||||||
|
'action': model_output.predictions.get('action', 'HOLD'),
|
||||||
|
'confidence': model_output.confidence,
|
||||||
|
'buy_probability': model_output.predictions.get('buy_probability', 0.0),
|
||||||
|
'sell_probability': model_output.predictions.get('sell_probability', 0.0),
|
||||||
|
'hold_probability': model_output.predictions.get('hold_probability', 0.0),
|
||||||
|
'timestamp': model_output.timestamp,
|
||||||
|
'hidden_states': model_output.hidden_states,
|
||||||
|
'metadata': model_output.metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"CNN prediction for {symbol}: {prediction['action']} ({prediction['confidence']:.3f})")
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting CNN prediction: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_base_data_input(self, symbol: str = 'ETH/USDT') -> Optional['BaseDataInput']:
|
||||||
|
"""Get standardized BaseDataInput from data provider"""
|
||||||
|
try:
|
||||||
|
# Check if data provider supports standardized input
|
||||||
|
if hasattr(self.data_provider, 'get_base_data_input'):
|
||||||
|
return self.data_provider.get_base_data_input(symbol)
|
||||||
|
|
||||||
|
# Fallback: create BaseDataInput from available data
|
||||||
|
from core.data_models import BaseDataInput, OHLCVBar, COBData
|
||||||
|
|
||||||
|
# Get OHLCV data for different timeframes
|
||||||
|
ohlcv_1s = self._get_ohlcv_bars(symbol, '1s', 300)
|
||||||
|
ohlcv_1m = self._get_ohlcv_bars(symbol, '1m', 300)
|
||||||
|
ohlcv_1h = self._get_ohlcv_bars(symbol, '1h', 300)
|
||||||
|
ohlcv_1d = self._get_ohlcv_bars(symbol, '1d', 300)
|
||||||
|
|
||||||
|
# Get BTC reference data
|
||||||
|
btc_ohlcv_1s = self._get_ohlcv_bars('BTC/USDT', '1s', 300)
|
||||||
|
|
||||||
|
# Get COB data if available
|
||||||
|
cob_data = self._get_cob_data(symbol)
|
||||||
|
|
||||||
|
# Create BaseDataInput
|
||||||
|
base_data_input = BaseDataInput(
|
||||||
|
symbol=symbol,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
ohlcv_1s=ohlcv_1s,
|
||||||
|
ohlcv_1m=ohlcv_1m,
|
||||||
|
ohlcv_1h=ohlcv_1h,
|
||||||
|
ohlcv_1d=ohlcv_1d,
|
||||||
|
btc_ohlcv_1s=btc_ohlcv_1s,
|
||||||
|
cob_data=cob_data,
|
||||||
|
technical_indicators=self._get_technical_indicators(symbol),
|
||||||
|
pivot_points=self._get_pivot_points(symbol),
|
||||||
|
last_predictions={} # TODO: Add cross-model predictions
|
||||||
|
)
|
||||||
|
|
||||||
|
return base_data_input
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating base data input: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_ohlcv_bars(self, symbol: str, timeframe: str, count: int) -> List['OHLCVBar']:
|
||||||
|
"""Get OHLCV bars from data provider"""
|
||||||
|
try:
|
||||||
|
from core.data_models import OHLCVBar
|
||||||
|
|
||||||
|
# Get data from data provider
|
||||||
|
df = self.data_provider.get_candles(symbol, timeframe)
|
||||||
|
if df is None or len(df) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Convert to OHLCVBar objects
|
||||||
|
bars = []
|
||||||
|
for idx, row in df.tail(count).iterrows():
|
||||||
|
bar = OHLCVBar(
|
||||||
|
symbol=symbol,
|
||||||
|
timestamp=idx if isinstance(idx, datetime) else datetime.now(),
|
||||||
|
open=float(row['open']),
|
||||||
|
high=float(row['high']),
|
||||||
|
low=float(row['low']),
|
||||||
|
close=float(row['close']),
|
||||||
|
volume=float(row['volume']),
|
||||||
|
timeframe=timeframe,
|
||||||
|
indicators={} # TODO: Add technical indicators
|
||||||
|
)
|
||||||
|
bars.append(bar)
|
||||||
|
|
||||||
|
return bars
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting OHLCV bars for {symbol} {timeframe}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _get_cob_data(self, symbol: str) -> Optional['COBData']:
|
||||||
|
"""Get COB data from latest cache"""
|
||||||
|
try:
|
||||||
|
if not hasattr(self, 'latest_cob_data') or symbol not in self.latest_cob_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
from core.data_models import COBData
|
||||||
|
|
||||||
|
cob_raw = self.latest_cob_data[symbol]
|
||||||
|
if not isinstance(cob_raw, dict) or 'stats' not in cob_raw:
|
||||||
|
return None
|
||||||
|
|
||||||
|
stats = cob_raw['stats']
|
||||||
|
current_price = stats.get('mid_price', 0.0)
|
||||||
|
|
||||||
|
# Create price buckets (simplified for now)
|
||||||
|
bucket_size = 1.0 if 'ETH' in symbol else 10.0
|
||||||
|
price_buckets = {}
|
||||||
|
|
||||||
|
# Create ±20 buckets around current price
|
||||||
|
for i in range(-20, 21):
|
||||||
|
price = current_price + (i * bucket_size)
|
||||||
|
price_buckets[price] = {
|
||||||
|
'bid_volume': 0.0,
|
||||||
|
'ask_volume': 0.0,
|
||||||
|
'total_volume': 0.0,
|
||||||
|
'imbalance': stats.get('imbalance', 0.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
cob_data = COBData(
|
||||||
|
symbol=symbol,
|
||||||
|
timestamp=cob_raw.get('timestamp', datetime.now()),
|
||||||
|
current_price=current_price,
|
||||||
|
bucket_size=bucket_size,
|
||||||
|
price_buckets=price_buckets,
|
||||||
|
bid_ask_imbalance={current_price: stats.get('imbalance', 0.0)},
|
||||||
|
volume_weighted_prices={current_price: current_price},
|
||||||
|
order_flow_metrics=stats,
|
||||||
|
ma_1s_imbalance={current_price: stats.get('imbalance', 0.0)},
|
||||||
|
ma_5s_imbalance={current_price: stats.get('imbalance_5s', 0.0)},
|
||||||
|
ma_15s_imbalance={current_price: stats.get('imbalance_15s', 0.0)},
|
||||||
|
ma_60s_imbalance={current_price: stats.get('imbalance_60s', 0.0)}
|
||||||
|
)
|
||||||
|
|
||||||
|
return cob_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating COB data for {symbol}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
|
||||||
|
"""Get technical indicators for symbol"""
|
||||||
|
try:
|
||||||
|
# TODO: Implement technical indicators calculation
|
||||||
|
return {}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting technical indicators for {symbol}: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _get_pivot_points(self, symbol: str) -> List['PivotPoint']:
|
||||||
|
"""Get pivot points for symbol"""
|
||||||
|
try:
|
||||||
|
# TODO: Implement pivot points calculation
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting pivot points for {symbol}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _start_cnn_prediction_loop(self):
|
||||||
|
"""Start CNN real-time prediction loop"""
|
||||||
|
try:
|
||||||
|
if not self.cnn_adapter:
|
||||||
|
logger.warning("CNN adapter not available, skipping prediction loop")
|
||||||
|
return
|
||||||
|
|
||||||
|
def cnn_prediction_worker():
|
||||||
|
"""Worker thread for CNN predictions"""
|
||||||
|
logger.info("CNN prediction worker started")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Make predictions for primary symbols
|
||||||
|
for symbol in ['ETH/USDT', 'BTC/USDT']:
|
||||||
|
prediction = self._get_cnn_prediction(symbol)
|
||||||
|
|
||||||
|
if prediction:
|
||||||
|
# Store prediction for dashboard display
|
||||||
|
if not hasattr(self, 'cnn_predictions'):
|
||||||
|
self.cnn_predictions = {}
|
||||||
|
|
||||||
|
self.cnn_predictions[symbol] = prediction
|
||||||
|
|
||||||
|
# Add to training data if confidence is high enough
|
||||||
|
if prediction['confidence'] > 0.7:
|
||||||
|
self._add_cnn_training_sample(symbol, prediction)
|
||||||
|
|
||||||
|
logger.debug(f"CNN prediction for {symbol}: {prediction['action']} ({prediction['confidence']:.3f})")
|
||||||
|
|
||||||
|
# Sleep for 1 second (1Hz prediction rate)
|
||||||
|
time.sleep(1.0)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in CNN prediction worker: {e}")
|
||||||
|
time.sleep(5.0) # Wait longer on error
|
||||||
|
|
||||||
|
# Start the worker thread
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
prediction_thread = threading.Thread(target=cnn_prediction_worker, daemon=True)
|
||||||
|
prediction_thread.start()
|
||||||
|
|
||||||
|
logger.info("CNN real-time prediction loop started")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error starting CNN prediction loop: {e}")
|
||||||
|
|
||||||
|
def _add_cnn_training_sample(self, symbol: str, prediction: Dict[str, Any]):
|
||||||
|
"""Add CNN training sample based on prediction outcome"""
|
||||||
|
try:
|
||||||
|
if not self.cnn_adapter or not hasattr(self.cnn_adapter, 'add_training_sample'):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get current price for reward calculation
|
||||||
|
current_price = self._get_current_price(symbol)
|
||||||
|
if not current_price:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate reward based on prediction accuracy (simplified)
|
||||||
|
# In a real implementation, this would be based on actual market movement
|
||||||
|
action = prediction['action']
|
||||||
|
confidence = prediction['confidence']
|
||||||
|
|
||||||
|
# Simple reward: higher confidence predictions get higher rewards
|
||||||
|
base_reward = confidence * 0.1
|
||||||
|
|
||||||
|
# Add some market context (price movement direction)
|
||||||
|
price_history = self._get_recent_price_history(symbol, 10)
|
||||||
|
if len(price_history) >= 2:
|
||||||
|
price_change = (price_history[-1] - price_history[-2]) / price_history[-2]
|
||||||
|
|
||||||
|
# Reward if prediction aligns with price movement
|
||||||
|
if (action == 'BUY' and price_change > 0) or (action == 'SELL' and price_change < 0):
|
||||||
|
reward = base_reward * 1.5 # Bonus for correct direction
|
||||||
|
else:
|
||||||
|
reward = base_reward * 0.5 # Penalty for wrong direction
|
||||||
|
else:
|
||||||
|
reward = base_reward
|
||||||
|
|
||||||
|
# Add training sample
|
||||||
|
self.cnn_adapter.add_training_sample(symbol, action, reward)
|
||||||
|
|
||||||
|
logger.debug(f"Added CNN training sample: {symbol} {action} (reward: {reward:.4f})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding CNN training sample: {e}")
|
||||||
|
|
||||||
|
def _get_recent_price_history(self, symbol: str, count: int) -> List[float]:
|
||||||
|
"""Get recent price history for reward calculation"""
|
||||||
|
try:
|
||||||
|
df = self.data_provider.get_candles(symbol, '1s')
|
||||||
|
if df is None or len(df) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return df['close'].tail(count).tolist()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting price history for {symbol}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
def _initialize_enhanced_position_sync(self):
|
def _initialize_enhanced_position_sync(self):
|
||||||
"""Initialize enhanced position synchronization system"""
|
"""Initialize enhanced position synchronization system"""
|
||||||
|
Reference in New Issue
Block a user