Skip to content

Serve a LLM using a single-host TPU on GKE with JetStream and PyTorch/XLA

Background

This tutorial shows you how to serve a large language model (LLM) using Tensor Processing Units (TPUs) on Google Kubernetes Engine (GKE) with JetStream and Jetstream-Pytorch.

Set default environment variables

gcloud config set project [PROJECT_ID]
export PROJECT_ID=$(gcloud config get project)
export REGION=[COMPUTE_REGION]
export ZONE=[ZONE]

Create GKE cluster and node pool

# Create zonal cluster with 2 CPU nodes
gcloud container clusters create jetstream-maxtext \
    --zone=${ZONE} \
    --project=${PROJECT_ID} \
    --workload-pool=${PROJECT_ID}.svc.id.goog \
    --release-channel=rapid \
    --addons GcsFuseCsiDriver
    --num-nodes=2
# Create one v5e TPU pool with topology 2x4 (1 TPU node with 8 chips)
gcloud container node-pools create tpu \
    --cluster=jetstream-maxtext \
    --zone=${ZONE} \
    --num-nodes=2 \
    --machine-type=ct5lp-hightpu-8t \
    --project=${PROJECT_ID}

You have created the following resources:

  • Standard cluster with 2 CPU nodes.
  • One v5e TPU node pool with 2 nodes, each with 8 chips.

Configure Applications to use Workload Identity

Prerequisite: make sure you have the following roles

roles/container.admin
roles/iam.serviceAccountAdmin

Follow these steps to configure the IAM and Kubernetes service account:

# Get credentials for your cluster
$ gcloud container clusters get-credentials jetstream-maxtext \
    --zone=${ZONE}
# Create an IAM service account.
$ gcloud iam service-accounts create jetstream-iam-sa
# Ensure the IAM service account has necessary roles. Here we add roles/storage.objectUser for gcs bucket access.
$ gcloud projects add-iam-policy-binding ${PROJECT_ID} \
    --member "serviceAccount:jetstream-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com" \
    --role roles/storage.objectUser
$ gcloud projects add-iam-policy-binding ${PROJECT_ID} \
    --member "serviceAccount:jetstream-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com" \
    --role roles/storage.insightsCollectorService
# Allow the Kubernetes default service account to impersonate the IAM service account
$ gcloud iam service-accounts add-iam-policy-binding jetstream-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com \
    --role roles/iam.workloadIdentityUser \
    --member "serviceAccount:${PROJECT_ID}.svc.id.goog[default/default]"
# Annotate the Kubernetes service account with the email address of the IAM service account.
$ kubectl annotate serviceaccount default \
    iam.gke.io/gcp-service-account=jetstream-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com

Create a Cloud Storage bucket to store your model checkpoint

BUCKET_NAME=<your desired gsbucket name>
gcloud storage buckets create $BUCKET_NAME

Checkpoint conversion

[Option #1] Download weights from GitHub

Follow the instructions here to download the llama-2-7b weights: https://github.com/meta-llama/llama#download

ls llama

llama-2-7b tokenizer.model ..

Upload your weights and tokenizer to your GSBucket

gcloud storage cp -r llama-2-7b/* gs://BUCKET_NAME/llama-2-7b/base/
gcloud storage cp tokenizer.model gs://BUCKET_NAME/llama-2-7b/base/

[Option #2] Download weights from HuggingFace

Accept the terms and conditions from https://huggingface.co/meta-llama/Llama-2-7b-hf.

For llama-3-8b: https://huggingface.co/meta-llama/Meta-Llama-3-8B.

For gemma-2b: https://huggingface.co/google/gemma-2b-pytorch.

Obtain a HuggingFace CLI token by going to your HuggingFace settings and under the Access Tokens, generate a New token. Edit permissions to your access token to have read access to your respective checkpoint repository.

Copy your access token and create a Secret to store the HuggingFace token

kubectl create secret generic huggingface-secret \
  --from-literal=HUGGINGFACE_TOKEN=<access_token>

Apply the checkpoint conversion job

For the following models, replace the following arguments in checkpoint-job.yaml

Llama-2-7b-hf

- -s=jetstream-pytorch
- -m=meta-llama/Llama-2-7b-hf
- -o=gs://BUCKET_NAME/pytorch/llama-2-7b/final/bf16/
- -n=llama-2
- -q=False
- -h=True

Llama-3-8b

- -s=jetstream-pytorch
- -m=meta-llama/Meta-Llama-3-8B
- -o=gs://BUCKET_NAME/pytorch/llama-3-8b/final/bf16/
- -n=llama-3
- -q=False
- -h=True

Gemma-2b

- -s=jetstream-pytorch
- -m=google/gemma-2b-pytorch
- -o=gs://BUCKET_NAME/pytorch/gemma-2b/final/bf16/
- -n=gemma
- -q=False
- -h=True

Run the checkpoint conversion job. This will use the checkpoint conversion script from Jetstream-pytorch to create a compatible Pytorch checkpoint

Please make sure you edit checkpoint-job and replace all occurrences of BUCKET_NAME with the BUCKET_NAME that you have set above.

kubectl apply -f checkpoint-job.yaml

Observe your checkpoint

kubectl logs -f jobs/checkpoint-converter
# This can take several minutes
...
Completed uploading converted checkpoint from local path /pt-ckpt/ to GSBucket gs://BUCKET_NAME/pytorch/llama-2-7b/final/bf16/"

Now your converted checkpoint will be located in gs://BUCKET_NAME/pytorch/llama-2-7b/final/bf16/

Deploy the Jetstream Pytorch server

The following flags are set in the manifest file

--size: Size of model
--model_name: Name of model (llama-2, llama-3, gemma)
--batch_size: Batch size
--max_cache_length: Maximum length of kv cache 
--tokenizer_path: Path to model tokenizer file
--checkpoint_path: Path to checkpoint
Optional flags to add
--quantize_weights (Default False): Checkpoint is quantized
--quantize_kv_cache (Default False): Quantized kv cache

For llama3-8b, you can use the following arguments:

- --size=8b
- --model_name=llama-3
- --batch_size=80
- --max_cache_length=2048
- --quantize_weights=False
- --quantize_kv_cache=False
- --tokenizer_path=/models/pytorch/llama3-8b/final/bf16/tokenizer.model
- --checkpoint_path=/models/pytorch/llama3-8b/final/bf16/model.safetensors
kubectl apply -f deployment.yaml

Verify the deployment

kubectl get deployment
NAME                       READY   UP-TO-DATE   AVAILABLE   AGE
jetstream-pytorch-server    2/2       2            2         ##s

View the HTTP server logs to check that the model has been loaded and compiled. It may take the server a few minutes to complete this operation.

kubectl logs deploy/jetstream-pytorch-server -f -c jetstream-http
INFO:     Started server process [1]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)

View the Jetstream Pytorch server logs and verify that the compilation is done.

kubectl logs deploy/jetstream-pytorch-server -f -c jetstream-pytorch-server
Started jetstream_server....
2024-04-12 04:33:37,128 - root - INFO - ---------Generate params 0 loaded.---------

Serve the model

kubectl port-forward svc/jetstream-svc 8000:8000

Interact with the model via curl

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

The initial request can take several seconds to complete due to model warmup. The output is similar to the following:

{
    "response": " for 2019?\nWhat are the top 5 programming languages for 2019? The top 5 programming languages for 2019 are Python, Java, JavaScript, C, and C++.\nWhat are the top 5 programming languages for 2019? The top 5 programming languages for 2019 are Python, Java, JavaScript, C, and C++. These languages are used in a variety of industries and are popular among developers.\nPython is a versatile language that can be used for web development, data analysis, and machine learning. It is easy to learn and has a large community of developers.\nJava is a popular language for enterprise applications and is used by many large companies. It is also used for mobile development and has a large community of developers.\nJavaScript is a popular language for web development and is used by many websites. It is also used for mobile development and has a"
}

Optionals

Interact with the Jetstream Pytorch server directly using gRPC

The Jetstream HTTP server is great for initial testing and validating end-to-end requests and responses. In production use case, it's recommended to interact with the JetStream-Pytorch server directly for better throughput/latency and use the streaming decode feature on the JetStream grpc server.

kubectl port-forward svc/jetstream-svc 9000:9000

Now you can interact with the JetStream grpc server directly via port 9000.

Use a Persistent Disk to host your checkpoint

Create a GCE CPU VM to do your checkpoint conversion.

gcloud compute instances create jetstream-ckpt-converter \
  --zone=us-central1-a \
  --machine-type=n2-standard-32 \
  --scopes=https://www.googleapis.com/auth/cloud-platform \
  --image=projects/ubuntu-os-cloud/global/images/ubuntu-2204-jammy-v20230919 \
  --boot-disk-size=128GB \
  --boot-disk-type=pd-balanced

SSH into the VM and install python and git

gcloud compute ssh jetstream-ckpt-converter --zone=us-central1-a
sudo apt update
sudo apt-get install python3-pip
sudo apt-get install git-all

In your CPU VM, follow these instructions to do the following: 1. Clone the jetstream-pytorch repository 2. Run the installation script 3. Download and convert llama-2-7b weights

After running weight safetensor convert, you should see the following files in the directory you have saved them to:

ls <directory where checkpoint is stored>
model.safetensors  params.json

Create a persistent disk to store your checkpoint

gcloud compute disks create jetstream-pytorch-ckpt --zone=us-west4-a --type pd-balanced
NAME                    ZONE        SIZE_GB  TYPE         STATUS
pytorch-jetstream-ckpt  us-west4-a  100      pd-balanced  READY

Attach the disk to your VM

gcloud compute instances attach-disk jetstream-ckpt-converter \
  --disk jetstream-pytorch-ckpt --project $PROJECT_ID --zone us-west4-a

Identity your disk, it will be similar to the following but may also have a different name:

lsblk
NAME    MAJ:MIN RM   SIZE RO TYPE MOUNTPOINTS
...
sdc       8:32   0   100G  0 disk 

Format your disk and create a directory in its mount folder

sudo mkfs.ext4 /dev/sdc
mkdir /mnt/jetstream-pytorch-ckpt
sudo mount /dev/sdc /mnt/jetstream-pytorch-ckpt

Copy your converted checkpoint folder into /mnt/jetstream-pytorch-ckpt

cp <path to converterted checkpoint> /mnt/jetstream-pytorch-ckpt

Unmount and detach your persistent disk

sudo umount /mnt/jetstream-pytorch-ckpt
gcloud compute instances detach-disk jetstream-ckpt-converter \
  --disk jetstream-pytorch-ckpt --project $PROJECT_ID --zone us-west4-a

Apply your storage and deployment manifest file

kubectl apply -f storage.yaml
kubectl apply -f pd-deployment.yaml