fix model mappings,dash updates, trading
This commit is contained in:
204
reset_models_and_fix_mapping.py
Normal file
204
reset_models_and_fix_mapping.py
Normal file
@ -0,0 +1,204 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Reset Models and Fix Action Mapping
|
||||
|
||||
This script:
|
||||
1. Deletes existing model files
|
||||
2. Creates new model files with consistent action mapping
|
||||
3. Updates action mapping in key files
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import logging
|
||||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def ensure_directory(directory):
|
||||
"""Ensure directory exists"""
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
logger.info(f"Created directory: {directory}")
|
||||
|
||||
def delete_directory_contents(directory):
|
||||
"""Delete all files in a directory"""
|
||||
if os.path.exists(directory):
|
||||
for filename in os.listdir(directory):
|
||||
file_path = os.path.join(directory, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
logger.info(f"Deleted: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete {file_path}. Reason: {e}")
|
||||
|
||||
def create_backup_directory():
|
||||
"""Create a backup directory with timestamp"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_dir = f"models/backup_{timestamp}"
|
||||
ensure_directory(backup_dir)
|
||||
return backup_dir
|
||||
|
||||
def backup_models():
|
||||
"""Backup existing models"""
|
||||
backup_dir = create_backup_directory()
|
||||
|
||||
# List of model directories to backup
|
||||
model_dirs = [
|
||||
"models/enhanced_rl",
|
||||
"models/enhanced_cnn",
|
||||
"models/realtime_rl_cob",
|
||||
"models/rl",
|
||||
"models/cnn"
|
||||
]
|
||||
|
||||
for model_dir in model_dirs:
|
||||
if os.path.exists(model_dir):
|
||||
dest_dir = os.path.join(backup_dir, os.path.basename(model_dir))
|
||||
ensure_directory(dest_dir)
|
||||
|
||||
# Copy files
|
||||
for filename in os.listdir(model_dir):
|
||||
file_path = os.path.join(model_dir, filename)
|
||||
if os.path.isfile(file_path):
|
||||
shutil.copy2(file_path, dest_dir)
|
||||
logger.info(f"Backed up: {file_path} to {dest_dir}")
|
||||
|
||||
return backup_dir
|
||||
|
||||
def initialize_dqn_model():
|
||||
"""Initialize a new DQN model with consistent action mapping"""
|
||||
try:
|
||||
# Import necessary modules
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
|
||||
# Define state shape for BTC and ETH
|
||||
state_shape = (100,) # Default feature dimension
|
||||
|
||||
# Create models directory
|
||||
ensure_directory("models/enhanced_rl")
|
||||
|
||||
# Initialize DQN with 3 actions (BUY=0, SELL=1, HOLD=2)
|
||||
dqn_btc = DQNAgent(
|
||||
state_shape=state_shape,
|
||||
n_actions=3, # BUY=0, SELL=1, HOLD=2
|
||||
learning_rate=0.001,
|
||||
epsilon=0.5, # Start with moderate exploration
|
||||
epsilon_min=0.01,
|
||||
epsilon_decay=0.995,
|
||||
model_name="BTC_USDT_dqn"
|
||||
)
|
||||
|
||||
dqn_eth = DQNAgent(
|
||||
state_shape=state_shape,
|
||||
n_actions=3, # BUY=0, SELL=1, HOLD=2
|
||||
learning_rate=0.001,
|
||||
epsilon=0.5, # Start with moderate exploration
|
||||
epsilon_min=0.01,
|
||||
epsilon_decay=0.995,
|
||||
model_name="ETH_USDT_dqn"
|
||||
)
|
||||
|
||||
# Save initial models
|
||||
torch.save(dqn_btc.policy_net.state_dict(), "models/enhanced_rl/BTC_USDT_dqn_policy.pth")
|
||||
torch.save(dqn_btc.target_net.state_dict(), "models/enhanced_rl/BTC_USDT_dqn_target.pth")
|
||||
torch.save(dqn_eth.policy_net.state_dict(), "models/enhanced_rl/ETH_USDT_dqn_policy.pth")
|
||||
torch.save(dqn_eth.target_net.state_dict(), "models/enhanced_rl/ETH_USDT_dqn_target.pth")
|
||||
|
||||
logger.info("Initialized new DQN models with consistent action mapping (BUY=0, SELL=1, HOLD=2)")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize DQN models: {e}")
|
||||
return False
|
||||
|
||||
def initialize_cnn_model():
|
||||
"""Initialize a new CNN model with consistent action mapping"""
|
||||
try:
|
||||
# Import necessary modules
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
from NN.models.enhanced_cnn import EnhancedCNN
|
||||
|
||||
# Define input dimension and number of actions
|
||||
input_dim = 100 # Default feature dimension
|
||||
n_actions = 3 # BUY=0, SELL=1, HOLD=2
|
||||
|
||||
# Create models directory
|
||||
ensure_directory("models/enhanced_cnn")
|
||||
|
||||
# Initialize CNN models for BTC and ETH
|
||||
cnn_btc = EnhancedCNN(input_dim, n_actions)
|
||||
cnn_eth = EnhancedCNN(input_dim, n_actions)
|
||||
|
||||
# Save initial models
|
||||
torch.save(cnn_btc.state_dict(), "models/enhanced_cnn/BTC_USDT_cnn.pth")
|
||||
torch.save(cnn_eth.state_dict(), "models/enhanced_cnn/ETH_USDT_cnn.pth")
|
||||
|
||||
logger.info("Initialized new CNN models with consistent action mapping (BUY=0, SELL=1, HOLD=2)")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize CNN models: {e}")
|
||||
return False
|
||||
|
||||
def initialize_realtime_rl_model():
|
||||
"""Initialize a new realtime RL model with consistent action mapping"""
|
||||
try:
|
||||
# Create models directory
|
||||
ensure_directory("models/realtime_rl_cob")
|
||||
|
||||
# Create empty model files to ensure directory is not empty
|
||||
with open("models/realtime_rl_cob/README.txt", "w") as f:
|
||||
f.write("Realtime RL COB models will be saved here.\n")
|
||||
f.write("Action mapping: BUY=0, SELL=1, HOLD=2\n")
|
||||
|
||||
logger.info("Initialized realtime RL model directory")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize realtime RL models: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main function to reset models and fix action mapping"""
|
||||
logger.info("Starting model reset and action mapping fix")
|
||||
|
||||
# Backup existing models
|
||||
backup_dir = backup_models()
|
||||
logger.info(f"Backed up existing models to {backup_dir}")
|
||||
|
||||
# Delete existing model files
|
||||
model_dirs = [
|
||||
"models/enhanced_rl",
|
||||
"models/enhanced_cnn",
|
||||
"models/realtime_rl_cob"
|
||||
]
|
||||
|
||||
for model_dir in model_dirs:
|
||||
delete_directory_contents(model_dir)
|
||||
logger.info(f"Deleted contents of {model_dir}")
|
||||
|
||||
# Initialize new models with consistent action mapping
|
||||
dqn_success = initialize_dqn_model()
|
||||
cnn_success = initialize_cnn_model()
|
||||
rl_success = initialize_realtime_rl_model()
|
||||
|
||||
if dqn_success and cnn_success and rl_success:
|
||||
logger.info("Successfully reset models and fixed action mapping")
|
||||
logger.info("New action mapping: BUY=0, SELL=1, HOLD=2")
|
||||
else:
|
||||
logger.error("Failed to reset models and fix action mapping")
|
||||
|
||||
logger.info("Model reset complete")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user