wip
This commit is contained in:
189
balance_trading_signals.py
Normal file
189
balance_trading_signals.py
Normal file
@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Balance Trading Signals - Analyze and fix SHORT signal bias
|
||||
|
||||
This script analyzes the trading signals from the orchestrator and adjusts
|
||||
the model weights to balance BUY and SELL signals.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import get_config, setup_logging
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Setup logging
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def analyze_trading_signals():
|
||||
"""Analyze trading signals from the orchestrator"""
|
||||
logger.info("Analyzing trading signals...")
|
||||
|
||||
# Initialize components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
|
||||
|
||||
# Get recent decisions
|
||||
symbols = orchestrator.symbols
|
||||
all_decisions = {}
|
||||
|
||||
for symbol in symbols:
|
||||
decisions = orchestrator.get_recent_decisions(symbol)
|
||||
all_decisions[symbol] = decisions
|
||||
|
||||
# Count actions
|
||||
action_counts = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
||||
for decision in decisions:
|
||||
action_counts[decision.action] += 1
|
||||
|
||||
total_decisions = sum(action_counts.values())
|
||||
if total_decisions > 0:
|
||||
buy_percent = action_counts['BUY'] / total_decisions * 100
|
||||
sell_percent = action_counts['SELL'] / total_decisions * 100
|
||||
hold_percent = action_counts['HOLD'] / total_decisions * 100
|
||||
|
||||
logger.info(f"Symbol: {symbol}")
|
||||
logger.info(f" Total decisions: {total_decisions}")
|
||||
logger.info(f" BUY: {action_counts['BUY']} ({buy_percent:.1f}%)")
|
||||
logger.info(f" SELL: {action_counts['SELL']} ({sell_percent:.1f}%)")
|
||||
logger.info(f" HOLD: {action_counts['HOLD']} ({hold_percent:.1f}%)")
|
||||
|
||||
# Check for bias
|
||||
if sell_percent > buy_percent * 2: # If SELL signals are more than twice BUY signals
|
||||
logger.warning(f" SELL bias detected: {sell_percent:.1f}% vs {buy_percent:.1f}%")
|
||||
|
||||
# Adjust model weights to balance signals
|
||||
logger.info(" Adjusting model weights to balance signals...")
|
||||
|
||||
# Get current model weights
|
||||
model_weights = orchestrator.model_weights
|
||||
logger.info(f" Current model weights: {model_weights}")
|
||||
|
||||
# Identify models with SELL bias
|
||||
model_predictions = {}
|
||||
for model_name in model_weights:
|
||||
model_predictions[model_name] = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
||||
|
||||
# Analyze recent decisions to identify biased models
|
||||
for decision in decisions:
|
||||
reasoning = decision.reasoning
|
||||
if 'models_used' in reasoning:
|
||||
for model_name in reasoning['models_used']:
|
||||
if model_name in model_predictions:
|
||||
model_predictions[model_name][decision.action] += 1
|
||||
|
||||
# Calculate bias for each model
|
||||
model_bias = {}
|
||||
for model_name, actions in model_predictions.items():
|
||||
total = sum(actions.values())
|
||||
if total > 0:
|
||||
buy_pct = actions['BUY'] / total * 100
|
||||
sell_pct = actions['SELL'] / total * 100
|
||||
|
||||
# Calculate bias score (-100 to 100, negative = SELL bias, positive = BUY bias)
|
||||
bias_score = buy_pct - sell_pct
|
||||
model_bias[model_name] = bias_score
|
||||
|
||||
logger.info(f" Model {model_name}: Bias score = {bias_score:.1f} (BUY: {buy_pct:.1f}%, SELL: {sell_pct:.1f}%)")
|
||||
|
||||
# Adjust weights based on bias
|
||||
adjusted_weights = {}
|
||||
for model_name, weight in model_weights.items():
|
||||
if model_name in model_bias:
|
||||
bias = model_bias[model_name]
|
||||
|
||||
# If model has strong SELL bias, reduce its weight
|
||||
if bias < -30: # Strong SELL bias
|
||||
adjusted_weights[model_name] = max(0.05, weight * 0.7) # Reduce weight by 30%
|
||||
logger.info(f" Reducing weight of {model_name} from {weight:.2f} to {adjusted_weights[model_name]:.2f} due to SELL bias")
|
||||
# If model has BUY bias, increase its weight to balance
|
||||
elif bias > 10: # BUY bias
|
||||
adjusted_weights[model_name] = min(0.5, weight * 1.3) # Increase weight by 30%
|
||||
logger.info(f" Increasing weight of {model_name} from {weight:.2f} to {adjusted_weights[model_name]:.2f} to balance SELL bias")
|
||||
else:
|
||||
adjusted_weights[model_name] = weight
|
||||
else:
|
||||
adjusted_weights[model_name] = weight
|
||||
|
||||
# Save adjusted weights
|
||||
save_adjusted_weights(adjusted_weights)
|
||||
|
||||
logger.info(f" Adjusted weights: {adjusted_weights}")
|
||||
logger.info(" Weights saved to 'adjusted_model_weights.json'")
|
||||
|
||||
# Recommend next steps
|
||||
logger.info("\nRecommended actions:")
|
||||
logger.info("1. Update the model weights in the orchestrator")
|
||||
logger.info("2. Monitor trading signals for balance")
|
||||
logger.info("3. Consider retraining models with balanced data")
|
||||
|
||||
def save_adjusted_weights(weights):
|
||||
"""Save adjusted weights to a file"""
|
||||
output = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'weights': weights,
|
||||
'notes': 'Adjusted to balance BUY/SELL signals'
|
||||
}
|
||||
|
||||
with open('adjusted_model_weights.json', 'w') as f:
|
||||
json.dump(output, f, indent=2)
|
||||
|
||||
def apply_balanced_weights():
|
||||
"""Apply balanced weights to the orchestrator"""
|
||||
try:
|
||||
# Check if weights file exists
|
||||
if not os.path.exists('adjusted_model_weights.json'):
|
||||
logger.error("Adjusted weights file not found. Run analyze_trading_signals() first.")
|
||||
return False
|
||||
|
||||
# Load adjusted weights
|
||||
with open('adjusted_model_weights.json', 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
weights = data.get('weights', {})
|
||||
if not weights:
|
||||
logger.error("No weights found in the file.")
|
||||
return False
|
||||
|
||||
logger.info(f"Loaded adjusted weights: {weights}")
|
||||
|
||||
# Initialize components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
|
||||
|
||||
# Apply weights
|
||||
for model_name, weight in weights.items():
|
||||
if model_name in orchestrator.model_weights:
|
||||
orchestrator.model_weights[model_name] = weight
|
||||
|
||||
# Save updated weights
|
||||
orchestrator._save_orchestrator_state()
|
||||
|
||||
logger.info("Applied balanced weights to orchestrator.")
|
||||
logger.info("Restart the trading system for changes to take effect.")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying balanced weights: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("=" * 70)
|
||||
logger.info("TRADING SIGNAL BALANCE ANALYZER")
|
||||
logger.info("=" * 70)
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == 'apply':
|
||||
apply_balanced_weights()
|
||||
else:
|
||||
analyze_trading_signals()
|
Reference in New Issue
Block a user