Text classification eval
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Text Classification Eval Recipe¶
This Eval Recipe demonstrates how to compare performance of two models on a text classification prompt using Vertex AI Evaluation Service.
Use case: given a Product Description find the most relevant Product Category from a predefined list of categories.
Metric: this eval uses a single deterministic metric "Accuracy" calculated by comparing model responses with ground truth labels.
Labeled evaluation dataset dataset.jsonl is based MAVE dataset from Google Research. It includes 6 records that represent products from different categories. Each record includes two attributes:
product
: product name and descriptionreference
: correct product category name which serves as the ground truth label
Prompt template is a zero-shot prompt located in prompt_template.txt with just one prompt variable
product
that maps to theproduct
attribute in the dataset.
Step 1 of 4: Configure eval settings
%%writefile .env
PROJECT_ID=your-project-id # Google Cloud Project ID
LOCATION=us-central1 # Region for all required Google Cloud services
EXPERIMENT_NAME=eval-recipe-demo # Creates Vertex AI Experiment to track the eval runs
MODEL_BASELINE=gemini-1.0-pro-002 # Name of your current model
MODEL_CANDIDATE=gemini-2.0-flash-001 # This model will be compared to the baseline model
DATASET_URI="gs://gemini_assets/classification_vertex/dataset.jsonl" # Evaluation dataset in Google Cloud Storage
PROMPT_TEMPLATE_URI=gs://gemini_assets/classification_vertex/prompt_template.txt # Text file in Google Cloud Storage
Step 2 of 4: Install Python libraries
%pip install --upgrade --user --quiet google-cloud-aiplatform[evaluation] plotly python-dotenv
# The error "session crashed" is expected. Please ignore it and proceed to the next cell.
import IPython
IPython.Application.instance().kernel.do_shutdown(True)
Step 3 of 4: Authenticate and initialize Vertex AI
import os
import sys
import vertexai
from dotenv import load_dotenv
from google.cloud import storage
load_dotenv(override=True)
if os.getenv("PROJECT_ID") == "your-project-id":
raise ValueError("Please configure your Google Cloud Project ID in the first cell.")
if "google.colab" in sys.modules:
from google.colab import auth
auth.authenticate_user()
vertexai.init(project=os.getenv('PROJECT_ID'), location=os.getenv('LOCATION'))
Step 4 of 4: Run the eval on both models and compare the Accuracy scores
from datetime import datetime
from IPython.display import clear_output
from vertexai.evaluation import EvalTask, EvalResult, CustomMetric, MetricPromptTemplateExamples
from vertexai.generative_models import GenerativeModel, HarmBlockThreshold, HarmCategory
def case_insensitive_match(record: dict[str, str]) -> dict[str, float]:
response = record["response"].strip().lower()
label = record["reference"].strip().lower()
return {"accuracy": 1.0 if label == response else 0.0}
def load_prompt_template() -> str:
blob = storage.Blob.from_string(os.getenv("PROMPT_TEMPLATE_URI"), storage.Client())
return blob.download_as_string().decode('utf-8')
def run_eval(model: str) -> EvalResult:
timestamp = f"{datetime.now().strftime('%b-%d-%H-%M-%S')}".lower()
return EvalTask(
dataset=os.getenv("DATASET_URI"),
metrics=[CustomMetric(name="accuracy", metric_function=case_insensitive_match)],
experiment=os.getenv('EXPERIMENT_NAME')
).evaluate(
model=GenerativeModel(model),
prompt_template=load_prompt_template(),
experiment_run_name=f"{timestamp}-{model.replace('.', '-')}"
)
baseline = run_eval(os.getenv("MODEL_BASELINE"))
candidate = run_eval(os.getenv("MODEL_CANDIDATE"))
clear_output()
print("Baseline model accuracy:", baseline.summary_metrics["accuracy/mean"])
print("Candidate model accuracy:", candidate.summary_metrics["accuracy/mean"])
This Eval Recipe is intended to be a simple starting point. Please use our documentation to learn about all available metrics and customization options.