204 lines
7.0 KiB
Python
204 lines
7.0 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Reset Models and Fix Action Mapping
|
|
|
|
This script:
|
|
1. Deletes existing model files
|
|
2. Creates new model files with consistent action mapping
|
|
3. Updates action mapping in key files
|
|
"""
|
|
|
|
import os
|
|
import shutil
|
|
import logging
|
|
import sys
|
|
import torch
|
|
import numpy as np
|
|
from datetime import datetime
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def ensure_directory(directory):
|
|
"""Ensure directory exists"""
|
|
if not os.path.exists(directory):
|
|
os.makedirs(directory)
|
|
logger.info(f"Created directory: {directory}")
|
|
|
|
def delete_directory_contents(directory):
|
|
"""Delete all files in a directory"""
|
|
if os.path.exists(directory):
|
|
for filename in os.listdir(directory):
|
|
file_path = os.path.join(directory, filename)
|
|
try:
|
|
if os.path.isfile(file_path) or os.path.islink(file_path):
|
|
os.unlink(file_path)
|
|
elif os.path.isdir(file_path):
|
|
shutil.rmtree(file_path)
|
|
logger.info(f"Deleted: {file_path}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete {file_path}. Reason: {e}")
|
|
|
|
def create_backup_directory():
|
|
"""Create a backup directory with timestamp"""
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
backup_dir = f"models/backup_{timestamp}"
|
|
ensure_directory(backup_dir)
|
|
return backup_dir
|
|
|
|
def backup_models():
|
|
"""Backup existing models"""
|
|
backup_dir = create_backup_directory()
|
|
|
|
# List of model directories to backup
|
|
model_dirs = [
|
|
"models/enhanced_rl",
|
|
"models/enhanced_cnn",
|
|
"models/realtime_rl_cob",
|
|
"models/rl",
|
|
"models/cnn"
|
|
]
|
|
|
|
for model_dir in model_dirs:
|
|
if os.path.exists(model_dir):
|
|
dest_dir = os.path.join(backup_dir, os.path.basename(model_dir))
|
|
ensure_directory(dest_dir)
|
|
|
|
# Copy files
|
|
for filename in os.listdir(model_dir):
|
|
file_path = os.path.join(model_dir, filename)
|
|
if os.path.isfile(file_path):
|
|
shutil.copy2(file_path, dest_dir)
|
|
logger.info(f"Backed up: {file_path} to {dest_dir}")
|
|
|
|
return backup_dir
|
|
|
|
def initialize_dqn_model():
|
|
"""Initialize a new DQN model with consistent action mapping"""
|
|
try:
|
|
# Import necessary modules
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
from NN.models.dqn_agent import DQNAgent
|
|
|
|
# Define state shape for BTC and ETH
|
|
state_shape = (100,) # Default feature dimension
|
|
|
|
# Create models directory
|
|
ensure_directory("models/enhanced_rl")
|
|
|
|
# Initialize DQN with 3 actions (BUY=0, SELL=1, HOLD=2)
|
|
dqn_btc = DQNAgent(
|
|
state_shape=state_shape,
|
|
n_actions=3, # BUY=0, SELL=1, HOLD=2
|
|
learning_rate=0.001,
|
|
epsilon=0.5, # Start with moderate exploration
|
|
epsilon_min=0.01,
|
|
epsilon_decay=0.995,
|
|
model_name="BTC_USDT_dqn"
|
|
)
|
|
|
|
dqn_eth = DQNAgent(
|
|
state_shape=state_shape,
|
|
n_actions=3, # BUY=0, SELL=1, HOLD=2
|
|
learning_rate=0.001,
|
|
epsilon=0.5, # Start with moderate exploration
|
|
epsilon_min=0.01,
|
|
epsilon_decay=0.995,
|
|
model_name="ETH_USDT_dqn"
|
|
)
|
|
|
|
# Save initial models
|
|
torch.save(dqn_btc.policy_net.state_dict(), "models/enhanced_rl/BTC_USDT_dqn_policy.pth")
|
|
torch.save(dqn_btc.target_net.state_dict(), "models/enhanced_rl/BTC_USDT_dqn_target.pth")
|
|
torch.save(dqn_eth.policy_net.state_dict(), "models/enhanced_rl/ETH_USDT_dqn_policy.pth")
|
|
torch.save(dqn_eth.target_net.state_dict(), "models/enhanced_rl/ETH_USDT_dqn_target.pth")
|
|
|
|
logger.info("Initialized new DQN models with consistent action mapping (BUY=0, SELL=1, HOLD=2)")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize DQN models: {e}")
|
|
return False
|
|
|
|
def initialize_cnn_model():
|
|
"""Initialize a new CNN model with consistent action mapping"""
|
|
try:
|
|
# Import necessary modules
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
from NN.models.enhanced_cnn import EnhancedCNN
|
|
|
|
# Define input dimension and number of actions
|
|
input_dim = 100 # Default feature dimension
|
|
n_actions = 3 # BUY=0, SELL=1, HOLD=2
|
|
|
|
# Create models directory
|
|
ensure_directory("models/enhanced_cnn")
|
|
|
|
# Initialize CNN models for BTC and ETH
|
|
cnn_btc = EnhancedCNN(input_dim, n_actions)
|
|
cnn_eth = EnhancedCNN(input_dim, n_actions)
|
|
|
|
# Save initial models
|
|
torch.save(cnn_btc.state_dict(), "models/enhanced_cnn/BTC_USDT_cnn.pth")
|
|
torch.save(cnn_eth.state_dict(), "models/enhanced_cnn/ETH_USDT_cnn.pth")
|
|
|
|
logger.info("Initialized new CNN models with consistent action mapping (BUY=0, SELL=1, HOLD=2)")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize CNN models: {e}")
|
|
return False
|
|
|
|
def initialize_realtime_rl_model():
|
|
"""Initialize a new realtime RL model with consistent action mapping"""
|
|
try:
|
|
# Create models directory
|
|
ensure_directory("models/realtime_rl_cob")
|
|
|
|
# Create empty model files to ensure directory is not empty
|
|
with open("models/realtime_rl_cob/README.txt", "w") as f:
|
|
f.write("Realtime RL COB models will be saved here.\n")
|
|
f.write("Action mapping: BUY=0, SELL=1, HOLD=2\n")
|
|
|
|
logger.info("Initialized realtime RL model directory")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize realtime RL models: {e}")
|
|
return False
|
|
|
|
def main():
|
|
"""Main function to reset models and fix action mapping"""
|
|
logger.info("Starting model reset and action mapping fix")
|
|
|
|
# Backup existing models
|
|
backup_dir = backup_models()
|
|
logger.info(f"Backed up existing models to {backup_dir}")
|
|
|
|
# Delete existing model files
|
|
model_dirs = [
|
|
"models/enhanced_rl",
|
|
"models/enhanced_cnn",
|
|
"models/realtime_rl_cob"
|
|
]
|
|
|
|
for model_dir in model_dirs:
|
|
delete_directory_contents(model_dir)
|
|
logger.info(f"Deleted contents of {model_dir}")
|
|
|
|
# Initialize new models with consistent action mapping
|
|
dqn_success = initialize_dqn_model()
|
|
cnn_success = initialize_cnn_model()
|
|
rl_success = initialize_realtime_rl_model()
|
|
|
|
if dqn_success and cnn_success and rl_success:
|
|
logger.info("Successfully reset models and fixed action mapping")
|
|
logger.info("New action mapping: BUY=0, SELL=1, HOLD=2")
|
|
else:
|
|
logger.error("Failed to reset models and fix action mapping")
|
|
|
|
logger.info("Model reset complete")
|
|
|
|
if __name__ == "__main__":
|
|
main() |