filter anotations by symbol

This commit is contained in:
Dobromir Popov
2025-11-13 15:40:21 +02:00
parent 25287d0e9e
commit bf2a6cf96e
4 changed files with 129 additions and 5 deletions

View File

@@ -338,8 +338,17 @@ class AnnotationManager:
logger.info(f"Generated lightweight test case: {test_case['test_case_id']} (OHLCV data will be fetched dynamically)")
return test_case
def get_all_test_cases(self) -> List[Dict]:
"""Load all test cases from disk"""
def get_all_test_cases(self, symbol: Optional[str] = None) -> List[Dict]:
"""
Load all test cases from disk
Args:
symbol: Optional symbol filter (e.g., 'ETH/USDT'). If provided, only returns
test cases for that symbol. Critical for avoiding cross-symbol training.
Returns:
List of test case dictionaries
"""
test_cases = []
if not self.test_cases_dir.exists():
@@ -349,11 +358,22 @@ class AnnotationManager:
try:
with open(test_case_file, 'r') as f:
test_case = json.load(f)
# CRITICAL: Filter by symbol to avoid training on wrong symbol
if symbol:
test_case_symbol = test_case.get('symbol', '')
if test_case_symbol != symbol:
logger.debug(f"Skipping {test_case_file.name}: symbol {test_case_symbol} != {symbol}")
continue
test_cases.append(test_case)
except Exception as e:
logger.error(f"Error loading test case {test_case_file}: {e}")
logger.info(f"Loaded {len(test_cases)} test cases from disk")
if symbol:
logger.info(f"Loaded {len(test_cases)} test cases for symbol {symbol}")
else:
logger.info(f"Loaded {len(test_cases)} test cases (all symbols)")
return test_cases
def _calculate_holding_period(self, annotation: TradeAnnotation) -> float:

View File

@@ -1740,13 +1740,17 @@ class RealTrainingAdapter:
logger.info(f"Using orchestrator's TradingTransformerTrainer")
logger.info(f" Trainer type: {type(trainer).__name__}")
# Import torch at function level (not inside try block)
import torch
import gc
# Load best checkpoint if available to continue training
try:
checkpoint_dir = "models/checkpoints/transformer"
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
if best_checkpoint_path and os.path.exists(best_checkpoint_path):
checkpoint = torch.load(best_checkpoint_path)
checkpoint = torch.load(best_checkpoint_path, map_location='cpu')
trainer.model.load_state_dict(checkpoint['model_state_dict'])
trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

View File

@@ -1167,7 +1167,9 @@ class AnnotationDashboard:
]
# Load test cases from disk (they were auto-generated when annotations were saved)
all_test_cases = self.annotation_manager.get_all_test_cases()
# CRITICAL: Filter by current symbol to avoid cross-symbol training
current_symbol = data.get('symbol', 'ETH/USDT')
all_test_cases = self.annotation_manager.get_all_test_cases(symbol=current_symbol)
# Filter to selected annotations
test_cases = [

View File

@@ -9,9 +9,107 @@ class ChartManager {
this.charts = {};
this.annotations = {};
this.syncedTime = null;
this.updateTimers = {}; // Track auto-update timers
this.autoUpdateEnabled = false; // Auto-update state
console.log('ChartManager initialized with timeframes:', timeframes);
}
/**
* Start auto-updating charts
*/
startAutoUpdate() {
if (this.autoUpdateEnabled) {
console.log('Auto-update already enabled');
return;
}
this.autoUpdateEnabled = true;
console.log('Starting chart auto-update...');
// Update 1s chart every 20 seconds
if (this.timeframes.includes('1s')) {
this.updateTimers['1s'] = setInterval(() => {
this.updateChart('1s');
}, 20000); // 20 seconds
}
// Update 1m chart - sync to whole minutes + every 20s
if (this.timeframes.includes('1m')) {
// Calculate ms until next whole minute
const now = new Date();
const msUntilNextMinute = (60 - now.getSeconds()) * 1000 - now.getMilliseconds();
// Update on next whole minute
setTimeout(() => {
this.updateChart('1m');
// Then update every 20s
this.updateTimers['1m'] = setInterval(() => {
this.updateChart('1m');
}, 20000); // 20 seconds
}, msUntilNextMinute);
}
console.log('Auto-update enabled for:', Object.keys(this.updateTimers));
}
/**
* Stop auto-updating charts
*/
stopAutoUpdate() {
if (!this.autoUpdateEnabled) {
return;
}
this.autoUpdateEnabled = false;
// Clear all timers
Object.values(this.updateTimers).forEach(timer => clearInterval(timer));
this.updateTimers = {};
console.log('Auto-update stopped');
}
/**
* Update a single chart with fresh data
*/
async updateChart(timeframe) {
try {
const response = await fetch(`/api/chart-data?timeframe=${timeframe}&limit=1000`);
if (!response.ok) {
throw new Error(`HTTP ${response.status}`);
}
const data = await response.json();
if (data.success && data.data && data.data[timeframe]) {
const chartData = data.data[timeframe];
const plotId = `plot-${timeframe}`;
// Update chart using Plotly.react (efficient update)
const candlestickUpdate = {
x: [chartData.timestamps],
open: [chartData.open],
high: [chartData.high],
low: [chartData.low],
close: [chartData.close]
};
const volumeUpdate = {
x: [chartData.timestamps],
y: [chartData.volume]
};
Plotly.restyle(plotId, candlestickUpdate, [0]);
Plotly.restyle(plotId, volumeUpdate, [1]);
console.log(`Updated ${timeframe} chart at ${new Date().toLocaleTimeString()}`);
}
} catch (error) {
console.error(`Error updating ${timeframe} chart:`, error);
}
}
/**
* Initialize charts for all timeframes with pivot bounds