filter anotations by symbol
This commit is contained in:
@@ -338,8 +338,17 @@ class AnnotationManager:
|
||||
logger.info(f"Generated lightweight test case: {test_case['test_case_id']} (OHLCV data will be fetched dynamically)")
|
||||
return test_case
|
||||
|
||||
def get_all_test_cases(self) -> List[Dict]:
|
||||
"""Load all test cases from disk"""
|
||||
def get_all_test_cases(self, symbol: Optional[str] = None) -> List[Dict]:
|
||||
"""
|
||||
Load all test cases from disk
|
||||
|
||||
Args:
|
||||
symbol: Optional symbol filter (e.g., 'ETH/USDT'). If provided, only returns
|
||||
test cases for that symbol. Critical for avoiding cross-symbol training.
|
||||
|
||||
Returns:
|
||||
List of test case dictionaries
|
||||
"""
|
||||
test_cases = []
|
||||
|
||||
if not self.test_cases_dir.exists():
|
||||
@@ -349,11 +358,22 @@ class AnnotationManager:
|
||||
try:
|
||||
with open(test_case_file, 'r') as f:
|
||||
test_case = json.load(f)
|
||||
|
||||
# CRITICAL: Filter by symbol to avoid training on wrong symbol
|
||||
if symbol:
|
||||
test_case_symbol = test_case.get('symbol', '')
|
||||
if test_case_symbol != symbol:
|
||||
logger.debug(f"Skipping {test_case_file.name}: symbol {test_case_symbol} != {symbol}")
|
||||
continue
|
||||
|
||||
test_cases.append(test_case)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading test case {test_case_file}: {e}")
|
||||
|
||||
logger.info(f"Loaded {len(test_cases)} test cases from disk")
|
||||
if symbol:
|
||||
logger.info(f"Loaded {len(test_cases)} test cases for symbol {symbol}")
|
||||
else:
|
||||
logger.info(f"Loaded {len(test_cases)} test cases (all symbols)")
|
||||
return test_cases
|
||||
|
||||
def _calculate_holding_period(self, annotation: TradeAnnotation) -> float:
|
||||
|
||||
@@ -1740,13 +1740,17 @@ class RealTrainingAdapter:
|
||||
logger.info(f"Using orchestrator's TradingTransformerTrainer")
|
||||
logger.info(f" Trainer type: {type(trainer).__name__}")
|
||||
|
||||
# Import torch at function level (not inside try block)
|
||||
import torch
|
||||
import gc
|
||||
|
||||
# Load best checkpoint if available to continue training
|
||||
try:
|
||||
checkpoint_dir = "models/checkpoints/transformer"
|
||||
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
|
||||
|
||||
if best_checkpoint_path and os.path.exists(best_checkpoint_path):
|
||||
checkpoint = torch.load(best_checkpoint_path)
|
||||
checkpoint = torch.load(best_checkpoint_path, map_location='cpu')
|
||||
trainer.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
@@ -1167,7 +1167,9 @@ class AnnotationDashboard:
|
||||
]
|
||||
|
||||
# Load test cases from disk (they were auto-generated when annotations were saved)
|
||||
all_test_cases = self.annotation_manager.get_all_test_cases()
|
||||
# CRITICAL: Filter by current symbol to avoid cross-symbol training
|
||||
current_symbol = data.get('symbol', 'ETH/USDT')
|
||||
all_test_cases = self.annotation_manager.get_all_test_cases(symbol=current_symbol)
|
||||
|
||||
# Filter to selected annotations
|
||||
test_cases = [
|
||||
|
||||
@@ -9,9 +9,107 @@ class ChartManager {
|
||||
this.charts = {};
|
||||
this.annotations = {};
|
||||
this.syncedTime = null;
|
||||
this.updateTimers = {}; // Track auto-update timers
|
||||
this.autoUpdateEnabled = false; // Auto-update state
|
||||
|
||||
console.log('ChartManager initialized with timeframes:', timeframes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Start auto-updating charts
|
||||
*/
|
||||
startAutoUpdate() {
|
||||
if (this.autoUpdateEnabled) {
|
||||
console.log('Auto-update already enabled');
|
||||
return;
|
||||
}
|
||||
|
||||
this.autoUpdateEnabled = true;
|
||||
console.log('Starting chart auto-update...');
|
||||
|
||||
// Update 1s chart every 20 seconds
|
||||
if (this.timeframes.includes('1s')) {
|
||||
this.updateTimers['1s'] = setInterval(() => {
|
||||
this.updateChart('1s');
|
||||
}, 20000); // 20 seconds
|
||||
}
|
||||
|
||||
// Update 1m chart - sync to whole minutes + every 20s
|
||||
if (this.timeframes.includes('1m')) {
|
||||
// Calculate ms until next whole minute
|
||||
const now = new Date();
|
||||
const msUntilNextMinute = (60 - now.getSeconds()) * 1000 - now.getMilliseconds();
|
||||
|
||||
// Update on next whole minute
|
||||
setTimeout(() => {
|
||||
this.updateChart('1m');
|
||||
|
||||
// Then update every 20s
|
||||
this.updateTimers['1m'] = setInterval(() => {
|
||||
this.updateChart('1m');
|
||||
}, 20000); // 20 seconds
|
||||
}, msUntilNextMinute);
|
||||
}
|
||||
|
||||
console.log('Auto-update enabled for:', Object.keys(this.updateTimers));
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop auto-updating charts
|
||||
*/
|
||||
stopAutoUpdate() {
|
||||
if (!this.autoUpdateEnabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.autoUpdateEnabled = false;
|
||||
|
||||
// Clear all timers
|
||||
Object.values(this.updateTimers).forEach(timer => clearInterval(timer));
|
||||
this.updateTimers = {};
|
||||
|
||||
console.log('Auto-update stopped');
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a single chart with fresh data
|
||||
*/
|
||||
async updateChart(timeframe) {
|
||||
try {
|
||||
const response = await fetch(`/api/chart-data?timeframe=${timeframe}&limit=1000`);
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success && data.data && data.data[timeframe]) {
|
||||
const chartData = data.data[timeframe];
|
||||
const plotId = `plot-${timeframe}`;
|
||||
|
||||
// Update chart using Plotly.react (efficient update)
|
||||
const candlestickUpdate = {
|
||||
x: [chartData.timestamps],
|
||||
open: [chartData.open],
|
||||
high: [chartData.high],
|
||||
low: [chartData.low],
|
||||
close: [chartData.close]
|
||||
};
|
||||
|
||||
const volumeUpdate = {
|
||||
x: [chartData.timestamps],
|
||||
y: [chartData.volume]
|
||||
};
|
||||
|
||||
Plotly.restyle(plotId, candlestickUpdate, [0]);
|
||||
Plotly.restyle(plotId, volumeUpdate, [1]);
|
||||
|
||||
console.log(`Updated ${timeframe} chart at ${new Date().toLocaleTimeString()}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Error updating ${timeframe} chart:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize charts for all timeframes with pivot bounds
|
||||
|
||||
Reference in New Issue
Block a user