better pivots and CNN wip training

This commit is contained in:
Dobromir Popov
2025-05-30 17:14:06 +03:00
parent 2a148b0ac6
commit 774debbf75
8 changed files with 314 additions and 32 deletions

View File

@ -36,11 +36,50 @@ from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from enum import Enum
# Setup logger immediately after logging import
logger = logging.getLogger(__name__)
try:
from NN.models.cnn_model import CNNModel
except ImportError:
CNNModel = None # Allow running without TF/CNN if not installed or path issue
print("Warning: CNNModel could not be imported. CNN-based pivot prediction/training will be disabled.")
try:
# Fallback import path
import sys
import os
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
from NN.models.cnn_model import CNNModel
except ImportError:
# Create fallback CNN model for development/testing
class CNNModel:
def __init__(self, input_shape=(900, 50), output_size=10):
self.input_shape = input_shape
self.output_size = output_size
self.model = None
logger.info(f"Fallback CNN Model initialized: input_shape={input_shape}, output_size={output_size}")
def build_model(self, **kwargs):
logger.info("Fallback CNN Model build_model called - using dummy model")
return self
def predict(self, X):
# Return dummy predictions for testing
batch_size = X.shape[0] if hasattr(X, 'shape') else 1
if self.output_size == 1:
pred_class = np.random.choice([0, 1], size=batch_size)
pred_proba = np.random.random(batch_size)
else:
pred_class = np.random.randint(0, self.output_size, size=batch_size)
pred_proba = np.random.random((batch_size, self.output_size))
logger.debug(f"Fallback CNN prediction: class={pred_class}, proba_shape={np.array(pred_proba).shape}")
return pred_class, pred_proba
def fit(self, X, y, **kwargs):
logger.info(f"Fallback CNN training: X_shape={X.shape}, y_shape={y.shape}")
return self
logger.warning("Using fallback CNN model - CNN training will work but with dummy predictions")
try:
from core.unified_data_stream import TrainingDataPacket
@ -48,7 +87,6 @@ except ImportError:
TrainingDataPacket = None
print("Warning: TrainingDataPacket could not be imported. Using fallback interface.")
logger = logging.getLogger(__name__)
class TrendDirection(Enum):
UP = "up"
@ -137,6 +175,8 @@ class WilliamsMarketStructure:
self.trend_cache = {}
self.enable_cnn_feature = enable_cnn_feature and CNNModel is not None
# Force enable CNN for development - always True now with fallback model
self.enable_cnn_feature = True
self.cnn_model: Optional[CNNModel] = None
self.previous_pivot_details_for_cnn: Optional[Dict[str, Any]] = None # Stores {'features': X, 'pivot': SwingPoint}
self.training_data_provider = training_data_provider # Access to TrainingDataPacket