gogo2/NN/main.py
2025-03-31 14:22:33 +03:00

244 lines
9.5 KiB
Python

"""
Neural Network Trading System Main Module (Compatibility Layer)
This module serves as a compatibility layer for the realtime.py module.
It re-exports the functionality from realtime_main.py that is needed by realtime.py.
"""
import os
import sys
import logging
from datetime import datetime
import numpy as np
# Configure logging
logger = logging.getLogger('NN')
logger.setLevel(logging.INFO)
# Re-export everything from realtime_main.py
from .realtime_main import (
parse_arguments,
realtime,
train,
predict
)
# Create a class that realtime.py expects
class NeuralNetworkOrchestrator:
"""
Orchestrates the neural network operations.
"""
def __init__(self, config):
"""
Initialize the orchestrator with configuration.
Args:
config (dict): Configuration parameters
"""
self.config = config
self.symbol = config.get('symbol', 'BTC/USDT')
self.timeframes = config.get('timeframes', ['1m', '5m', '1h', '4h'])
self.window_size = config.get('window_size', 20)
self.n_features = config.get('n_features', 5)
self.output_size = config.get('output_size', 3)
self.model_dir = config.get('model_dir', 'NN/models/saved')
self.data_dir = config.get('data_dir', 'NN/data')
self.model = None
self.data_interface = None
# Initialize with default values in case imports fail
self.model_initialized = False
self.data_initialized = False
# Import necessary modules dynamically
try:
from .utils.data_interface import DataInterface
# Initialize data interface
self.data_interface = DataInterface(
symbol=self.symbol,
timeframes=self.timeframes
)
self.data_initialized = True
logger.info(f"Data interface initialized for {self.symbol}")
try:
from .models.cnn_model_pytorch import CNNModelPyTorch as Model
# Initialize model
feature_count = self.data_interface.get_feature_count() if hasattr(self.data_interface, 'get_feature_count') else 5
try:
# First try with expected parameters
self.model = Model(
window_size=self.window_size,
num_features=feature_count,
output_size=self.output_size,
timeframes=self.timeframes
)
except TypeError as e:
logger.warning(f"TypeError in model initialization with num_features: {str(e)}")
# Try alternate parameter naming
try:
self.model = Model(
input_shape=(self.window_size, feature_count),
output_size=self.output_size
)
logger.info("Model initialized with alternate parameters")
except Exception as ex:
logger.error(f"Failed to initialize model with alternate parameters: {str(ex)}")
self.model = DummyModel()
# Try to load the best model
self._load_model()
self.model_initialized = True
logger.info("Model initialized successfully")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
import traceback
logger.error(traceback.format_exc())
self.model = DummyModel()
logger.info(f"NeuralNetworkOrchestrator initialized with config: {config}")
except Exception as e:
logger.error(f"Error initializing NeuralNetworkOrchestrator: {str(e)}")
import traceback
logger.error(traceback.format_exc())
self.model = DummyModel()
def _load_model(self):
"""Load the best trained model from available files"""
try:
model_paths = [
os.path.join(self.model_dir, "dqn_agent_best_policy.pt"),
os.path.join(self.model_dir, "cnn_model_best.pt"),
os.path.join("models/saved", "dqn_agent_best_policy.pt"),
os.path.join("models/saved", "cnn_model_best.pt")
]
for model_path in model_paths:
if os.path.exists(model_path):
try:
self.model.load(model_path)
logger.info(f"Loaded model from {model_path}")
return True
except Exception as e:
logger.warning(f"Failed to load model from {model_path}: {str(e)}")
continue
logger.warning("No trained model found, using dummy model")
self.model = DummyModel()
return False
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
self.model = DummyModel()
return False
def run_inference_pipeline(self, model_type='cnn', timeframe='1h'):
"""
Run the inference pipeline using the trained model.
Args:
model_type (str): Type of model to use (cnn, transformer, etc.)
timeframe (str): Timeframe to use for inference
Returns:
dict: Inference result
"""
try:
# Check if we have a model
if not hasattr(self, 'model') or self.model is None:
logger.warning("No model available, initializing dummy model")
self.model = DummyModel()
# Check if we have a data interface
if not hasattr(self, 'data_interface') or self.data_interface is None:
logger.warning("No data interface available")
# Return a dummy prediction
return self._get_dummy_prediction()
# Prepare input data for the selected timeframe
X, timestamp = self.data_interface.prepare_realtime_input(
timeframe=timeframe,
n_candles=self.window_size + 10, # Extra candles for safety
window_size=self.window_size
)
if X is None:
logger.warning(f"No data available for {self.symbol}")
return self._get_dummy_prediction()
# Get model predictions
action_probs, price_pred = self.model.predict(X)
# Convert predictions to action
action_idx = np.argmax(action_probs) if hasattr(action_probs, 'argmax') else 1 # Default to HOLD
action_names = ['SELL', 'HOLD', 'BUY']
action = action_names[action_idx]
# Format timestamp
if not isinstance(timestamp, str):
try:
if hasattr(timestamp, 'isoformat'): # If it's already a datetime-like object
timestamp = timestamp.isoformat()
else: # If it's a numeric timestamp
timestamp = datetime.fromtimestamp(float(timestamp)/1000).isoformat()
except (TypeError, ValueError):
timestamp = datetime.now().isoformat()
# Return result
result = {
'timestamp': timestamp,
'action': action,
'action_index': int(action_idx),
'probability': float(action_probs[action_idx]) if hasattr(action_probs, '__getitem__') else 0.33,
'probabilities': {name: float(prob) for name, prob in zip(action_names, action_probs)} if hasattr(action_probs, '__iter__') else {'SELL': 0.33, 'HOLD': 0.34, 'BUY': 0.33},
'price_prediction': float(price_pred) if price_pred is not None else None
}
logger.info(f"Inference result: {result}")
return result
except Exception as e:
logger.error(f"Error in inference pipeline: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return self._get_dummy_prediction()
def _get_dummy_prediction(self):
"""Return a dummy prediction when model or data is unavailable"""
action_names = ['SELL', 'HOLD', 'BUY']
action_idx = 1 # Default to HOLD
timestamp = datetime.now().isoformat()
return {
'timestamp': timestamp,
'action': 'HOLD',
'action_index': action_idx,
'probability': 0.8,
'probabilities': {'SELL': 0.1, 'HOLD': 0.8, 'BUY': 0.1},
'price_prediction': None,
'is_dummy': True
}
class DummyModel:
"""Dummy model that returns random predictions"""
def __init__(self):
logger.info("Initializing dummy model")
def predict(self, X):
"""Return random predictions"""
# Generate random probabilities for SELL, HOLD, BUY
action_probs = np.array([0.1, 0.8, 0.1]) # Bias towards HOLD
# Generate a random price prediction (None for now)
price_pred = None
return action_probs, price_pred
def load(self, model_path):
"""Dummy load method"""
logger.info(f"Dummy model pretending to load from {model_path}")
return True