filter anotations by symbol
This commit is contained in:
@@ -338,8 +338,17 @@ class AnnotationManager:
|
|||||||
logger.info(f"Generated lightweight test case: {test_case['test_case_id']} (OHLCV data will be fetched dynamically)")
|
logger.info(f"Generated lightweight test case: {test_case['test_case_id']} (OHLCV data will be fetched dynamically)")
|
||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
def get_all_test_cases(self) -> List[Dict]:
|
def get_all_test_cases(self, symbol: Optional[str] = None) -> List[Dict]:
|
||||||
"""Load all test cases from disk"""
|
"""
|
||||||
|
Load all test cases from disk
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Optional symbol filter (e.g., 'ETH/USDT'). If provided, only returns
|
||||||
|
test cases for that symbol. Critical for avoiding cross-symbol training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of test case dictionaries
|
||||||
|
"""
|
||||||
test_cases = []
|
test_cases = []
|
||||||
|
|
||||||
if not self.test_cases_dir.exists():
|
if not self.test_cases_dir.exists():
|
||||||
@@ -349,11 +358,22 @@ class AnnotationManager:
|
|||||||
try:
|
try:
|
||||||
with open(test_case_file, 'r') as f:
|
with open(test_case_file, 'r') as f:
|
||||||
test_case = json.load(f)
|
test_case = json.load(f)
|
||||||
|
|
||||||
|
# CRITICAL: Filter by symbol to avoid training on wrong symbol
|
||||||
|
if symbol:
|
||||||
|
test_case_symbol = test_case.get('symbol', '')
|
||||||
|
if test_case_symbol != symbol:
|
||||||
|
logger.debug(f"Skipping {test_case_file.name}: symbol {test_case_symbol} != {symbol}")
|
||||||
|
continue
|
||||||
|
|
||||||
test_cases.append(test_case)
|
test_cases.append(test_case)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading test case {test_case_file}: {e}")
|
logger.error(f"Error loading test case {test_case_file}: {e}")
|
||||||
|
|
||||||
logger.info(f"Loaded {len(test_cases)} test cases from disk")
|
if symbol:
|
||||||
|
logger.info(f"Loaded {len(test_cases)} test cases for symbol {symbol}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Loaded {len(test_cases)} test cases (all symbols)")
|
||||||
return test_cases
|
return test_cases
|
||||||
|
|
||||||
def _calculate_holding_period(self, annotation: TradeAnnotation) -> float:
|
def _calculate_holding_period(self, annotation: TradeAnnotation) -> float:
|
||||||
|
|||||||
@@ -1740,13 +1740,17 @@ class RealTrainingAdapter:
|
|||||||
logger.info(f"Using orchestrator's TradingTransformerTrainer")
|
logger.info(f"Using orchestrator's TradingTransformerTrainer")
|
||||||
logger.info(f" Trainer type: {type(trainer).__name__}")
|
logger.info(f" Trainer type: {type(trainer).__name__}")
|
||||||
|
|
||||||
|
# Import torch at function level (not inside try block)
|
||||||
|
import torch
|
||||||
|
import gc
|
||||||
|
|
||||||
# Load best checkpoint if available to continue training
|
# Load best checkpoint if available to continue training
|
||||||
try:
|
try:
|
||||||
checkpoint_dir = "models/checkpoints/transformer"
|
checkpoint_dir = "models/checkpoints/transformer"
|
||||||
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
|
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
|
||||||
|
|
||||||
if best_checkpoint_path and os.path.exists(best_checkpoint_path):
|
if best_checkpoint_path and os.path.exists(best_checkpoint_path):
|
||||||
checkpoint = torch.load(best_checkpoint_path)
|
checkpoint = torch.load(best_checkpoint_path, map_location='cpu')
|
||||||
trainer.model.load_state_dict(checkpoint['model_state_dict'])
|
trainer.model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||||
|
|||||||
@@ -1167,7 +1167,9 @@ class AnnotationDashboard:
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Load test cases from disk (they were auto-generated when annotations were saved)
|
# Load test cases from disk (they were auto-generated when annotations were saved)
|
||||||
all_test_cases = self.annotation_manager.get_all_test_cases()
|
# CRITICAL: Filter by current symbol to avoid cross-symbol training
|
||||||
|
current_symbol = data.get('symbol', 'ETH/USDT')
|
||||||
|
all_test_cases = self.annotation_manager.get_all_test_cases(symbol=current_symbol)
|
||||||
|
|
||||||
# Filter to selected annotations
|
# Filter to selected annotations
|
||||||
test_cases = [
|
test_cases = [
|
||||||
|
|||||||
@@ -9,9 +9,107 @@ class ChartManager {
|
|||||||
this.charts = {};
|
this.charts = {};
|
||||||
this.annotations = {};
|
this.annotations = {};
|
||||||
this.syncedTime = null;
|
this.syncedTime = null;
|
||||||
|
this.updateTimers = {}; // Track auto-update timers
|
||||||
|
this.autoUpdateEnabled = false; // Auto-update state
|
||||||
|
|
||||||
console.log('ChartManager initialized with timeframes:', timeframes);
|
console.log('ChartManager initialized with timeframes:', timeframes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Start auto-updating charts
|
||||||
|
*/
|
||||||
|
startAutoUpdate() {
|
||||||
|
if (this.autoUpdateEnabled) {
|
||||||
|
console.log('Auto-update already enabled');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.autoUpdateEnabled = true;
|
||||||
|
console.log('Starting chart auto-update...');
|
||||||
|
|
||||||
|
// Update 1s chart every 20 seconds
|
||||||
|
if (this.timeframes.includes('1s')) {
|
||||||
|
this.updateTimers['1s'] = setInterval(() => {
|
||||||
|
this.updateChart('1s');
|
||||||
|
}, 20000); // 20 seconds
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 1m chart - sync to whole minutes + every 20s
|
||||||
|
if (this.timeframes.includes('1m')) {
|
||||||
|
// Calculate ms until next whole minute
|
||||||
|
const now = new Date();
|
||||||
|
const msUntilNextMinute = (60 - now.getSeconds()) * 1000 - now.getMilliseconds();
|
||||||
|
|
||||||
|
// Update on next whole minute
|
||||||
|
setTimeout(() => {
|
||||||
|
this.updateChart('1m');
|
||||||
|
|
||||||
|
// Then update every 20s
|
||||||
|
this.updateTimers['1m'] = setInterval(() => {
|
||||||
|
this.updateChart('1m');
|
||||||
|
}, 20000); // 20 seconds
|
||||||
|
}, msUntilNextMinute);
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log('Auto-update enabled for:', Object.keys(this.updateTimers));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stop auto-updating charts
|
||||||
|
*/
|
||||||
|
stopAutoUpdate() {
|
||||||
|
if (!this.autoUpdateEnabled) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.autoUpdateEnabled = false;
|
||||||
|
|
||||||
|
// Clear all timers
|
||||||
|
Object.values(this.updateTimers).forEach(timer => clearInterval(timer));
|
||||||
|
this.updateTimers = {};
|
||||||
|
|
||||||
|
console.log('Auto-update stopped');
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update a single chart with fresh data
|
||||||
|
*/
|
||||||
|
async updateChart(timeframe) {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`/api/chart-data?timeframe=${timeframe}&limit=1000`);
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`HTTP ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success && data.data && data.data[timeframe]) {
|
||||||
|
const chartData = data.data[timeframe];
|
||||||
|
const plotId = `plot-${timeframe}`;
|
||||||
|
|
||||||
|
// Update chart using Plotly.react (efficient update)
|
||||||
|
const candlestickUpdate = {
|
||||||
|
x: [chartData.timestamps],
|
||||||
|
open: [chartData.open],
|
||||||
|
high: [chartData.high],
|
||||||
|
low: [chartData.low],
|
||||||
|
close: [chartData.close]
|
||||||
|
};
|
||||||
|
|
||||||
|
const volumeUpdate = {
|
||||||
|
x: [chartData.timestamps],
|
||||||
|
y: [chartData.volume]
|
||||||
|
};
|
||||||
|
|
||||||
|
Plotly.restyle(plotId, candlestickUpdate, [0]);
|
||||||
|
Plotly.restyle(plotId, volumeUpdate, [1]);
|
||||||
|
|
||||||
|
console.log(`Updated ${timeframe} chart at ${new Date().toLocaleTimeString()}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Error updating ${timeframe} chart:`, error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize charts for all timeframes with pivot bounds
|
* Initialize charts for all timeframes with pivot bounds
|
||||||
|
|||||||
Reference in New Issue
Block a user