dash and training wip
This commit is contained in:
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@@ -47,6 +47,9 @@
|
||||
"env": {
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"ENABLE_REALTIME_CHARTS": "1"
|
||||
},
|
||||
"linux": {
|
||||
"python": "${workspaceFolder}/venv/bin/python"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -156,6 +159,7 @@
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "run_clean_dashboard.py",
|
||||
"python": "${workspaceFolder}/venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"env": {
|
||||
|
@@ -1,104 +1,3 @@
|
||||
{
|
||||
"decision": [
|
||||
{
|
||||
"checkpoint_id": "decision_20250704_082022",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_082022.pt",
|
||||
"created_at": "2025-07-04T08:20:22.416087",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.79971076963062,
|
||||
"accuracy": null,
|
||||
"loss": 2.8923120591883844e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250704_082021",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_082021.pt",
|
||||
"created_at": "2025-07-04T08:20:21.900854",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.79970038321,
|
||||
"accuracy": null,
|
||||
"loss": 2.996176877014177e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250704_082022",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_082022.pt",
|
||||
"created_at": "2025-07-04T08:20:22.294191",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.79969219038436,
|
||||
"accuracy": null,
|
||||
"loss": 3.0781056310808756e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250704_134829",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_134829.pt",
|
||||
"created_at": "2025-07-04T13:48:29.903250",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.79967532851693,
|
||||
"accuracy": null,
|
||||
"loss": 3.2467253719811344e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250704_214714",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_214714.pt",
|
||||
"created_at": "2025-07-04T21:47:14.427187",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.79966325731509,
|
||||
"accuracy": null,
|
||||
"loss": 3.3674381887394134e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
}
|
||||
]
|
||||
"decision": []
|
||||
}
|
@@ -1969,7 +1969,17 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
self.last_prediction_time[symbol] = int(current_time)
|
||||
|
||||
logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||
# Robust action labeling
|
||||
if action is None:
|
||||
action_label = 'HOLD'
|
||||
elif action == 0:
|
||||
action_label = 'SELL'
|
||||
elif action == 1:
|
||||
action_label = 'BUY'
|
||||
else:
|
||||
action_label = 'UNKNOWN'
|
||||
|
||||
logger.info(f"Forward DQN prediction: {symbol} action={action_label} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating forward DQN prediction: {e}")
|
||||
|
@@ -81,8 +81,8 @@ orchestrator:
|
||||
# Model weights for decision combination
|
||||
cnn_weight: 0.7 # Weight for CNN predictions
|
||||
rl_weight: 0.3 # Weight for RL decisions
|
||||
confidence_threshold: 0.15
|
||||
confidence_threshold_close: 0.08
|
||||
confidence_threshold: 0.45
|
||||
confidence_threshold_close: 0.30
|
||||
decision_frequency: 30
|
||||
|
||||
# Multi-symbol coordination
|
||||
|
@@ -349,7 +349,8 @@ class TradingOrchestrator:
|
||||
try:
|
||||
self.cob_rl_agent.load_model() # This loads the state into the model
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("cob_rl_model")
|
||||
# Use consistent model name with checkpoint manager and get_model_states
|
||||
result = load_best_checkpoint("cob_rl")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.model_states['cob_rl']['initial_loss'] = getattr(metadata, 'initial_loss', None)
|
||||
@@ -1592,13 +1593,16 @@ class TradingOrchestrator:
|
||||
logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled")
|
||||
self.training_enabled = False
|
||||
return
|
||||
|
||||
# Initialize the enhanced training system
|
||||
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
|
||||
# Initialize unified training manager
|
||||
from utils.training_integration import get_unified_training_manager
|
||||
self.training_manager = get_unified_training_manager(
|
||||
orchestrator=self,
|
||||
data_provider=self.data_provider,
|
||||
dashboard=None # Will be set by dashboard when available
|
||||
dashboard=None
|
||||
)
|
||||
self.training_manager.initialize()
|
||||
# Keep backward-compatible attribute
|
||||
self.enhanced_training_system = getattr(self.training_manager, 'training_system', None)
|
||||
|
||||
logger.info("Enhanced real-time training system initialized")
|
||||
logger.info(" - Real-time model training: ENABLED")
|
||||
@@ -1614,11 +1618,11 @@ class TradingOrchestrator:
|
||||
def start_enhanced_training(self):
|
||||
"""Start the enhanced real-time training system"""
|
||||
try:
|
||||
if not self.training_enabled or not self.enhanced_training_system:
|
||||
if not self.training_enabled or not getattr(self, 'training_manager', None):
|
||||
logger.warning("Enhanced training system not available")
|
||||
return False
|
||||
|
||||
self.enhanced_training_system.start_training()
|
||||
self.training_manager.start()
|
||||
logger.info("Enhanced real-time training started")
|
||||
return True
|
||||
|
||||
@@ -1629,8 +1633,8 @@ class TradingOrchestrator:
|
||||
def stop_enhanced_training(self):
|
||||
"""Stop the enhanced real-time training system"""
|
||||
try:
|
||||
if self.enhanced_training_system:
|
||||
self.enhanced_training_system.stop_training()
|
||||
if getattr(self, 'training_manager', None):
|
||||
self.training_manager.stop()
|
||||
logger.info("Enhanced real-time training stopped")
|
||||
return True
|
||||
return False
|
||||
|
@@ -731,7 +731,8 @@ class RealtimeRLCOBTrader:
|
||||
with self.training_lock:
|
||||
# Check if we have enough data for training
|
||||
predictions = list(self.prediction_history[symbol])
|
||||
if len(predictions) < 10:
|
||||
# Train with fewer samples to kickstart learning
|
||||
if len(predictions) < 6:
|
||||
return
|
||||
|
||||
# Calculate rewards for recent predictions
|
||||
@@ -739,11 +740,11 @@ class RealtimeRLCOBTrader:
|
||||
|
||||
# Filter predictions with calculated rewards
|
||||
training_predictions = [p for p in predictions if p.reward is not None]
|
||||
if len(training_predictions) < 5:
|
||||
if len(training_predictions) < 3:
|
||||
return
|
||||
|
||||
# Prepare training batch
|
||||
batch_size = min(32, len(training_predictions))
|
||||
batch_size = min(16, len(training_predictions))
|
||||
batch_predictions = training_predictions[-batch_size:]
|
||||
|
||||
# Train model
|
||||
|
8
enhanced_realtime_training.py
Normal file
8
enhanced_realtime_training.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Shim module to expose EnhancedRealtimeTrainingSystem at project root.
|
||||
This avoids import issues when modules do `from enhanced_realtime_training import EnhancedRealtimeTrainingSystem`.
|
||||
"""
|
||||
|
||||
from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
||||
|
||||
__all__ = ["EnhancedRealtimeTrainingSystem"]
|
@@ -10,7 +10,6 @@ tensorboard>=2.15.0
|
||||
scikit-learn>=1.3.0
|
||||
matplotlib>=3.7.0
|
||||
seaborn>=0.12.0
|
||||
wandb>=0.16.0
|
||||
|
||||
ta>=0.11.0
|
||||
ccxt>=4.0.0
|
||||
|
@@ -3,6 +3,34 @@
|
||||
Clean Trading Dashboard Runner with Enhanced Stability and Error Handling
|
||||
"""
|
||||
|
||||
# Ensure we run with the project's virtual environment Python
|
||||
try:
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import platform
|
||||
|
||||
def _ensure_project_venv():
|
||||
try:
|
||||
project_root = Path(__file__).resolve().parent
|
||||
if platform.system().lower().startswith('win'):
|
||||
venv_python = project_root / 'venv' / 'Scripts' / 'python.exe'
|
||||
else:
|
||||
venv_python = project_root / 'venv' / 'bin' / 'python'
|
||||
|
||||
if venv_python.exists():
|
||||
current = Path(sys.executable).resolve()
|
||||
target = venv_python.resolve()
|
||||
if current != target:
|
||||
os.execv(str(target), [str(target), *sys.argv])
|
||||
except Exception:
|
||||
# If anything goes wrong, continue with current interpreter
|
||||
pass
|
||||
|
||||
_ensure_project_venv()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import traceback
|
||||
|
@@ -14,11 +14,7 @@ from collections import defaultdict
|
||||
import torch
|
||||
import random
|
||||
|
||||
try:
|
||||
import wandb
|
||||
WANDB_AVAILABLE = True
|
||||
except ImportError:
|
||||
WANDB_AVAILABLE = False
|
||||
WANDB_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -58,13 +54,13 @@ class CheckpointManager:
|
||||
base_checkpoint_dir: str = "NN/models/saved",
|
||||
max_checkpoints_per_model: int = 5,
|
||||
metadata_file: str = "checkpoint_metadata.json",
|
||||
enable_wandb: bool = True):
|
||||
enable_wandb: bool = False):
|
||||
self.base_dir = Path(base_checkpoint_dir)
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.max_checkpoints = max_checkpoints_per_model
|
||||
self.metadata_file = self.base_dir / metadata_file
|
||||
self.enable_wandb = enable_wandb and WANDB_AVAILABLE
|
||||
self.enable_wandb = False
|
||||
|
||||
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
|
||||
self._load_metadata()
|
||||
@@ -115,10 +111,7 @@ class CheckpointManager:
|
||||
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
|
||||
)
|
||||
|
||||
if self.enable_wandb and wandb.run is not None:
|
||||
artifact_name = self._upload_to_wandb(checkpoint_path, metadata)
|
||||
metadata.wandb_run_id = wandb.run.id
|
||||
metadata.wandb_artifact_name = artifact_name
|
||||
# W&B disabled
|
||||
|
||||
self.checkpoints[model_name].append(metadata)
|
||||
self._rotate_checkpoints(model_name)
|
||||
@@ -273,18 +266,6 @@ class CheckpointManager:
|
||||
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
|
||||
|
||||
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
|
||||
try:
|
||||
if not self.enable_wandb or wandb.run is None:
|
||||
return None
|
||||
|
||||
artifact_name = f"{metadata.model_name}_checkpoint"
|
||||
artifact = wandb.Artifact(artifact_name, type="model")
|
||||
artifact.add_file(str(file_path))
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
return artifact_name
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading to W&B: {e}")
|
||||
return None
|
||||
|
||||
def _load_metadata(self):
|
||||
@@ -404,6 +385,56 @@ class CheckpointManager:
|
||||
logger.debug(f"Found legacy model file in subdirectory: {candidate_path}")
|
||||
return candidate_path
|
||||
|
||||
# Extended search: scan common project model directories for best checkpoints
|
||||
try:
|
||||
# Attempt to infer project root from base_dir (NN/models/saved -> root)
|
||||
project_root = base_dir.resolve().parent.parent.parent
|
||||
except Exception:
|
||||
project_root = Path(".").resolve()
|
||||
additional_dirs = [
|
||||
project_root / "models",
|
||||
project_root / "models" / "archive",
|
||||
project_root / "models" / "backtest",
|
||||
]
|
||||
|
||||
def _match_legacy_name(candidate: Path, model: str) -> bool:
|
||||
name = candidate.name.lower()
|
||||
model_keys = {
|
||||
'dqn_agent': ['dqn', 'agent', 'policy'],
|
||||
'enhanced_cnn': ['cnn', 'optimized_short_term'],
|
||||
'extrema_trainer': ['supervised', 'extrema'],
|
||||
'cob_rl': ['cob', 'rl', 'policy'],
|
||||
'decision': ['decision', 'transformer']
|
||||
}.get(model, [model])
|
||||
return any(k in name for k in model_keys)
|
||||
|
||||
candidates: List[Path] = []
|
||||
for adir in additional_dirs:
|
||||
if not adir.exists():
|
||||
continue
|
||||
try:
|
||||
for pt in adir.rglob('*.pt'):
|
||||
# Prefer files that indicate "best" and match model hints
|
||||
lname = pt.name.lower()
|
||||
if 'best' in lname and _match_legacy_name(pt, model_name):
|
||||
candidates.append(pt)
|
||||
# Do not add generic fallbacks to avoid mismatched model types
|
||||
except Exception:
|
||||
# Ignore directory traversal issues
|
||||
pass
|
||||
|
||||
if candidates:
|
||||
# Pick the most recently modified candidate
|
||||
try:
|
||||
best = max(candidates, key=lambda p: p.stat().st_mtime)
|
||||
logger.debug(f"Found legacy model file in project models dir: {best}")
|
||||
return best
|
||||
except Exception:
|
||||
# If stat fails, just return the first one deterministically
|
||||
candidates.sort()
|
||||
logger.debug(f"Found legacy model file in project models dir: {candidates[0]}")
|
||||
return candidates[0]
|
||||
|
||||
return None
|
||||
|
||||
def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata:
|
||||
|
@@ -75,15 +75,18 @@ class RewardCalculator:
|
||||
def calculate_basic_reward(self, pnl, confidence):
|
||||
"""Calculate basic training reward based on P&L and confidence"""
|
||||
try:
|
||||
# Reward based on net PnL after fees and confidence alignment
|
||||
base_reward = pnl
|
||||
if pnl < 0 and confidence > 0.7:
|
||||
confidence_adjustment = -confidence * 2
|
||||
elif pnl > 0 and confidence > 0.7:
|
||||
confidence_adjustment = confidence * 1.5
|
||||
# Stronger penalty for confident wrong decisions
|
||||
if pnl < 0 and confidence >= 0.6:
|
||||
confidence_adjustment = -confidence * 3.0
|
||||
elif pnl > 0 and confidence >= 0.6:
|
||||
confidence_adjustment = confidence * 1.0
|
||||
else:
|
||||
confidence_adjustment = 0
|
||||
confidence_adjustment = 0.0
|
||||
final_reward = base_reward + confidence_adjustment
|
||||
normalized_reward = np.tanh(final_reward / 10.0)
|
||||
# Reduce tanh compression so small PnL changes are not flattened
|
||||
normalized_reward = np.tanh(final_reward / 2.5)
|
||||
logger.debug(f"Basic reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}")
|
||||
return float(normalized_reward)
|
||||
except Exception as e:
|
||||
|
@@ -14,7 +14,7 @@ from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_be
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TrainingIntegration:
|
||||
def __init__(self, enable_wandb: bool = True):
|
||||
def __init__(self, enable_wandb: bool = False):
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
self.enable_wandb = enable_wandb
|
||||
|
||||
@@ -22,24 +22,8 @@ class TrainingIntegration:
|
||||
self._init_wandb()
|
||||
|
||||
def _init_wandb(self):
|
||||
try:
|
||||
import wandb
|
||||
|
||||
if wandb.run is None:
|
||||
wandb.init(
|
||||
project="gogo2-trading",
|
||||
name=f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
config={
|
||||
"max_checkpoints_per_model": self.checkpoint_manager.max_checkpoints,
|
||||
"checkpoint_dir": str(self.checkpoint_manager.base_dir)
|
||||
}
|
||||
)
|
||||
logger.info(f"Initialized W&B run: {wandb.run.id}")
|
||||
|
||||
except ImportError:
|
||||
logger.warning("W&B not available - checkpoint management will work without it")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing W&B: {e}")
|
||||
# Disabled by default to avoid CLI prompts
|
||||
pass
|
||||
|
||||
def save_cnn_checkpoint(self,
|
||||
cnn_model,
|
||||
@@ -64,19 +48,7 @@ class TrainingIntegration:
|
||||
'total_parameters': self._count_parameters(cnn_model)
|
||||
}
|
||||
|
||||
if self.enable_wandb:
|
||||
try:
|
||||
import wandb
|
||||
if wandb.run is not None:
|
||||
wandb.log({
|
||||
f"{model_name}/train_accuracy": train_accuracy,
|
||||
f"{model_name}/val_accuracy": val_accuracy,
|
||||
f"{model_name}/train_loss": train_loss,
|
||||
f"{model_name}/val_loss": val_loss,
|
||||
f"{model_name}/epoch": epoch
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
# W&B disabled
|
||||
|
||||
metadata = save_checkpoint(
|
||||
model=cnn_model,
|
||||
@@ -120,22 +92,7 @@ class TrainingIntegration:
|
||||
'total_parameters': self._count_parameters(rl_agent)
|
||||
}
|
||||
|
||||
if self.enable_wandb:
|
||||
try:
|
||||
import wandb
|
||||
if wandb.run is not None:
|
||||
wandb.log({
|
||||
f"{model_name}/avg_reward": avg_reward,
|
||||
f"{model_name}/best_reward": best_reward,
|
||||
f"{model_name}/epsilon": epsilon,
|
||||
f"{model_name}/episode": episode
|
||||
})
|
||||
|
||||
if total_pnl is not None:
|
||||
wandb.log({f"{model_name}/total_pnl": total_pnl})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
# W&B disabled
|
||||
|
||||
metadata = save_checkpoint(
|
||||
model=rl_agent,
|
||||
@@ -202,3 +159,75 @@ def get_training_integration() -> TrainingIntegration:
|
||||
if _training_integration is None:
|
||||
_training_integration = TrainingIntegration()
|
||||
return _training_integration
|
||||
|
||||
# ---------------- Unified Training Manager ----------------
|
||||
|
||||
class UnifiedTrainingManager:
|
||||
"""Single entry point to manage all training in the system.
|
||||
|
||||
Coordinates EnhancedRealtimeTrainingSystem and provides start/stop/status.
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator, data_provider, dashboard=None):
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = data_provider
|
||||
self.dashboard = dashboard
|
||||
self.training_system = None
|
||||
self.started = False
|
||||
|
||||
def initialize(self) -> bool:
|
||||
try:
|
||||
# Import via project root shim to avoid path issues
|
||||
from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
||||
self.training_system = EnhancedRealtimeTrainingSystem(
|
||||
orchestrator=self.orchestrator,
|
||||
data_provider=self.data_provider,
|
||||
dashboard=self.dashboard
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"UnifiedTrainingManager: failed to initialize training system: {e}")
|
||||
self.training_system = None
|
||||
return False
|
||||
|
||||
def start(self) -> bool:
|
||||
try:
|
||||
if self.training_system is None:
|
||||
if not self.initialize():
|
||||
return False
|
||||
self.training_system.start_training()
|
||||
self.started = True
|
||||
logger.info("UnifiedTrainingManager: training started")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"UnifiedTrainingManager: error starting training: {e}")
|
||||
return False
|
||||
|
||||
def stop(self) -> bool:
|
||||
try:
|
||||
if self.training_system and self.started:
|
||||
self.training_system.stop_training()
|
||||
self.started = False
|
||||
logger.info("UnifiedTrainingManager: training stopped")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"UnifiedTrainingManager: error stopping training: {e}")
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
try:
|
||||
if self.training_system and hasattr(self.training_system, 'get_training_stats'):
|
||||
return self.training_system.get_training_stats()
|
||||
return {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
_unified_training_manager = None
|
||||
|
||||
def get_unified_training_manager(orchestrator=None, data_provider=None, dashboard=None) -> UnifiedTrainingManager:
|
||||
global _unified_training_manager
|
||||
if _unified_training_manager is None:
|
||||
if orchestrator is None or data_provider is None:
|
||||
raise ValueError("orchestrator and data_provider are required for first-time initialization")
|
||||
_unified_training_manager = UnifiedTrainingManager(orchestrator, data_provider, dashboard)
|
||||
return _unified_training_manager
|
||||
|
@@ -642,6 +642,35 @@ class CleanTradingDashboard:
|
||||
logger.error(f"Error updating trades table: {e}")
|
||||
return html.P(f"Error: {str(e)}", className="text-danger")
|
||||
|
||||
@self.app.callback(
|
||||
Output('training-status', 'children'),
|
||||
[Input('start-training-btn', 'n_clicks'),
|
||||
Input('stop-training-btn', 'n_clicks')],
|
||||
prevent_initial_call=True
|
||||
)
|
||||
def control_training(start_clicks, stop_clicks):
|
||||
try:
|
||||
from utils.training_integration import get_unified_training_manager
|
||||
manager = get_unified_training_manager(
|
||||
orchestrator=self.orchestrator,
|
||||
data_provider=self.data_provider,
|
||||
dashboard=self
|
||||
)
|
||||
ctx = dash.callback_context
|
||||
if not ctx.triggered:
|
||||
raise PreventUpdate
|
||||
trigger_id = ctx.triggered[0]['prop_id'].split('.')[0]
|
||||
if trigger_id == 'start-training-btn':
|
||||
ok = manager.start()
|
||||
return 'Running' if ok else 'Error'
|
||||
elif trigger_id == 'stop-training-btn':
|
||||
ok = manager.stop()
|
||||
return 'Stopped' if ok else 'Error'
|
||||
return 'Idle'
|
||||
except Exception as e:
|
||||
logger.error(f"Training control error: {e}")
|
||||
return 'Error'
|
||||
|
||||
@self.app.callback(
|
||||
[Output('eth-cob-content', 'children'),
|
||||
Output('btc-cob-content', 'children')],
|
||||
@@ -5215,7 +5244,12 @@ class CleanTradingDashboard:
|
||||
"""Start the Dash server"""
|
||||
try:
|
||||
logger.info(f"TRADING: Starting Clean Dashboard at http://{host}:{port}")
|
||||
# Run the Dash app normally; launch/activation is handled by the runner
|
||||
if hasattr(self, 'app') and self.app is not None:
|
||||
# Dash 3.x: use app.run
|
||||
self.app.run(host=host, port=port, debug=debug)
|
||||
else:
|
||||
logger.error("Dash app is not initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting dashboard server: {e}")
|
||||
raise
|
||||
|
@@ -153,6 +153,29 @@ class DashboardLayoutManager:
|
||||
tooltip={"placement": "bottom", "always_visible": False}
|
||||
)
|
||||
], className="mb-2"),
|
||||
# Training Controls
|
||||
html.Div([
|
||||
html.Label([
|
||||
html.I(className="fas fa-play me-1"),
|
||||
"Training Controls"
|
||||
], className="form-label small mb-1"),
|
||||
html.Div([
|
||||
html.Button([
|
||||
html.I(className="fas fa-play me-1"),
|
||||
"Start Training"
|
||||
], id="start-training-btn", className="btn btn-success btn-sm me-2",
|
||||
style={"fontSize": "10px", "padding": "2px 8px"}),
|
||||
html.Button([
|
||||
html.I(className="fas fa-stop me-1"),
|
||||
"Stop Training"
|
||||
], id="stop-training-btn", className="btn btn-danger btn-sm",
|
||||
style={"fontSize": "10px", "padding": "2px 8px"})
|
||||
], className="d-flex align-items-center mb-1"),
|
||||
html.Div([
|
||||
html.Span("Training:", className="small me-1"),
|
||||
html.Span(id="training-status", children="Idle", className="badge bg-secondary small")
|
||||
])
|
||||
], className="mb-2"),
|
||||
|
||||
# Entry Aggressiveness Control
|
||||
html.Div([
|
||||
|
Reference in New Issue
Block a user