244 lines
9.5 KiB
Python
244 lines
9.5 KiB
Python
"""
|
|
Neural Network Trading System Main Module (Compatibility Layer)
|
|
|
|
This module serves as a compatibility layer for the realtime.py module.
|
|
It re-exports the functionality from realtime_main.py that is needed by realtime.py.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
from datetime import datetime
|
|
import numpy as np
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger('NN')
|
|
logger.setLevel(logging.INFO)
|
|
|
|
# Re-export everything from realtime_main.py
|
|
from .realtime_main import (
|
|
parse_arguments,
|
|
realtime,
|
|
train,
|
|
predict
|
|
)
|
|
|
|
# Create a class that realtime.py expects
|
|
class NeuralNetworkOrchestrator:
|
|
"""
|
|
Orchestrates the neural network operations.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
"""
|
|
Initialize the orchestrator with configuration.
|
|
|
|
Args:
|
|
config (dict): Configuration parameters
|
|
"""
|
|
self.config = config
|
|
self.symbol = config.get('symbol', 'BTC/USDT')
|
|
self.timeframes = config.get('timeframes', ['1m', '5m', '1h', '4h'])
|
|
self.window_size = config.get('window_size', 20)
|
|
self.n_features = config.get('n_features', 5)
|
|
self.output_size = config.get('output_size', 3)
|
|
self.model_dir = config.get('model_dir', 'NN/models/saved')
|
|
self.data_dir = config.get('data_dir', 'NN/data')
|
|
self.model = None
|
|
self.data_interface = None
|
|
|
|
# Initialize with default values in case imports fail
|
|
self.model_initialized = False
|
|
self.data_initialized = False
|
|
|
|
# Import necessary modules dynamically
|
|
try:
|
|
from .utils.data_interface import DataInterface
|
|
|
|
# Initialize data interface
|
|
self.data_interface = DataInterface(
|
|
symbol=self.symbol,
|
|
timeframes=self.timeframes
|
|
)
|
|
self.data_initialized = True
|
|
logger.info(f"Data interface initialized for {self.symbol}")
|
|
|
|
try:
|
|
from .models.cnn_model_pytorch import CNNModelPyTorch as Model
|
|
|
|
# Initialize model
|
|
feature_count = self.data_interface.get_feature_count() if hasattr(self.data_interface, 'get_feature_count') else 5
|
|
try:
|
|
# First try with expected parameters
|
|
self.model = Model(
|
|
window_size=self.window_size,
|
|
num_features=feature_count,
|
|
output_size=self.output_size,
|
|
timeframes=self.timeframes
|
|
)
|
|
except TypeError as e:
|
|
logger.warning(f"TypeError in model initialization with num_features: {str(e)}")
|
|
# Try alternate parameter naming
|
|
try:
|
|
self.model = Model(
|
|
input_shape=(self.window_size, feature_count),
|
|
output_size=self.output_size
|
|
)
|
|
logger.info("Model initialized with alternate parameters")
|
|
except Exception as ex:
|
|
logger.error(f"Failed to initialize model with alternate parameters: {str(ex)}")
|
|
self.model = DummyModel()
|
|
|
|
# Try to load the best model
|
|
self._load_model()
|
|
self.model_initialized = True
|
|
logger.info("Model initialized successfully")
|
|
except Exception as e:
|
|
logger.error(f"Error initializing model: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
self.model = DummyModel()
|
|
|
|
logger.info(f"NeuralNetworkOrchestrator initialized with config: {config}")
|
|
except Exception as e:
|
|
logger.error(f"Error initializing NeuralNetworkOrchestrator: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
self.model = DummyModel()
|
|
|
|
def _load_model(self):
|
|
"""Load the best trained model from available files"""
|
|
try:
|
|
model_paths = [
|
|
os.path.join(self.model_dir, "dqn_agent_best_policy.pt"),
|
|
os.path.join(self.model_dir, "cnn_model_best.pt"),
|
|
os.path.join("models/saved", "dqn_agent_best_policy.pt"),
|
|
os.path.join("models/saved", "cnn_model_best.pt")
|
|
]
|
|
|
|
for model_path in model_paths:
|
|
if os.path.exists(model_path):
|
|
try:
|
|
self.model.load(model_path)
|
|
logger.info(f"Loaded model from {model_path}")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load model from {model_path}: {str(e)}")
|
|
continue
|
|
|
|
logger.warning("No trained model found, using dummy model")
|
|
self.model = DummyModel()
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Error loading model: {str(e)}")
|
|
self.model = DummyModel()
|
|
return False
|
|
|
|
def run_inference_pipeline(self, model_type='cnn', timeframe='1h'):
|
|
"""
|
|
Run the inference pipeline using the trained model.
|
|
|
|
Args:
|
|
model_type (str): Type of model to use (cnn, transformer, etc.)
|
|
timeframe (str): Timeframe to use for inference
|
|
|
|
Returns:
|
|
dict: Inference result
|
|
"""
|
|
try:
|
|
# Check if we have a model
|
|
if not hasattr(self, 'model') or self.model is None:
|
|
logger.warning("No model available, initializing dummy model")
|
|
self.model = DummyModel()
|
|
|
|
# Check if we have a data interface
|
|
if not hasattr(self, 'data_interface') or self.data_interface is None:
|
|
logger.warning("No data interface available")
|
|
# Return a dummy prediction
|
|
return self._get_dummy_prediction()
|
|
|
|
# Prepare input data for the selected timeframe
|
|
X, timestamp = self.data_interface.prepare_realtime_input(
|
|
timeframe=timeframe,
|
|
n_candles=self.window_size + 10, # Extra candles for safety
|
|
window_size=self.window_size
|
|
)
|
|
|
|
if X is None:
|
|
logger.warning(f"No data available for {self.symbol}")
|
|
return self._get_dummy_prediction()
|
|
|
|
# Get model predictions
|
|
action_probs, price_pred = self.model.predict(X)
|
|
|
|
# Convert predictions to action
|
|
action_idx = np.argmax(action_probs) if hasattr(action_probs, 'argmax') else 1 # Default to HOLD
|
|
action_names = ['SELL', 'HOLD', 'BUY']
|
|
action = action_names[action_idx]
|
|
|
|
# Format timestamp
|
|
if not isinstance(timestamp, str):
|
|
try:
|
|
if hasattr(timestamp, 'isoformat'): # If it's already a datetime-like object
|
|
timestamp = timestamp.isoformat()
|
|
else: # If it's a numeric timestamp
|
|
timestamp = datetime.fromtimestamp(float(timestamp)/1000).isoformat()
|
|
except (TypeError, ValueError):
|
|
timestamp = datetime.now().isoformat()
|
|
|
|
# Return result
|
|
result = {
|
|
'timestamp': timestamp,
|
|
'action': action,
|
|
'action_index': int(action_idx),
|
|
'probability': float(action_probs[action_idx]) if hasattr(action_probs, '__getitem__') else 0.33,
|
|
'probabilities': {name: float(prob) for name, prob in zip(action_names, action_probs)} if hasattr(action_probs, '__iter__') else {'SELL': 0.33, 'HOLD': 0.34, 'BUY': 0.33},
|
|
'price_prediction': float(price_pred) if price_pred is not None else None
|
|
}
|
|
|
|
logger.info(f"Inference result: {result}")
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in inference pipeline: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return self._get_dummy_prediction()
|
|
|
|
def _get_dummy_prediction(self):
|
|
"""Return a dummy prediction when model or data is unavailable"""
|
|
action_names = ['SELL', 'HOLD', 'BUY']
|
|
action_idx = 1 # Default to HOLD
|
|
timestamp = datetime.now().isoformat()
|
|
|
|
return {
|
|
'timestamp': timestamp,
|
|
'action': 'HOLD',
|
|
'action_index': action_idx,
|
|
'probability': 0.8,
|
|
'probabilities': {'SELL': 0.1, 'HOLD': 0.8, 'BUY': 0.1},
|
|
'price_prediction': None,
|
|
'is_dummy': True
|
|
}
|
|
|
|
|
|
class DummyModel:
|
|
"""Dummy model that returns random predictions"""
|
|
|
|
def __init__(self):
|
|
logger.info("Initializing dummy model")
|
|
|
|
def predict(self, X):
|
|
"""Return random predictions"""
|
|
# Generate random probabilities for SELL, HOLD, BUY
|
|
action_probs = np.array([0.1, 0.8, 0.1]) # Bias towards HOLD
|
|
|
|
# Generate a random price prediction (None for now)
|
|
price_pred = None
|
|
|
|
return action_probs, price_pred
|
|
|
|
def load(self, model_path):
|
|
"""Dummy load method"""
|
|
logger.info(f"Dummy model pretending to load from {model_path}")
|
|
return True |