Copyright 2024 Google LLC.¶
#@title Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Author(s) | Emmanuel Awa, Dennis Kashkin |
Reviewer(s) | Skander Hannachi, Rajesh Thallam |
Last updated | 2024 12 11: Gemini 2.0 Flash Experimental Release |
2024 12 11: Initial Publication |
Explore Object Detection with Gemini 2.0 in Vertex AI¶
This notebook showcases the power of Gemini 2.0 for object detection and spatial understanding using Vertex AI.
Using the embedded app, you'll discover how to leverage Gemini to accurately identify and locate objects in images, similar to the example shown below.
Feel free to explore different prompt styles to achieve the desired results. You can start with the pre-defined prompts and image provided, or personalize your experience by uploading your own images and switching to 'Custom Prompt' to craft your own.
IMPORTANT NOTICE: This notebook only showcases functionality available in model name gemini-2.0-flash-exp
Environment Setup: GCP and Libraries¶
Install Packages and Restart Runtime (If needed)¶
To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel. It confirms if the right packages are already installed and restarts the runtime if needed.
import sys
import importlib.metadata
import time
import IPython
def check_package_version(package_name, required_version):
try:
installed_version = importlib.metadata.version(package_name)
if installed_version != required_version:
print(f"Warning: {package_name} {required_version} "
f"required, but {installed_version} is installed.")
return False # Indicate version mismatch
return True # Indicate correct version
except importlib.metadata.PackageNotFoundError:
print(f"Warning: {package_name} is not installed.")
return False # Indicate package not found
# List of packages and their required versions
packages_to_check = {
'google-cloud-aiplatform': '1.74.0', # Replace with your desired version
'gradio': '5.8.0', # Replace with your desired version
# Add more packages and versions as needed
}
# Check if any required package is missing or has a version mismatch
restart_required = False
for package_name, required_version in packages_to_check.items():
if not check_package_version(package_name, required_version):
restart_required = True
print(f"Installing {package_name}=={required_version}")
!pip install {package_name}=={required_version} --quiet --user
# Restart the kernel if necessary
if restart_required:
print("Restarting kernel...")
time.sleep(5) # Add time for the environment to update
app = IPython.Application.instance()
app.kernel.do_shutdown(True)
Before running the notebook, you will need to provide the following:
- GCP PROJECT_ID: Your Google Cloud Project ID.
- LOCATION: The region for your Vertex AI resources (e.g., 'us-central1').
Make sure you have a Google Cloud Project with billing enabled before proceeding. You can create a new project or use an existing one.
You will be prompted to enter these values in a form below.
PROJECT_ID = '[your-project-id-here]' # @param {type: 'string'}
LOCATION = 'us-central1' # @param {type: 'string'}
Authentication¶
If you're using Colab, run the code in the next cell. Follow the popups and authenticate with an account that has access to your Google Cloud project.
If you're running this notebook somewhere besides Colab, make sure your environment has the right Google Cloud access. If that's a new concept to you, consider looking into Application Default Credentials for your local environment and initializing the Google Cloud CLI. In many cases, running gcloud auth application-default login
in a shell on the machine running the notebook kernel is sufficient.
More authentication options are discussed here.
import sys
if 'google.colab' in sys.modules:
from google.colab import auth
auth.authenticate_user()
print('Authenticated')
Import Libraries¶
import base64
import hashlib
import json
import os
import sys
import re
import gradio as gr
from google.colab import auth, userdata
from typing import Optional, Union
from PIL import Image as PILImage
from PIL import ImageDraw, ImageColor, ImageFont, UnidentifiedImageError
import vertexai
from vertexai.generative_models import (GenerativeModel,
HarmBlockThreshold,
HarmCategory,
Part)
MODEL_NAME = 'gemini-2.0-flash-exp'
vertexai.init(project=PROJECT_ID, location=LOCATION)
Leveraging Gemini for Bounding Boxes¶
This section demonstrates how to utilize the power of Google's Gemini model to identify and extract bounding boxes of objects within images. We'll explore prompt engineering techniques to guide Gemini in accurately detecting desired elements, and then visualize the results by overlaying the bounding boxes onto the original image. This showcases Gemini's capabilities for object detection and its potential applications in various computer vision tasks.
Bounding Box Data Class¶
Define a object that represents a bounding box with 4 coordinates in Gemini format where X and Y are on the 0 to 1000 scale
class BoundingBox():
"""Create a BoundingBox in Gemini format (X and Y are on 0..1000 scale)"""
def __init__(self, top: int, left: int, bottom: int, right: int, label: str | None = None):
if None in [top, left, bottom, right]:
raise ValueError(f"BoundingBox requires all coordinates to be set.")
if top < 0 or top > 1000: raise ValueError(f'ymin must be an integer between 0 and 1000')
if left < 0: raise ValueError(f'xmin must be an integer between 0 and 1000')
if bottom < 0: raise ValueError(f'ymax must be an integer between 0 and 1000')
if right < 0: raise ValueError(f'xmax must be an integer between 0 and 1000')
if right <= left: raise ValueError(f'xmax must be greater than xmin (right={right}, left={left})')
if bottom <= top: raise ValueError(f'ymax must be greater than ymin (bottom={bottom}, top={top})')
self.left = left
self.right = right
self.top = top
self.bottom = bottom
self.label = label
signature = label or f'{top}-{left}-{bottom}-{right}'
int_hash = int(hashlib.sha256(signature.encode('utf-8')).hexdigest(), 16) % (10 ** 8)
colors = list(ImageColor.colormap.keys())
colors = [color for color in colors if color != 'grey'] # Reserve grey for borders
stable_color_index = int_hash % len(colors)
self.color = colors[stable_color_index]
print(f'Stable color index: {stable_color_index} based on int_hash={int_hash}: {self.color}, {len(colors)} colors')
def __repr__(self):
"""Return a string representation of the bounding box."""
return f'TLBR[{self.top}, {self.left}, {self.bottom}, {self.right}]: {self.label or ""} #{self.color}'
@staticmethod
def is_numeric(value: str) -> bool:
"""Check if a string is a number."""
return value.strip().lstrip('-').replace('.', '', 1).isdigit()
@staticmethod
def from_markdown(text: str) -> Union['BoundingBox', None]:
"""Create a bounding box from a markdown string."""
if not text:
return None
for line in text.strip().splitlines():
line = line.strip().lstrip('-').strip()
# Extract the numbers from the line after removing brackets and splitting by comma
if '[' in line and ']' in line:
numbers = line.split('[')[1].split(']')[0].split(',')
else:
numbers = line.split(',')
if len(numbers) != 4:
print(f'Skipping response line with {len(numbers)} comma separated parts instead of 4: {text}')
continue
ints = [int(num.strip()) for num in numbers if BoundingBox.is_numeric(num)]
if len(ints) != 4:
print(f'Skipping response line with only {len(numbers)} comma separated numbers instead of 4: {text}')
continue
return BoundingBox(ints[0], ints[1], ints[2], ints[3]) # Using the first bounding box (even if the model returns multiple)
return None
@staticmethod
def from_list_of_ints(array, label: str | None = None) -> 'BoundingBox':
"""Create a bounding box from a list of integers."""
if not isinstance(array, list):
raise ValueError(f'Model returned unexpected JSON structure for bounding box coordinates: {json.dumps(array)}')
for coordinate in array:
if not isinstance(coordinate, int):
raise ValueError(f'Model returned unrecognized JSON bounding box coordinate: {coordinate}')
return BoundingBox(top=array[0], left=array[1], bottom=array[2], right=array[3], label=label)
Utilities - Preprocessing and Postprocessing¶
Some helper functions for preprocessing and postprocessing
def read_image(local_path_to_image_file: str) -> PILImage.Image:
"""
Reads an image from a local file path.
Args:
local_path_to_image_file: Path to the local image file.
Returns:
The image as a PIL Image object.
"""
return PILImage.open(local_path_to_image_file)
def encode_image_for_gemini(local_path_to_image_file: str):
"""
Encodes an image for Gemini.
Args:
local_path_to_image_file: Path to the local image file.
Returns:
The encoded image as a Part object.
"""
encoded_image = base64.b64encode(open(local_path_to_image_file, 'rb').read()).decode('utf-8')
return Part.from_data(data=base64.b64decode(encoded_image), mime_type='image/jpeg')
def strip_json_code_block(text: str) -> str:
"""Strips the ```json code block markers from a string and returns the JSON.
Args:
text: The input string containing the JSON code block.
Returns:
The extracted JSON string.
"""
pattern = r"```json\s*(.*?)\s*```" # Matches ```json ... ``` with optional whitespace
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1).strip()
else:
return text # Return original text if no code block is found
def download_image_from_gcs(gcs_uri: str, local_path: str) -> None:
"""Downloads an image from Google Cloud Storage (GCS) to a local file.
Args:
gcs_uri: The GCS URI of the image to download (e.g., 'gs://bucket-name/image.jpg').
local_path: The local path where the downloaded image will be saved (e.g., './image.jpg').
"""
if not os.path.exists(local_path):
print(f'Local path to image file does not exist: {local_path}')
print(f'Downloading sample image from GCS...')
!gsutil cp "{gcs_uri}" "{local_path}"
else:
print(f'Image already exists at: {local_path}')
Generating Bounding Boxes with Gemini¶
This section focuses on generating bounding boxes using the Gemini model. It involves the following steps:
- Defining the
generate_bounding_boxes
function: This function handles the interaction with the Gemini API, sending the image and prompt and receiving the predicted bounding boxes. - Setting Generation Parameters: We'll define parameters like temperature and top_p to control the model's creativity and diversity.
- Generating Results: We'll use the defined function and parameters to obtain bounding box predictions from Gemini for a given image.
def generate_bounding_boxes(
model_name: str,
system_instrs: str,
user_prompt: str,
local_path_to_image_file: str,
generation_config: Optional[dict] = None,
safety_settings: Optional[dict] = None
) -> list:
"""
Sends a message to Gemini and returns the model's response for bounding boxes.
Args:
model_name: The name of the generative model to use.
system_in: System-level instructions for the model.
user_message: The user's message to send to the model.
local_path_to_image_file: Path to the local image file.
generation_config: (Optional) Configuration for the model's generation process.
safety_settings: (Optional) Safety settings for the model.
Returns:
A list of bounding boxes, where each box is represented as a dictionary
with keys like 'x', 'y', 'width', 'height', and 'label'.
Raises:
ValueError: If the image file path is invalid or the model returns an unexpected response.
"""
if not os.path.exists(local_path_to_image_file):
raise ValueError(f'Local path to image file does not exist: {local_path_to_image_file}')
generation_config = generation_config or {
'temperature': 0.36,
'top_p': 1.0,
'top_k': 40,
'max_output_tokens': 8192,
'candidate_count': 1,
}
safety_settings = safety_settings or {
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
}
model = GenerativeModel(
model_name=model_name,
system_instruction=system_instrs,
generation_config=generation_config,
safety_settings=safety_settings,
)
response = model.generate_content(contents=[user_prompt, encode_image_for_gemini(local_path_to_image_file)], stream=False)
try:
bounding_boxes_list = json.loads(strip_json_code_block(response.text))
except json.JSONDecodeError as e:
raise ValueError(f"Error decoding JSON response from model: {e}")
if not isinstance(bounding_boxes_list, list):
raise ValueError(f'Model returned unexpected JSON structure instead of an array of objects: {json.dumps(bounding_boxes_list)}')
return bounding_boxes_list
Utilities - Visualizing Bounding Boxes with Plotting and Rendering¶
This section introduces a set of utility functions designed to visualize the bounding boxes identified by Gemini. These functions handle tasks such as reading images, plotting bounding boxes with distinct colors and labels, and rendering the final output with overlaid boxes onto the source image. These utilities streamline the process of visualizing object detection results and provide a clear representation of Gemini's capabilities.
def plot_bounding_boxes(
image: PILImage.Image, bounding_boxes: list
) -> PILImage.Image:
"""
Overlays bounding boxes on an image with random colors and optional labels.
BoundingBoxes in Gemini format (X and Y are on 0..1000 scale) are auto scaled to the actual image size.
Args:
image: source image
bounding_boxes: a list with one or more bounding boxes
Returns: a new PIL Image object with bounding boxes
"""
if not image:
raise ValueError("image is required")
if not bounding_boxes:
return image
draw = ImageDraw.Draw(image)
font = ImageFont.load_default(size=15)
for bb in bounding_boxes:
# Convert Gemini coordinates from Gemini scale (0..1000) to absolute pixels
left = int(bb.left / 1000 * image.width)
top = int(bb.top / 1000 * image.height)
right = int(bb.right / 1000 * image.width)
bottom = int(bb.bottom / 1000 * image.height)
if right < left + 8:
right = left + 8 # we cannot fit the border
if bottom < top + 8:
bottom = top + 8
# Border line style: 1 pixel grey + 3 pixels box specific color + 1 pixel grey
draw.rectangle(((left, top), (right, bottom)), outline=ImageColor.colormap['grey'], width=1)
draw.rectangle(((left+1, top+1), (right-1, bottom-1)), outline=bb.color, width=3)
draw.rectangle(((left+4, top+4), (right-4, bottom-4)), outline=ImageColor.colormap['grey'], width=1)
if bb.label:
# Check if the label fits inside of the bounding box
label_left, label_top, label_right, label_bottom = font.getbbox(bb.label)
print(f'label coordinates: label_left={label_left}, label_top={label_top}, label_right={label_right}, label_bottom={label_bottom}')
is_box_wide_enough = (label_right - label_left) + 8 < (right - left)
is_box_tall_enough = (label_bottom - label_top) + 6 < (bottom - top)
print(f'{bb} is_box_wide_enough={is_box_wide_enough}, is_box_tall_enough={is_box_tall_enough}')
if is_box_wide_enough and is_box_tall_enough:
label_offset_x = 7
label_offset_y = 3
else: # Print the label below the bounding box
label_offset_x = 0
label_offset_y = bottom - top
print(f'label_offset_y={label_offset_y}')
draw.text((left + label_offset_x, top + label_offset_y), bb.label, fill=ImageColor.colormap['red'], font=font)
return image
def render_predicted_bounding_boxes(predictions: list, source_image_path: str, result_image_path: str) -> None:
"""Renders predicted bounding boxes onto an image and saves the result.
This function takes a list of predictions from a model,
reads the source image, extracts bounding box information
from the predictions, overlays the boxes onto the image,
and saves the resulting image to the specified path.
Args:
predictions: A list of predictions from the model, expected to
contain bounding box data and optional labels.
source_image_path: The path to the source image file.
result_image_path: The path where the resulting image
with bounding boxes will be saved.
Raises:
ValueError: If the predictions are not in the expected format,
or if required attributes are missing.
Returns:
None. The function saves the resulting image to the
specified path.
"""
image = read_image(source_image_path)
boxes: list[BoundingBox] = []
if not isinstance(predictions, list):
raise ValueError(f'Model returned unexpected JSON structure instead of an array of objects: {json.dumps(predictions)}')
for item in predictions:
if not isinstance(item, dict):
raise ValueError(f'Model returned unexpected array item: {json.dumps(item)}')
if 'box_2d' not in item:
raise ValueError(f'Model returned bounding box with missing attribute "box_2d": {json.dumps(item)}')
label = item['label'] if 'label' in item else None
box_2d = item['box_2d']
if not isinstance(box_2d, list):
raise ValueError(f'Model returned unexpected box_2d value instead of an array: {json.dumps(box_2d)}')
boxes.append(BoundingBox.from_list_of_ints(box_2d, label=label))
class_labels = {box.label for box in boxes}
if len(class_labels) == 1: # no need to display identical labels
for box in boxes:
box.label = None
result = plot_bounding_boxes(image, boxes)
result.save(result_image_path)
Prompt Templates¶
SYSTEM_INSTRUCTIONS = '''
Return bounding boxes as a JSON array with labels. Never return masks or code fencing. Limit to 25 objects.
If an object is present multiple times, name them according to their unique characteristic (colors, size, position, unique characteristics, etc..).
'''
PROMPT_SINGLE_OBJECT = 'Could you display the bounding boxes around the Ferris wheel.'
PROMPT_SINGLE_CLASS = 'Give me the bounding boxes for all the kites in the park.'
PROMPT_MULTIPLE_CLASSES = 'What are the regions defined by the bounding boxes for two types of animals: cats and dogs.'
Testing the Bounding Box Generation¶
Now that we've defined the generate_bounding_boxes
function and set the generation parameters, let's test it with a sample image and prompt. This will help us verify that the model is correctly identifying and returning bounding boxes.
NOTE: The test on the next cell assumes you have uploaded a sample image to the Colab filesystem and updated sample_image_path
with the correct file name below. For the purposes of a seamless experiment, we've uploaded a sample image to GCS.
model_name = MODEL_NAME
GCS_IMAGE_URI = 'gs://public-aaie-genai-samples/gemini_2_0/spatial_understanding/park.jpg'
local_image_path = './park.jpg'
download_image_from_gcs(GCS_IMAGE_URI, local_image_path)
results = generate_bounding_boxes(model_name,
SYSTEM_INSTRUCTIONS,
PROMPT_SINGLE_CLASS,
local_path_to_image_file=local_image_path)
print(results)
Interactive Visualization with Gradio¶
This section integrates bounding box generation into a Gradio application, enabling you to interactively visualize object detection results on uploaded images.
Predefined Prompts:
Start by exploring object detection with our predefined prompts using the provided sample image.
Custom Prompts:
Switch to "Custom Prompt" to unlock the full potential of Gemini. Experiment with your own prompts, such as precisely locating specific objects within an image and retrieving their bounding box information. For example, you can try prompts like "Find the red car" or "Where are the bicycles?". Feel free to upload your own images and tailor your prompts for personalized exploration.
# @title Download Defaut Image for app
GCS_IMAGE_URI = 'gs://public-aaie-genai-samples/gemini_2_0/spatial_understanding/park.jpg'
DEFAULT_IMAGE_PATH = './park.jpg'
download_image_from_gcs(GCS_IMAGE_URI, DEFAULT_IMAGE_PATH)
# @title Image Processor
def process_image(file_name: str, user_prompt: str = PROMPT_SINGLE_CLASS):
"""
Takes an input image uploaded from local disk, uploads it to colab and processes it
"""
try:
current_dir = os.getcwd()
base_name = os.path.basename(file_name)
save_path = os.path.join(current_dir, base_name)
image = PILImage.open(file_name)
image.save(save_path)
message = f'Image saved as {save_path} in the current directory.'
print(message)
try:
results = generate_bounding_boxes(MODEL_NAME, SYSTEM_INSTRUCTIONS, user_prompt=user_prompt, local_path_to_image_file=save_path)
except Exception as e:
error_message = f"Error generating bounding boxes: {e}"
raise gr.Error(error_message)
bb_save_path = os.path.join(current_dir, f'{base_name}_bb.jpg')
render_predicted_bounding_boxes(results, save_path, bb_save_path)
return PILImage.open(bb_save_path)
except FileNotFoundError:
raise gr.Error(f"Error: Image file not found at {file_name}")
except UnidentifiedImageError:
raise gr.Error(f"Error: Could not open or read image file {file_name}")
except Exception as e: # Catch any other unexpected errors
raise gr.Error(f"An unexpected error occurred during processing: {e}")
# @title Main Gradio application
with gr.Blocks(title="BoxIt With Gemini 2.0") as demo:
gr.Markdown('# **BoxIt**')
with gr.Row():
image_display = gr.Image(type='filepath', label='Image', value=DEFAULT_IMAGE_PATH)
with gr.Column():
prompt_type = gr.Radio(
choices=['Predefined Prompts', 'Custom Prompt'],
label='Select Prompt Style',
)
predefined_prompts = [PROMPT_SINGLE_OBJECT, PROMPT_SINGLE_CLASS, PROMPT_MULTIPLE_CLASSES]
prompt_dropdown = gr.Dropdown(
choices=predefined_prompts,
label='Choose a predefined prompt:',
visible=False,
)
custom_prompt = gr.Textbox(
lines=2,
label='Enter your prompt',
visible=False,
)
submit_btn = gr.Button('Find Bounding Boxes')
# bounding_box_output = gr.Textbox(label="Bounding Boxes", visible=False)
original_image = gr.State(DEFAULT_IMAGE_PATH) # Store the *original* uploaded image
def toggle_prompt_input(prompt_type, original_img):
if original_img is not None:
if prompt_type == 'Predefined Prompts':
return gr.update(visible=True), gr.update(visible=False), original_img
elif prompt_type == 'Custom Prompt':
return gr.update(visible=False), gr.update(visible=True), original_img
else: # "Select Prompt Style"
return gr.update(visible=False), gr.update(visible=False), original_img
else:
return gr.update(visible=False), gr.update(visible=False), gr.update()
prompt_type.change(
fn=toggle_prompt_input,
inputs=[prompt_type, original_image],
outputs=[prompt_dropdown, custom_prompt, image_display],
)
def process_and_display(img, prompt_type, selected_prompt, custom_prompt):
if not img:
return gr.update()
if prompt_type == 'Predefined Prompts':
prompt = selected_prompt
elif prompt_type == 'Custom Prompt':
prompt = custom_prompt
else: # "Select Prompt Style" - Do nothing
return img
print(f'Prompt: {prompt}')
print(f'Processing image: {img}')
try:
processed_image = process_image(file_name=img, user_prompt=prompt)
return processed_image
except Exception as e:
print(f'Error processing image: {e}')
return img # Return the original image in case of an error
image_display.upload(lambda x: x, inputs=image_display, outputs=original_image)
submit_btn.click(
fn=process_and_display,
inputs=[original_image, prompt_type, prompt_dropdown, custom_prompt],
outputs=image_display
)
demo.queue()
demo.launch(quiet=True)
Summary¶
This notebook demonstrates how to use Google's Gemini model to identify and extract bounding boxes of elements within an image. It covers the following key aspects:
- Environment Setup: Setting up the Google Cloud Project, installing necessary libraries, and authenticating with Google Cloud.
- Leveraging Gemini for Bounding Boxes: Utilizing the Gemini model to generate bounding boxes for specific objects or classes within an image.
- Prompt Engineering: Defining prompt templates to guide the Gemini model in accurately detecting the desired elements.
- Bounding Box Data Class: Defining a class to represent bounding boxes in the Gemini format, along with methods for parsing and processing them.
- Utilities: Helper functions for reading, encoding, and plotting bounding boxes onto images.
- Visualizing Bounding Boxes: Displaying the predicted bounding boxes overlaid on the source image for clear visualization.
- Interactive Interface: Building a Gradio interface to allow users to upload images, generate bounding boxes, and visualize the results interactively.
The notebook showcases the power of Gemini for object detection tasks and provides a practical example of its potential applications in computer vision. It offers a comprehensive guide to using Gemini, from setup to interactive visualization.