multi pair inputs, wip, working training??
This commit is contained in:
@@ -17,16 +17,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class DataInterface:
|
||||
"""
|
||||
Handles data collection, processing, and preparation for neural network models.
|
||||
|
||||
This class is responsible for:
|
||||
1. Fetching historical data
|
||||
2. Preprocessing data for neural network input
|
||||
3. Generating training datasets
|
||||
4. Handling real-time data integration
|
||||
Enhanced Data Interface supporting:
|
||||
- Multiple trading pairs (up to 3)
|
||||
- Multiple timeframes per pair (1s, 1m, 1h, 1d + custom)
|
||||
- Technical indicators (up to 20)
|
||||
- Cross-timeframe normalization
|
||||
- Real-time tick streaming
|
||||
"""
|
||||
|
||||
def __init__(self, symbol="BTC/USDT", timeframes=None, data_dir="NN/data"):
|
||||
def __init__(self, symbol=None, timeframes=None, data_dir="NN/data"):
|
||||
"""
|
||||
Initialize the data interface.
|
||||
|
||||
@@ -157,9 +156,9 @@ class DataInterface:
|
||||
else:
|
||||
cycle = np.sin(i / 24 * np.pi) * 0.01 # Daily cycle
|
||||
|
||||
# Calculate price change with random walk + cycles
|
||||
price_change = price * (drift + volatility * np.random.randn() + cycle)
|
||||
price += price_change
|
||||
# Calculate price change with random walk + cycles (clamped to prevent overflow)
|
||||
price_change = price * np.clip(drift + volatility * np.random.randn() + cycle, -0.1, 0.1)
|
||||
price = np.clip(price + price_change, 1000, 100000) # Keep price in reasonable range
|
||||
|
||||
# Generate OHLC from the price
|
||||
open_price = price
|
||||
@@ -171,8 +170,8 @@ class DataInterface:
|
||||
high_price = max(high_price, open_price, close_price)
|
||||
low_price = min(low_price, open_price, close_price)
|
||||
|
||||
# Generate volume (higher for larger price movements)
|
||||
volume = abs(price_change) * (10000 + 5000 * np.random.rand())
|
||||
# Generate volume (higher for larger price movements) with safe calculation
|
||||
volume = 10000 + 5000 * np.random.rand() + abs(price_change)/price * 10000
|
||||
|
||||
prices.append((open_price, high_price, low_price, close_price))
|
||||
volumes.append(volume)
|
||||
@@ -217,19 +216,41 @@ class DataInterface:
|
||||
logger.error("No data available for feature creation")
|
||||
return None, None, None
|
||||
|
||||
# For simplicity, we'll use just one timeframe for now
|
||||
# In a more complex implementation, we would merge multiple timeframes
|
||||
primary_tf = timeframes[0]
|
||||
if primary_tf not in dfs:
|
||||
logger.error(f"Primary timeframe {primary_tf} not available")
|
||||
# Create features for each timeframe
|
||||
features = []
|
||||
targets = []
|
||||
timestamps = []
|
||||
|
||||
for tf in timeframes:
|
||||
if tf in dfs:
|
||||
X, y, ts = self._create_features(dfs[tf], window_size)
|
||||
features.append(X)
|
||||
if len(targets) == 0: # Only need targets from one timeframe
|
||||
targets = y
|
||||
timestamps = ts
|
||||
|
||||
if not features:
|
||||
return None, None, None
|
||||
|
||||
# Stack features from all timeframes along the time dimension
|
||||
# Reshape each timeframe's features to [samples, window, 1, features]
|
||||
reshaped_features = [f.reshape(f.shape[0], f.shape[1], 1, f.shape[2])
|
||||
for f in features]
|
||||
# Concatenate along the channel dimension
|
||||
X = np.concatenate(reshaped_features, axis=2)
|
||||
# Reshape to [samples, window, features*timeframes]
|
||||
X = X.reshape(X.shape[0], X.shape[1], -1)
|
||||
|
||||
df = dfs[primary_tf]
|
||||
# Validate data
|
||||
if np.any(np.isnan(X)) or np.any(np.isinf(X)):
|
||||
logger.error("Generated features contain NaN or infinite values")
|
||||
return None, None, None
|
||||
|
||||
# Ensure all values are finite and normalized
|
||||
X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
X = np.clip(X, -1e6, 1e6) # Clip extreme values
|
||||
|
||||
# Create features
|
||||
X, y, timestamps = self._create_features(df, window_size)
|
||||
|
||||
return X, y, timestamps
|
||||
return X, targets, timestamps
|
||||
|
||||
def _create_features(self, df, window_size):
|
||||
"""
|
||||
@@ -248,9 +269,28 @@ class DataInterface:
|
||||
# Extract OHLCV columns
|
||||
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values
|
||||
|
||||
# Scale the data
|
||||
scaler = MinMaxScaler()
|
||||
ohlcv_scaled = scaler.fit_transform(ohlcv)
|
||||
# Validate data before scaling
|
||||
if np.any(np.isnan(ohlcv)) or np.any(np.isinf(ohlcv)):
|
||||
logger.error("Input data contains NaN or infinite values")
|
||||
return None, None, None
|
||||
|
||||
# Handle potential constant columns (avoid division by zero in scaler)
|
||||
ohlcv = np.nan_to_num(ohlcv, nan=0.0)
|
||||
ranges = np.ptp(ohlcv, axis=0)
|
||||
for i in range(len(ranges)):
|
||||
if ranges[i] == 0: # Constant column
|
||||
ohlcv[:, i] = 1 if i == 3 else 0 # Set close to 1, others to 0
|
||||
|
||||
# Scale the data with safety checks
|
||||
try:
|
||||
scaler = MinMaxScaler()
|
||||
ohlcv_scaled = scaler.fit_transform(ohlcv)
|
||||
if np.any(np.isnan(ohlcv_scaled)) or np.any(np.isinf(ohlcv_scaled)):
|
||||
logger.error("Scaling produced invalid values")
|
||||
return None, None, None
|
||||
except Exception as e:
|
||||
logger.error(f"Scaling failed: {str(e)}")
|
||||
return None, None, None
|
||||
|
||||
# Store the scaler for later use
|
||||
timeframe = next((tf for tf in self.timeframes if self.dataframes.get(tf) is not None and
|
||||
@@ -343,6 +383,11 @@ class DataInterface:
|
||||
logger.info(f"Dataset generated and saved: {dataset_name}")
|
||||
return dataset_info
|
||||
|
||||
def get_feature_count(self):
|
||||
"""Get the number of features per input sample"""
|
||||
# OHLCV (5 features) per timeframe
|
||||
return 5 * len(self.timeframes)
|
||||
|
||||
def prepare_realtime_input(self, timeframe='1h', n_candles=30, window_size=20):
|
||||
"""
|
||||
Prepare a single input sample from the most recent data for real-time inference.
|
||||
@@ -387,4 +432,4 @@ class DataInterface:
|
||||
# Get timestamp of the most recent candle
|
||||
timestamp = df['timestamp'].iloc[-1]
|
||||
|
||||
return X, timestamp
|
||||
return X, timestamp
|
||||
|
||||
Reference in New Issue
Block a user