stored
This commit is contained in:
@@ -7,10 +7,11 @@ Handles storage, retrieval, and test case generation from manual trade annotatio
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
import logging
|
||||
import pytz
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -205,49 +206,20 @@ class AnnotationManager:
|
||||
}
|
||||
}
|
||||
|
||||
# Populate market state with ±5 minutes of data for negative examples
|
||||
# Populate market state with ±5 minutes of data for training
|
||||
if data_provider:
|
||||
try:
|
||||
entry_time = datetime.fromisoformat(annotation.entry['timestamp'].replace('Z', '+00:00'))
|
||||
exit_time = datetime.fromisoformat(annotation.exit['timestamp'].replace('Z', '+00:00'))
|
||||
|
||||
# Calculate time window: ±5 minutes around entry
|
||||
time_window_before = timedelta(minutes=5)
|
||||
time_window_after = timedelta(minutes=5)
|
||||
logger.info(f"Fetching market state for {annotation.symbol} at {entry_time} (±5min around entry)")
|
||||
|
||||
start_time = entry_time - time_window_before
|
||||
end_time = entry_time + time_window_after
|
||||
|
||||
logger.info(f"Fetching market data from {start_time} to {end_time} (±5min around entry)")
|
||||
|
||||
# Fetch OHLCV data for all timeframes
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
market_state = {}
|
||||
|
||||
for tf in timeframes:
|
||||
# Get data for the time window
|
||||
df = data_provider.get_historical_data(
|
||||
symbol=annotation.symbol,
|
||||
timeframe=tf,
|
||||
limit=1000 # Get enough data to cover ±5 minutes
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Filter to time window
|
||||
df_window = df[(df.index >= start_time) & (df.index <= end_time)]
|
||||
|
||||
if not df_window.empty:
|
||||
# Convert to list format
|
||||
market_state[f'ohlcv_{tf}'] = {
|
||||
'timestamps': df_window.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': df_window['open'].tolist(),
|
||||
'high': df_window['high'].tolist(),
|
||||
'low': df_window['low'].tolist(),
|
||||
'close': df_window['close'].tolist(),
|
||||
'volume': df_window['volume'].tolist()
|
||||
}
|
||||
|
||||
logger.info(f" {tf}: {len(df_window)} candles in ±5min window")
|
||||
# Use the new data provider method to get market state at the entry time
|
||||
market_state = data_provider.get_market_state_at_time(
|
||||
symbol=annotation.symbol,
|
||||
timestamp=entry_time,
|
||||
context_window_minutes=5
|
||||
)
|
||||
|
||||
# Add training labels for each timestamp
|
||||
# This helps model learn WHERE to signal and WHERE NOT to signal
|
||||
@@ -330,6 +302,9 @@ class AnnotationManager:
|
||||
for ts_str in timestamps:
|
||||
try:
|
||||
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
|
||||
# Make timezone-aware to match entry_time
|
||||
if ts.tzinfo is None:
|
||||
ts = pytz.UTC.localize(ts)
|
||||
|
||||
# Determine label based on position relative to entry/exit
|
||||
if abs((ts - entry_time).total_seconds()) < 60: # Within 1 minute of entry
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
{
|
||||
"test_case_id": "annotation_844508ec-fd73-46e9-861e-b7c401448693",
|
||||
"symbol": "ETH/USDT",
|
||||
"timestamp": "2025-04-16",
|
||||
"action": "BUY",
|
||||
"market_state": {},
|
||||
"expected_outcome": {
|
||||
"direction": "LONG",
|
||||
"profit_loss_pct": 185.7520575218433,
|
||||
"holding_period_seconds": 11491200.0,
|
||||
"exit_price": 4506.71,
|
||||
"entry_price": 1577.14
|
||||
},
|
||||
"annotation_metadata": {
|
||||
"annotator": "manual",
|
||||
"confidence": 1.0,
|
||||
"notes": "",
|
||||
"created_at": "2025-10-20T13:53:02.710405",
|
||||
"timeframe": "1d"
|
||||
}
|
||||
}
|
||||
@@ -84,7 +84,29 @@ class AnnotationDashboard:
|
||||
def __init__(self):
|
||||
"""Initialize the dashboard"""
|
||||
# Load configuration
|
||||
self.config = get_config() if get_config else {}
|
||||
try:
|
||||
# Always try YAML loading first since get_config might not work in standalone mode
|
||||
import yaml
|
||||
with open('config.yaml', 'r') as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
logger.info(f"Loaded config via YAML: {len(self.config)} keys")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load config via YAML: {e}")
|
||||
try:
|
||||
# Fallback to get_config if available
|
||||
if get_config:
|
||||
self.config = get_config()
|
||||
logger.info(f"Loaded config via get_config: {len(self.config)} keys")
|
||||
else:
|
||||
raise Exception("get_config not available")
|
||||
except Exception as e2:
|
||||
logger.warning(f"Could not load config via get_config: {e2}")
|
||||
# Final fallback config with SOL/USDT
|
||||
self.config = {
|
||||
'symbols': ['ETH/USDT', 'BTC/USDT', 'SOL/USDT'],
|
||||
'timeframes': ['1s', '1m', '1h', '1d']
|
||||
}
|
||||
logger.info("Using fallback config")
|
||||
|
||||
# Initialize Flask app
|
||||
self.server = Flask(
|
||||
@@ -207,10 +229,15 @@ class AnnotationDashboard:
|
||||
|
||||
logger.info(f"Loading dashboard with {len(annotations_data)} existing annotations")
|
||||
|
||||
# Get symbols and timeframes from config
|
||||
symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT'])
|
||||
timeframes = self.config.get('timeframes', ['1s', '1m', '1h', '1d'])
|
||||
|
||||
# Prepare template data
|
||||
template_data = {
|
||||
'current_symbol': 'ETH/USDT',
|
||||
'timeframes': ['1s', '1m', '1h', '1d'],
|
||||
'current_symbol': symbols[0] if symbols else 'ETH/USDT', # Use first symbol as default
|
||||
'symbols': symbols,
|
||||
'timeframes': timeframes,
|
||||
'annotations': annotations_data
|
||||
}
|
||||
|
||||
@@ -328,59 +355,36 @@ class AnnotationDashboard:
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
# Capture market state at entry and exit times
|
||||
# Capture market state at entry and exit times using data provider
|
||||
entry_market_state = {}
|
||||
exit_market_state = {}
|
||||
|
||||
if self.data_loader:
|
||||
if self.data_provider:
|
||||
try:
|
||||
# Parse timestamps
|
||||
entry_time = datetime.fromisoformat(data['entry']['timestamp'].replace('Z', '+00:00'))
|
||||
exit_time = datetime.fromisoformat(data['exit']['timestamp'].replace('Z', '+00:00'))
|
||||
|
||||
# Fetch market data for all timeframes at entry time
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
for tf in timeframes:
|
||||
df = self.data_loader.get_data(
|
||||
symbol=data['symbol'],
|
||||
timeframe=tf,
|
||||
end_time=entry_time,
|
||||
limit=100
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
entry_market_state[f'ohlcv_{tf}'] = {
|
||||
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': df['open'].tolist(),
|
||||
'high': df['high'].tolist(),
|
||||
'low': df['low'].tolist(),
|
||||
'close': df['close'].tolist(),
|
||||
'volume': df['volume'].tolist()
|
||||
}
|
||||
# Use the new data provider method to get market state at entry time
|
||||
entry_market_state = self.data_provider.get_market_state_at_time(
|
||||
symbol=data['symbol'],
|
||||
timestamp=entry_time,
|
||||
context_window_minutes=5
|
||||
)
|
||||
|
||||
# Fetch market data at exit time
|
||||
for tf in timeframes:
|
||||
df = self.data_loader.get_data(
|
||||
symbol=data['symbol'],
|
||||
timeframe=tf,
|
||||
end_time=exit_time,
|
||||
limit=100
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
exit_market_state[f'ohlcv_{tf}'] = {
|
||||
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': df['open'].tolist(),
|
||||
'high': df['high'].tolist(),
|
||||
'low': df['low'].tolist(),
|
||||
'close': df['close'].tolist(),
|
||||
'volume': df['volume'].tolist()
|
||||
}
|
||||
# Use the new data provider method to get market state at exit time
|
||||
exit_market_state = self.data_provider.get_market_state_at_time(
|
||||
symbol=data['symbol'],
|
||||
timestamp=exit_time,
|
||||
context_window_minutes=5
|
||||
)
|
||||
|
||||
logger.info(f"Captured market state: {len(entry_market_state)} timeframes at entry, {len(exit_market_state)} at exit")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing market state: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Create annotation with market context
|
||||
annotation = self.annotation_manager.create_annotation(
|
||||
@@ -456,6 +460,103 @@ class AnnotationDashboard:
|
||||
}
|
||||
})
|
||||
|
||||
@self.server.route('/api/clear-all-annotations', methods=['POST'])
|
||||
def clear_all_annotations():
|
||||
"""Clear all annotations"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
symbol = data.get('symbol', None)
|
||||
|
||||
# Get current annotations count
|
||||
annotations = self.annotation_manager.get_annotations(symbol=symbol)
|
||||
deleted_count = len(annotations)
|
||||
|
||||
if deleted_count == 0:
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'deleted_count': 0,
|
||||
'message': 'No annotations to clear'
|
||||
})
|
||||
|
||||
# Clear all annotations
|
||||
for annotation in annotations:
|
||||
annotation_id = annotation.annotation_id if hasattr(annotation, 'annotation_id') else annotation.get('annotation_id')
|
||||
self.annotation_manager.delete_annotation(annotation_id)
|
||||
|
||||
logger.info(f"Cleared {deleted_count} annotations")
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'deleted_count': deleted_count,
|
||||
'message': f'Cleared {deleted_count} annotations'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing all annotations: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': {
|
||||
'code': 'CLEAR_ALL_ANNOTATIONS_ERROR',
|
||||
'message': str(e)
|
||||
}
|
||||
})
|
||||
|
||||
@self.server.route('/api/refresh-data', methods=['POST'])
|
||||
def refresh_data():
|
||||
"""Refresh chart data from data provider"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
symbol = data.get('symbol', 'ETH/USDT')
|
||||
timeframes = data.get('timeframes', ['1s', '1m', '1h', '1d'])
|
||||
|
||||
logger.info(f"Refreshing data for {symbol} with timeframes: {timeframes}")
|
||||
|
||||
# Force refresh data from data provider
|
||||
chart_data = {}
|
||||
|
||||
if self.data_provider:
|
||||
for timeframe in timeframes:
|
||||
try:
|
||||
# Force refresh by setting refresh=True
|
||||
df = self.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=1000,
|
||||
refresh=True
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
chart_data[timeframe] = {
|
||||
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': df['open'].tolist(),
|
||||
'high': df['high'].tolist(),
|
||||
'low': df['low'].tolist(),
|
||||
'close': df['close'].tolist(),
|
||||
'volume': df['volume'].tolist()
|
||||
}
|
||||
logger.info(f"Refreshed {timeframe}: {len(df)} candles")
|
||||
else:
|
||||
logger.warning(f"No data available for {timeframe}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing {timeframe} data: {e}")
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'chart_data': chart_data,
|
||||
'message': f'Refreshed data for {symbol}'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing data: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': {
|
||||
'code': 'REFRESH_DATA_ERROR',
|
||||
'message': str(e)
|
||||
}
|
||||
})
|
||||
|
||||
@self.server.route('/api/generate-test-case', methods=['POST'])
|
||||
def generate_test_case():
|
||||
"""Generate test case from annotation"""
|
||||
|
||||
@@ -4,9 +4,14 @@
|
||||
<i class="fas fa-tags"></i>
|
||||
Annotations
|
||||
</h6>
|
||||
<button class="btn btn-sm btn-outline-light" id="export-annotations-btn" title="Export">
|
||||
<i class="fas fa-download"></i>
|
||||
</button>
|
||||
<div class="btn-group btn-group-sm">
|
||||
<button class="btn btn-sm btn-outline-light" id="export-annotations-btn" title="Export">
|
||||
<i class="fas fa-download"></i>
|
||||
</button>
|
||||
<button class="btn btn-sm btn-outline-danger" id="clear-all-annotations-btn" title="Clear All">
|
||||
<i class="fas fa-trash-alt"></i>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="card-body p-2">
|
||||
<div class="list-group list-group-flush" id="annotations-list">
|
||||
@@ -48,6 +53,48 @@
|
||||
});
|
||||
});
|
||||
|
||||
// Clear all annotations
|
||||
document.getElementById('clear-all-annotations-btn').addEventListener('click', function() {
|
||||
if (appState.annotations.length === 0) {
|
||||
showError('No annotations to clear');
|
||||
return;
|
||||
}
|
||||
|
||||
if (!confirm(`Are you sure you want to delete all ${appState.annotations.length} annotations? This action cannot be undone.`)) {
|
||||
return;
|
||||
}
|
||||
|
||||
fetch('/api/clear-all-annotations', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({
|
||||
symbol: appState.currentSymbol
|
||||
})
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.success) {
|
||||
// Clear from app state
|
||||
appState.annotations = [];
|
||||
|
||||
// Update UI
|
||||
renderAnnotationsList(appState.annotations);
|
||||
|
||||
// Clear from chart
|
||||
if (appState.chartManager) {
|
||||
appState.chartManager.clearAllAnnotations();
|
||||
}
|
||||
|
||||
showSuccess(`Cleared ${data.deleted_count} annotations`);
|
||||
} else {
|
||||
showError('Failed to clear annotations: ' + data.error.message);
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
showError('Network error: ' + error.message);
|
||||
});
|
||||
});
|
||||
|
||||
// Function to render annotations list
|
||||
function renderAnnotationsList(annotations) {
|
||||
const listContainer = document.getElementById('annotations-list');
|
||||
|
||||
@@ -10,30 +10,31 @@
|
||||
<div class="mb-3">
|
||||
<label for="symbol-select" class="form-label">Symbol</label>
|
||||
<select class="form-select form-select-sm" id="symbol-select">
|
||||
<option value="ETH/USDT" selected>ETH/USDT</option>
|
||||
<option value="BTC/USDT">BTC/USDT</option>
|
||||
{% for symbol in symbols %}
|
||||
<option value="{{ symbol }}" {% if symbol == current_symbol %}selected{% endif %}>{{ symbol }}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<!-- Timeframe Selection -->
|
||||
<div class="mb-3">
|
||||
<label class="form-label">Timeframes</label>
|
||||
{% for timeframe in timeframes %}
|
||||
<div class="form-check">
|
||||
<input class="form-check-input" type="checkbox" id="tf-1s" value="1s" checked>
|
||||
<label class="form-check-label" for="tf-1s">1 Second</label>
|
||||
</div>
|
||||
<div class="form-check">
|
||||
<input class="form-check-input" type="checkbox" id="tf-1m" value="1m" checked>
|
||||
<label class="form-check-label" for="tf-1m">1 Minute</label>
|
||||
</div>
|
||||
<div class="form-check">
|
||||
<input class="form-check-input" type="checkbox" id="tf-1h" value="1h" checked>
|
||||
<label class="form-check-label" for="tf-1h">1 Hour</label>
|
||||
</div>
|
||||
<div class="form-check">
|
||||
<input class="form-check-input" type="checkbox" id="tf-1d" value="1d" checked>
|
||||
<label class="form-check-label" for="tf-1d">1 Day</label>
|
||||
<input class="form-check-input" type="checkbox" id="tf-{{ timeframe }}" value="{{ timeframe }}" checked>
|
||||
<label class="form-check-label" for="tf-{{ timeframe }}">
|
||||
{% if timeframe == '1s' %}1 Second
|
||||
{% elif timeframe == '1m' %}1 Minute
|
||||
{% elif timeframe == '1h' %}1 Hour
|
||||
{% elif timeframe == '1d' %}1 Day
|
||||
{% elif timeframe == '5m' %}5 Minutes
|
||||
{% elif timeframe == '15m' %}15 Minutes
|
||||
{% elif timeframe == '4h' %}4 Hours
|
||||
{% else %}{{ timeframe }}
|
||||
{% endif %}
|
||||
</label>
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
|
||||
<!-- Time Navigation -->
|
||||
@@ -73,6 +74,21 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Data Refresh -->
|
||||
<div class="mb-3">
|
||||
<label class="form-label">Data</label>
|
||||
<div class="btn-group w-100" role="group">
|
||||
<button type="button" class="btn btn-sm btn-outline-success" id="refresh-data-btn" title="Refresh Data">
|
||||
<i class="fas fa-sync-alt"></i>
|
||||
Refresh
|
||||
</button>
|
||||
<button type="button" class="btn btn-sm btn-outline-info" id="auto-refresh-toggle" title="Auto Refresh">
|
||||
<i class="fas fa-play" id="auto-refresh-icon"></i>
|
||||
</button>
|
||||
</div>
|
||||
<small class="text-muted">Refresh chart data from data provider</small>
|
||||
</div>
|
||||
|
||||
<!-- Annotation Mode -->
|
||||
<div class="mb-3">
|
||||
<label class="form-label">Annotation Mode</label>
|
||||
@@ -168,4 +184,87 @@
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Data refresh functionality
|
||||
let autoRefreshInterval = null;
|
||||
let isAutoRefreshEnabled = false;
|
||||
|
||||
// Manual refresh button
|
||||
document.getElementById('refresh-data-btn').addEventListener('click', function() {
|
||||
refreshChartData();
|
||||
});
|
||||
|
||||
// Auto refresh toggle
|
||||
document.getElementById('auto-refresh-toggle').addEventListener('click', function() {
|
||||
toggleAutoRefresh();
|
||||
});
|
||||
|
||||
function refreshChartData() {
|
||||
const refreshBtn = document.getElementById('refresh-data-btn');
|
||||
const icon = refreshBtn.querySelector('i');
|
||||
|
||||
// Show loading state
|
||||
icon.className = 'fas fa-spinner fa-spin';
|
||||
refreshBtn.disabled = true;
|
||||
|
||||
fetch('/api/refresh-data', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({
|
||||
symbol: appState.currentSymbol,
|
||||
timeframes: appState.currentTimeframes
|
||||
})
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.success) {
|
||||
// Update charts with new data
|
||||
if (appState.chartManager) {
|
||||
appState.chartManager.updateCharts(data.chart_data);
|
||||
}
|
||||
showSuccess('Chart data refreshed successfully');
|
||||
} else {
|
||||
showError('Failed to refresh data: ' + data.error.message);
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
showError('Network error: ' + error.message);
|
||||
})
|
||||
.finally(() => {
|
||||
// Reset button state
|
||||
icon.className = 'fas fa-sync-alt';
|
||||
refreshBtn.disabled = false;
|
||||
});
|
||||
}
|
||||
|
||||
function toggleAutoRefresh() {
|
||||
const toggleBtn = document.getElementById('auto-refresh-toggle');
|
||||
const icon = document.getElementById('auto-refresh-icon');
|
||||
|
||||
if (isAutoRefreshEnabled) {
|
||||
// Disable auto refresh
|
||||
if (autoRefreshInterval) {
|
||||
clearInterval(autoRefreshInterval);
|
||||
autoRefreshInterval = null;
|
||||
}
|
||||
isAutoRefreshEnabled = false;
|
||||
icon.className = 'fas fa-play';
|
||||
toggleBtn.title = 'Enable Auto Refresh';
|
||||
showSuccess('Auto refresh disabled');
|
||||
} else {
|
||||
// Enable auto refresh (every 30 seconds)
|
||||
autoRefreshInterval = setInterval(refreshChartData, 30000);
|
||||
isAutoRefreshEnabled = true;
|
||||
icon.className = 'fas fa-pause';
|
||||
toggleBtn.title = 'Disable Auto Refresh';
|
||||
showSuccess('Auto refresh enabled (30s interval)');
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up interval when page unloads
|
||||
window.addEventListener('beforeunload', function() {
|
||||
if (autoRefreshInterval) {
|
||||
clearInterval(autoRefreshInterval);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
@@ -22,11 +22,24 @@ from typing import Dict, List, Optional, Any, Tuple
|
||||
from collections import deque
|
||||
import random
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
# Import prediction database
|
||||
try:
|
||||
from core.prediction_database import get_prediction_db
|
||||
except ImportError:
|
||||
# Fallback if prediction database is not available
|
||||
def get_prediction_db():
|
||||
logger.warning("Prediction database not available, using mock")
|
||||
return None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -2364,7 +2377,13 @@ class EnhancedRealtimeTrainingSystem:
|
||||
# Use DQN model to predict action (if available)
|
||||
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
|
||||
and self.orchestrator.rl_agent):
|
||||
|
||||
# Get prediction from DQN agent
|
||||
state = self.orchestrator._create_rl_state(symbol)
|
||||
if state is not None:
|
||||
action, confidence, q_values = self.orchestrator.rl_agent.act_with_confidence(state)
|
||||
else:
|
||||
# Fallback if state creation fails
|
||||
action, q_values, confidence = self._technical_analysis_prediction(symbol)
|
||||
else:
|
||||
# Fallback to technical analysis-based prediction
|
||||
action, q_values, confidence = self._technical_analysis_prediction(symbol)
|
||||
@@ -2725,8 +2744,9 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error estimating price change: {e}")
|
||||
return 0.0 d
|
||||
ef _save_model_checkpoint(self, model_name: str, model_obj, loss: float):
|
||||
return 0.0
|
||||
|
||||
def _save_model_checkpoint(self, model_name: str, model_obj, loss: float):
|
||||
"""
|
||||
Save model checkpoint after training if performance improved
|
||||
|
||||
|
||||
@@ -72,10 +72,11 @@ exchanges:
|
||||
|
||||
# Trading Symbols Configuration
|
||||
# Primary trading pair: ETH/USDT (main signals generation)
|
||||
# Reference pair: BTC/USDT (correlation analysis only, no trading signals)
|
||||
# Reference pairs: BTC/USDT, SOL/USDT (correlation analysis and trading)
|
||||
symbols:
|
||||
- "ETH/USDT" # MAIN TRADING PAIR - Generate signals and execute trades
|
||||
- "BTC/USDT" # REFERENCE ONLY - For correlation analysis, no direct trading
|
||||
- "BTC/USDT" # REFERENCE - For correlation analysis and trading
|
||||
- "SOL/USDT" # REFERENCE - For correlation analysis and trading
|
||||
|
||||
# Timeframes for ultra-fast scalping (500x leverage)
|
||||
timeframes:
|
||||
|
||||
@@ -1269,6 +1269,110 @@ class DataProvider:
|
||||
logger.error(f"Error getting price range for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_historical_data_replay(self, symbol: str, start_time: datetime, end_time: datetime,
|
||||
timeframes: List[str] = None) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
Get historical data for a specific time period for replay/training purposes.
|
||||
This method allows "going back in time" to replay market moves.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
start_time: Start of the time period
|
||||
end_time: End of the time period
|
||||
timeframes: List of timeframes to fetch (default: ['1s', '1m', '1h', '1d'])
|
||||
|
||||
Returns:
|
||||
Dict mapping timeframe to DataFrame with OHLCV data for the period
|
||||
"""
|
||||
if timeframes is None:
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
logger.info(f"Replaying historical data for {symbol} from {start_time} to {end_time}")
|
||||
|
||||
replay_data = {}
|
||||
|
||||
for timeframe in timeframes:
|
||||
try:
|
||||
# Calculate how many candles we need for the time period
|
||||
if timeframe == '1s':
|
||||
limit = int((end_time - start_time).total_seconds()) + 100 # Extra buffer
|
||||
elif timeframe == '1m':
|
||||
limit = int((end_time - start_time).total_seconds() / 60) + 10
|
||||
elif timeframe == '1h':
|
||||
limit = int((end_time - start_time).total_seconds() / 3600) + 5
|
||||
elif timeframe == '1d':
|
||||
limit = int((end_time - start_time).total_seconds() / 86400) + 2
|
||||
else:
|
||||
limit = 1000
|
||||
|
||||
# Fetch historical data
|
||||
df = self.get_historical_data(symbol, timeframe, limit=limit, refresh=True)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Filter to the exact time period
|
||||
df_filtered = df[(df.index >= start_time) & (df.index <= end_time)]
|
||||
|
||||
if not df_filtered.empty:
|
||||
replay_data[timeframe] = df_filtered
|
||||
logger.info(f" {timeframe}: {len(df_filtered)} candles in replay period")
|
||||
else:
|
||||
logger.warning(f" {timeframe}: No data in replay period")
|
||||
replay_data[timeframe] = pd.DataFrame()
|
||||
else:
|
||||
logger.warning(f" {timeframe}: No data available")
|
||||
replay_data[timeframe] = pd.DataFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching {timeframe} data for replay: {e}")
|
||||
replay_data[timeframe] = pd.DataFrame()
|
||||
|
||||
logger.info(f"Historical replay data prepared: {len([tf for tf, df in replay_data.items() if not df.empty])} timeframes")
|
||||
return replay_data
|
||||
|
||||
def get_market_state_at_time(self, symbol: str, timestamp: datetime,
|
||||
context_window_minutes: int = 5) -> Dict[str, Any]:
|
||||
"""
|
||||
Get complete market state at a specific point in time for training.
|
||||
This includes OHLCV data ±context_window_minutes around the timestamp.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timestamp: The specific point in time
|
||||
context_window_minutes: Minutes before/after the timestamp to include
|
||||
|
||||
Returns:
|
||||
Dict with market state data in training format
|
||||
"""
|
||||
try:
|
||||
start_time = timestamp - timedelta(minutes=context_window_minutes)
|
||||
end_time = timestamp + timedelta(minutes=context_window_minutes)
|
||||
|
||||
logger.info(f"Getting market state for {symbol} at {timestamp} (±{context_window_minutes}min)")
|
||||
|
||||
# Get replay data for the time window
|
||||
replay_data = self.get_historical_data_replay(symbol, start_time, end_time)
|
||||
|
||||
# Convert to training format
|
||||
market_state = {}
|
||||
|
||||
for timeframe, df in replay_data.items():
|
||||
if not df.empty:
|
||||
market_state[f'ohlcv_{timeframe}'] = {
|
||||
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': df['open'].tolist(),
|
||||
'high': df['high'].tolist(),
|
||||
'low': df['low'].tolist(),
|
||||
'close': df['close'].tolist(),
|
||||
'volume': df['volume'].tolist()
|
||||
}
|
||||
|
||||
logger.info(f"Market state prepared with {len(market_state)} timeframes")
|
||||
return market_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting market state at time: {e}")
|
||||
return {}
|
||||
|
||||
def get_historical_data(self, symbol: str, timeframe: str, limit: int = 1000, refresh: bool = False) -> Optional[pd.DataFrame]:
|
||||
"""Get historical OHLCV data.
|
||||
- Prefer cached data for low latency.
|
||||
|
||||
Reference in New Issue
Block a user