Running distributed ML training workloads on GKE using the JobSet API
In this guide you will run a distributed ML training workload on GKE using the JobSet API. Specifically, you will train a handwritten digit image classifier on the classic MNIST dataset using PyTorch. The training computation will be distributed across 4 nodes in a GKE cluster.
Prerequisites
- Google Cloud account set up.
- gcloud command line tool installed and configured to use your GCP project.
- kubectl command line utility is installed.
1. Create a GKE cluster with 4 nodes
Run the command:
gcloud container clusters create jobset-cluster --zone us-central1-c --num_nodes=4
You should see output indicating the cluster is being created (this can take ~10 minutes or so).
2. Install the JobSet CRD on your cluster
Follow the JobSet installation guide.
3. Apply the PyTorch MNIST example JobSet
Run the command:
$ kubectl apply -f https://raw.githubusercontent.com/kubernetes-sigs/jobset/main/examples/pytorch/cnn-mnist/mnist.yaml
jobset.jobset.x-k8s.io/pytorch created
You should see 4 pods created (note the container image is large and may take a few minutes to pull before the container starts running):
$ kubectl get pods
NAME READY STATUS RESTARTS AGE
pytorch-workers-0-0-ph645 0/1 ContainerCreating 0 6s
pytorch-workers-0-1-mddhj 0/1 ContainerCreating 0 6s
pytorch-workers-0-2-z9ffc 0/1 ContainerCreating 0 6s
pytorch-workers-0-3-f9ps4 0/1 ContainerCreating 0 6s
4. Observe training logs
You can observe the training logs by examining the pod logs.
$ kubectl logs -f pytorch-workers-0-1-drvk6
+ torchrun --rdzv_id=123 --nnodes=4 --nproc_per_node=1 --master_addr=pytorch-workers-0-0.pytorch --master_port=3389 --node_rank=1 mnist.py --epochs=1 --log-interval=1
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [00:00<00:00, 90162259.46it/s]
Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 33279036.76it/s]
Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:00<00:00, 23474415.33it/s]
Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 19165521.90it/s]
Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
Train Epoch: 1 [0/60000 (0%)] Loss: 2.297087
Train Epoch: 1 [64/60000 (0%)] Loss: 2.550339
Train Epoch: 1 [128/60000 (1%)] Loss: 2.361300
...