CNN training first working

This commit is contained in:
Dobromir Popov
2025-07-23 22:39:00 +03:00
parent 26e6ba2e1d
commit 94ee7389c4
2 changed files with 456 additions and 8 deletions

View File

@ -51,10 +51,68 @@ class EnhancedCNNAdapter:
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# Initialize model
# Initialize the model
self._initialize_model()
logger.info(f"EnhancedCNNAdapter initialized with device: {self.device}")
# Load checkpoint if available
if model_path and os.path.exists(model_path):
self._load_checkpoint(model_path)
else:
self._load_best_checkpoint()
logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
def _load_checkpoint(self, checkpoint_path: str) -> bool:
"""Load model from checkpoint path"""
try:
if self.model and os.path.exists(checkpoint_path):
success = self.model.load(checkpoint_path)
if success:
logger.info(f"Loaded model from {checkpoint_path}")
return True
else:
logger.warning(f"Failed to load model from {checkpoint_path}")
return False
else:
logger.warning(f"Checkpoint path does not exist: {checkpoint_path}")
return False
except Exception as e:
logger.error(f"Error loading checkpoint: {e}")
return False
def _load_best_checkpoint(self) -> bool:
"""Load the best available checkpoint"""
try:
return self.load_best_checkpoint()
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return False
def _create_default_output(self, symbol: str) -> ModelOutput:
"""Create default output when prediction fails"""
return create_model_output(
model_type='cnn',
model_name=self.model_name,
symbol=symbol,
action='HOLD',
confidence=0.0,
metadata={'error': 'Prediction failed, using default output'}
)
def _process_hidden_states(self, hidden_states: Dict[str, Any]) -> Dict[str, Any]:
"""Process hidden states for cross-model feeding"""
processed_states = {}
for key, value in hidden_states.items():
if isinstance(value, torch.Tensor):
# Convert tensor to numpy array
processed_states[key] = value.cpu().numpy().tolist()
else:
processed_states[key] = value
return processed_states
def _initialize_model(self):
"""Initialize the EnhancedCNN model"""

View File

@ -259,6 +259,10 @@ class CleanTradingDashboard:
self.data_provider.start_cob_collection()
logger.info("Started COB collection in data provider")
# Start CNN real-time prediction loop
self._start_cnn_prediction_loop()
logger.info("Started CNN real-time prediction loop")
# Then subscribe to updates
self.data_provider.subscribe_to_cob(self._on_cob_data_update)
logger.info("Subscribed to COB data updates from data provider")
@ -2718,6 +2722,82 @@ class CleanTradingDashboard:
logger.debug(f"Error getting enhanced training stats: {e}")
return {}
def _update_cnn_model_panel(self) -> Dict[str, Any]:
"""Update CNN model panel with real-time data and performance metrics"""
try:
if not self.cnn_adapter:
return {
'status': 'NOT_AVAILABLE',
'parameters': '0M',
'current_loss': 0.0,
'accuracy': 0.0,
'confidence': 0.0,
'last_prediction': 'N/A',
'training_samples': 0,
'inference_rate': '0.00/s'
}
# Get CNN prediction for ETH/USDT
prediction = self._get_cnn_prediction('ETH/USDT')
# Get model performance metrics
model_info = self.cnn_adapter.get_model_info() if hasattr(self.cnn_adapter, 'get_model_info') else {}
# Calculate inference rate
inference_times = getattr(self.cnn_adapter, 'inference_times', [])
if len(inference_times) > 0:
avg_inference_time = sum(inference_times[-10:]) / min(len(inference_times), 10)
inference_rate = f"{1.0/avg_inference_time:.2f}/s" if avg_inference_time > 0 else "0.00/s"
else:
inference_rate = "0.00/s"
# Get training data count
training_samples = len(getattr(self.cnn_adapter, 'training_data', []))
# Format last prediction
if prediction:
last_prediction = f"{prediction['action']} ({prediction['confidence']:.1%})"
current_confidence = prediction['confidence']
else:
last_prediction = "No prediction"
current_confidence = 0.0
# Get model status
if hasattr(self.cnn_adapter, 'model') and self.cnn_adapter.model:
if training_samples > 100:
status = 'TRAINED'
elif training_samples > 0:
status = 'TRAINING'
else:
status = 'FRESH'
else:
status = 'NOT_LOADED'
return {
'status': status,
'parameters': model_info.get('parameters', '50.0M'),
'current_loss': model_info.get('current_loss', 0.0),
'accuracy': model_info.get('accuracy', 0.0),
'confidence': current_confidence,
'last_prediction': last_prediction,
'training_samples': training_samples,
'inference_rate': inference_rate,
'last_update': datetime.now().strftime('%H:%M:%S')
}
except Exception as e:
logger.error(f"Error updating CNN model panel: {e}")
return {
'status': 'ERROR',
'parameters': '0M',
'current_loss': 0.0,
'accuracy': 0.0,
'confidence': 0.0,
'last_prediction': f'Error: {str(e)}',
'training_samples': 0,
'inference_rate': '0.00/s'
}
def _get_training_metrics(self) -> Dict:
"""Get training metrics from unified orchestrator - using orchestrator as SSOT"""
try:
@ -2751,6 +2831,19 @@ class CleanTradingDashboard:
latest_predictions = self._get_latest_model_predictions()
cnn_prediction = self._get_cnn_pivot_prediction()
# Get enhanced CNN model panel data
cnn_panel_data = self._update_cnn_model_panel()
# Update CNN model in loaded_models with real-time data
if cnn_panel_data:
model_states['cnn'].update({
'status': cnn_panel_data.get('status', 'FRESH'),
'confidence': cnn_panel_data.get('confidence', 0.0),
'last_prediction': cnn_panel_data.get('last_prediction', 'No prediction'),
'training_samples': cnn_panel_data.get('training_samples', 0),
'inference_rate': cnn_panel_data.get('inference_rate', '0.00/s')
})
# Get enhanced training statistics if available
enhanced_training_stats = self._get_enhanced_training_stats()
@ -5534,14 +5627,311 @@ class CleanTradingDashboard:
self.training_system = None
def _initialize_standardized_cnn(self):
"""Initialize StandardizedCNN model for the dashboard"""
"""Initialize Enhanced CNN model with standardized input format for the dashboard"""
try:
from NN.models.standardized_cnn import StandardizedCNN
self.standardized_cnn = StandardizedCNN(model_name="dashboard_standardized_cnn")
logger.info("StandardizedCNN model initialized for dashboard")
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
# Initialize the enhanced CNN adapter
self.cnn_adapter = EnhancedCNNAdapter(
checkpoint_dir="models/enhanced_cnn"
)
# For backward compatibility
self.standardized_cnn = self.cnn_adapter
logger.info("Enhanced CNN adapter initialized for dashboard with standardized input format")
except Exception as e:
logger.warning(f"StandardizedCNN initialization failed: {e}")
self.standardized_cnn = None
logger.warning(f"Enhanced CNN adapter initialization failed: {e}")
# Fallback to original StandardizedCNN
try:
from NN.models.standardized_cnn import StandardizedCNN
self.standardized_cnn = StandardizedCNN(model_name="dashboard_standardized_cnn")
self.cnn_adapter = None
logger.info("Fallback to StandardizedCNN model initialized for dashboard")
except Exception as e2:
logger.warning(f"StandardizedCNN fallback initialization failed: {e2}")
self.standardized_cnn = None
self.cnn_adapter = None
def _get_cnn_prediction(self, symbol: str = 'ETH/USDT') -> Optional[Dict[str, Any]]:
"""Get CNN prediction using standardized input format"""
try:
if not self.cnn_adapter:
return None
# Get standardized input data from data provider
base_data_input = self._get_base_data_input(symbol)
if not base_data_input:
logger.debug(f"No base data input available for {symbol}")
return None
# Make prediction using CNN adapter
model_output = self.cnn_adapter.predict(base_data_input)
# Convert to dictionary for dashboard use
prediction = {
'action': model_output.predictions.get('action', 'HOLD'),
'confidence': model_output.confidence,
'buy_probability': model_output.predictions.get('buy_probability', 0.0),
'sell_probability': model_output.predictions.get('sell_probability', 0.0),
'hold_probability': model_output.predictions.get('hold_probability', 0.0),
'timestamp': model_output.timestamp,
'hidden_states': model_output.hidden_states,
'metadata': model_output.metadata
}
logger.debug(f"CNN prediction for {symbol}: {prediction['action']} ({prediction['confidence']:.3f})")
return prediction
except Exception as e:
logger.error(f"Error getting CNN prediction: {e}")
return None
def _get_base_data_input(self, symbol: str = 'ETH/USDT') -> Optional['BaseDataInput']:
"""Get standardized BaseDataInput from data provider"""
try:
# Check if data provider supports standardized input
if hasattr(self.data_provider, 'get_base_data_input'):
return self.data_provider.get_base_data_input(symbol)
# Fallback: create BaseDataInput from available data
from core.data_models import BaseDataInput, OHLCVBar, COBData
# Get OHLCV data for different timeframes
ohlcv_1s = self._get_ohlcv_bars(symbol, '1s', 300)
ohlcv_1m = self._get_ohlcv_bars(symbol, '1m', 300)
ohlcv_1h = self._get_ohlcv_bars(symbol, '1h', 300)
ohlcv_1d = self._get_ohlcv_bars(symbol, '1d', 300)
# Get BTC reference data
btc_ohlcv_1s = self._get_ohlcv_bars('BTC/USDT', '1s', 300)
# Get COB data if available
cob_data = self._get_cob_data(symbol)
# Create BaseDataInput
base_data_input = BaseDataInput(
symbol=symbol,
timestamp=datetime.now(),
ohlcv_1s=ohlcv_1s,
ohlcv_1m=ohlcv_1m,
ohlcv_1h=ohlcv_1h,
ohlcv_1d=ohlcv_1d,
btc_ohlcv_1s=btc_ohlcv_1s,
cob_data=cob_data,
technical_indicators=self._get_technical_indicators(symbol),
pivot_points=self._get_pivot_points(symbol),
last_predictions={} # TODO: Add cross-model predictions
)
return base_data_input
except Exception as e:
logger.error(f"Error creating base data input: {e}")
return None
def _get_ohlcv_bars(self, symbol: str, timeframe: str, count: int) -> List['OHLCVBar']:
"""Get OHLCV bars from data provider"""
try:
from core.data_models import OHLCVBar
# Get data from data provider
df = self.data_provider.get_candles(symbol, timeframe)
if df is None or len(df) == 0:
return []
# Convert to OHLCVBar objects
bars = []
for idx, row in df.tail(count).iterrows():
bar = OHLCVBar(
symbol=symbol,
timestamp=idx if isinstance(idx, datetime) else datetime.now(),
open=float(row['open']),
high=float(row['high']),
low=float(row['low']),
close=float(row['close']),
volume=float(row['volume']),
timeframe=timeframe,
indicators={} # TODO: Add technical indicators
)
bars.append(bar)
return bars
except Exception as e:
logger.error(f"Error getting OHLCV bars for {symbol} {timeframe}: {e}")
return []
def _get_cob_data(self, symbol: str) -> Optional['COBData']:
"""Get COB data from latest cache"""
try:
if not hasattr(self, 'latest_cob_data') or symbol not in self.latest_cob_data:
return None
from core.data_models import COBData
cob_raw = self.latest_cob_data[symbol]
if not isinstance(cob_raw, dict) or 'stats' not in cob_raw:
return None
stats = cob_raw['stats']
current_price = stats.get('mid_price', 0.0)
# Create price buckets (simplified for now)
bucket_size = 1.0 if 'ETH' in symbol else 10.0
price_buckets = {}
# Create ±20 buckets around current price
for i in range(-20, 21):
price = current_price + (i * bucket_size)
price_buckets[price] = {
'bid_volume': 0.0,
'ask_volume': 0.0,
'total_volume': 0.0,
'imbalance': stats.get('imbalance', 0.0)
}
cob_data = COBData(
symbol=symbol,
timestamp=cob_raw.get('timestamp', datetime.now()),
current_price=current_price,
bucket_size=bucket_size,
price_buckets=price_buckets,
bid_ask_imbalance={current_price: stats.get('imbalance', 0.0)},
volume_weighted_prices={current_price: current_price},
order_flow_metrics=stats,
ma_1s_imbalance={current_price: stats.get('imbalance', 0.0)},
ma_5s_imbalance={current_price: stats.get('imbalance_5s', 0.0)},
ma_15s_imbalance={current_price: stats.get('imbalance_15s', 0.0)},
ma_60s_imbalance={current_price: stats.get('imbalance_60s', 0.0)}
)
return cob_data
except Exception as e:
logger.error(f"Error creating COB data for {symbol}: {e}")
return None
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
"""Get technical indicators for symbol"""
try:
# TODO: Implement technical indicators calculation
return {}
except Exception as e:
logger.error(f"Error getting technical indicators for {symbol}: {e}")
return {}
def _get_pivot_points(self, symbol: str) -> List['PivotPoint']:
"""Get pivot points for symbol"""
try:
# TODO: Implement pivot points calculation
return []
except Exception as e:
logger.error(f"Error getting pivot points for {symbol}: {e}")
return []
def _start_cnn_prediction_loop(self):
"""Start CNN real-time prediction loop"""
try:
if not self.cnn_adapter:
logger.warning("CNN adapter not available, skipping prediction loop")
return
def cnn_prediction_worker():
"""Worker thread for CNN predictions"""
logger.info("CNN prediction worker started")
while True:
try:
# Make predictions for primary symbols
for symbol in ['ETH/USDT', 'BTC/USDT']:
prediction = self._get_cnn_prediction(symbol)
if prediction:
# Store prediction for dashboard display
if not hasattr(self, 'cnn_predictions'):
self.cnn_predictions = {}
self.cnn_predictions[symbol] = prediction
# Add to training data if confidence is high enough
if prediction['confidence'] > 0.7:
self._add_cnn_training_sample(symbol, prediction)
logger.debug(f"CNN prediction for {symbol}: {prediction['action']} ({prediction['confidence']:.3f})")
# Sleep for 1 second (1Hz prediction rate)
time.sleep(1.0)
except Exception as e:
logger.error(f"Error in CNN prediction worker: {e}")
time.sleep(5.0) # Wait longer on error
# Start the worker thread
import threading
import time
prediction_thread = threading.Thread(target=cnn_prediction_worker, daemon=True)
prediction_thread.start()
logger.info("CNN real-time prediction loop started")
except Exception as e:
logger.error(f"Error starting CNN prediction loop: {e}")
def _add_cnn_training_sample(self, symbol: str, prediction: Dict[str, Any]):
"""Add CNN training sample based on prediction outcome"""
try:
if not self.cnn_adapter or not hasattr(self.cnn_adapter, 'add_training_sample'):
return
# Get current price for reward calculation
current_price = self._get_current_price(symbol)
if not current_price:
return
# Calculate reward based on prediction accuracy (simplified)
# In a real implementation, this would be based on actual market movement
action = prediction['action']
confidence = prediction['confidence']
# Simple reward: higher confidence predictions get higher rewards
base_reward = confidence * 0.1
# Add some market context (price movement direction)
price_history = self._get_recent_price_history(symbol, 10)
if len(price_history) >= 2:
price_change = (price_history[-1] - price_history[-2]) / price_history[-2]
# Reward if prediction aligns with price movement
if (action == 'BUY' and price_change > 0) or (action == 'SELL' and price_change < 0):
reward = base_reward * 1.5 # Bonus for correct direction
else:
reward = base_reward * 0.5 # Penalty for wrong direction
else:
reward = base_reward
# Add training sample
self.cnn_adapter.add_training_sample(symbol, action, reward)
logger.debug(f"Added CNN training sample: {symbol} {action} (reward: {reward:.4f})")
except Exception as e:
logger.error(f"Error adding CNN training sample: {e}")
def _get_recent_price_history(self, symbol: str, count: int) -> List[float]:
"""Get recent price history for reward calculation"""
try:
df = self.data_provider.get_candles(symbol, '1s')
if df is None or len(df) == 0:
return []
return df['close'].tail(count).tolist()
except Exception as e:
logger.error(f"Error getting price history for {symbol}: {e}")
return []
def _initialize_enhanced_position_sync(self):
"""Initialize enhanced position synchronization system"""