58 lines
1.8 KiB
Bash
Executable File
58 lines
1.8 KiB
Bash
Executable File
#!/usr/bin/env bash
|
|
|
|
# Distributed training script for SwiftFormerTemporal
|
|
# Usage: ./dist_temporal_train.sh <DATA_PATH> <NUM_GPUS> [OPTIONS]
|
|
|
|
DATA_PATH=$1
|
|
NUM_GPUS=$2
|
|
|
|
# Shift arguments to pass remaining options to python script
|
|
shift 2
|
|
|
|
# Default parameters
|
|
MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
|
|
BATCH_SIZE=${BATCH_SIZE:-128}
|
|
EPOCHS=${EPOCHS:-100}
|
|
# LR=${LR:-1e-3}
|
|
LR=${LR:-0.01}
|
|
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}
|
|
|
|
echo "Starting distributed training with $NUM_GPUS GPUs"
|
|
echo "Data path: $DATA_PATH"
|
|
echo "Model: $MODEL"
|
|
echo "Batch size: $BATCH_SIZE"
|
|
echo "Epochs: $EPOCHS"
|
|
echo "Output dir: $OUTPUT_DIR"
|
|
|
|
# Check if torch.distributed.launch or torchrun should be used
|
|
# For newer PyTorch versions (>=1.9), torchrun is recommended
|
|
PYTHON_VERSION=$(python -c "import torch; print(torch.__version__)")
|
|
echo "PyTorch version: $PYTHON_VERSION"
|
|
|
|
# Use torchrun for newer PyTorch versions
|
|
if [[ "$PYTHON_VERSION" =~ ^2\. ]] || [[ "$PYTHON_VERSION" =~ ^1\.1[0-9]\. ]]; then
|
|
echo "Using torchrun (PyTorch >=1.10)"
|
|
torchrun --nproc_per_node=$NUM_GPUS --master_port=12345 main_temporal.py \
|
|
--data-path "$DATA_PATH" \
|
|
--model "$MODEL" \
|
|
--batch-size $BATCH_SIZE \
|
|
--epochs $EPOCHS \
|
|
--lr $LR \
|
|
--output-dir "$OUTPUT_DIR" \
|
|
"$@"
|
|
else
|
|
echo "Using torch.distributed.launch"
|
|
python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS --master_port=12345 --use_env main_temporal.py \
|
|
--data-path "$DATA_PATH" \
|
|
--model "$MODEL" \
|
|
--batch-size $BATCH_SIZE \
|
|
--epochs $EPOCHS \
|
|
--lr $LR \
|
|
--output-dir "$OUTPUT_DIR" \
|
|
"$@"
|
|
fi
|
|
|
|
# For single-node multi-GPU training with specific options:
|
|
# --world-size 1 --rank 0 --dist-url 'tcp://localhost:12345'
|
|
|
|
echo "Training completed. Check logs in $OUTPUT_DIR" |