wip
This commit is contained in:
parent
1b9f471076
commit
8981ad0691
@ -263,8 +263,12 @@ class CNNModelPyTorch:
|
||||
local_min_mask = (price_changes < -0.001).to(torch.bool) # -0.1% threshold for local minimum
|
||||
|
||||
# Apply masks to set retrospective targets using torch.where
|
||||
retrospective_targets = torch.where(local_max_mask, torch.zeros_like(targets), retrospective_targets) # SELL at local max
|
||||
retrospective_targets = torch.where(local_min_mask, 2 * torch.ones_like(targets), retrospective_targets) # BUY at local min
|
||||
# Use indices where the masks have True values
|
||||
for i in range(len(retrospective_targets)):
|
||||
if local_max_mask[i]:
|
||||
retrospective_targets[i] = 0 # SELL at local max
|
||||
elif local_min_mask[i]:
|
||||
retrospective_targets[i] = 2 # BUY at local min
|
||||
|
||||
# Calculate retrospective loss with higher weight for profitable signals
|
||||
retrospective_loss = self.criterion(outputs, retrospective_targets)
|
||||
@ -372,8 +376,12 @@ class CNNModelPyTorch:
|
||||
local_min_mask = (price_changes < -0.001).to(torch.bool) # -0.1% threshold for local minimum
|
||||
|
||||
# Apply masks to set retrospective targets using torch.where
|
||||
retrospective_targets = torch.where(local_max_mask, torch.zeros_like(targets), retrospective_targets) # SELL at local max
|
||||
retrospective_targets = torch.where(local_min_mask, 2 * torch.ones_like(targets), retrospective_targets) # BUY at local min
|
||||
# Use indices where the masks have True values
|
||||
for i in range(len(retrospective_targets)):
|
||||
if local_max_mask[i]:
|
||||
retrospective_targets[i] = 0 # SELL at local max
|
||||
elif local_min_mask[i]:
|
||||
retrospective_targets[i] = 2 # BUY at local min
|
||||
|
||||
# Calculate retrospective loss with higher weight for profitable signals
|
||||
retrospective_loss = self.criterion(outputs, retrospective_targets)
|
||||
@ -403,13 +411,29 @@ class CNNModelPyTorch:
|
||||
def predict(self, X):
|
||||
"""Make predictions on input data"""
|
||||
self.model.eval()
|
||||
|
||||
# Convert to tensor if not already
|
||||
if not isinstance(X, torch.Tensor):
|
||||
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
|
||||
else:
|
||||
X_tensor = X.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(X_tensor)
|
||||
# To maintain compatibility with the transformer model, return the action probs
|
||||
# And a dummy price prediction of zeros
|
||||
return outputs.cpu().numpy(), np.zeros((len(X), 1))
|
||||
|
||||
# Get the current close prices from the input
|
||||
current_prices = X_tensor[:, -1, 3].cpu().numpy() # Last timestamp's close price
|
||||
|
||||
# For price predictions, we'll estimate based on the action probabilities
|
||||
# Buy (2) means price likely to go up, Sell (0) means price likely to go down
|
||||
action_probs = outputs.cpu().numpy()
|
||||
price_directions = np.argmax(action_probs, axis=1) - 1 # -1, 0, or 1
|
||||
|
||||
# Simple price prediction: current price + small change based on predicted direction
|
||||
# Use 0.001 (0.1%) as a baseline change
|
||||
price_preds = current_prices * (1 + price_directions * 0.001)
|
||||
|
||||
return action_probs, price_preds.reshape(-1, 1)
|
||||
|
||||
def predict_next_candles(self, X, n_candles=3):
|
||||
"""
|
||||
|
@ -198,16 +198,44 @@ def train(data_interface, model, args):
|
||||
val_action_probs, val_price_preds = model.predict(X_val)
|
||||
|
||||
# Calculate PnL and win rates
|
||||
try:
|
||||
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
|
||||
train_action_probs, train_prices, position_size=1.0
|
||||
train_preds, train_prices, position_size=1.0
|
||||
)
|
||||
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
|
||||
val_action_probs, val_prices, position_size=1.0
|
||||
val_preds, val_prices, position_size=1.0
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating PnL: {str(e)}")
|
||||
train_pnl, train_win_rate, val_pnl, val_win_rate = 0, 0, 0, 0
|
||||
train_trades, val_trades = [], []
|
||||
|
||||
# Calculate price prediction error
|
||||
train_price_mae = np.mean(np.abs(train_price_preds - train_future_prices))
|
||||
val_price_mae = np.mean(np.abs(val_price_preds - val_future_prices))
|
||||
if train_future_prices is not None and train_price_preds is not None:
|
||||
# Ensure arrays have the same shape and are numpy arrays
|
||||
train_future_prices_np = np.array(train_future_prices) if not isinstance(train_future_prices, np.ndarray) else train_future_prices
|
||||
train_price_preds_np = np.array(train_price_preds) if not isinstance(train_price_preds, np.ndarray) else train_price_preds
|
||||
|
||||
if len(train_price_preds_np) > 0 and len(train_future_prices_np) > 0:
|
||||
min_len = min(len(train_price_preds_np), len(train_future_prices_np))
|
||||
train_price_mae = np.mean(np.abs(train_price_preds_np[:min_len] - train_future_prices_np[:min_len]))
|
||||
else:
|
||||
train_price_mae = float('inf')
|
||||
else:
|
||||
train_price_mae = float('inf')
|
||||
|
||||
if val_future_prices is not None and val_price_preds is not None:
|
||||
# Ensure arrays have the same shape and are numpy arrays
|
||||
val_future_prices_np = np.array(val_future_prices) if not isinstance(val_future_prices, np.ndarray) else val_future_prices
|
||||
val_price_preds_np = np.array(val_price_preds) if not isinstance(val_price_preds, np.ndarray) else val_price_preds
|
||||
|
||||
if len(val_price_preds_np) > 0 and len(val_future_prices_np) > 0:
|
||||
min_len = min(len(val_price_preds_np), len(val_future_prices_np))
|
||||
val_price_mae = np.mean(np.abs(val_price_preds_np[:min_len] - val_future_prices_np[:min_len]))
|
||||
else:
|
||||
val_price_mae = float('inf')
|
||||
else:
|
||||
val_price_mae = float('inf')
|
||||
|
||||
# Monitor action distribution
|
||||
train_actions = np.bincount(np.argmax(train_action_probs, axis=1), minlength=3)
|
||||
@ -233,7 +261,7 @@ def train(data_interface, model, args):
|
||||
writer.add_scalar(f'Actions/val_{action}', val_actions[i], epoch)
|
||||
|
||||
# Save best model based on validation metrics
|
||||
if val_pnl > best_val_pnl or (val_pnl == best_val_pnl and val_acc > best_val_acc):
|
||||
if np.isscalar(val_pnl) and np.isscalar(best_val_pnl) and (val_pnl > best_val_pnl or (np.isclose(val_pnl, best_val_pnl) and val_acc > best_val_acc)):
|
||||
best_val_pnl = val_pnl
|
||||
best_val_acc = val_acc
|
||||
best_win_rate = val_win_rate
|
||||
|
@ -12,6 +12,15 @@ from datetime import datetime, timedelta
|
||||
import json
|
||||
import pickle
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
import sys
|
||||
|
||||
# Add project root to sys.path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
if project_root not in sys.path:
|
||||
sys.path.append(project_root)
|
||||
|
||||
# Import BinanceHistoricalData from the root module
|
||||
from realtime import BinanceHistoricalData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -39,6 +48,9 @@ class DataInterface:
|
||||
self.data_dir = data_dir
|
||||
self.scalers = {} # Store scalers for each timeframe
|
||||
|
||||
# Initialize the historical data fetcher
|
||||
self.historical_data = BinanceHistoricalData()
|
||||
|
||||
# Create data directory if it doesn't exist
|
||||
os.makedirs(self.data_dir, exist_ok=True)
|
||||
|
||||
@ -59,138 +71,39 @@ class DataInterface:
|
||||
Returns:
|
||||
pd.DataFrame: DataFrame with OHLCV data
|
||||
"""
|
||||
cache_file = os.path.join(self.data_dir, f"{self.symbol.replace('/', '_')}_{timeframe}.csv")
|
||||
|
||||
# For 1s timeframe, always fetch fresh data
|
||||
if timeframe == '1s':
|
||||
use_cache = False
|
||||
|
||||
# Check if cached data exists and is recent
|
||||
if use_cache and os.path.exists(cache_file):
|
||||
try:
|
||||
df = pd.read_csv(cache_file, parse_dates=['timestamp'])
|
||||
# If we have enough data and it's recent, use it
|
||||
if len(df) >= n_candles:
|
||||
logger.info(f"Using cached data for {self.symbol} {timeframe} ({len(df)} candles)")
|
||||
self.dataframes[timeframe] = df
|
||||
return df.tail(n_candles)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading cached data: {str(e)}")
|
||||
|
||||
# If we get here, we need to fetch data
|
||||
try:
|
||||
logger.info(f"Fetching historical data for {self.symbol} {timeframe}")
|
||||
|
||||
# For 1s timeframe, we need more data points
|
||||
if timeframe == '1s':
|
||||
n_candles = min(n_candles * 60, 10000) # Up to 10k ticks
|
||||
|
||||
# Placeholder for real data fetching
|
||||
self._fetch_data_from_exchange(timeframe, n_candles)
|
||||
|
||||
# Save to cache (except for 1s timeframe)
|
||||
if self.dataframes[timeframe] is not None and timeframe != '1s':
|
||||
self.dataframes[timeframe].to_csv(cache_file, index=False)
|
||||
return self.dataframes[timeframe]
|
||||
else:
|
||||
# Create dummy data as fallback
|
||||
logger.warning(f"Could not fetch data for {self.symbol} {timeframe}, using dummy data")
|
||||
df = self._create_dummy_data(timeframe, n_candles)
|
||||
self.dataframes[timeframe] = df
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching data: {str(e)}")
|
||||
return None
|
||||
|
||||
def _fetch_data_from_exchange(self, timeframe, n_candles):
|
||||
"""
|
||||
Placeholder method for fetching data from an exchange.
|
||||
In a real implementation, this would connect to an exchange API.
|
||||
"""
|
||||
# This is a placeholder - in a real implementation this would make API calls
|
||||
# to a cryptocurrency exchange to fetch OHLCV data
|
||||
|
||||
# For now, just generate dummy data
|
||||
self.dataframes[timeframe] = self._create_dummy_data(timeframe, n_candles)
|
||||
|
||||
def _create_dummy_data(self, timeframe, n_candles):
|
||||
"""
|
||||
Create dummy OHLCV data for testing purposes.
|
||||
|
||||
Args:
|
||||
timeframe (str): Timeframe to create data for
|
||||
n_candles (int): Number of candles to create
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: DataFrame with dummy OHLCV data
|
||||
"""
|
||||
# Map timeframe to seconds
|
||||
tf_seconds = {
|
||||
'1s': 1, # Added 1s timeframe
|
||||
# Map timeframe string to seconds for BinanceHistoricalData
|
||||
timeframe_to_seconds = {
|
||||
'1s': 1,
|
||||
'1m': 60,
|
||||
'5m': 300,
|
||||
'15m': 900,
|
||||
'30m': 1800,
|
||||
'1h': 3600,
|
||||
'4h': 14400,
|
||||
'1d': 86400
|
||||
}
|
||||
seconds = tf_seconds.get(timeframe, 3600) # Default to 1h
|
||||
|
||||
# Create timestamps
|
||||
end_time = datetime.now()
|
||||
timestamps = [end_time - timedelta(seconds=seconds * i) for i in range(n_candles)]
|
||||
timestamps.reverse() # Oldest first
|
||||
interval_seconds = timeframe_to_seconds.get(timeframe, 3600) # Default to 1h if not found
|
||||
|
||||
# Generate random price data with realistic patterns
|
||||
np.random.seed(42) # For reproducibility
|
||||
|
||||
# Start price
|
||||
price = 50000 # For BTC/USDT
|
||||
prices = []
|
||||
volumes = []
|
||||
|
||||
for i in range(n_candles):
|
||||
# Random walk with drift and volatility based on timeframe
|
||||
drift = 0.0001 * seconds # Larger drift for larger timeframes
|
||||
volatility = 0.01 * np.sqrt(seconds / 3600) # Scale volatility by sqrt of time
|
||||
|
||||
# Daily/weekly patterns
|
||||
if timeframe in ['1d', '4h']:
|
||||
# Add some cyclical patterns
|
||||
cycle = np.sin(i / 7 * np.pi) * 0.02 # Weekly cycle
|
||||
else:
|
||||
cycle = np.sin(i / 24 * np.pi) * 0.01 # Daily cycle
|
||||
|
||||
# Calculate price change with random walk + cycles (clamped to prevent overflow)
|
||||
price_change = price * np.clip(drift + volatility * np.random.randn() + cycle, -0.1, 0.1)
|
||||
price = np.clip(price + price_change, 1000, 100000) # Keep price in reasonable range
|
||||
|
||||
# Generate OHLC from the price
|
||||
open_price = price
|
||||
high_price = price * (1 + abs(0.005 * np.random.randn()))
|
||||
low_price = price * (1 - abs(0.005 * np.random.randn()))
|
||||
close_price = price * (1 + 0.002 * np.random.randn())
|
||||
|
||||
# Ensure high >= open, close, low and low <= open, close
|
||||
high_price = max(high_price, open_price, close_price)
|
||||
low_price = min(low_price, open_price, close_price)
|
||||
|
||||
# Generate volume (higher for larger price movements) with safe calculation
|
||||
volume = 10000 + 5000 * np.random.rand() + abs(price_change)/price * 10000
|
||||
|
||||
prices.append((open_price, high_price, low_price, close_price))
|
||||
volumes.append(volume)
|
||||
|
||||
# Update price for next iteration
|
||||
price = close_price
|
||||
|
||||
# Create DataFrame
|
||||
df = pd.DataFrame(
|
||||
[(t, o, h, l, c, v) for t, (o, h, l, c), v in zip(timestamps, prices, volumes)],
|
||||
columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
try:
|
||||
# Fetch data using BinanceHistoricalData
|
||||
df = self.historical_data.get_historical_candles(
|
||||
symbol=self.symbol,
|
||||
interval_seconds=interval_seconds,
|
||||
limit=n_candles
|
||||
)
|
||||
|
||||
if not df.empty:
|
||||
logger.info(f"Using data for {self.symbol} {timeframe} ({len(df)} candles)")
|
||||
self.dataframes[timeframe] = df
|
||||
return df
|
||||
else:
|
||||
logger.error(f"No data available for {self.symbol} {timeframe}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching data for {self.symbol} {timeframe}: {str(e)}")
|
||||
return None
|
||||
|
||||
def prepare_nn_input(self, timeframes=None, n_candles=500, window_size=20):
|
||||
"""
|
||||
@ -459,24 +372,27 @@ class DataInterface:
|
||||
n_candles (int): Number of future candles to look at
|
||||
|
||||
Returns:
|
||||
np.ndarray: Future prices array
|
||||
np.ndarray: Future prices array (1D array)
|
||||
"""
|
||||
if len(prices) < n_candles + 1:
|
||||
if prices is None or len(prices) < n_candles + 1:
|
||||
return None
|
||||
|
||||
# For each price point, get the maximum price in the next n_candles
|
||||
future_prices = np.zeros(len(prices))
|
||||
# Convert to numpy array if it's not already
|
||||
prices_np = np.array(prices).flatten() if not isinstance(prices, np.ndarray) else prices.flatten()
|
||||
|
||||
for i in range(len(prices) - n_candles):
|
||||
# For each price point, get the maximum price in the next n_candles
|
||||
future_prices = np.zeros(len(prices_np))
|
||||
|
||||
for i in range(len(prices_np) - n_candles):
|
||||
# Get the next n candles
|
||||
next_candles = prices[i+1:i+n_candles+1]
|
||||
next_candles = prices_np[i+1:i+n_candles+1]
|
||||
# Use the maximum price as the future price
|
||||
future_prices[i] = np.max(next_candles)
|
||||
|
||||
# For the last n_candles points, use the last available price
|
||||
future_prices[-n_candles:] = prices[-1]
|
||||
future_prices[-n_candles:] = prices_np[-1]
|
||||
|
||||
return future_prices
|
||||
return future_prices.flatten() # Ensure it's a 1D array
|
||||
|
||||
def prepare_training_data(self, refresh=False, refresh_interval=300):
|
||||
"""
|
||||
|
@ -42,6 +42,7 @@ python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_ma
|
||||
python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --framework pytorch --epochs 1000
|
||||
python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --framework pytorch --epochs 1000 --symbol BTC/USDT --timeframes 1m 5m 1h 4h --epochs 10 --batch-size 32 --window-size 20 --output-size 3
|
||||
python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --framework pytorch --epochs 10 --symbol BTC/USDT --timeframes 1s 1m 1h 1d --batch-size 32 --window-size 20 --output-size 3
|
||||
python NN/realtime_main.py --mode train --model-type cnn --epochs 1 --symbol BTC/USDT --timeframes 1s 1m --batch-size 32 --window-size 20 --output-size 3
|
||||
|
||||
python NN/realtime-main.py --mode train --model-type cnn --framework pytorch --symbol BTC/USDT --timeframes 1m 5m 1h 4h --epochs 10 --batch-size 32 --window-size 20 --output-size 3
|
||||
|
||||
|
@ -912,8 +912,16 @@ class BinanceHistoricalData:
|
||||
"""Convert interval seconds to Binance interval string"""
|
||||
if interval_seconds == 60: # 1m
|
||||
return "1m"
|
||||
elif interval_seconds == 300: # 5m
|
||||
return "5m"
|
||||
elif interval_seconds == 900: # 15m
|
||||
return "15m"
|
||||
elif interval_seconds == 1800: # 30m
|
||||
return "30m"
|
||||
elif interval_seconds == 3600: # 1h
|
||||
return "1h"
|
||||
elif interval_seconds == 14400: # 4h
|
||||
return "4h"
|
||||
elif interval_seconds == 86400: # 1d
|
||||
return "1d"
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user