combined edits

This commit is contained in:
Dobromir Popov 2025-05-24 00:59:29 +03:00
parent c0872248ab
commit 477e5dca39
13 changed files with 1378 additions and 15 deletions

View File

@ -21,7 +21,7 @@ if project_root not in sys.path:
sys.path.append(project_root) sys.path.append(project_root)
# Import BinanceHistoricalData from the root module # Import BinanceHistoricalData from the root module
from realtime import BinanceHistoricalData from dataprovider_realtime import BinanceHistoricalData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,7 +1,7 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
import random import random
import time import time
from realtime import RealTimeChart from dataprovider_realtime import RealTimeChart
# Create a standalone chart instance # Create a standalone chart instance
chart = RealTimeChart('BTC/USDT') chart = RealTimeChart('BTC/USDT')

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from realtime import RealTimeChart from dataprovider_realtime import RealTimeChart
chart = RealTimeChart('BTC/USDT') chart = RealTimeChart('BTC/USDT')

View File

@ -29,7 +29,7 @@ from PIL import Image
import matplotlib.pyplot as mpf import matplotlib.pyplot as mpf
import matplotlib.gridspec as gridspec import matplotlib.gridspec as gridspec
import datetime import datetime
from realtime import BinanceWebSocket, BinanceHistoricalData from dataprovider_realtime import BinanceWebSocket, BinanceHistoricalData
from datetime import datetime as dt from datetime import datetime as dt
# Add Dash-related imports # Add Dash-related imports
import dash import dash
@ -376,7 +376,7 @@ def main():
# Initialize real-time charts and data interfaces # Initialize real-time charts and data interfaces
try: try:
from realtime import RealTimeChart from dataprovider_realtime import RealTimeChart
# Create a real-time chart for each symbol # Create a real-time chart for each symbol
charts = {} charts = {}
@ -1152,7 +1152,7 @@ from PIL import Image
import matplotlib.pyplot as mpf import matplotlib.pyplot as mpf
import matplotlib.gridspec as gridspec import matplotlib.gridspec as gridspec
import datetime import datetime
from realtime import BinanceWebSocket, BinanceHistoricalData from dataprovider_realtime import BinanceWebSocket, BinanceHistoricalData
from datetime import datetime as dt from datetime import datetime as dt
# Add Dash-related imports # Add Dash-related imports
import dash import dash

116
test_positions.py Normal file
View File

@ -0,0 +1,116 @@
from NN.environments.trading_env import TradingEnvironment
import logging
import numpy as np
import pandas as pd
import os
import sys
from datetime import datetime, timedelta
# Add the project root directory to the path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Create a mock data interface class
class MockDataInterface:
def __init__(self, symbol, timeframes):
self.symbol = symbol
self.timeframes = timeframes
self.dataframes = {}
# Create mock data for each timeframe
for tf in timeframes:
self.dataframes[tf] = self._create_mock_data(tf)
def _create_mock_data(self, timeframe):
# Generate timestamps
end_time = datetime.now()
if timeframe == '1m':
start_time = end_time - timedelta(minutes=1000)
freq = 'T' # minute frequency
elif timeframe == '5m':
start_time = end_time - timedelta(minutes=5000)
freq = '5T'
else: # '15m'
start_time = end_time - timedelta(minutes=15000)
freq = '15T'
dates = pd.date_range(start=start_time, end=end_time, freq=freq)
# Create price data with some random walk behavior
np.random.seed(42) # For reproducibility
price = 1000.0
prices = [price]
for _ in range(len(dates) - 1):
price = price * (1 + np.random.normal(0, 0.005)) # 0.5% daily volatility
prices.append(price)
# Calculate OHLCV data
df = pd.DataFrame(index=dates)
df['close'] = prices
df['open'] = df['close'].shift(1).fillna(df['close'].iloc[0] * 0.999)
df['high'] = df['close'] * (1 + abs(np.random.normal(0, 0.001, len(df))))
df['low'] = df['open'] * (1 - abs(np.random.normal(0, 0.001, len(df))))
df['volume'] = np.random.normal(1000, 100, len(df))
return df
# Create mock data interface
di = MockDataInterface('ETH/USDT', ['1m', '5m', '15m'])
# Create environment
env = TradingEnvironment(di, initial_balance=1000.0, max_position=0.1)
# Run multiple episodes to accumulate some trade history
for episode in range(3):
logger.info(f"Episode {episode+1}/3")
# Reset environment
observation = env.reset()
# Run episode
for step in range(100):
# Choose action: 0=Buy, 1=Sell, 2=Hold
# Use a more deliberate pattern to generate trades
if step % 10 == 0:
action = 0 # Buy
elif step % 10 == 5:
action = 1 # Sell
else:
action = 2 # Hold
# Take action
observation, reward, done, info = env.step(action)
# Print trade information if a trade was made
if 'trade_result' in info:
trade = info['trade_result']
print(f"\nTrade executed:")
print(f"Action: {['BUY', 'SELL', 'HOLD'][trade['action']]}")
print(f"Price: {trade['price']:.4f}")
print(f"Position change: {trade['prev_position']:.4f} -> {trade['new_position']:.4f}")
print(f"Entry price: {trade.get('entry_price', 0):.4f}")
if trade.get('realized_pnl', 0) != 0:
print(f"Realized PnL: {trade['realized_pnl']:.4f}")
print(f"Balance: {trade['balance_before']:.2f} -> {trade['balance_after']:.2f}")
# End episode if done
if done:
break
# Render environment with final state
print("\n\nFinal environment state:")
env.render()
# Print detailed information about the last 5 positions
positions = env.get_last_positions(5)
print("\nDetailed position history:")
for i, pos in enumerate(positions):
print(f"\nPosition {i+1}:")
for key, value in pos.items():
if isinstance(value, float):
print(f" {key}: {value:.4f}")
else:
print(f" {key}: {value}")

View File

@ -1,5 +1,5 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from realtime import RealTimeChart from dataprovider_realtime import RealTimeChart
# Create a chart instance # Create a chart instance
chart = RealTimeChart('BTC/USDT') chart = RealTimeChart('BTC/USDT')

View File

@ -72,7 +72,7 @@ def main():
# Initialize real-time charts and data interfaces # Initialize real-time charts and data interfaces
try: try:
from realtime import RealTimeChart from dataprovider_realtime import RealTimeChart
# Create a real-time chart for each symbol # Create a real-time chart for each symbol
charts = {} charts = {}

View File

@ -35,7 +35,7 @@ logger = logging.getLogger('realtime_training')
# Import the model and data interfaces # Import the model and data interfaces
from NN.models.cnn_model_pytorch import CNNModelPyTorch from NN.models.cnn_model_pytorch import CNNModelPyTorch
from realtime import MultiTimeframeDataInterface from dataprovider_realtime import MultiTimeframeDataInterface
from NN.utils.signal_interpreter import SignalInterpreter from NN.utils.signal_interpreter import SignalInterpreter
# Global variables for graceful shutdown # Global variables for graceful shutdown

View File

@ -30,7 +30,7 @@ import train_config
# Import key components # Import key components
from NN.models.dqn_agent import DQNAgent from NN.models.dqn_agent import DQNAgent
from realtime import MultiTimeframeDataInterface from dataprovider_realtime import MultiTimeframeDataInterface
# Configure logging # Configure logging
log_dir = Path("logs") log_dir = Path("logs")

View File

@ -39,7 +39,7 @@ import train_config
# Import key components # Import key components
from NN.models.cnn_model_pytorch import CNNModelPyTorch from NN.models.cnn_model_pytorch import CNNModelPyTorch
from NN.models.dqn_agent import DQNAgent from NN.models.dqn_agent import DQNAgent
from realtime import MultiTimeframeDataInterface, RealTimeChart from dataprovider_realtime import MultiTimeframeDataInterface, RealTimeChart
from NN.utils.signal_interpreter import SignalInterpreter from NN.utils.signal_interpreter import SignalInterpreter
# Global variables for graceful shutdown # Global variables for graceful shutdown
@ -241,7 +241,7 @@ class HybridModel:
def _initialize_chart(self): def _initialize_chart(self):
"""Initialize the RealTimeChart for visualization""" """Initialize the RealTimeChart for visualization"""
try: try:
from realtime import RealTimeChart from dataprovider_realtime import RealTimeChart
symbol = self.config['market_data']['symbol'] symbol = self.config['market_data']['symbol']
self.logger.info(f"Initializing RealTimeChart for {symbol}") self.logger.info(f"Initializing RealTimeChart for {symbol}")

1247
train_hybrid_fixed.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -27,7 +27,7 @@ if project_root not in sys.path:
from NN.models.dqn_agent import DQNAgent from NN.models.dqn_agent import DQNAgent
from NN.utils.trading_env import TradingEnvironment from NN.utils.trading_env import TradingEnvironment
from NN.utils.data_interface import DataInterface from NN.utils.data_interface import DataInterface
from realtime import BinanceHistoricalData, RealTimeChart from dataprovider_realtime import BinanceHistoricalData, RealTimeChart
# Configure logging # Configure logging
log_filename = f'improved_rl_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' log_filename = f'improved_rl_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'

View File

@ -1055,7 +1055,7 @@ async def start_realtime_chart(symbol="ETH/USDT", port=8050, manual_mode=False):
Returns: Returns:
tuple: (RealTimeChart instance, websocket task) tuple: (RealTimeChart instance, websocket task)
""" """
from realtime import RealTimeChart from dataprovider_realtime import RealTimeChart
try: try:
logger.info(f"Initializing RealTimeChart for {symbol}") logger.info(f"Initializing RealTimeChart for {symbol}")
@ -1105,7 +1105,7 @@ async def start_realtime_chart(symbol="ETH/USDT", port=8050, manual_mode=False):
def _add_trade_compat(chart, price, timestamp, amount, pnl=0.0, action="BUY"): def _add_trade_compat(chart, price, timestamp, amount, pnl=0.0, action="BUY"):
"""Compatibility function for adding trades to the chart""" """Compatibility function for adding trades to the chart"""
from realtime import Position from dataprovider_realtime import Position
try: try:
# Create a new position # Create a new position