combined edits
This commit is contained in:
parent
c0872248ab
commit
477e5dca39
@ -21,7 +21,7 @@ if project_root not in sys.path:
|
|||||||
sys.path.append(project_root)
|
sys.path.append(project_root)
|
||||||
|
|
||||||
# Import BinanceHistoricalData from the root module
|
# Import BinanceHistoricalData from the root module
|
||||||
from realtime import BinanceHistoricalData
|
from dataprovider_realtime import BinanceHistoricalData
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from realtime import RealTimeChart
|
from dataprovider_realtime import RealTimeChart
|
||||||
|
|
||||||
# Create a standalone chart instance
|
# Create a standalone chart instance
|
||||||
chart = RealTimeChart('BTC/USDT')
|
chart = RealTimeChart('BTC/USDT')
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from realtime import RealTimeChart
|
from dataprovider_realtime import RealTimeChart
|
||||||
|
|
||||||
chart = RealTimeChart('BTC/USDT')
|
chart = RealTimeChart('BTC/USDT')
|
||||||
|
|
||||||
|
6
main.py
6
main.py
@ -29,7 +29,7 @@ from PIL import Image
|
|||||||
import matplotlib.pyplot as mpf
|
import matplotlib.pyplot as mpf
|
||||||
import matplotlib.gridspec as gridspec
|
import matplotlib.gridspec as gridspec
|
||||||
import datetime
|
import datetime
|
||||||
from realtime import BinanceWebSocket, BinanceHistoricalData
|
from dataprovider_realtime import BinanceWebSocket, BinanceHistoricalData
|
||||||
from datetime import datetime as dt
|
from datetime import datetime as dt
|
||||||
# Add Dash-related imports
|
# Add Dash-related imports
|
||||||
import dash
|
import dash
|
||||||
@ -376,7 +376,7 @@ def main():
|
|||||||
|
|
||||||
# Initialize real-time charts and data interfaces
|
# Initialize real-time charts and data interfaces
|
||||||
try:
|
try:
|
||||||
from realtime import RealTimeChart
|
from dataprovider_realtime import RealTimeChart
|
||||||
|
|
||||||
# Create a real-time chart for each symbol
|
# Create a real-time chart for each symbol
|
||||||
charts = {}
|
charts = {}
|
||||||
@ -1152,7 +1152,7 @@ from PIL import Image
|
|||||||
import matplotlib.pyplot as mpf
|
import matplotlib.pyplot as mpf
|
||||||
import matplotlib.gridspec as gridspec
|
import matplotlib.gridspec as gridspec
|
||||||
import datetime
|
import datetime
|
||||||
from realtime import BinanceWebSocket, BinanceHistoricalData
|
from dataprovider_realtime import BinanceWebSocket, BinanceHistoricalData
|
||||||
from datetime import datetime as dt
|
from datetime import datetime as dt
|
||||||
# Add Dash-related imports
|
# Add Dash-related imports
|
||||||
import dash
|
import dash
|
||||||
|
116
test_positions.py
Normal file
116
test_positions.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
from NN.environments.trading_env import TradingEnvironment
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
# Add the project root directory to the path
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Create a mock data interface class
|
||||||
|
class MockDataInterface:
|
||||||
|
def __init__(self, symbol, timeframes):
|
||||||
|
self.symbol = symbol
|
||||||
|
self.timeframes = timeframes
|
||||||
|
self.dataframes = {}
|
||||||
|
|
||||||
|
# Create mock data for each timeframe
|
||||||
|
for tf in timeframes:
|
||||||
|
self.dataframes[tf] = self._create_mock_data(tf)
|
||||||
|
|
||||||
|
def _create_mock_data(self, timeframe):
|
||||||
|
# Generate timestamps
|
||||||
|
end_time = datetime.now()
|
||||||
|
if timeframe == '1m':
|
||||||
|
start_time = end_time - timedelta(minutes=1000)
|
||||||
|
freq = 'T' # minute frequency
|
||||||
|
elif timeframe == '5m':
|
||||||
|
start_time = end_time - timedelta(minutes=5000)
|
||||||
|
freq = '5T'
|
||||||
|
else: # '15m'
|
||||||
|
start_time = end_time - timedelta(minutes=15000)
|
||||||
|
freq = '15T'
|
||||||
|
|
||||||
|
dates = pd.date_range(start=start_time, end=end_time, freq=freq)
|
||||||
|
|
||||||
|
# Create price data with some random walk behavior
|
||||||
|
np.random.seed(42) # For reproducibility
|
||||||
|
price = 1000.0
|
||||||
|
prices = [price]
|
||||||
|
for _ in range(len(dates) - 1):
|
||||||
|
price = price * (1 + np.random.normal(0, 0.005)) # 0.5% daily volatility
|
||||||
|
prices.append(price)
|
||||||
|
|
||||||
|
# Calculate OHLCV data
|
||||||
|
df = pd.DataFrame(index=dates)
|
||||||
|
df['close'] = prices
|
||||||
|
df['open'] = df['close'].shift(1).fillna(df['close'].iloc[0] * 0.999)
|
||||||
|
df['high'] = df['close'] * (1 + abs(np.random.normal(0, 0.001, len(df))))
|
||||||
|
df['low'] = df['open'] * (1 - abs(np.random.normal(0, 0.001, len(df))))
|
||||||
|
df['volume'] = np.random.normal(1000, 100, len(df))
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
# Create mock data interface
|
||||||
|
di = MockDataInterface('ETH/USDT', ['1m', '5m', '15m'])
|
||||||
|
|
||||||
|
# Create environment
|
||||||
|
env = TradingEnvironment(di, initial_balance=1000.0, max_position=0.1)
|
||||||
|
|
||||||
|
# Run multiple episodes to accumulate some trade history
|
||||||
|
for episode in range(3):
|
||||||
|
logger.info(f"Episode {episode+1}/3")
|
||||||
|
|
||||||
|
# Reset environment
|
||||||
|
observation = env.reset()
|
||||||
|
|
||||||
|
# Run episode
|
||||||
|
for step in range(100):
|
||||||
|
# Choose action: 0=Buy, 1=Sell, 2=Hold
|
||||||
|
# Use a more deliberate pattern to generate trades
|
||||||
|
if step % 10 == 0:
|
||||||
|
action = 0 # Buy
|
||||||
|
elif step % 10 == 5:
|
||||||
|
action = 1 # Sell
|
||||||
|
else:
|
||||||
|
action = 2 # Hold
|
||||||
|
|
||||||
|
# Take action
|
||||||
|
observation, reward, done, info = env.step(action)
|
||||||
|
|
||||||
|
# Print trade information if a trade was made
|
||||||
|
if 'trade_result' in info:
|
||||||
|
trade = info['trade_result']
|
||||||
|
print(f"\nTrade executed:")
|
||||||
|
print(f"Action: {['BUY', 'SELL', 'HOLD'][trade['action']]}")
|
||||||
|
print(f"Price: {trade['price']:.4f}")
|
||||||
|
print(f"Position change: {trade['prev_position']:.4f} -> {trade['new_position']:.4f}")
|
||||||
|
print(f"Entry price: {trade.get('entry_price', 0):.4f}")
|
||||||
|
if trade.get('realized_pnl', 0) != 0:
|
||||||
|
print(f"Realized PnL: {trade['realized_pnl']:.4f}")
|
||||||
|
print(f"Balance: {trade['balance_before']:.2f} -> {trade['balance_after']:.2f}")
|
||||||
|
|
||||||
|
# End episode if done
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Render environment with final state
|
||||||
|
print("\n\nFinal environment state:")
|
||||||
|
env.render()
|
||||||
|
|
||||||
|
# Print detailed information about the last 5 positions
|
||||||
|
positions = env.get_last_positions(5)
|
||||||
|
print("\nDetailed position history:")
|
||||||
|
for i, pos in enumerate(positions):
|
||||||
|
print(f"\nPosition {i+1}:")
|
||||||
|
for key, value in pos.items():
|
||||||
|
if isinstance(value, float):
|
||||||
|
print(f" {key}: {value:.4f}")
|
||||||
|
else:
|
||||||
|
print(f" {key}: {value}")
|
@ -1,5 +1,5 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from realtime import RealTimeChart
|
from dataprovider_realtime import RealTimeChart
|
||||||
|
|
||||||
# Create a chart instance
|
# Create a chart instance
|
||||||
chart = RealTimeChart('BTC/USDT')
|
chart = RealTimeChart('BTC/USDT')
|
||||||
|
@ -72,7 +72,7 @@ def main():
|
|||||||
|
|
||||||
# Initialize real-time charts and data interfaces
|
# Initialize real-time charts and data interfaces
|
||||||
try:
|
try:
|
||||||
from realtime import RealTimeChart
|
from dataprovider_realtime import RealTimeChart
|
||||||
|
|
||||||
# Create a real-time chart for each symbol
|
# Create a real-time chart for each symbol
|
||||||
charts = {}
|
charts = {}
|
||||||
|
@ -35,7 +35,7 @@ logger = logging.getLogger('realtime_training')
|
|||||||
|
|
||||||
# Import the model and data interfaces
|
# Import the model and data interfaces
|
||||||
from NN.models.cnn_model_pytorch import CNNModelPyTorch
|
from NN.models.cnn_model_pytorch import CNNModelPyTorch
|
||||||
from realtime import MultiTimeframeDataInterface
|
from dataprovider_realtime import MultiTimeframeDataInterface
|
||||||
from NN.utils.signal_interpreter import SignalInterpreter
|
from NN.utils.signal_interpreter import SignalInterpreter
|
||||||
|
|
||||||
# Global variables for graceful shutdown
|
# Global variables for graceful shutdown
|
||||||
|
@ -30,7 +30,7 @@ import train_config
|
|||||||
|
|
||||||
# Import key components
|
# Import key components
|
||||||
from NN.models.dqn_agent import DQNAgent
|
from NN.models.dqn_agent import DQNAgent
|
||||||
from realtime import MultiTimeframeDataInterface
|
from dataprovider_realtime import MultiTimeframeDataInterface
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
log_dir = Path("logs")
|
log_dir = Path("logs")
|
||||||
|
@ -39,7 +39,7 @@ import train_config
|
|||||||
# Import key components
|
# Import key components
|
||||||
from NN.models.cnn_model_pytorch import CNNModelPyTorch
|
from NN.models.cnn_model_pytorch import CNNModelPyTorch
|
||||||
from NN.models.dqn_agent import DQNAgent
|
from NN.models.dqn_agent import DQNAgent
|
||||||
from realtime import MultiTimeframeDataInterface, RealTimeChart
|
from dataprovider_realtime import MultiTimeframeDataInterface, RealTimeChart
|
||||||
from NN.utils.signal_interpreter import SignalInterpreter
|
from NN.utils.signal_interpreter import SignalInterpreter
|
||||||
|
|
||||||
# Global variables for graceful shutdown
|
# Global variables for graceful shutdown
|
||||||
@ -241,7 +241,7 @@ class HybridModel:
|
|||||||
def _initialize_chart(self):
|
def _initialize_chart(self):
|
||||||
"""Initialize the RealTimeChart for visualization"""
|
"""Initialize the RealTimeChart for visualization"""
|
||||||
try:
|
try:
|
||||||
from realtime import RealTimeChart
|
from dataprovider_realtime import RealTimeChart
|
||||||
|
|
||||||
symbol = self.config['market_data']['symbol']
|
symbol = self.config['market_data']['symbol']
|
||||||
self.logger.info(f"Initializing RealTimeChart for {symbol}")
|
self.logger.info(f"Initializing RealTimeChart for {symbol}")
|
||||||
|
1247
train_hybrid_fixed.py
Normal file
1247
train_hybrid_fixed.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -27,7 +27,7 @@ if project_root not in sys.path:
|
|||||||
from NN.models.dqn_agent import DQNAgent
|
from NN.models.dqn_agent import DQNAgent
|
||||||
from NN.utils.trading_env import TradingEnvironment
|
from NN.utils.trading_env import TradingEnvironment
|
||||||
from NN.utils.data_interface import DataInterface
|
from NN.utils.data_interface import DataInterface
|
||||||
from realtime import BinanceHistoricalData, RealTimeChart
|
from dataprovider_realtime import BinanceHistoricalData, RealTimeChart
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
log_filename = f'improved_rl_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
|
log_filename = f'improved_rl_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
|
||||||
|
@ -1055,7 +1055,7 @@ async def start_realtime_chart(symbol="ETH/USDT", port=8050, manual_mode=False):
|
|||||||
Returns:
|
Returns:
|
||||||
tuple: (RealTimeChart instance, websocket task)
|
tuple: (RealTimeChart instance, websocket task)
|
||||||
"""
|
"""
|
||||||
from realtime import RealTimeChart
|
from dataprovider_realtime import RealTimeChart
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Initializing RealTimeChart for {symbol}")
|
logger.info(f"Initializing RealTimeChart for {symbol}")
|
||||||
@ -1105,7 +1105,7 @@ async def start_realtime_chart(symbol="ETH/USDT", port=8050, manual_mode=False):
|
|||||||
|
|
||||||
def _add_trade_compat(chart, price, timestamp, amount, pnl=0.0, action="BUY"):
|
def _add_trade_compat(chart, price, timestamp, amount, pnl=0.0, action="BUY"):
|
||||||
"""Compatibility function for adding trades to the chart"""
|
"""Compatibility function for adding trades to the chart"""
|
||||||
from realtime import Position
|
from dataprovider_realtime import Position
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create a new position
|
# Create a new position
|
||||||
|
Loading…
x
Reference in New Issue
Block a user