Serve a LLM using a single-host TPU on GKE with JetStream and MaxText
Background
This tutorial shows you how to serve a large language model (LLM) using Tensor Processing Units (TPUs) on Google Kubernetes Engine (GKE) with JetStream and MaxText.
Setup
Set default environment variables
gcloud config set project [PROJECT_ID]
export PROJECT_ID=$(gcloud config get project)
export REGION=[COMPUTE_REGION]
export ZONE=[ZONE]
Create GKE cluster and node pool
# Create zonal cluster with 2 CPU nodes
gcloud container clusters create jetstream-maxtext \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--workload-pool=${PROJECT_ID}.svc.id.goog \
--release-channel=rapid \
--num-nodes=2
# Create one v5e TPU pool with topology 2x4 (1 TPU node with 8 chips)
gcloud container node-pools create tpu \
--cluster=jetstream-maxtext \
--zone=${ZONE} \
--num-nodes=2 \
--machine-type=ct5lp-hightpu-8t \
--project=${PROJECT_ID}
You have created the following resources:
- Standard cluster with 2 CPU nodes.
- One v5e TPU node pool with 2 nodes, each with 8 chips.
Configure Applications to use Workload Identity
Prerequisite: make sure you have the following roles
roles/container.admin
roles/iam.serviceAccountAdmin
Follow these steps to configure the IAM and Kubernetes service account:
# Get credentials for your cluster
$ gcloud container clusters get-credentials jetstream-maxtext \
--zone=${ZONE}
# Create an IAM service account.
$ gcloud iam service-accounts create jetstream-iam-sa
# Ensure the IAM service account has necessary roles. Here we add roles/storage.objectUser for gcs bucket access.
$ gcloud projects add-iam-policy-binding ${PROJECT_ID} \
--member "serviceAccount:jetstream-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com" \
--role roles/storage.objectUser
$ gcloud projects add-iam-policy-binding ${PROJECT_ID} \
--member "serviceAccount:jetstream-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com" \
--role roles/storage.insightsCollectorService
# Allow the Kubernetes default service account to impersonate the IAM service account
$ gcloud iam service-accounts add-iam-policy-binding jetstream-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com \
--role roles/iam.workloadIdentityUser \
--member "serviceAccount:${PROJECT_ID}.svc.id.goog[default/default]"
# Annotate the Kubernetes service account with the email address of the IAM service account.
$ kubectl annotate serviceaccount default \
iam.gke.io/gcp-service-account=jetstream-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com
Create a Cloud Storage bucket to store the Gemma-7b model checkpoint
gcloud storage buckets create $BUCKET_NAME
Get access to the model
Access the model consent page and request access with your Kaggle Account. Accept the Terms and Conditions.
Obtain a Kaggle API token by going to your Kaggle settings and under the API
section, click Create New Token
. A kaggle.json
file will be downloaded.
Create a Secret to store the Kaggle credentials
kubectl create secret generic kaggle-secret \
--from-file=kaggle.json
Convert the Gemma-7b checkpoint
To convert the Gemma-7b checkpoint, we have created a job checkpoint-job.yaml
that does the following:
1. Download the base orbax checkpoint from kaggle
2. Upload the checkpoint to a Cloud Storage bucket
3. Convert the checkpoint to a MaxText compatible checkpoint
4. Unscan the checkpoint to be used for inference
In the manifest, ensure the value of the BUCKET_NAME environment variable is the name of the Cloud Storage bucket you created above. Do not include the gs://
prefix.
Apply the manifest:
kubectl apply -f checkpoint-job.yaml
Observe the logs:
kubectl logs -f jobs/data-loader-7b
You should see the following output once the job has completed. This will take around 10 minutes:
Successfully generated decode checkpoint at: gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items
+ echo -e '\nCompleted unscanning checkpoint to gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items'
Completed unscanning checkpoint to gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items
Deploy Maxengine Server and HTTP Server
Next, deploy a Maxengine server hosting the Gemma-7b model. You can use the provided Maxengine server and HTTP server images or build your own. Depending on your needs and constraints you can elect to deploy either via Terraform or via Kubectl.
Deploy via Kubectl
See the Jetstream component README for start to finish instructions on how to deploy jetstream to your cluster, assure the value of the PARAMETERS_PATH is the path where the checkpoint-converter job uploaded the converted checkpoints to, in this case it should be gs://$BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items
where $BUCKET_NAME is the same as above.
This README also includes instructions for setting up autoscaling. Follow those instructions to install the required components for autoscaling and configuring your HPAs appropriately.
Deploy via Terraform
Navigate to the ./terraform
directory and run terraform init
. The deployment requires some inputs, an example sample-terraform.tfvars
is provided as a starting point, run cp sample-terraform.tfvars terraform.tfvars
and modify the resulting terraform.tfvars
as needed. Since we're using gemma-7b the maxengine_deployment_settings.parameters_path
terraform variable should be set to the following: gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items
. Finally run terraform apply
to apply these resources to your cluster.
For deploying autoscaling components via terraform, a few more variables to be set, doing so and rerunning the prior step with these set will deploy the components. The following variables should be set:
maxengine_deployment_settings = {
metrics = {
port: <same as above> # which port will we scrape server metrics from
scrape_interval: 5s # how often do we scrape
}
}
hpa_config = {
metrics_adapter = <either 'prometheus-adapter` (recommended) or 'custom-metrics-stackdriver-adapter' >
max_replicas
min_replicas
rules = [{
target_query = <see [jetstream-maxtext-module README](https://github.com/GoogleCloudPlatform/ai-on-gke/tree/main/modules//jetstream-maxtext-deployment/README.md) for a list of valid values>
average_value_target
}]
}
Verify the deployment
Wait for the containers to finish creating:
kubectl get deployment
NAME READY UP-TO-DATE AVAILABLE AGE
maxengine-server 2/2 2 2 ##s
Check the Maxengine pod’s logs, and verify the compilation is done. You will see similar logs of the following:
kubectl logs deploy/maxengine-server -f -c maxengine-server
2024-03-29 17:09:08,047 - jax._src.dispatch - DEBUG - Finished XLA compilation of jit(initialize) in 0.26236414909362793 sec
2024-03-29 17:09:08,150 - root - INFO - ---------Generate params 0 loaded.---------
Check http server logs, this can take a couple minutes:
kubectl logs deploy/maxengine-server -f -c jetstream-http
INFO: Started server process [1]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
Send sample requests
Run the following command to set up port forwarding to the http server:
kubectl port-forward svc/jetstream-svc 8000:8000
In a new terminal, send a request to the server:
curl --request POST --header "Content-type: application/json" -s localhost:8000/generate --data '{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
The output should be similar to the following:
{
"response": " in 2021?\n\nThe answer to this question is not as simple as it may seem. There are many factors that go into determining the most popular programming languages, and they can change from year to year.\n\nIn this blog post, we will discuss the top 5 programming languages in 2021 and why they are so popular.\n\n<h2><strong>1. Python</strong></h2>\n\nPython is a high-level programming language that is used for web development, data analysis, and machine learning. It is one of the most popular languages in the world and is used by many companies such as Google, Facebook, and Instagram.\n\nPython is easy to learn and has a large community of developers who are always willing to help out.\n\n<h2><strong>2. Java</strong></h2>\n\nJava is a general-purpose programming language that is used for web development, mobile development, and game development. It is one of the most popular languages in the"
}
Other optional steps
Build and upload Maxengine Server image
Build the Maxengine Server from here and upload to your project
docker build -t maxengine-server .
docker tag maxengine-server gcr.io/${PROJECT_ID}/jetstream/maxtext/maxengine-server:latest
docker push gcr.io/${PROJECT_ID}/jetstream/maxtext/maxengine-server:latest
Build and upload HTTP Server image
Build the HTTP Server Dockerfile from here and upload to your project
docker build -t jetstream-http .
docker tag jetstream-http gcr.io/${PROJECT_ID}/jetstream/maxtext/jetstream-http:latest
docker push gcr.io/${PROJECT_ID}/jetstream/maxtext/jetstream-http:latest
Interact with the Maxengine server directly using gRPC
The Jetstream HTTP server is great for initial testing and validating end-to-end requests and responses. If you would like to interact directly with the Maxengine server directly for use cases such as benchmarking, you can do so by following the Jetstream benchmarking setup and applying the deployment.yaml
manifest file and interacting with the Jetstream gRPC server at port 9000.
kubectl apply -f kubectl/deployment.yaml
kubectl port-forward svc/jetstream-svc 9000:9000
To run benchmarking, pass in the flag --server 127.0.0.1
when running the benchmarking script.
Observe custom metrics
This step assumes you specified a metrics port to your jetstream deployment via prometheus_port
. If you would like to probe the metrics manually, cURL
your maxengine-server container on the metrics port you set and you should see something similar to the following:
# HELP jetstream_prefill_backlog_size Size of prefill queue
# TYPE jetstream_prefill_backlog_size gauge
jetstream_prefill_backlog_size{id="SOME-HOSTNAME-HERE>"} 0.0
# HELP jetstream_slots_used_percentage The percentage of decode slots currently being used
# TYPE jetstream_slots_used_percentage gauge
jetstream_slots_used_percentage{id="<SOME-HOSTNAME-HERE>",idx="0"} 0.04166666666666663