210 lines
6.7 KiB
Bash
210 lines
6.7 KiB
Bash
#!/bin/bash
|
|
# Automatic PyTorch installation script
|
|
# Detects hardware and installs the appropriate PyTorch build
|
|
# Works with: NVIDIA (CUDA), AMD (ROCm), or CPU-only
|
|
|
|
set -e
|
|
|
|
echo "=================================================="
|
|
echo " PyTorch Auto-Setup for Trading System"
|
|
echo "=================================================="
|
|
echo ""
|
|
|
|
# Colors for output
|
|
RED='\033[0;31m'
|
|
GREEN='\033[0;32m'
|
|
YELLOW='\033[1;33m'
|
|
NC='\033[0m' # No Color
|
|
|
|
# Detect GPU hardware
|
|
detect_hardware() {
|
|
echo "Detecting GPU hardware..."
|
|
|
|
# Check for NVIDIA GPU
|
|
if command -v nvidia-smi &> /dev/null; then
|
|
if nvidia-smi &> /dev/null; then
|
|
echo -e "${GREEN}✓ NVIDIA GPU detected${NC}"
|
|
CUDA_VERSION=$(nvidia-smi | grep "CUDA Version" | awk '{print $9}' | cut -d. -f1,2)
|
|
echo " CUDA Version: $CUDA_VERSION"
|
|
GPU_TYPE="nvidia"
|
|
return
|
|
fi
|
|
fi
|
|
|
|
# Check for AMD GPU
|
|
if lspci 2>/dev/null | grep -iE "VGA|3D|Display" | grep -iq "AMD\|ATI"; then
|
|
echo -e "${GREEN}✓ AMD GPU detected${NC}"
|
|
GPU_MODEL=$(lspci | grep -iE "VGA|3D|Display" | grep -i "AMD\|ATI" | head -1)
|
|
echo " $GPU_MODEL"
|
|
|
|
# Check if ROCm is available
|
|
if command -v rocm-smi &> /dev/null; then
|
|
ROCM_VERSION=$(rocm-smi --version 2>/dev/null | grep "ROCm" | awk '{print $3}' || echo "unknown")
|
|
echo " ROCm installed: $ROCM_VERSION"
|
|
else
|
|
echo -e "${YELLOW} ⚠ ROCm not detected - will install ROCm PyTorch anyway${NC}"
|
|
fi
|
|
|
|
GPU_TYPE="amd"
|
|
return
|
|
fi
|
|
|
|
# No GPU detected
|
|
echo -e "${YELLOW}⚠ No GPU detected - will use CPU-only build${NC}"
|
|
GPU_TYPE="cpu"
|
|
}
|
|
|
|
# Check if PyTorch is already installed
|
|
check_existing_pytorch() {
|
|
if python -c "import torch" 2>/dev/null; then
|
|
TORCH_VERSION=$(python -c "import torch; print(torch.__version__)")
|
|
GPU_AVAILABLE=$(python -c "import torch; print(torch.cuda.is_available())")
|
|
|
|
echo ""
|
|
echo "PyTorch is already installed:"
|
|
echo " Version: $TORCH_VERSION"
|
|
echo " GPU available: $GPU_AVAILABLE"
|
|
echo ""
|
|
|
|
read -p "Reinstall PyTorch? (y/N): " -n 1 -r
|
|
echo
|
|
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
|
echo "Keeping existing PyTorch installation"
|
|
exit 0
|
|
fi
|
|
|
|
echo "Uninstalling existing PyTorch..."
|
|
pip uninstall -y torch 2>/dev/null || true
|
|
fi
|
|
}
|
|
|
|
# Install PyTorch based on hardware
|
|
install_pytorch() {
|
|
echo ""
|
|
echo "Installing PyTorch for $GPU_TYPE..."
|
|
echo ""
|
|
|
|
case $GPU_TYPE in
|
|
nvidia)
|
|
# Determine CUDA version to use
|
|
if [[ "$CUDA_VERSION" == "12.1" ]] || [[ "$CUDA_VERSION" == "12.2" ]] || [[ "$CUDA_VERSION" == "12.3" ]]; then
|
|
CUDA_BUILD="cu121"
|
|
elif [[ "$CUDA_VERSION" == "12.4" ]] || [[ "$CUDA_VERSION" == "12.5" ]] || [[ "$CUDA_VERSION" == "12.6" ]]; then
|
|
CUDA_BUILD="cu124"
|
|
elif [[ "$CUDA_VERSION" == "11."* ]]; then
|
|
CUDA_BUILD="cu118"
|
|
else
|
|
echo -e "${YELLOW}⚠ Unknown CUDA version, using CUDA 12.1 build${NC}"
|
|
CUDA_BUILD="cu121"
|
|
fi
|
|
|
|
echo "Installing PyTorch with CUDA $CUDA_BUILD support..."
|
|
pip install torch --index-url https://download.pytorch.org/whl/$CUDA_BUILD
|
|
;;
|
|
|
|
amd)
|
|
echo "Installing PyTorch with ROCm 6.2 support..."
|
|
echo "(This works with RDNA 2, RDNA 3, and newer AMD GPUs)"
|
|
pip install torch --index-url https://download.pytorch.org/whl/rocm6.2
|
|
;;
|
|
|
|
cpu)
|
|
echo "Installing CPU-only PyTorch..."
|
|
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
|
;;
|
|
esac
|
|
}
|
|
|
|
# Verify installation
|
|
verify_installation() {
|
|
echo ""
|
|
echo "Verifying installation..."
|
|
echo ""
|
|
|
|
if ! python -c "import torch" 2>/dev/null; then
|
|
echo -e "${RED}✗ PyTorch installation failed!${NC}"
|
|
exit 1
|
|
fi
|
|
|
|
TORCH_VERSION=$(python -c "import torch; print(torch.__version__)")
|
|
GPU_AVAILABLE=$(python -c "import torch; print(torch.cuda.is_available())")
|
|
|
|
echo -e "${GREEN}✓ PyTorch installed successfully!${NC}"
|
|
echo " Version: $TORCH_VERSION"
|
|
echo " GPU available: $GPU_AVAILABLE"
|
|
|
|
if [[ "$GPU_AVAILABLE" == "True" ]]; then
|
|
DEVICE_NAME=$(python -c "import torch; print(torch.cuda.get_device_name(0))")
|
|
DEVICE_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
|
|
MEMORY_GB=$(python -c "import torch; print(f'{torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}')")
|
|
|
|
echo -e "${GREEN} Device: $DEVICE_NAME${NC}"
|
|
echo " Count: $DEVICE_COUNT"
|
|
echo " Memory: ${MEMORY_GB} GB"
|
|
|
|
case $GPU_TYPE in
|
|
nvidia)
|
|
echo ""
|
|
echo "🚀 Training will be 5-10x faster with NVIDIA GPU!"
|
|
;;
|
|
amd)
|
|
echo ""
|
|
echo "🚀 Training will be 2-3x faster with AMD GPU!"
|
|
;;
|
|
esac
|
|
else
|
|
if [[ "$GPU_TYPE" != "cpu" ]]; then
|
|
echo -e "${YELLOW}⚠ GPU detected but not available in PyTorch${NC}"
|
|
echo " This might mean:"
|
|
echo " - GPU drivers need to be installed/updated"
|
|
echo " - Wrong PyTorch build was installed"
|
|
echo " - GPU is not supported"
|
|
else
|
|
echo " CPU-only mode (slower training)"
|
|
fi
|
|
fi
|
|
|
|
echo ""
|
|
echo "=================================================="
|
|
echo "✓ Setup complete!"
|
|
echo "=================================================="
|
|
echo ""
|
|
echo "Test your setup:"
|
|
echo " python -c \"import torch; print(f'GPU: {torch.cuda.is_available()}')\""
|
|
echo ""
|
|
echo "Start ANNOTATE:"
|
|
echo " python ANNOTATE/web/app.py"
|
|
echo ""
|
|
}
|
|
|
|
# Main execution
|
|
main() {
|
|
# Check if we're in a virtual environment
|
|
if [[ -z "$VIRTUAL_ENV" ]]; then
|
|
echo -e "${YELLOW}⚠ Not in a virtual environment${NC}"
|
|
echo ""
|
|
echo "It's recommended to use a virtual environment:"
|
|
echo " python -m venv venv"
|
|
echo " source venv/bin/activate # Linux/Mac"
|
|
echo " .\\venv\\Scripts\\activate # Windows"
|
|
echo ""
|
|
read -p "Continue anyway? (y/N): " -n 1 -r
|
|
echo
|
|
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
|
exit 1
|
|
fi
|
|
else
|
|
echo -e "${GREEN}✓ Virtual environment active: $VIRTUAL_ENV${NC}"
|
|
echo ""
|
|
fi
|
|
|
|
detect_hardware
|
|
check_existing_pytorch
|
|
install_pytorch
|
|
verify_installation
|
|
}
|
|
|
|
# Run main function
|
|
main
|
|
|