wip
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -47,3 +47,4 @@ chrome_user_data/*
|
|||||||
!.aider.model.metadata.json
|
!.aider.model.metadata.json
|
||||||
|
|
||||||
.env
|
.env
|
||||||
|
.env
|
||||||
|
189
balance_trading_signals.py
Normal file
189
balance_trading_signals.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Balance Trading Signals - Analyze and fix SHORT signal bias
|
||||||
|
|
||||||
|
This script analyzes the trading signals from the orchestrator and adjusts
|
||||||
|
the model weights to balance BUY and SELL signals.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
project_root = Path(__file__).parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from core.config import get_config, setup_logging
|
||||||
|
from core.orchestrator import TradingOrchestrator
|
||||||
|
from core.data_provider import DataProvider
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
setup_logging()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def analyze_trading_signals():
|
||||||
|
"""Analyze trading signals from the orchestrator"""
|
||||||
|
logger.info("Analyzing trading signals...")
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
data_provider = DataProvider()
|
||||||
|
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
|
||||||
|
|
||||||
|
# Get recent decisions
|
||||||
|
symbols = orchestrator.symbols
|
||||||
|
all_decisions = {}
|
||||||
|
|
||||||
|
for symbol in symbols:
|
||||||
|
decisions = orchestrator.get_recent_decisions(symbol)
|
||||||
|
all_decisions[symbol] = decisions
|
||||||
|
|
||||||
|
# Count actions
|
||||||
|
action_counts = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
||||||
|
for decision in decisions:
|
||||||
|
action_counts[decision.action] += 1
|
||||||
|
|
||||||
|
total_decisions = sum(action_counts.values())
|
||||||
|
if total_decisions > 0:
|
||||||
|
buy_percent = action_counts['BUY'] / total_decisions * 100
|
||||||
|
sell_percent = action_counts['SELL'] / total_decisions * 100
|
||||||
|
hold_percent = action_counts['HOLD'] / total_decisions * 100
|
||||||
|
|
||||||
|
logger.info(f"Symbol: {symbol}")
|
||||||
|
logger.info(f" Total decisions: {total_decisions}")
|
||||||
|
logger.info(f" BUY: {action_counts['BUY']} ({buy_percent:.1f}%)")
|
||||||
|
logger.info(f" SELL: {action_counts['SELL']} ({sell_percent:.1f}%)")
|
||||||
|
logger.info(f" HOLD: {action_counts['HOLD']} ({hold_percent:.1f}%)")
|
||||||
|
|
||||||
|
# Check for bias
|
||||||
|
if sell_percent > buy_percent * 2: # If SELL signals are more than twice BUY signals
|
||||||
|
logger.warning(f" SELL bias detected: {sell_percent:.1f}% vs {buy_percent:.1f}%")
|
||||||
|
|
||||||
|
# Adjust model weights to balance signals
|
||||||
|
logger.info(" Adjusting model weights to balance signals...")
|
||||||
|
|
||||||
|
# Get current model weights
|
||||||
|
model_weights = orchestrator.model_weights
|
||||||
|
logger.info(f" Current model weights: {model_weights}")
|
||||||
|
|
||||||
|
# Identify models with SELL bias
|
||||||
|
model_predictions = {}
|
||||||
|
for model_name in model_weights:
|
||||||
|
model_predictions[model_name] = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
||||||
|
|
||||||
|
# Analyze recent decisions to identify biased models
|
||||||
|
for decision in decisions:
|
||||||
|
reasoning = decision.reasoning
|
||||||
|
if 'models_used' in reasoning:
|
||||||
|
for model_name in reasoning['models_used']:
|
||||||
|
if model_name in model_predictions:
|
||||||
|
model_predictions[model_name][decision.action] += 1
|
||||||
|
|
||||||
|
# Calculate bias for each model
|
||||||
|
model_bias = {}
|
||||||
|
for model_name, actions in model_predictions.items():
|
||||||
|
total = sum(actions.values())
|
||||||
|
if total > 0:
|
||||||
|
buy_pct = actions['BUY'] / total * 100
|
||||||
|
sell_pct = actions['SELL'] / total * 100
|
||||||
|
|
||||||
|
# Calculate bias score (-100 to 100, negative = SELL bias, positive = BUY bias)
|
||||||
|
bias_score = buy_pct - sell_pct
|
||||||
|
model_bias[model_name] = bias_score
|
||||||
|
|
||||||
|
logger.info(f" Model {model_name}: Bias score = {bias_score:.1f} (BUY: {buy_pct:.1f}%, SELL: {sell_pct:.1f}%)")
|
||||||
|
|
||||||
|
# Adjust weights based on bias
|
||||||
|
adjusted_weights = {}
|
||||||
|
for model_name, weight in model_weights.items():
|
||||||
|
if model_name in model_bias:
|
||||||
|
bias = model_bias[model_name]
|
||||||
|
|
||||||
|
# If model has strong SELL bias, reduce its weight
|
||||||
|
if bias < -30: # Strong SELL bias
|
||||||
|
adjusted_weights[model_name] = max(0.05, weight * 0.7) # Reduce weight by 30%
|
||||||
|
logger.info(f" Reducing weight of {model_name} from {weight:.2f} to {adjusted_weights[model_name]:.2f} due to SELL bias")
|
||||||
|
# If model has BUY bias, increase its weight to balance
|
||||||
|
elif bias > 10: # BUY bias
|
||||||
|
adjusted_weights[model_name] = min(0.5, weight * 1.3) # Increase weight by 30%
|
||||||
|
logger.info(f" Increasing weight of {model_name} from {weight:.2f} to {adjusted_weights[model_name]:.2f} to balance SELL bias")
|
||||||
|
else:
|
||||||
|
adjusted_weights[model_name] = weight
|
||||||
|
else:
|
||||||
|
adjusted_weights[model_name] = weight
|
||||||
|
|
||||||
|
# Save adjusted weights
|
||||||
|
save_adjusted_weights(adjusted_weights)
|
||||||
|
|
||||||
|
logger.info(f" Adjusted weights: {adjusted_weights}")
|
||||||
|
logger.info(" Weights saved to 'adjusted_model_weights.json'")
|
||||||
|
|
||||||
|
# Recommend next steps
|
||||||
|
logger.info("\nRecommended actions:")
|
||||||
|
logger.info("1. Update the model weights in the orchestrator")
|
||||||
|
logger.info("2. Monitor trading signals for balance")
|
||||||
|
logger.info("3. Consider retraining models with balanced data")
|
||||||
|
|
||||||
|
def save_adjusted_weights(weights):
|
||||||
|
"""Save adjusted weights to a file"""
|
||||||
|
output = {
|
||||||
|
'timestamp': datetime.now().isoformat(),
|
||||||
|
'weights': weights,
|
||||||
|
'notes': 'Adjusted to balance BUY/SELL signals'
|
||||||
|
}
|
||||||
|
|
||||||
|
with open('adjusted_model_weights.json', 'w') as f:
|
||||||
|
json.dump(output, f, indent=2)
|
||||||
|
|
||||||
|
def apply_balanced_weights():
|
||||||
|
"""Apply balanced weights to the orchestrator"""
|
||||||
|
try:
|
||||||
|
# Check if weights file exists
|
||||||
|
if not os.path.exists('adjusted_model_weights.json'):
|
||||||
|
logger.error("Adjusted weights file not found. Run analyze_trading_signals() first.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Load adjusted weights
|
||||||
|
with open('adjusted_model_weights.json', 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
weights = data.get('weights', {})
|
||||||
|
if not weights:
|
||||||
|
logger.error("No weights found in the file.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(f"Loaded adjusted weights: {weights}")
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
data_provider = DataProvider()
|
||||||
|
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
|
||||||
|
|
||||||
|
# Apply weights
|
||||||
|
for model_name, weight in weights.items():
|
||||||
|
if model_name in orchestrator.model_weights:
|
||||||
|
orchestrator.model_weights[model_name] = weight
|
||||||
|
|
||||||
|
# Save updated weights
|
||||||
|
orchestrator._save_orchestrator_state()
|
||||||
|
|
||||||
|
logger.info("Applied balanced weights to orchestrator.")
|
||||||
|
logger.info("Restart the trading system for changes to take effect.")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error applying balanced weights: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info("TRADING SIGNAL BALANCE ANALYZER")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
|
||||||
|
if len(sys.argv) > 1 and sys.argv[1] == 'apply':
|
||||||
|
apply_balanced_weights()
|
||||||
|
else:
|
||||||
|
analyze_trading_signals()
|
@ -4,13 +4,10 @@ import logging
|
|||||||
import importlib
|
import importlib
|
||||||
import asyncio
|
import asyncio
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from safe_logging import setup_safe_logging
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
setup_safe_logging()
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
||||||
handlers=[logging.StreamHandler()]
|
|
||||||
)
|
|
||||||
logger = logging.getLogger("check_live_trading")
|
logger = logging.getLogger("check_live_trading")
|
||||||
|
|
||||||
def check_dependencies():
|
def check_dependencies():
|
||||||
|
@ -8,6 +8,7 @@ It loads settings from config.yaml and provides easy access to all components.
|
|||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
import logging
|
import logging
|
||||||
|
from safe_logging import setup_safe_logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Any, Optional
|
from typing import Dict, List, Any, Optional
|
||||||
|
|
||||||
@ -247,23 +248,11 @@ def load_config(config_path: str = "config.yaml") -> Dict[str, Any]:
|
|||||||
|
|
||||||
def setup_logging(config: Optional[Config] = None):
|
def setup_logging(config: Optional[Config] = None):
|
||||||
"""Setup logging based on configuration"""
|
"""Setup logging based on configuration"""
|
||||||
|
setup_safe_logging()
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
|
||||||
log_config = config.logging
|
log_config = config.logging
|
||||||
|
|
||||||
# Create logs directory
|
logger.info("Logging configured successfully with SafeFormatter")
|
||||||
log_file = Path(log_config.get('file', 'logs/trading.log'))
|
|
||||||
log_file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Setup logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=getattr(logging, log_config.get('level', 'INFO')),
|
|
||||||
format=log_config.get('format', '%(asctime)s - %(name)s - %(levelname)s - %(message)s'),
|
|
||||||
handlers=[
|
|
||||||
logging.FileHandler(log_file),
|
|
||||||
logging.StreamHandler()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Logging configured successfully")
|
|
||||||
|
3
main.py
3
main.py
@ -24,6 +24,7 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
import time
|
import time
|
||||||
|
from safe_logging import setup_safe_logging
|
||||||
|
|
||||||
# Add project root to path
|
# Add project root to path
|
||||||
project_root = Path(__file__).parent
|
project_root = Path(__file__).parent
|
||||||
@ -395,7 +396,7 @@ async def main():
|
|||||||
# Setup logging and ensure directories exist
|
# Setup logging and ensure directories exist
|
||||||
Path("logs").mkdir(exist_ok=True)
|
Path("logs").mkdir(exist_ok=True)
|
||||||
Path("NN/models/saved").mkdir(parents=True, exist_ok=True)
|
Path("NN/models/saved").mkdir(parents=True, exist_ok=True)
|
||||||
setup_logging()
|
setup_safe_logging()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("=" * 70)
|
logger.info("=" * 70)
|
||||||
|
@ -1,306 +1,193 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
Enhanced Position Synchronization System
|
Position Sync Enhancement - Fix P&L and Win Rate Calculation
|
||||||
Addresses the gap between dashboard position display and actual exchange account state
|
|
||||||
|
This script enhances the position synchronization and P&L calculation
|
||||||
|
to properly account for leverage in the trading system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import logging
|
import logging
|
||||||
import time
|
from pathlib import Path
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional, Any
|
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
project_root = Path(__file__).parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from core.config import get_config, setup_logging
|
||||||
|
from core.trading_executor import TradingExecutor, TradeRecord
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
setup_logging()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class EnhancedPositionSync:
|
def analyze_trade_records():
|
||||||
"""Enhanced position synchronization to ensure dashboard matches actual exchange state"""
|
"""Analyze trade records for P&L calculation issues"""
|
||||||
|
logger.info("Analyzing trade records for P&L calculation issues...")
|
||||||
|
|
||||||
def __init__(self, trading_executor, dashboard):
|
# Initialize trading executor
|
||||||
self.trading_executor = trading_executor
|
trading_executor = TradingExecutor()
|
||||||
self.dashboard = dashboard
|
|
||||||
self.last_sync_time = 0
|
|
||||||
self.sync_interval = 10 # Sync every 10 seconds
|
|
||||||
self.position_history = [] # Track position changes
|
|
||||||
|
|
||||||
def sync_all_positions(self) -> Dict[str, Any]:
|
# Get trade records
|
||||||
"""Comprehensive position sync for all symbols"""
|
trade_records = trading_executor.trade_records
|
||||||
try:
|
|
||||||
sync_results = {}
|
|
||||||
|
|
||||||
# 1. Get actual exchange positions
|
if not trade_records:
|
||||||
exchange_positions = self._get_actual_exchange_positions()
|
logger.warning("No trade records found.")
|
||||||
|
|
||||||
# 2. Get dashboard positions
|
|
||||||
dashboard_positions = self._get_dashboard_positions()
|
|
||||||
|
|
||||||
# 3. Compare and sync
|
|
||||||
for symbol in ['ETH/USDT', 'BTC/USDT']:
|
|
||||||
sync_result = self._sync_symbol_position(
|
|
||||||
symbol,
|
|
||||||
exchange_positions.get(symbol),
|
|
||||||
dashboard_positions.get(symbol)
|
|
||||||
)
|
|
||||||
sync_results[symbol] = sync_result
|
|
||||||
|
|
||||||
# 4. Update closed trades list from exchange
|
|
||||||
self._sync_closed_trades()
|
|
||||||
|
|
||||||
return {
|
|
||||||
'sync_time': datetime.now().isoformat(),
|
|
||||||
'results': sync_results,
|
|
||||||
'total_synced': len(sync_results),
|
|
||||||
'issues_found': sum(1 for r in sync_results.values() if not r['in_sync'])
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in comprehensive position sync: {e}")
|
|
||||||
return {'error': str(e)}
|
|
||||||
|
|
||||||
def _get_actual_exchange_positions(self) -> Dict[str, Dict]:
|
|
||||||
"""Get actual positions from exchange account"""
|
|
||||||
try:
|
|
||||||
positions = {}
|
|
||||||
|
|
||||||
if not self.trading_executor:
|
|
||||||
return positions
|
|
||||||
|
|
||||||
# Get account balances
|
|
||||||
if hasattr(self.trading_executor, 'get_account_balance'):
|
|
||||||
balances = self.trading_executor.get_account_balance()
|
|
||||||
|
|
||||||
for symbol in ['ETH/USDT', 'BTC/USDT']:
|
|
||||||
# Parse symbol to get base asset
|
|
||||||
base_asset = symbol.split('/')[0]
|
|
||||||
|
|
||||||
# Get balance for base asset
|
|
||||||
base_balance = balances.get(base_asset, {}).get('total', 0.0)
|
|
||||||
|
|
||||||
if base_balance > 0.001: # Minimum threshold
|
|
||||||
positions[symbol] = {
|
|
||||||
'side': 'LONG',
|
|
||||||
'size': base_balance,
|
|
||||||
'value': base_balance * self._get_current_price(symbol),
|
|
||||||
'source': 'exchange_balance'
|
|
||||||
}
|
|
||||||
|
|
||||||
# Also check trading executor's position tracking
|
|
||||||
if hasattr(self.trading_executor, 'get_positions'):
|
|
||||||
executor_positions = self.trading_executor.get_positions()
|
|
||||||
for symbol, position in executor_positions.items():
|
|
||||||
if position and hasattr(position, 'quantity') and position.quantity > 0:
|
|
||||||
positions[symbol] = {
|
|
||||||
'side': position.side,
|
|
||||||
'size': position.quantity,
|
|
||||||
'entry_price': position.entry_price,
|
|
||||||
'value': position.quantity * self._get_current_price(symbol),
|
|
||||||
'source': 'executor_tracking'
|
|
||||||
}
|
|
||||||
|
|
||||||
return positions
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting actual exchange positions: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def _get_dashboard_positions(self) -> Dict[str, Dict]:
|
|
||||||
"""Get positions as shown on dashboard"""
|
|
||||||
try:
|
|
||||||
positions = {}
|
|
||||||
|
|
||||||
# Get from dashboard's current_position
|
|
||||||
if self.dashboard.current_position:
|
|
||||||
symbol = self.dashboard.current_position.get('symbol', 'ETH/USDT')
|
|
||||||
positions[symbol] = {
|
|
||||||
'side': self.dashboard.current_position.get('side'),
|
|
||||||
'size': self.dashboard.current_position.get('size'),
|
|
||||||
'entry_price': self.dashboard.current_position.get('price'),
|
|
||||||
'value': self.dashboard.current_position.get('size', 0) * self._get_current_price(symbol),
|
|
||||||
'source': 'dashboard_display'
|
|
||||||
}
|
|
||||||
|
|
||||||
return positions
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting dashboard positions: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def _sync_symbol_position(self, symbol: str, exchange_pos: Optional[Dict], dashboard_pos: Optional[Dict]) -> Dict[str, Any]:
|
|
||||||
"""Sync position for a specific symbol"""
|
|
||||||
try:
|
|
||||||
sync_result = {
|
|
||||||
'symbol': symbol,
|
|
||||||
'exchange_position': exchange_pos,
|
|
||||||
'dashboard_position': dashboard_pos,
|
|
||||||
'in_sync': True,
|
|
||||||
'action_taken': 'none'
|
|
||||||
}
|
|
||||||
|
|
||||||
# Case 1: Exchange has position, dashboard doesn't
|
|
||||||
if exchange_pos and not dashboard_pos:
|
|
||||||
logger.warning(f"SYNC ISSUE: Exchange has {symbol} position but dashboard shows none")
|
|
||||||
|
|
||||||
# Update dashboard to reflect exchange position
|
|
||||||
self.dashboard.current_position = {
|
|
||||||
'symbol': symbol,
|
|
||||||
'side': exchange_pos['side'],
|
|
||||||
'size': exchange_pos['size'],
|
|
||||||
'price': exchange_pos.get('entry_price', self._get_current_price(symbol)),
|
|
||||||
'entry_time': datetime.now(),
|
|
||||||
'leverage': self.dashboard.current_leverage,
|
|
||||||
'source': 'sync_correction'
|
|
||||||
}
|
|
||||||
|
|
||||||
sync_result['in_sync'] = False
|
|
||||||
sync_result['action_taken'] = 'updated_dashboard_from_exchange'
|
|
||||||
|
|
||||||
# Case 2: Dashboard has position, exchange doesn't
|
|
||||||
elif dashboard_pos and not exchange_pos:
|
|
||||||
logger.warning(f"SYNC ISSUE: Dashboard shows {symbol} position but exchange has none")
|
|
||||||
|
|
||||||
# Clear dashboard position
|
|
||||||
self.dashboard.current_position = None
|
|
||||||
|
|
||||||
sync_result['in_sync'] = False
|
|
||||||
sync_result['action_taken'] = 'cleared_dashboard_position'
|
|
||||||
|
|
||||||
# Case 3: Both have positions but they differ
|
|
||||||
elif exchange_pos and dashboard_pos:
|
|
||||||
if (exchange_pos['side'] != dashboard_pos['side'] or
|
|
||||||
abs(exchange_pos['size'] - dashboard_pos['size']) > 0.001):
|
|
||||||
|
|
||||||
logger.warning(f"SYNC ISSUE: {symbol} position mismatch - Exchange: {exchange_pos['side']} {exchange_pos['size']:.3f}, Dashboard: {dashboard_pos['side']} {dashboard_pos['size']:.3f}")
|
|
||||||
|
|
||||||
# Update dashboard to match exchange
|
|
||||||
self.dashboard.current_position.update({
|
|
||||||
'side': exchange_pos['side'],
|
|
||||||
'size': exchange_pos['size'],
|
|
||||||
'price': exchange_pos.get('entry_price', dashboard_pos['entry_price'])
|
|
||||||
})
|
|
||||||
|
|
||||||
sync_result['in_sync'] = False
|
|
||||||
sync_result['action_taken'] = 'updated_dashboard_to_match_exchange'
|
|
||||||
|
|
||||||
return sync_result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error syncing position for {symbol}: {e}")
|
|
||||||
return {'symbol': symbol, 'error': str(e), 'in_sync': False}
|
|
||||||
|
|
||||||
def _sync_closed_trades(self):
|
|
||||||
"""Sync closed trades list with actual exchange trade history"""
|
|
||||||
try:
|
|
||||||
if not self.trading_executor:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get trade history from executor
|
logger.info(f"Found {len(trade_records)} trade records.")
|
||||||
if hasattr(self.trading_executor, 'get_trade_history'):
|
|
||||||
executor_trades = self.trading_executor.get_trade_history()
|
|
||||||
|
|
||||||
# Clear and rebuild closed_trades list
|
# Analyze P&L calculation
|
||||||
self.dashboard.closed_trades = []
|
total_pnl = 0.0
|
||||||
|
total_gross_pnl = 0.0
|
||||||
|
total_fees = 0.0
|
||||||
|
winning_trades = 0
|
||||||
|
losing_trades = 0
|
||||||
|
breakeven_trades = 0
|
||||||
|
|
||||||
for trade in executor_trades:
|
for trade in trade_records:
|
||||||
# Convert to dashboard format
|
# Calculate correct P&L with leverage
|
||||||
trade_record = {
|
entry_value = trade.entry_price * trade.quantity
|
||||||
'symbol': getattr(trade, 'symbol', 'ETH/USDT'),
|
exit_value = trade.exit_price * trade.quantity
|
||||||
'side': getattr(trade, 'side', 'UNKNOWN'),
|
|
||||||
'quantity': getattr(trade, 'quantity', 0),
|
|
||||||
'entry_price': getattr(trade, 'entry_price', 0),
|
|
||||||
'exit_price': getattr(trade, 'exit_price', 0),
|
|
||||||
'entry_time': getattr(trade, 'entry_time', datetime.now()),
|
|
||||||
'exit_time': getattr(trade, 'exit_time', datetime.now()),
|
|
||||||
'pnl': getattr(trade, 'pnl', 0),
|
|
||||||
'fees': getattr(trade, 'fees', 0),
|
|
||||||
'confidence': getattr(trade, 'confidence', 1.0),
|
|
||||||
'trade_type': 'synced_from_executor'
|
|
||||||
}
|
|
||||||
|
|
||||||
# Only add completed trades (with exit_time)
|
if trade.side == 'LONG':
|
||||||
if trade_record['exit_time']:
|
gross_pnl = (exit_value - entry_value) * trade.leverage
|
||||||
self.dashboard.closed_trades.append(trade_record)
|
else: # SHORT
|
||||||
|
gross_pnl = (entry_value - exit_value) * trade.leverage
|
||||||
|
|
||||||
# Update session PnL
|
# Calculate fees
|
||||||
self.dashboard.session_pnl = sum(trade['pnl'] for trade in self.dashboard.closed_trades)
|
fees = (entry_value + exit_value) * 0.001 # 0.1% fee on both entry and exit
|
||||||
|
|
||||||
logger.info(f"Synced {len(self.dashboard.closed_trades)} closed trades from executor")
|
# Calculate net P&L
|
||||||
|
net_pnl = gross_pnl - fees
|
||||||
|
|
||||||
except Exception as e:
|
# Compare with stored values
|
||||||
logger.error(f"Error syncing closed trades: {e}")
|
pnl_diff = abs(net_pnl - trade.pnl)
|
||||||
|
if pnl_diff > 0.01: # More than 1 cent difference
|
||||||
|
logger.warning(f"P&L calculation issue detected for trade {trade.entry_time}:")
|
||||||
|
logger.warning(f" Stored P&L: ${trade.pnl:.2f}")
|
||||||
|
logger.warning(f" Calculated P&L: ${net_pnl:.2f}")
|
||||||
|
logger.warning(f" Difference: ${pnl_diff:.2f}")
|
||||||
|
logger.warning(f" Leverage used: {trade.leverage}x")
|
||||||
|
|
||||||
def _get_current_price(self, symbol: str) -> float:
|
# Update statistics
|
||||||
"""Get current price for a symbol"""
|
total_pnl += net_pnl
|
||||||
try:
|
total_gross_pnl += gross_pnl
|
||||||
return self.dashboard._get_current_price(symbol) or 3500.0
|
total_fees += fees
|
||||||
except:
|
|
||||||
return 3500.0 # Fallback price
|
|
||||||
|
|
||||||
def should_sync(self) -> bool:
|
if net_pnl > 0.01: # More than 1 cent profit
|
||||||
"""Check if sync is needed based on time interval"""
|
winning_trades += 1
|
||||||
current_time = time.time()
|
elif net_pnl < -0.01: # More than 1 cent loss
|
||||||
if current_time - self.last_sync_time >= self.sync_interval:
|
losing_trades += 1
|
||||||
self.last_sync_time = current_time
|
else:
|
||||||
return True
|
breakeven_trades += 1
|
||||||
return False
|
|
||||||
|
|
||||||
def create_sync_status_display(self) -> Dict[str, Any]:
|
# Calculate win rate
|
||||||
"""Create detailed sync status for dashboard display"""
|
total_trades = winning_trades + losing_trades + breakeven_trades
|
||||||
try:
|
win_rate = (winning_trades / total_trades * 100) if total_trades > 0 else 0.0
|
||||||
# Get current sync status
|
|
||||||
sync_results = self.sync_all_positions()
|
|
||||||
|
|
||||||
# Create display-friendly format
|
logger.info("\nTrade Analysis Results:")
|
||||||
status_display = {
|
logger.info(f" Total trades: {total_trades}")
|
||||||
'last_sync': datetime.now().strftime('%H:%M:%S'),
|
logger.info(f" Winning trades: {winning_trades}")
|
||||||
'sync_healthy': sync_results.get('issues_found', 0) == 0,
|
logger.info(f" Losing trades: {losing_trades}")
|
||||||
'positions': {},
|
logger.info(f" Breakeven trades: {breakeven_trades}")
|
||||||
'closed_trades_count': len(self.dashboard.closed_trades),
|
logger.info(f" Win rate: {win_rate:.1f}%")
|
||||||
'session_pnl': self.dashboard.session_pnl
|
logger.info(f" Total P&L: ${total_pnl:.2f}")
|
||||||
}
|
logger.info(f" Total gross P&L: ${total_gross_pnl:.2f}")
|
||||||
|
logger.info(f" Total fees: ${total_fees:.2f}")
|
||||||
|
|
||||||
# Add position details
|
# Check for leverage issues
|
||||||
for symbol, result in sync_results.get('results', {}).items():
|
leverage_issues = False
|
||||||
status_display['positions'][symbol] = {
|
for trade in trade_records:
|
||||||
'in_sync': result['in_sync'],
|
if trade.leverage <= 1.0:
|
||||||
'action_taken': result.get('action_taken', 'none'),
|
leverage_issues = True
|
||||||
'has_exchange_position': result['exchange_position'] is not None,
|
logger.warning(f"Low leverage detected: {trade.leverage}x for trade at {trade.entry_time}")
|
||||||
'has_dashboard_position': result['dashboard_position'] is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
return status_display
|
if leverage_issues:
|
||||||
|
logger.warning("\nLeverage issues detected. Consider fixing the leverage calculation.")
|
||||||
|
logger.info("Recommended fix: Ensure leverage is properly set in the trading executor.")
|
||||||
|
else:
|
||||||
|
logger.info("\nNo leverage issues detected.")
|
||||||
|
|
||||||
except Exception as e:
|
def fix_leverage_calculation():
|
||||||
logger.error(f"Error creating sync status display: {e}")
|
"""Fix leverage calculation in the trading executor"""
|
||||||
return {'error': str(e)}
|
logger.info("Fixing leverage calculation in the trading executor...")
|
||||||
|
|
||||||
|
# Initialize trading executor
|
||||||
|
trading_executor = TradingExecutor()
|
||||||
|
|
||||||
# Integration with existing dashboard
|
# Get current leverage
|
||||||
def integrate_enhanced_sync(dashboard):
|
current_leverage = trading_executor.current_leverage
|
||||||
"""Integrate enhanced sync with existing dashboard"""
|
logger.info(f"Current leverage setting: {current_leverage}x")
|
||||||
|
|
||||||
# Create enhanced sync instance
|
# Check if leverage is properly set
|
||||||
enhanced_sync = EnhancedPositionSync(dashboard.trading_executor, dashboard)
|
if current_leverage <= 1:
|
||||||
|
logger.warning("Leverage is set too low. Updating to 20x...")
|
||||||
|
trading_executor.current_leverage = 20
|
||||||
|
logger.info(f"Updated leverage to {trading_executor.current_leverage}x")
|
||||||
|
else:
|
||||||
|
logger.info("Leverage is already set correctly.")
|
||||||
|
|
||||||
# Add to dashboard
|
# Update trade records with correct leverage
|
||||||
dashboard.enhanced_sync = enhanced_sync
|
updated_count = 0
|
||||||
|
for i, trade in enumerate(trading_executor.trade_records):
|
||||||
|
if trade.leverage <= 1.0:
|
||||||
|
# Create updated trade record
|
||||||
|
updated_trade = TradeRecord(
|
||||||
|
symbol=trade.symbol,
|
||||||
|
side=trade.side,
|
||||||
|
quantity=trade.quantity,
|
||||||
|
entry_price=trade.entry_price,
|
||||||
|
exit_price=trade.exit_price,
|
||||||
|
entry_time=trade.entry_time,
|
||||||
|
exit_time=trade.exit_time,
|
||||||
|
pnl=trade.pnl,
|
||||||
|
fees=trade.fees,
|
||||||
|
confidence=trade.confidence,
|
||||||
|
hold_time_seconds=trade.hold_time_seconds,
|
||||||
|
leverage=trading_executor.current_leverage, # Use current leverage setting
|
||||||
|
position_size_usd=trade.position_size_usd,
|
||||||
|
gross_pnl=trade.gross_pnl,
|
||||||
|
net_pnl=trade.net_pnl
|
||||||
|
)
|
||||||
|
|
||||||
# Modify existing metrics update to include sync
|
# Recalculate P&L with correct leverage
|
||||||
original_update_metrics = dashboard.update_metrics
|
entry_value = updated_trade.entry_price * updated_trade.quantity
|
||||||
|
exit_value = updated_trade.exit_price * updated_trade.quantity
|
||||||
|
|
||||||
def enhanced_update_metrics(n):
|
if updated_trade.side == 'LONG':
|
||||||
"""Enhanced metrics update with position sync"""
|
updated_trade.gross_pnl = (exit_value - entry_value) * updated_trade.leverage
|
||||||
try:
|
else: # SHORT
|
||||||
# Perform periodic sync
|
updated_trade.gross_pnl = (entry_value - exit_value) * updated_trade.leverage
|
||||||
if enhanced_sync.should_sync():
|
|
||||||
sync_results = enhanced_sync.sync_all_positions()
|
|
||||||
if sync_results.get('issues_found', 0) > 0:
|
|
||||||
logger.info(f"Position sync performed: {sync_results['issues_found']} issues corrected")
|
|
||||||
|
|
||||||
# Call original metrics update
|
# Recalculate fees
|
||||||
return original_update_metrics(n)
|
updated_trade.fees = (entry_value + exit_value) * 0.001 # 0.1% fee on both entry and exit
|
||||||
|
|
||||||
except Exception as e:
|
# Recalculate net P&L
|
||||||
logger.error(f"Error in enhanced metrics update: {e}")
|
updated_trade.net_pnl = updated_trade.gross_pnl - updated_trade.fees
|
||||||
return original_update_metrics(n)
|
updated_trade.pnl = updated_trade.net_pnl
|
||||||
|
|
||||||
# Replace the update method
|
# Update trade record
|
||||||
dashboard.update_metrics = enhanced_update_metrics
|
trading_executor.trade_records[i] = updated_trade
|
||||||
|
updated_count += 1
|
||||||
|
|
||||||
return enhanced_sync
|
logger.info(f"Updated {updated_count} trade records with correct leverage.")
|
||||||
|
|
||||||
|
# Save updated trade records
|
||||||
|
# Note: This is a placeholder. In a real implementation, you would need to
|
||||||
|
# persist the updated trade records to storage.
|
||||||
|
logger.info("Changes will take effect on next dashboard restart.")
|
||||||
|
|
||||||
|
return updated_count > 0
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info("POSITION SYNC ENHANCEMENT")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
|
||||||
|
if len(sys.argv) > 1 and sys.argv[1] == 'fix':
|
||||||
|
fix_leverage_calculation()
|
||||||
|
else:
|
||||||
|
analyze_trade_records()
|
@ -16,6 +16,7 @@ matplotlib.use('Agg') # Use non-interactive Agg backend
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from safe_logging import setup_safe_logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -32,7 +33,7 @@ from utils.checkpoint_manager import get_checkpoint_manager
|
|||||||
from utils.training_integration import get_training_integration
|
from utils.training_integration import get_training_integration
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
setup_logging()
|
setup_safe_logging()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def start_training_pipeline(orchestrator, trading_executor):
|
async def start_training_pipeline(orchestrator, trading_executor):
|
||||||
|
@ -23,6 +23,7 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from safe_logging import setup_safe_logging
|
||||||
|
|
||||||
# Add project root to path
|
# Add project root to path
|
||||||
project_root = Path(__file__).parent
|
project_root = Path(__file__).parent
|
||||||
@ -149,7 +150,7 @@ def run_all_tests():
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Main test runner"""
|
"""Main test runner"""
|
||||||
setup_logging()
|
setup_safe_logging()
|
||||||
|
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
if len(sys.argv) > 1:
|
if len(sys.argv) > 1:
|
||||||
|
112
safe_logging.py
Normal file
112
safe_logging.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import platform
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
class SafeFormatter(logging.Formatter):
|
||||||
|
"""Custom formatter that safely handles non-ASCII characters"""
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
# Handle message string safely
|
||||||
|
if hasattr(record, 'msg') and record.msg is not None:
|
||||||
|
if isinstance(record.msg, str):
|
||||||
|
# Strip non-ASCII characters to prevent encoding errors
|
||||||
|
record.msg = record.msg.encode("ascii", "ignore").decode()
|
||||||
|
elif isinstance(record.msg, bytes):
|
||||||
|
# Handle bytes objects
|
||||||
|
record.msg = record.msg.decode("utf-8", "ignore")
|
||||||
|
|
||||||
|
# Handle args tuple if present
|
||||||
|
if hasattr(record, 'args') and record.args:
|
||||||
|
safe_args = []
|
||||||
|
for arg in record.args:
|
||||||
|
if isinstance(arg, str):
|
||||||
|
safe_args.append(arg.encode("ascii", "ignore").decode())
|
||||||
|
elif isinstance(arg, bytes):
|
||||||
|
safe_args.append(arg.decode("utf-8", "ignore"))
|
||||||
|
else:
|
||||||
|
safe_args.append(str(arg))
|
||||||
|
record.args = tuple(safe_args)
|
||||||
|
|
||||||
|
# Handle exc_text if present
|
||||||
|
if hasattr(record, 'exc_text') and record.exc_text:
|
||||||
|
if isinstance(record.exc_text, str):
|
||||||
|
record.exc_text = record.exc_text.encode("ascii", "ignore").decode()
|
||||||
|
|
||||||
|
return super().format(record)
|
||||||
|
|
||||||
|
class SafeStreamHandler(logging.StreamHandler):
|
||||||
|
"""Stream handler that forces UTF-8 encoding where supported"""
|
||||||
|
|
||||||
|
def __init__(self, stream=None):
|
||||||
|
super().__init__(stream)
|
||||||
|
# Try to set UTF-8 encoding on stdout/stderr if supported
|
||||||
|
if hasattr(self.stream, 'reconfigure'):
|
||||||
|
try:
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
# On Windows, use errors='ignore'
|
||||||
|
self.stream.reconfigure(encoding='utf-8', errors='ignore')
|
||||||
|
else:
|
||||||
|
# On Unix-like systems, use backslashreplace
|
||||||
|
self.stream.reconfigure(encoding='utf-8', errors='backslashreplace')
|
||||||
|
except (AttributeError, OSError):
|
||||||
|
# If reconfigure is not available or fails, continue silently
|
||||||
|
pass
|
||||||
|
|
||||||
|
def setup_safe_logging(log_level=logging.INFO, log_file='logs/safe_logging.log'):
|
||||||
|
"""Setup logging with SafeFormatter and UTF-8 encoding
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_level: Logging level (default: INFO)
|
||||||
|
log_file: Path to log file (default: logs/safe_logging.log)
|
||||||
|
"""
|
||||||
|
# Ensure logs directory exists
|
||||||
|
log_path = Path(log_file)
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Clear existing handlers to avoid duplicates
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
for handler in root_logger.handlers[:]:
|
||||||
|
root_logger.removeHandler(handler)
|
||||||
|
|
||||||
|
# Create handlers with proper encoding
|
||||||
|
handlers = []
|
||||||
|
|
||||||
|
# Console handler with safe UTF-8 handling
|
||||||
|
console_handler = SafeStreamHandler(sys.stdout)
|
||||||
|
console_handler.setFormatter(SafeFormatter(
|
||||||
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
))
|
||||||
|
handlers.append(console_handler)
|
||||||
|
|
||||||
|
# File handler with UTF-8 encoding and error handling
|
||||||
|
try:
|
||||||
|
encoding_kwargs = {
|
||||||
|
"encoding": "utf-8",
|
||||||
|
"errors": "ignore" if platform.system() == "Windows" else "backslashreplace"
|
||||||
|
}
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(log_file, **encoding_kwargs)
|
||||||
|
file_handler.setFormatter(SafeFormatter(
|
||||||
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
))
|
||||||
|
handlers.append(file_handler)
|
||||||
|
except (OSError, IOError) as e:
|
||||||
|
# If file handler fails, just use console handler
|
||||||
|
print(f"Warning: Could not create log file {log_file}: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Configure root logger
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
handlers=handlers,
|
||||||
|
force=True # Force reconfiguration
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply SafeFormatter to all existing loggers
|
||||||
|
safe_formatter = SafeFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
for logger_name in logging.Logger.manager.loggerDict:
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
for handler in logger.handlers:
|
||||||
|
handler.setFormatter(safe_formatter)
|
||||||
|
|
69
test_safe_logging.py
Normal file
69
test_safe_logging.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script to verify safe_logging functionality
|
||||||
|
|
||||||
|
This script tests that the safe logging module properly handles:
|
||||||
|
1. Non-ASCII characters (emojis, smart quotes)
|
||||||
|
2. UTF-8 encoding
|
||||||
|
3. Error handling on Windows
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from safe_logging import setup_safe_logging
|
||||||
|
|
||||||
|
def test_safe_logging():
|
||||||
|
"""Test the safe logging module with various character types"""
|
||||||
|
# Setup safe logging
|
||||||
|
setup_safe_logging()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
print("Testing Safe Logging Module...")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Test regular ASCII messages
|
||||||
|
logger.info("Regular ASCII message - this should work fine")
|
||||||
|
|
||||||
|
# Test messages with emojis
|
||||||
|
logger.info("Testing emojis: 🚀 💰 📈 📊 🔥")
|
||||||
|
|
||||||
|
# Test messages with smart quotes and special characters
|
||||||
|
logger.info("Testing smart quotes: Hello World Test")
|
||||||
|
|
||||||
|
# Test messages with various Unicode characters
|
||||||
|
logger.info("Testing Unicode: café résumé naïve Ω α β γ δ")
|
||||||
|
|
||||||
|
# Test messages with mixed content
|
||||||
|
logger.info("Mixed content: Regular text with emojis 🎉 and quotes like this")
|
||||||
|
|
||||||
|
# Test error messages with special characters
|
||||||
|
logger.error("Error with special chars: ❌ Failed to process €100.50")
|
||||||
|
|
||||||
|
# Test warning messages
|
||||||
|
logger.warning("Warning with symbols: ⚠️ Temperature is 37°C")
|
||||||
|
|
||||||
|
# Test debug messages
|
||||||
|
logger.debug("Debug info: Processing file data.txt at 95% completion ✓")
|
||||||
|
|
||||||
|
# Test exception handling with special characters
|
||||||
|
try:
|
||||||
|
raise ValueError("Error with emoji: 💥 Something went wrong!")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Exception caught with special chars: %s", str(e))
|
||||||
|
|
||||||
|
# Test formatting with special characters
|
||||||
|
symbol = "ETH/USDT"
|
||||||
|
price = 2500.50
|
||||||
|
change = 2.3
|
||||||
|
logger.info(f"Price update for {symbol}: ${price:.2f} (+{change}% 📈)")
|
||||||
|
|
||||||
|
# Test large message with many special characters
|
||||||
|
large_msg = "Large message: " + "🔄" * 50 + " Processing complete ✅"
|
||||||
|
logger.info(large_msg)
|
||||||
|
|
||||||
|
print("=" * 50)
|
||||||
|
print("✅ Safe logging test completed!")
|
||||||
|
print("If you see this message, all logging calls were successful.")
|
||||||
|
print("Check the log file at logs/safe_logging.log for the complete output.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_safe_logging()
|
@ -1,400 +1,118 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
Test Training Data Collection System
|
Test Training Data Collection and Checkpoint Storage
|
||||||
|
|
||||||
This script demonstrates and tests the comprehensive training data collection
|
This script tests if the training system is working correctly and storing checkpoints.
|
||||||
system with data validation, rapid change detection, and profitable setup replay.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import os
|
||||||
|
import sys
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import asyncio
|
||||||
import pandas as pd
|
|
||||||
import time
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
project_root = Path(__file__).parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from core.config import get_config, setup_logging
|
||||||
|
from core.orchestrator import TradingOrchestrator
|
||||||
|
from core.data_provider import DataProvider
|
||||||
|
from utils.checkpoint_manager import get_checkpoint_manager
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(
|
setup_logging()
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Import our training system components
|
async def test_training_system():
|
||||||
from core.training_data_collector import (
|
"""Test if the training system is working and storing checkpoints"""
|
||||||
TrainingDataCollector,
|
logger.info("Testing training system and checkpoint storage...")
|
||||||
RapidChangeDetector,
|
|
||||||
ModelInputPackage,
|
|
||||||
TrainingOutcome,
|
|
||||||
TrainingEpisode
|
|
||||||
)
|
|
||||||
from core.cnn_training_pipeline import (
|
|
||||||
CNNPivotPredictor,
|
|
||||||
CNNTrainer
|
|
||||||
)
|
|
||||||
from core.data_provider import DataProvider
|
|
||||||
|
|
||||||
def create_sample_ohlcv_data() -> Dict[str, pd.DataFrame]:
|
# Initialize components
|
||||||
"""Create sample OHLCV data for testing"""
|
data_provider = DataProvider()
|
||||||
timeframes = ['1s', '1m', '5m', '15m', '1h']
|
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
|
||||||
ohlcv_data = {}
|
|
||||||
|
|
||||||
for timeframe in timeframes:
|
# Get checkpoint manager
|
||||||
# Create sample data
|
checkpoint_manager = get_checkpoint_manager()
|
||||||
dates = pd.date_range(start='2024-01-01', periods=300, freq='1min')
|
|
||||||
|
|
||||||
# Generate realistic price data
|
# Check if checkpoint directory exists
|
||||||
base_price = 3000.0 # ETH price
|
checkpoint_dir = Path("models/saved")
|
||||||
price_data = []
|
if not checkpoint_dir.exists():
|
||||||
current_price = base_price
|
logger.warning(f"Checkpoint directory {checkpoint_dir} does not exist. Creating...")
|
||||||
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
for i in range(300):
|
# Check for existing checkpoints
|
||||||
# Add some randomness
|
checkpoint_stats = checkpoint_manager.get_checkpoint_stats()
|
||||||
change = np.random.normal(0, 0.002) # 0.2% std dev
|
logger.info(f"Found {checkpoint_stats['total_checkpoints']} existing checkpoints.")
|
||||||
current_price *= (1 + change)
|
logger.info(f"Total checkpoint size: {checkpoint_stats['total_size_mb']:.2f} MB")
|
||||||
|
|
||||||
# OHLCV for this period
|
# List checkpoint files
|
||||||
open_price = current_price
|
checkpoint_files = list(checkpoint_dir.glob("*.pt"))
|
||||||
high_price = current_price * (1 + abs(np.random.normal(0, 0.001)))
|
if checkpoint_files:
|
||||||
low_price = current_price * (1 - abs(np.random.normal(0, 0.001)))
|
logger.info("Recent checkpoint files:")
|
||||||
close_price = current_price * (1 + np.random.normal(0, 0.0005))
|
for i, file in enumerate(sorted(checkpoint_files, key=lambda f: f.stat().st_mtime, reverse=True)[:5]):
|
||||||
volume = np.random.uniform(100, 1000)
|
file_size = file.stat().st_size / (1024 * 1024) # Convert to MB
|
||||||
|
modified_time = datetime.fromtimestamp(file.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
price_data.append({
|
logger.info(f" {i+1}. {file.name} ({file_size:.2f} MB, modified: {modified_time})")
|
||||||
'timestamp': dates[i],
|
|
||||||
'open': open_price,
|
|
||||||
'high': high_price,
|
|
||||||
'low': low_price,
|
|
||||||
'close': close_price,
|
|
||||||
'volume': volume
|
|
||||||
})
|
|
||||||
|
|
||||||
current_price = close_price
|
|
||||||
|
|
||||||
df = pd.DataFrame(price_data)
|
|
||||||
df.set_index('timestamp', inplace=True)
|
|
||||||
ohlcv_data[timeframe] = df
|
|
||||||
|
|
||||||
return ohlcv_data
|
|
||||||
|
|
||||||
def create_sample_tick_data() -> List[Dict[str, Any]]:
|
|
||||||
"""Create sample tick data for testing"""
|
|
||||||
tick_data = []
|
|
||||||
base_price = 3000.0
|
|
||||||
|
|
||||||
for i in range(100):
|
|
||||||
tick_data.append({
|
|
||||||
'timestamp': datetime.now() - timedelta(seconds=100-i),
|
|
||||||
'price': base_price + np.random.normal(0, 5),
|
|
||||||
'volume': np.random.uniform(0.1, 10.0),
|
|
||||||
'side': 'buy' if np.random.random() > 0.5 else 'sell',
|
|
||||||
'trade_id': f'trade_{i}',
|
|
||||||
'quantity': np.random.uniform(0.1, 5.0)
|
|
||||||
})
|
|
||||||
|
|
||||||
return tick_data
|
|
||||||
|
|
||||||
def create_sample_cob_data() -> Dict[str, Any]:
|
|
||||||
"""Create sample COB data for testing"""
|
|
||||||
return {
|
|
||||||
'timestamp': datetime.now(),
|
|
||||||
'bid_levels': [3000 - i for i in range(10)],
|
|
||||||
'ask_levels': [3000 + i for i in range(10)],
|
|
||||||
'bid_volumes': [np.random.uniform(1, 10) for _ in range(10)],
|
|
||||||
'ask_volumes': [np.random.uniform(1, 10) for _ in range(10)],
|
|
||||||
'spread': 1.0,
|
|
||||||
'depth': 100.0
|
|
||||||
}
|
|
||||||
|
|
||||||
def test_rapid_change_detector():
|
|
||||||
"""Test the rapid change detection system"""
|
|
||||||
logger.info("=== Testing Rapid Change Detector ===")
|
|
||||||
|
|
||||||
detector = RapidChangeDetector(
|
|
||||||
velocity_threshold=0.5,
|
|
||||||
volatility_multiplier=3.0,
|
|
||||||
lookback_minutes=5
|
|
||||||
)
|
|
||||||
|
|
||||||
symbol = 'ETHUSDT'
|
|
||||||
base_price = 3000.0
|
|
||||||
|
|
||||||
# Add normal price points
|
|
||||||
for i in range(120): # 2 minutes of data
|
|
||||||
timestamp = datetime.now() - timedelta(seconds=120-i)
|
|
||||||
price = base_price + np.random.normal(0, 1) # Small changes
|
|
||||||
detector.add_price_point(symbol, timestamp, price)
|
|
||||||
|
|
||||||
# Check for rapid change (should be False)
|
|
||||||
is_rapid, velocity, volatility_spike = detector.detect_rapid_change(symbol)
|
|
||||||
logger.info(f"Normal conditions - Rapid change: {is_rapid}, Velocity: {velocity:.3f}")
|
|
||||||
|
|
||||||
# Add rapid price change
|
|
||||||
for i in range(60): # 1 minute of rapid changes
|
|
||||||
timestamp = datetime.now() - timedelta(seconds=60-i)
|
|
||||||
price = base_price + 50 + i * 0.5 # Rapid increase
|
|
||||||
detector.add_price_point(symbol, timestamp, price)
|
|
||||||
|
|
||||||
# Check for rapid change (should be True)
|
|
||||||
is_rapid, velocity, volatility_spike = detector.detect_rapid_change(symbol)
|
|
||||||
logger.info(f"Rapid change conditions - Rapid change: {is_rapid}, Velocity: {velocity:.3f}")
|
|
||||||
|
|
||||||
return detector
|
|
||||||
|
|
||||||
def test_training_data_collector():
|
|
||||||
"""Test the training data collection system"""
|
|
||||||
logger.info("=== Testing Training Data Collector ===")
|
|
||||||
|
|
||||||
# Initialize collector
|
|
||||||
collector = TrainingDataCollector(
|
|
||||||
storage_dir="test_training_data",
|
|
||||||
max_episodes_per_symbol=100
|
|
||||||
)
|
|
||||||
|
|
||||||
collector.start_collection()
|
|
||||||
|
|
||||||
symbol = 'ETHUSDT'
|
|
||||||
|
|
||||||
# Create sample data
|
|
||||||
ohlcv_data = create_sample_ohlcv_data()
|
|
||||||
tick_data = create_sample_tick_data()
|
|
||||||
cob_data = create_sample_cob_data()
|
|
||||||
technical_indicators = {
|
|
||||||
'rsi_14': 65.5,
|
|
||||||
'macd': 0.5,
|
|
||||||
'sma_20': 3000.0,
|
|
||||||
'ema_12': 3005.0,
|
|
||||||
'bollinger_upper': 3050.0,
|
|
||||||
'bollinger_lower': 2950.0
|
|
||||||
}
|
|
||||||
pivot_points = [
|
|
||||||
{'timestamp': datetime.now(), 'price': 3020.0, 'type': 'high'},
|
|
||||||
{'timestamp': datetime.now() - timedelta(minutes=30), 'price': 2980.0, 'type': 'low'}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create CNN and RL features
|
|
||||||
cnn_features = np.random.randn(2000).astype(np.float32)
|
|
||||||
rl_state = np.random.randn(2000).astype(np.float32)
|
|
||||||
orchestrator_context = {
|
|
||||||
'market_session': 'european',
|
|
||||||
'volatility_regime': 'medium',
|
|
||||||
'trend_direction': 'uptrend'
|
|
||||||
}
|
|
||||||
|
|
||||||
# Collect training data
|
|
||||||
episode_id = collector.collect_training_data(
|
|
||||||
symbol=symbol,
|
|
||||||
ohlcv_data=ohlcv_data,
|
|
||||||
tick_data=tick_data,
|
|
||||||
cob_data=cob_data,
|
|
||||||
technical_indicators=technical_indicators,
|
|
||||||
pivot_points=pivot_points,
|
|
||||||
cnn_features=cnn_features,
|
|
||||||
rl_state=rl_state,
|
|
||||||
orchestrator_context=orchestrator_context
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Created training episode: {episode_id}")
|
|
||||||
|
|
||||||
# Test data validation
|
|
||||||
validation_results = collector.validate_data_integrity()
|
|
||||||
logger.info(f"Data integrity validation: {validation_results}")
|
|
||||||
|
|
||||||
# Get statistics
|
|
||||||
stats = collector.get_collection_statistics()
|
|
||||||
logger.info(f"Collection statistics: {stats}")
|
|
||||||
|
|
||||||
collector.stop_collection()
|
|
||||||
|
|
||||||
return collector
|
|
||||||
|
|
||||||
def test_cnn_training_pipeline():
|
|
||||||
"""Test the CNN training pipeline"""
|
|
||||||
logger.info("=== Testing CNN Training Pipeline ===")
|
|
||||||
|
|
||||||
# Initialize CNN model and trainer
|
|
||||||
model = CNNPivotPredictor(
|
|
||||||
input_channels=10,
|
|
||||||
sequence_length=300,
|
|
||||||
hidden_dim=128, # Smaller for testing
|
|
||||||
num_pivot_classes=3
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = CNNTrainer(
|
|
||||||
model=model,
|
|
||||||
device='cpu', # Use CPU for testing
|
|
||||||
learning_rate=0.001,
|
|
||||||
storage_dir="test_cnn_training"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create sample training episodes
|
|
||||||
episodes = []
|
|
||||||
for i in range(50): # Create 50 sample episodes
|
|
||||||
# Create sample input package
|
|
||||||
input_package = ModelInputPackage(
|
|
||||||
timestamp=datetime.now() - timedelta(minutes=i),
|
|
||||||
symbol='ETHUSDT',
|
|
||||||
ohlcv_data=create_sample_ohlcv_data(),
|
|
||||||
tick_data=create_sample_tick_data(),
|
|
||||||
cob_data=create_sample_cob_data(),
|
|
||||||
technical_indicators={'rsi': 50.0, 'macd': 0.0},
|
|
||||||
pivot_points=[],
|
|
||||||
cnn_features=np.random.randn(2000).astype(np.float32),
|
|
||||||
rl_state=np.random.randn(2000).astype(np.float32),
|
|
||||||
orchestrator_context={}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create sample outcome
|
|
||||||
outcome = TrainingOutcome(
|
|
||||||
input_package_hash=input_package.data_hash,
|
|
||||||
timestamp=input_package.timestamp,
|
|
||||||
symbol='ETHUSDT',
|
|
||||||
price_change_1m=np.random.normal(0, 0.01),
|
|
||||||
price_change_5m=np.random.normal(0, 0.02),
|
|
||||||
price_change_15m=np.random.normal(0, 0.03),
|
|
||||||
price_change_1h=np.random.normal(0, 0.05),
|
|
||||||
max_profit_potential=abs(np.random.normal(0, 0.02)),
|
|
||||||
max_loss_potential=abs(np.random.normal(0, 0.015)),
|
|
||||||
optimal_entry_price=3000.0,
|
|
||||||
optimal_exit_price=3000.0 + np.random.normal(0, 10),
|
|
||||||
optimal_holding_time=timedelta(minutes=np.random.randint(5, 60)),
|
|
||||||
is_profitable=np.random.random() > 0.4, # 60% profitable
|
|
||||||
profitability_score=np.random.uniform(0.3, 1.0),
|
|
||||||
risk_reward_ratio=np.random.uniform(1.0, 3.0),
|
|
||||||
is_rapid_change=np.random.random() > 0.8, # 20% rapid changes
|
|
||||||
change_velocity=np.random.uniform(0.1, 2.0),
|
|
||||||
volatility_spike=np.random.random() > 0.9,
|
|
||||||
outcome_validated=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create training episode
|
|
||||||
episode = TrainingEpisode(
|
|
||||||
episode_id=f"test_episode_{i}",
|
|
||||||
input_package=input_package,
|
|
||||||
model_predictions={},
|
|
||||||
actual_outcome=outcome,
|
|
||||||
episode_type='normal'
|
|
||||||
)
|
|
||||||
|
|
||||||
episodes.append(episode)
|
|
||||||
|
|
||||||
# Test training on episodes
|
|
||||||
results = trainer._train_on_episodes(episodes, training_mode='test_batch')
|
|
||||||
logger.info(f"Training results: {results}")
|
|
||||||
|
|
||||||
# Test profitable episode training
|
|
||||||
profitable_results = trainer.train_on_profitable_episodes(
|
|
||||||
symbol='ETHUSDT',
|
|
||||||
min_profitability=0.7,
|
|
||||||
max_episodes=20
|
|
||||||
)
|
|
||||||
logger.info(f"Profitable training results: {profitable_results}")
|
|
||||||
|
|
||||||
# Get training statistics
|
|
||||||
stats = trainer.get_training_statistics()
|
|
||||||
logger.info(f"Training statistics: {stats}")
|
|
||||||
|
|
||||||
return trainer
|
|
||||||
|
|
||||||
def test_integration():
|
|
||||||
"""Test the complete integration"""
|
|
||||||
logger.info("=== Testing Complete Integration ===")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Test individual components
|
|
||||||
detector = test_rapid_change_detector()
|
|
||||||
collector = test_training_data_collector()
|
|
||||||
trainer = test_cnn_training_pipeline()
|
|
||||||
|
|
||||||
logger.info("✅ All components tested successfully!")
|
|
||||||
|
|
||||||
# Test data flow
|
|
||||||
logger.info("Testing data flow integration...")
|
|
||||||
|
|
||||||
# Simulate real-time data collection and training
|
|
||||||
symbol = 'ETHUSDT'
|
|
||||||
|
|
||||||
# Collect multiple data points
|
|
||||||
for i in range(10):
|
|
||||||
ohlcv_data = create_sample_ohlcv_data()
|
|
||||||
tick_data = create_sample_tick_data()
|
|
||||||
cob_data = create_sample_cob_data()
|
|
||||||
|
|
||||||
episode_id = collector.collect_training_data(
|
|
||||||
symbol=symbol,
|
|
||||||
ohlcv_data=ohlcv_data,
|
|
||||||
tick_data=tick_data,
|
|
||||||
cob_data=cob_data,
|
|
||||||
technical_indicators={'rsi': 50.0 + i},
|
|
||||||
pivot_points=[],
|
|
||||||
cnn_features=np.random.randn(2000).astype(np.float32),
|
|
||||||
rl_state=np.random.randn(2000).astype(np.float32),
|
|
||||||
orchestrator_context={}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Collected episode {i+1}: {episode_id}")
|
|
||||||
time.sleep(0.1) # Small delay
|
|
||||||
|
|
||||||
# Get final statistics
|
|
||||||
final_stats = collector.get_collection_statistics()
|
|
||||||
logger.info(f"Final collection statistics: {final_stats}")
|
|
||||||
|
|
||||||
logger.info("✅ Integration test completed successfully!")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ Integration test failed: {e}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main test function"""
|
|
||||||
logger.info("=" * 80)
|
|
||||||
logger.info("COMPREHENSIVE TRAINING DATA COLLECTION SYSTEM TEST")
|
|
||||||
logger.info("=" * 80)
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Run integration test
|
|
||||||
success = test_integration()
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
duration = end_time - start_time
|
|
||||||
|
|
||||||
logger.info("=" * 80)
|
|
||||||
if success:
|
|
||||||
logger.info("✅ ALL TESTS PASSED!")
|
|
||||||
else:
|
else:
|
||||||
logger.info("❌ SOME TESTS FAILED!")
|
logger.warning("No checkpoint files found.")
|
||||||
|
|
||||||
logger.info(f"Test duration: {duration:.2f} seconds")
|
# Test training by making trading decisions
|
||||||
logger.info("=" * 80)
|
logger.info("\nTesting training by making trading decisions...")
|
||||||
|
symbols = orchestrator.symbols
|
||||||
|
|
||||||
# Display summary
|
for symbol in symbols:
|
||||||
logger.info("\n📊 SYSTEM CAPABILITIES DEMONSTRATED:")
|
logger.info(f"Making trading decision for {symbol}...")
|
||||||
logger.info("✓ Comprehensive training data collection with validation")
|
decision = await orchestrator.make_trading_decision(symbol)
|
||||||
logger.info("✓ Rapid price change detection for premium training examples")
|
|
||||||
logger.info("✓ Data integrity validation and completeness checking")
|
|
||||||
logger.info("✓ CNN training pipeline with backpropagation data storage")
|
|
||||||
logger.info("✓ Profitable episode prioritization and replay")
|
|
||||||
logger.info("✓ Training session value calculation and ranking")
|
|
||||||
logger.info("✓ Real-time data integration capabilities")
|
|
||||||
|
|
||||||
logger.info("\n🎯 NEXT STEPS:")
|
if decision:
|
||||||
logger.info("1. Integrate with existing DataProvider for real market data")
|
logger.info(f"Decision for {symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
|
||||||
logger.info("2. Connect with actual CNN and RL models")
|
else:
|
||||||
logger.info("3. Implement outcome validation with real price data")
|
logger.warning(f"No decision made for {symbol}.")
|
||||||
logger.info("4. Add dashboard integration for monitoring")
|
|
||||||
logger.info("5. Scale up for production deployment")
|
|
||||||
|
|
||||||
except Exception as e:
|
# Check if new checkpoints were created
|
||||||
logger.error(f"❌ Test execution failed: {e}")
|
new_checkpoint_stats = checkpoint_manager.get_checkpoint_stats()
|
||||||
import traceback
|
new_checkpoints = new_checkpoint_stats['total_checkpoints'] - checkpoint_stats['total_checkpoints']
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
if new_checkpoints > 0:
|
||||||
|
logger.info(f"\nSuccess! {new_checkpoints} new checkpoints were created.")
|
||||||
|
logger.info("Training system is working correctly.")
|
||||||
|
else:
|
||||||
|
logger.warning("\nNo new checkpoints were created.")
|
||||||
|
logger.warning("This could be normal if the training threshold wasn't met.")
|
||||||
|
logger.warning("Check the orchestrator's checkpoint saving logic.")
|
||||||
|
|
||||||
|
# Check model states
|
||||||
|
model_states = orchestrator.get_model_states()
|
||||||
|
logger.info("\nModel states:")
|
||||||
|
for model_name, state in model_states.items():
|
||||||
|
checkpoint_loaded = state.get('checkpoint_loaded', False)
|
||||||
|
checkpoint_filename = state.get('checkpoint_filename', 'none')
|
||||||
|
current_loss = state.get('current_loss', None)
|
||||||
|
|
||||||
|
status = "LOADED" if checkpoint_loaded else "FRESH"
|
||||||
|
loss_str = f"{current_loss:.4f}" if current_loss is not None else "N/A"
|
||||||
|
|
||||||
|
logger.info(f" {model_name}: {status}, Loss: {loss_str}, Checkpoint: {checkpoint_filename}")
|
||||||
|
|
||||||
|
return new_checkpoints > 0
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main function"""
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info("TRAINING SYSTEM TEST")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
|
||||||
|
success = await test_training_system()
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info("\nTraining system test passed!")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
logger.warning("\nTraining system test completed with warnings.")
|
||||||
|
logger.info("Check the logs for details.")
|
||||||
|
return 1
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
sys.exit(asyncio.run(main()))
|
@ -11,6 +11,7 @@ import sys
|
|||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from safe_logging import setup_safe_logging
|
||||||
|
|
||||||
# Add project root to path
|
# Add project root to path
|
||||||
project_root = Path(__file__).parent
|
project_root = Path(__file__).parent
|
||||||
|
@ -1052,6 +1052,8 @@ class CleanTradingDashboard:
|
|||||||
"""Handle clear session button"""
|
"""Handle clear session button"""
|
||||||
if n_clicks:
|
if n_clicks:
|
||||||
self._clear_session()
|
self._clear_session()
|
||||||
|
# Return a visual confirmation that the session was cleared
|
||||||
|
return [html.I(className="fas fa-check me-1 text-success"), "Cleared"]
|
||||||
return [html.I(className="fas fa-trash me-1"), "Clear Session"]
|
return [html.I(className="fas fa-trash me-1"), "Clear Session"]
|
||||||
|
|
||||||
@self.app.callback(
|
@self.app.callback(
|
||||||
|
Reference in New Issue
Block a user