model metadata
This commit is contained in:
11
.kiro/settings/mcp.json
Normal file
11
.kiro/settings/mcp.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"fetch": {
|
||||
"command": "uvx",
|
||||
"args": ["mcp-server-fetch"],
|
||||
"env": {},
|
||||
"disabled": true,
|
||||
"autoApprove": []
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1723,15 +1723,9 @@ class RealTrainingAdapter:
|
||||
else:
|
||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||
|
||||
# CRITICAL FIX: Delete batch tensors immediately to free GPU memory
|
||||
# This prevents memory accumulation during gradient accumulation
|
||||
for key in list(batch.keys()):
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
del batch[key]
|
||||
del batch
|
||||
|
||||
# CRITICAL: Clear CUDA cache after EVERY batch to prevent memory accumulation
|
||||
# This is essential with large models and limited GPU memory
|
||||
# NOTE: We don't delete the batch dict itself because it's reused across epochs
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -1786,6 +1780,41 @@ class RealTrainingAdapter:
|
||||
|
||||
logger.info(f" Saved checkpoint: {checkpoint_path}")
|
||||
|
||||
# Save metadata to database for easy retrieval
|
||||
try:
|
||||
from utils.database_manager import DatabaseManager
|
||||
|
||||
db_manager = DatabaseManager()
|
||||
checkpoint_id = f"transformer_e{epoch+1}_{timestamp}"
|
||||
|
||||
# Create metadata object
|
||||
from utils.database_manager import CheckpointMetadata
|
||||
metadata = CheckpointMetadata(
|
||||
checkpoint_id=checkpoint_id,
|
||||
model_name="transformer",
|
||||
model_type="transformer",
|
||||
timestamp=datetime.now(),
|
||||
performance_metrics={
|
||||
'loss': float(avg_loss),
|
||||
'accuracy': float(avg_accuracy),
|
||||
'epoch': epoch + 1,
|
||||
'learning_rate': float(trainer.scheduler.get_last_lr()[0])
|
||||
},
|
||||
training_metadata={
|
||||
'num_samples': len(training_data),
|
||||
'num_batches': num_batches,
|
||||
'training_id': training_id
|
||||
},
|
||||
file_path=checkpoint_path,
|
||||
performance_score=float(avg_accuracy), # Use accuracy as score
|
||||
is_active=True
|
||||
)
|
||||
|
||||
if db_manager.save_checkpoint_metadata(metadata):
|
||||
logger.info(f" Saved checkpoint metadata to database: {checkpoint_id}")
|
||||
except Exception as meta_error:
|
||||
logger.warning(f" Could not save checkpoint metadata: {meta_error}")
|
||||
|
||||
# Keep only best 5 checkpoints based on accuracy
|
||||
self._cleanup_old_checkpoints(checkpoint_dir, keep_best=5, metric='accuracy')
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ class AnnotationDashboard:
|
||||
def _get_best_checkpoint_info(self, model_name: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get best checkpoint info for a model without loading it
|
||||
Uses filename parsing instead of torch.load to avoid crashes
|
||||
First tries database, then falls back to filename parsing
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
@@ -202,6 +202,41 @@ class AnnotationDashboard:
|
||||
Dict with checkpoint info or None if no checkpoint found
|
||||
"""
|
||||
try:
|
||||
# Try to get from database first (has full metadata)
|
||||
try:
|
||||
from utils.database_manager import DatabaseManager
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
# Get active checkpoint for this model
|
||||
with db_manager._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT checkpoint_id, performance_metrics, timestamp, file_path
|
||||
FROM checkpoint_metadata
|
||||
WHERE model_name = ? AND is_active = TRUE
|
||||
ORDER BY performance_score DESC
|
||||
LIMIT 1
|
||||
""", (model_name.lower(),))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
import json
|
||||
checkpoint_id, metrics_json, timestamp, file_path = row
|
||||
metrics = json.loads(metrics_json) if metrics_json else {}
|
||||
|
||||
checkpoint_info = {
|
||||
'filename': os.path.basename(file_path) if file_path else checkpoint_id,
|
||||
'epoch': metrics.get('epoch', 0),
|
||||
'loss': metrics.get('loss'),
|
||||
'accuracy': metrics.get('accuracy'),
|
||||
'source': 'database'
|
||||
}
|
||||
|
||||
logger.info(f"Loaded checkpoint info from database for {model_name}: E{checkpoint_info['epoch']}, Loss={checkpoint_info['loss']}, Acc={checkpoint_info['accuracy']}")
|
||||
return checkpoint_info
|
||||
except Exception as db_error:
|
||||
logger.debug(f"Could not load from database: {db_error}")
|
||||
|
||||
# Fallback to filename parsing
|
||||
import glob
|
||||
import re
|
||||
|
||||
|
||||
@@ -120,4 +120,8 @@ Let's use the mean squared difference between the prediction and the empirical o
|
||||
|
||||
----------
|
||||
can we check the "live inference" mode now. it should to a realtime inference/training each second (as much barches as can pass in 1s) and prediction should be next candle - training will be retrospective with 1 candle delay (called each s, m, h and d for the previous candle when we know the result)
|
||||
calculate the angle between each 2 candles features and train to predict those (top- top; open -open, etc.)
|
||||
calculate the angle between each 2 candles features and train to predict those (top- top; open -open, etc.)
|
||||
|
||||
|
||||
use this for sentiment analysis:
|
||||
https://www.coinglass.com/LongShortRatio
|
||||
Reference in New Issue
Block a user