try fixing GPU (torch)
This commit is contained in:
209
scripts/setup-pytorch.sh
Normal file
209
scripts/setup-pytorch.sh
Normal file
@@ -0,0 +1,209 @@
|
||||
#!/bin/bash
|
||||
# Automatic PyTorch installation script
|
||||
# Detects hardware and installs the appropriate PyTorch build
|
||||
# Works with: NVIDIA (CUDA), AMD (ROCm), or CPU-only
|
||||
|
||||
set -e
|
||||
|
||||
echo "=================================================="
|
||||
echo " PyTorch Auto-Setup for Trading System"
|
||||
echo "=================================================="
|
||||
echo ""
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Detect GPU hardware
|
||||
detect_hardware() {
|
||||
echo "Detecting GPU hardware..."
|
||||
|
||||
# Check for NVIDIA GPU
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
if nvidia-smi &> /dev/null; then
|
||||
echo -e "${GREEN}✓ NVIDIA GPU detected${NC}"
|
||||
CUDA_VERSION=$(nvidia-smi | grep "CUDA Version" | awk '{print $9}' | cut -d. -f1,2)
|
||||
echo " CUDA Version: $CUDA_VERSION"
|
||||
GPU_TYPE="nvidia"
|
||||
return
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check for AMD GPU
|
||||
if lspci 2>/dev/null | grep -iE "VGA|3D|Display" | grep -iq "AMD\|ATI"; then
|
||||
echo -e "${GREEN}✓ AMD GPU detected${NC}"
|
||||
GPU_MODEL=$(lspci | grep -iE "VGA|3D|Display" | grep -i "AMD\|ATI" | head -1)
|
||||
echo " $GPU_MODEL"
|
||||
|
||||
# Check if ROCm is available
|
||||
if command -v rocm-smi &> /dev/null; then
|
||||
ROCM_VERSION=$(rocm-smi --version 2>/dev/null | grep "ROCm" | awk '{print $3}' || echo "unknown")
|
||||
echo " ROCm installed: $ROCM_VERSION"
|
||||
else
|
||||
echo -e "${YELLOW} ⚠ ROCm not detected - will install ROCm PyTorch anyway${NC}"
|
||||
fi
|
||||
|
||||
GPU_TYPE="amd"
|
||||
return
|
||||
fi
|
||||
|
||||
# No GPU detected
|
||||
echo -e "${YELLOW}⚠ No GPU detected - will use CPU-only build${NC}"
|
||||
GPU_TYPE="cpu"
|
||||
}
|
||||
|
||||
# Check if PyTorch is already installed
|
||||
check_existing_pytorch() {
|
||||
if python -c "import torch" 2>/dev/null; then
|
||||
TORCH_VERSION=$(python -c "import torch; print(torch.__version__)")
|
||||
GPU_AVAILABLE=$(python -c "import torch; print(torch.cuda.is_available())")
|
||||
|
||||
echo ""
|
||||
echo "PyTorch is already installed:"
|
||||
echo " Version: $TORCH_VERSION"
|
||||
echo " GPU available: $GPU_AVAILABLE"
|
||||
echo ""
|
||||
|
||||
read -p "Reinstall PyTorch? (y/N): " -n 1 -r
|
||||
echo
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
echo "Keeping existing PyTorch installation"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Uninstalling existing PyTorch..."
|
||||
pip uninstall -y torch 2>/dev/null || true
|
||||
fi
|
||||
}
|
||||
|
||||
# Install PyTorch based on hardware
|
||||
install_pytorch() {
|
||||
echo ""
|
||||
echo "Installing PyTorch for $GPU_TYPE..."
|
||||
echo ""
|
||||
|
||||
case $GPU_TYPE in
|
||||
nvidia)
|
||||
# Determine CUDA version to use
|
||||
if [[ "$CUDA_VERSION" == "12.1" ]] || [[ "$CUDA_VERSION" == "12.2" ]] || [[ "$CUDA_VERSION" == "12.3" ]]; then
|
||||
CUDA_BUILD="cu121"
|
||||
elif [[ "$CUDA_VERSION" == "12.4" ]] || [[ "$CUDA_VERSION" == "12.5" ]] || [[ "$CUDA_VERSION" == "12.6" ]]; then
|
||||
CUDA_BUILD="cu124"
|
||||
elif [[ "$CUDA_VERSION" == "11."* ]]; then
|
||||
CUDA_BUILD="cu118"
|
||||
else
|
||||
echo -e "${YELLOW}⚠ Unknown CUDA version, using CUDA 12.1 build${NC}"
|
||||
CUDA_BUILD="cu121"
|
||||
fi
|
||||
|
||||
echo "Installing PyTorch with CUDA $CUDA_BUILD support..."
|
||||
pip install torch --index-url https://download.pytorch.org/whl/$CUDA_BUILD
|
||||
;;
|
||||
|
||||
amd)
|
||||
echo "Installing PyTorch with ROCm 6.2 support..."
|
||||
echo "(This works with RDNA 2, RDNA 3, and newer AMD GPUs)"
|
||||
pip install torch --index-url https://download.pytorch.org/whl/rocm6.2
|
||||
;;
|
||||
|
||||
cpu)
|
||||
echo "Installing CPU-only PyTorch..."
|
||||
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Verify installation
|
||||
verify_installation() {
|
||||
echo ""
|
||||
echo "Verifying installation..."
|
||||
echo ""
|
||||
|
||||
if ! python -c "import torch" 2>/dev/null; then
|
||||
echo -e "${RED}✗ PyTorch installation failed!${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
TORCH_VERSION=$(python -c "import torch; print(torch.__version__)")
|
||||
GPU_AVAILABLE=$(python -c "import torch; print(torch.cuda.is_available())")
|
||||
|
||||
echo -e "${GREEN}✓ PyTorch installed successfully!${NC}"
|
||||
echo " Version: $TORCH_VERSION"
|
||||
echo " GPU available: $GPU_AVAILABLE"
|
||||
|
||||
if [[ "$GPU_AVAILABLE" == "True" ]]; then
|
||||
DEVICE_NAME=$(python -c "import torch; print(torch.cuda.get_device_name(0))")
|
||||
DEVICE_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
|
||||
MEMORY_GB=$(python -c "import torch; print(f'{torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}')")
|
||||
|
||||
echo -e "${GREEN} Device: $DEVICE_NAME${NC}"
|
||||
echo " Count: $DEVICE_COUNT"
|
||||
echo " Memory: ${MEMORY_GB} GB"
|
||||
|
||||
case $GPU_TYPE in
|
||||
nvidia)
|
||||
echo ""
|
||||
echo "🚀 Training will be 5-10x faster with NVIDIA GPU!"
|
||||
;;
|
||||
amd)
|
||||
echo ""
|
||||
echo "🚀 Training will be 2-3x faster with AMD GPU!"
|
||||
;;
|
||||
esac
|
||||
else
|
||||
if [[ "$GPU_TYPE" != "cpu" ]]; then
|
||||
echo -e "${YELLOW}⚠ GPU detected but not available in PyTorch${NC}"
|
||||
echo " This might mean:"
|
||||
echo " - GPU drivers need to be installed/updated"
|
||||
echo " - Wrong PyTorch build was installed"
|
||||
echo " - GPU is not supported"
|
||||
else
|
||||
echo " CPU-only mode (slower training)"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=================================================="
|
||||
echo "✓ Setup complete!"
|
||||
echo "=================================================="
|
||||
echo ""
|
||||
echo "Test your setup:"
|
||||
echo " python -c \"import torch; print(f'GPU: {torch.cuda.is_available()}')\""
|
||||
echo ""
|
||||
echo "Start ANNOTATE:"
|
||||
echo " python ANNOTATE/web/app.py"
|
||||
echo ""
|
||||
}
|
||||
|
||||
# Main execution
|
||||
main() {
|
||||
# Check if we're in a virtual environment
|
||||
if [[ -z "$VIRTUAL_ENV" ]]; then
|
||||
echo -e "${YELLOW}⚠ Not in a virtual environment${NC}"
|
||||
echo ""
|
||||
echo "It's recommended to use a virtual environment:"
|
||||
echo " python -m venv venv"
|
||||
echo " source venv/bin/activate # Linux/Mac"
|
||||
echo " .\\venv\\Scripts\\activate # Windows"
|
||||
echo ""
|
||||
read -p "Continue anyway? (y/N): " -n 1 -r
|
||||
echo
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo -e "${GREEN}✓ Virtual environment active: $VIRTUAL_ENV${NC}"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
detect_hardware
|
||||
check_existing_pytorch
|
||||
install_pytorch
|
||||
verify_installation
|
||||
}
|
||||
|
||||
# Run main function
|
||||
main
|
||||
|
||||
Reference in New Issue
Block a user