filter anotations by symbol

This commit is contained in:
Dobromir Popov
2025-11-13 15:40:21 +02:00
parent 25287d0e9e
commit bf2a6cf96e
4 changed files with 129 additions and 5 deletions

View File

@@ -338,8 +338,17 @@ class AnnotationManager:
logger.info(f"Generated lightweight test case: {test_case['test_case_id']} (OHLCV data will be fetched dynamically)")
return test_case
def get_all_test_cases(self) -> List[Dict]:
"""Load all test cases from disk"""
def get_all_test_cases(self, symbol: Optional[str] = None) -> List[Dict]:
"""
Load all test cases from disk
Args:
symbol: Optional symbol filter (e.g., 'ETH/USDT'). If provided, only returns
test cases for that symbol. Critical for avoiding cross-symbol training.
Returns:
List of test case dictionaries
"""
test_cases = []
if not self.test_cases_dir.exists():
@@ -349,11 +358,22 @@ class AnnotationManager:
try:
with open(test_case_file, 'r') as f:
test_case = json.load(f)
# CRITICAL: Filter by symbol to avoid training on wrong symbol
if symbol:
test_case_symbol = test_case.get('symbol', '')
if test_case_symbol != symbol:
logger.debug(f"Skipping {test_case_file.name}: symbol {test_case_symbol} != {symbol}")
continue
test_cases.append(test_case)
except Exception as e:
logger.error(f"Error loading test case {test_case_file}: {e}")
logger.info(f"Loaded {len(test_cases)} test cases from disk")
if symbol:
logger.info(f"Loaded {len(test_cases)} test cases for symbol {symbol}")
else:
logger.info(f"Loaded {len(test_cases)} test cases (all symbols)")
return test_cases
def _calculate_holding_period(self, annotation: TradeAnnotation) -> float:

View File

@@ -1740,13 +1740,17 @@ class RealTrainingAdapter:
logger.info(f"Using orchestrator's TradingTransformerTrainer")
logger.info(f" Trainer type: {type(trainer).__name__}")
# Import torch at function level (not inside try block)
import torch
import gc
# Load best checkpoint if available to continue training
try:
checkpoint_dir = "models/checkpoints/transformer"
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
if best_checkpoint_path and os.path.exists(best_checkpoint_path):
checkpoint = torch.load(best_checkpoint_path)
checkpoint = torch.load(best_checkpoint_path, map_location='cpu')
trainer.model.load_state_dict(checkpoint['model_state_dict'])
trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])