new cnn nn implementation (wip)

This commit is contained in:
Dobromir Popov 2025-03-25 13:37:34 +02:00
parent 114ced03b7
commit 50eb50696b

View File

@ -19,6 +19,7 @@ import os
from datetime import datetime, timedelta from datetime import datetime, timedelta
import pytz import pytz
import tzlocal import tzlocal
import threading
# Configure logging with more detailed format # Configure logging with more detailed format
logging.basicConfig( logging.basicConfig(
@ -31,6 +32,144 @@ logging.basicConfig(
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Neural Network integration (conditional import)
NN_ENABLED = os.environ.get('ENABLE_NN_MODELS', '0') == '1'
nn_orchestrator = None
nn_inference_thread = None
if NN_ENABLED:
try:
import sys
# Add project root to sys.path if needed
project_root = os.path.dirname(os.path.abspath(__file__))
if project_root not in sys.path:
sys.path.append(project_root)
from NN.main import NeuralNetworkOrchestrator
logger.info("Neural Network module enabled")
except ImportError as e:
logger.warning(f"Failed to import Neural Network module, disabling NN features: {str(e)}")
NN_ENABLED = False
# NN utility functions
def setup_neural_network():
"""Initialize the neural network components if enabled"""
global nn_orchestrator, NN_ENABLED
if not NN_ENABLED:
return False
try:
# Get configuration from environment variables or use defaults
symbol = os.environ.get('NN_SYMBOL', 'BTC/USDT')
timeframes = os.environ.get('NN_TIMEFRAMES', '1m,5m,1h,4h,1d').split(',')
output_size = int(os.environ.get('NN_OUTPUT_SIZE', '3')) # 3 for BUY/HOLD/SELL
# Configure the orchestrator
config = {
'symbol': symbol,
'timeframes': timeframes,
'window_size': int(os.environ.get('NN_WINDOW_SIZE', '20')),
'n_features': 5, # OHLCV
'output_size': output_size,
'model_dir': 'NN/models/saved',
'data_dir': 'NN/data'
}
# Initialize the orchestrator
logger.info(f"Initializing Neural Network Orchestrator with config: {config}")
nn_orchestrator = NeuralNetworkOrchestrator(config)
# Start inference thread if enabled
inference_interval = int(os.environ.get('NN_INFERENCE_INTERVAL', '60'))
if inference_interval > 0:
start_nn_inference_thread(inference_interval)
return True
except Exception as e:
logger.error(f"Error setting up neural network: {str(e)}")
import traceback
logger.error(traceback.format_exc())
NN_ENABLED = False
return False
def start_nn_inference_thread(interval_seconds):
"""Start a background thread to periodically run inference with the neural network"""
global nn_inference_thread
if not NN_ENABLED or nn_orchestrator is None:
logger.warning("Cannot start inference thread - Neural Network not enabled or initialized")
return False
def inference_worker():
"""Worker function for the inference thread"""
model_type = os.environ.get('NN_MODEL_TYPE', 'cnn')
timeframe = os.environ.get('NN_TIMEFRAME', '1h')
logger.info(f"Starting neural network inference thread with {interval_seconds}s interval")
logger.info(f"Using model type: {model_type}, timeframe: {timeframe}")
# Wait a bit for charts to initialize
time.sleep(5)
# Track active charts
active_charts = []
while True:
try:
# Find active charts if we don't have them yet
if not active_charts and 'charts' in globals():
active_charts = globals()['charts']
logger.info(f"Found {len(active_charts)} active charts for NN signals")
# Run inference
result = nn_orchestrator.run_inference_pipeline(
model_type=model_type,
timeframe=timeframe
)
if result:
# Log the result
logger.info(f"Neural network inference result: {result}")
# Add signal to charts
if active_charts:
try:
if 'action' in result:
action = result['action']
timestamp = datetime.fromisoformat(result['timestamp'].replace('Z', '+00:00'))
# Get probability if available
probability = None
if 'probability' in result:
probability = result['probability']
elif 'probabilities' in result:
probability = result['probabilities'].get(action, None)
# Add signal to each chart
for chart in active_charts:
if hasattr(chart, 'add_nn_signal'):
chart.add_nn_signal(action, timestamp, probability)
except Exception as e:
logger.error(f"Error adding NN signal to chart: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Sleep for the interval
time.sleep(interval_seconds)
except Exception as e:
logger.error(f"Error in inference thread: {str(e)}")
import traceback
logger.error(traceback.format_exc())
time.sleep(5) # Wait a bit before retrying
# Create and start the thread
nn_inference_thread = threading.Thread(target=inference_worker, daemon=True)
nn_inference_thread.start()
return True
# Try to get local timezone, default to Sofia/EET if not available # Try to get local timezone, default to Sofia/EET if not available
try: try:
local_timezone = tzlocal.get_localzone() local_timezone = tzlocal.get_localzone()
@ -1125,7 +1264,10 @@ class CandleCache:
return pd.DataFrame() return pd.DataFrame()
class RealTimeChart: class RealTimeChart:
"""Real-time chart using Dash and Plotly with WebSocket data feed"""
def __init__(self, symbol: str): def __init__(self, symbol: str):
"""Initialize the chart with necessary components"""
self.symbol = symbol self.symbol = symbol
# Create a multi-page Dash app instead of a simple Dash app # Create a multi-page Dash app instead of a simple Dash app
self.app = dash.Dash(__name__, self.app = dash.Dash(__name__,
@ -1143,6 +1285,10 @@ class RealTimeChart:
self.historical_data = BinanceHistoricalData() # For fetching historical data self.historical_data = BinanceHistoricalData() # For fetching historical data
self.last_cache_save_time = time.time() # Track last time we saved cache to disk self.last_cache_save_time = time.time() # Track last time we saved cache to disk
self.first_render = True # Flag to track first render self.first_render = True # Flag to track first render
# Storage for NN signals
self.nn_signals = []
logger.info(f"Initializing RealTimeChart for {symbol}") logger.info(f"Initializing RealTimeChart for {symbol}")
# Load historical data for longer timeframes at startup # Load historical data for longer timeframes at startup
@ -1791,6 +1937,57 @@ class RealTimeChart:
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)') fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)') fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
# Add neural network signals as annotations if available
if hasattr(self, 'nn_signals') and self.nn_signals:
for signal in self.nn_signals:
try:
# Skip HOLD signals for clarity
if signal['type'] == 'HOLD':
continue
# Check if this signal's timestamp is within the visible range
signal_ts = signal['timestamp']
# Only add annotations for signals in the visible time range
if df is not None and not df.empty:
if signal_ts < df.index.min() or signal_ts > df.index.max():
continue
# Set color and symbol based on signal type
if signal['type'] == 'BUY':
color = 'green'
symbol = '' # Up triangle
y_position = df['low'].min() * 0.98 # Below the candles
else: # SELL
color = 'red'
symbol = '' # Down triangle
y_position = df['high'].max() * 1.02 # Above the candles
# Add probability if available
text = f"{signal['type']}"
if signal['probability'] is not None:
text += f" ({signal['probability']:.2f})"
# Add annotation
fig.add_annotation(
x=signal_ts,
y=y_position,
text=text,
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=2,
arrowcolor=color,
font=dict(size=12, color=color),
bgcolor='rgba(255, 255, 255, 0.8)',
bordercolor=color,
borderwidth=1,
borderpad=4,
row=1, col=1
)
except Exception as e:
logger.error(f"Error adding NN signal annotation: {str(e)}")
return fig return fig
except Exception as layout_error: except Exception as layout_error:
logger.error(f"Error updating layout: {str(layout_error)}") logger.error(f"Error updating layout: {str(layout_error)}")
@ -2573,11 +2770,56 @@ class RealTimeChart:
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
def add_nn_signal(self, signal_type, timestamp, probability=None):
"""Add a neural network signal to be displayed on the chart
Args:
signal_type: The type of signal (BUY, SELL, HOLD)
timestamp: The timestamp for the signal
probability: Optional probability/confidence value
"""
if signal_type not in ['BUY', 'SELL', 'HOLD']:
logger.warning(f"Invalid NN signal type: {signal_type}")
return
# Convert timestamp to datetime if it's not already
if not isinstance(timestamp, datetime):
try:
if isinstance(timestamp, str):
timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
elif isinstance(timestamp, (int, float)):
timestamp = datetime.fromtimestamp(timestamp / 1000.0)
except Exception as e:
logger.error(f"Error converting timestamp for NN signal: {str(e)}")
timestamp = datetime.now()
# Add the signal to our list
self.nn_signals.append({
'type': signal_type,
'timestamp': timestamp,
'probability': probability,
'added': datetime.now()
})
# Only keep the most recent 50 signals
if len(self.nn_signals) > 50:
self.nn_signals = self.nn_signals[-50:]
logger.info(f"Added NN signal: {signal_type} at {timestamp}")
async def main(): async def main():
global charts # Make charts globally accessible for NN integration
symbols = ["ETH/USDT", "BTC/USDT"] symbols = ["ETH/USDT", "BTC/USDT"]
logger.info(f"Starting application for symbols: {symbols}") logger.info(f"Starting application for symbols: {symbols}")
# Initialize neural network if enabled
if NN_ENABLED:
logger.info("Initializing Neural Network integration...")
if setup_neural_network():
logger.info("Neural Network integration initialized successfully")
else:
logger.warning("Neural Network integration failed to initialize")
charts = [] charts = []
websocket_tasks = [] websocket_tasks = []