# Personalized Product Descriptions with Weaviate and Gemini

Weaviate is an open-source vector database that enables you to build AI-Native applications with Gemini! This notebook has four parts:
1. [Part 1: Connect to Weaviate, Define Schema, and Import Data](#part-1-install-dependencies-and-connect-to-weaviate)

2. [Part 2: Run Vector Search Queries](#part-2-vector-search)

3. [Part 3: Generative Feedback Loops](#part-3-generative-feedback-loops)

4. [Part 4: Personalized Product Descriptions](#part-4-personalization)


In this demo, we will show you how to embed your data, run a semantic search, make a generative call to Gemini and store the output in your vector database, and personalize the description based on the user profile. We are using the Google merch products as our dataset and will generate product descriptions by calling the Gemini API.

# Use Case

We will be working with an e-commerce dataset containing Google merch. We will load the data into the Weaviate vector database and use the semantic search features to retrieve data. Next, we will generate product descriptions and store them back into the database with a vector embedding for retrieval (aka, generative feedback loops). Lastly, we will create a small knowledge graph with uniquely generated product descriptions for the buyer personas Alice and Bob.

### Requirements
1. Weaviate vector database
1. Gemini API key

### Video
**For an awesome walk through of this demo, check out [this](https://youtu.be/WORgeRAAN-4?si=-WvqNkPn8oCmnLGQ&t=1138) presentation from Google Cloud Next!**

[![From RAG to autonomous apps with Weaviate and Gemini on Google Kubernetes Engine](http://i3.ytimg.com/vi/WORgeRAAN-4/hqdefault.jpg)](https://youtu.be/WORgeRAAN-4?si=-WvqNkPn8oCmnLGQ&t=1138)

## Install Dependencies and Libraries

In [None]:
!pip install weaviate-client==4.5.5
!pip install google-generativeai
!pip install requests
!pip install python-dotenv

In [None]:
import weaviate
import weaviate.classes.config as wvcc
from weaviate.embedded import EmbeddedOptions
import weaviate.classes as wvc
from weaviate.classes.config import Property, DataType, ReferenceProperty
from weaviate.util import generate_uuid5
from weaviate.classes.query import QueryReference

import os
from dotenv import load_dotenv
import json
import requests
import PIL
import IPython

from PIL import Image
from io import BytesIO
import google.generativeai as genai

# Convert image links to PIL object
def url_to_pil(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content))

## Part 1: Connect to Weaviate, Define Schema, and Import Data

### Connect to Weaviate
You will need to deploy Weaviate on Kubernetes. Learn how to install the Weaviate helm chart [here](https://weaviate.io/developers/weaviate/installation/kubernetes).

In [None]:
client = weaviate.connect_to_custom(
    http_host=WEAVIATE_HTTP_URL,  # URL for the Weaviate HTTP endpoint
    http_port="80",               # Port number for the Weaviate HTTP endpoint
    http_secure=False,            
    grpc_host=WEAVIATE_GRPC_URL,  # URL for the Weaviate gRPC endpoint
    grpc_port="50051",            # Port number for the Weaviate gRPC endpoint
    grpc_secure=False,            
    auth_credentials=weaviate.auth.AuthApiKey(WEAVIATE_AUTH)  # Authentication credentials
)

### Choose **only one** installation option

Pick one of the three options below to run Weaviate

#### 1. Weaviate Cloud Service

The first option is the [Weaviate Cloud Service](https://console.weaviate.cloud/), you can connect your notebook to a serverless Weaviate to keep the data persistent in the cloud.

In [None]:
load_dotenv()

client = weaviate.connect_to_wcs(
    cluster_url=os.getenv(WCS_DEMO_URL),  # Replace with your WCS URL
    auth_credentials=weaviate.auth.AuthApiKey(os.getenv(WCS_DEMO_RO_KEY)),  # Replace with your WCS key
    headers={"X-PaLM-Api-Key": os.getenv("PALM-API-KEY")},  # Replace with your Gemini API key
)

print(client.is_ready())

#### 2. Weaviate Embedded

The second option is Weaviate embedded. This runs Weaviate inside your notebook. Ideal for quick experimentation.

In [None]:
client = weaviate.WeaviateClient(
    embedded_options=EmbeddedOptions(
        version="1.24.8",
        additional_env_vars={
            "ENABLE_MODULES": "text2vec-palm, generative-palm"
        }),
        additional_headers={
            "X-PaLM-Api-Key": 'PALM-API-KEY' # Replace with your Gemini API key
        }
)

client.connect()

#### 3. Local (Docker)

If you like to run Weaviate yourself, you can download the [Docker files](https://weaviate.io/developers/weaviate/installation/docker-compose) and run it locally on your machine or in the cloud. Make sure to include the Google module in the configurator.

In [None]:
client = weaviate.connect_to_local()

print(client.is_ready())

### Create schema
The schema tells Weaviate how you want to store your data. We will have two collections: Products and Personas. Each collection has metadata (properties) and specifies the embedding and language model.

In [None]:
# This is optional to empty your database
result = client.collections.delete("Products")
print(result)
result = client.collections.delete("Personas")
print(result)
result = client.collections.delete("Personalized")
print(result)

In [None]:
# Products Collection
if not client.collections.exists("Products"):
  collection = client.collections.create(
    name="Products",
    vectorizer_config=wvcc.Configure.Vectorizer.text2vec_palm
    (
        project_id="project-id", # Only required if you're using Vertex AI. Replace with your project id
        api_endpoint="generativelanguage.googleapis.com",
        model_id="embedding-gecko-001" # default model. You can switch to another model if desired
    ),
    generative_config=wvcc.Configure.Generative.palm(
        project_id="project-id", # Only required if you're using Vertex AI. Replace with your project id
        api_endpoint="generativelanguage.googleapis.com",
        model_id="gemini-pro-vision" # You can switch to another model if desired
    ),
    properties=[ # properties for the Products collection
            Property(name="product_id", data_type=DataType.TEXT),
            Property(name="title", data_type=DataType.TEXT),
            Property(name="category", data_type=DataType.TEXT),
            Property(name="link", data_type=DataType.TEXT),
            Property(name="description", data_type=DataType.TEXT),
            Property(name="brand", data_type=DataType.TEXT),
            Property(name="generated_description", data_type=DataType.TEXT),
      ]
  )

# Personas Collection
if not client.collections.exists("Personas"):
  collection = client.collections.create(
    name="Personas",
    vectorizer_config=wvcc.Configure.Vectorizer.text2vec_palm
    (
        project_id="project-id", # Only required if you're using Vertex AI. Replace with your project id
        api_endpoint="generativelanguage.googleapis.com",
        model_id="embedding-gecko-001" # default model. You can switch to another model if desired
    ),
    generative_config=wvcc.Configure.Generative.palm(
        project_id="project-id", # Only required if you're using Vertex AI. Replace with your project id
        api_endpoint="generativelanguage.googleapis.com",
        model_id="gemini-pro-vision" # You can switch to another model if desired
    ),
    properties=[ # properties for the Personas collection
            Property(name="name", data_type=DataType.TEXT),
            Property(name="description", data_type=DataType.TEXT),
      ]
  )

### Import Objects

In [None]:
# URL to the raw JSON file
url = 'https://raw.githubusercontent.com/bkauf/next-store/main/first_99_objects.json'
response = requests.get(url)

# Load the entire JSON content
data = json.loads(response.text)

In [None]:
# Print first object

data[0]

#### Upload to Weaviate
We will use Weaviate's batch import to get the 99 objects into our database

In [None]:
products = client.collections.get("Products")

with products.batch.dynamic() as batch:
  for item in data:
    batch.add_object(
      properties={
        "product_id": item['product_id'],
        "title": item['title'],
        "category": item['category'],
        "link": item['link'],
        "description": item['description'],
        "brand": item['brand']
    }
)

In [None]:
# count how many objects are in the database
products = client.collections.get("Products")
response = products.aggregate.over_all(total_count=True)
print(response.total_count)

In [None]:
# print the objects uuid and properties

for product in products.iterator():
    print(product.uuid, product.properties)

From the printed list above, select one `uuid` and paste it in the below cell.

Note: If you run the cell below without grabbing a `uuid`, it will result in an error.

In [None]:
product = products.query.fetch_object_by_id(
    "87e5a137-d943-4863-90df-7eed6415fd58", # <== paste a new product UUID here after importing
    include_vector=True
)

print(product.properties["title"], product.vector["default"])

## Part 2: Vector Search

### Vector Search
Vector search returns the objects with most similar vectors to that of the query. We will use the `near_text` operator to find objects with the nearest vector to an input text.

In [None]:
products = client.collections.get("Products")

response = products.query.near_text(
        query="travel mug",
        return_properties=["title", "description", "link"], # only return these 3 properties
        limit=3 # limited to 3 objects
)

for product in response.objects:
    print(json.dumps(product.properties, indent=2))

### Hybrid Search
[Hybrid search](https://weaviate.io/developers/weaviate/search/hybrid) combines keyword (BM25) and vector search together, giving you the best of both algorithms.

To use hybrid search in Weaviate, all you have to do is define the `alpha` parameter to determine the weighting.

`alpha` = 0 --> pure BM25

`alpha` = 0.5 --> half BM25, half vector search

`alpha` = 1 --> pure vector search

In [None]:
products = client.collections.get("Products")

response = products.query.hybrid(
    query = "dishwasher safe container", # query
    alpha = 0.75, # leaning more towards vector search
    return_properties=["title", "description", "link"], # return these 3 properties
    limit = 3 # limited to only 3 objects
)

for product in response.objects:
    print(json.dumps(product.properties, indent=2))

### Autocut
Rather than hard-coding the limit on the number of objects (seen above), we can use [autocut](https://weaviate.io/developers/weaviate/api/graphql/additional-operators#autocut) to cut off the result set. Autocut limits the number of results returned based on significant variations in the result set's metrics, such as vector distance or score.


To use autocut, you must specify the `auto_limit` parameter, which will stop returning results after the specified number of variations, or "jumps," is reached.

We will use the same hybrid search query above but use `auto_limit` rather than `limit`. Notice how there are actually 4 objects retrieved in this case, compared to the 3 objects returned in the previous query.

In [None]:
# auto_limit set to 1

products = client.collections.get("Products")

response = products.query.hybrid(
    query = "dishwasher safe container", # query
    alpha = 0.75, # leaning more towards vector search
    return_properties=["title", "description", "link"], # return these 3 properties
    auto_limit = 1 # autocut after 1 jump
)

for product in response.objects:
    print(json.dumps(product.properties, indent=2))

### Filters
We can narrow down our results by adding a filter to the query.

We will look for objects where `category` is equal to `drinkware`.

In [None]:
products = client.collections.get("Products")

response = products.query.near_text(
    query="travel cup",
    return_properties=["title", "description", "category", "link"], # returned properties
    filters=wvc.query.Filter.by_property("category").equal("Drinkware"), # filter
    limit=3, # limit to 3 objects
)

for product in response.objects:
    print(product.properties)
    print('===')

## Part 3: Generative Feedback Loops

[Generative Feedback Loops](https://weaviate.io/blog/generative-feedback-loops-with-llms) refers to the process of storing the output from the language model back to the database.

We will generate a description for each product in our database using Gemini and save it to the `generated_description` property in the `Products` collection.

### Connect and configure Gemini model

In [None]:
genai.configure(api_key='gemini-api-key') # gemini api key

# Multimodal model
model_pro_vision = genai.GenerativeModel(model_name='gemini-pro-vision') # multi-modal model (text and image)

# LLM
model_pro = genai.GenerativeModel(model_name='gemini-pro') # text only model

### Generate a description and store it in the `Products` collection

Steps for the below cell:
1. Run a vector search query to find travel jackets
    1. Learn more about autocut (`auto_limit`) [here](https://weaviate.io/developers/weaviate/api/graphql/additional-operators#autocut).

2. Grab the returned objects, prompt Gemini with the task and image, store the description in the `generated_description` property

In [None]:
response = products.query.near_text( # first find travel jackets
    query="travel jacket",
    return_properties=["title", "description", "category", "link"],
    auto_limit=1, # limit it to 1 close group
)

for product in response.objects:
    if "link" in product.properties:
        id = product.uuid
        img_url = product.properties["link"]

        pil_image = url_to_pil(img_url) # convert image to PIL object
        generated_description = model_pro_vision.generate_content(["Write a short Facebook ad about this product photo.", pil_image]) # prompt to Gemini
        generated_description = generated_description.text
        print(img_url)
        print(generated_description)
        print('===')

        # Update the Product collection with the generated description
        products.data.update(uuid=id, properties={"generated_description": generated_description})

### Vector Search on the `generated_description` property

Since the product description was saved in our `Products` collection, we can run a vector search query on it.

In [None]:
products = client.collections.get("Products")

response = products.query.near_text(
        query="travel jacket",
        return_properties=["generated_description", "description", "title"],
        limit=1
    )

for o in response.objects:
    print(o.uuid)
    print(json.dumps(o.properties, indent=2))

## Part 4: Personalization

So far, we've generated product descriptions using Gemini's multi-modal model. In Part 4, we will generate product descriptions tailored to the persona.

We will use [cross-references](https://weaviate.io/developers/weaviate/manage-data/cross-references) to establish directional relationships between collections.

In [None]:
# Personalized Collection

if not client.collections.exists("Personalized"):
  collection = client.collections.create(
    name="Personalized",
    vectorizer_config=wvcc.Configure.Vectorizer.text2vec_palm
    (
        project_id="project-id", # Only required if you're using Vertex AI. Replace with your project id
        api_endpoint="generativelanguage.googleapis.com",
        model_id="embedding-gecko-001" # default model. You can switch to another model if desired
    ),
    generative_config=wvcc.Configure.Generative.palm(
        project_id="project-id", # Only required if you're using Vertex AI. Replace with your project id
        api_endpoint="generativelanguage.googleapis.com",
        model_id="gemini-pro-vision" # You can switch to another model if desired
    ),
    properties=[
            Property(name="description", data_type=DataType.TEXT),
    ],
    # cross-references
    references=[
        ReferenceProperty(
            name="ofProduct",
            target_collection="Products" # connect personalized to the products collection
        ),
        ReferenceProperty(
            name="ofPersona",
            target_collection="Personas" # connect personalized to the personas collection
        )
    ]
)

### Create two personas (Alice and Bob)

In [None]:
personas = client.collections.get("Personas")

for persona in ['Alice', 'Bob']:
  generated_description = model_pro.generate_content(["Create a fictional buyer persona named " + persona + ", write a short description about them"]) # use gemini-pro to generate persona description
  uuid = personas.data.insert({
    "name": persona,
    "description": generated_description.text
  })
  print(uuid)
  print(generated_description.text)
  print("===")

In [None]:
# print objects in the Personas collection

personas = client.collections.get("Personas")

for persona in personas.iterator():
    print(persona.uuid, persona.properties)

### Generate a product description tailored to the persona

Grab the product uuid from Part 1 and paste it below

In [None]:
personalized = client.collections.get("Personalized")

product = products.query.fetch_object_by_id("87e5a137-d943-4863-90df-7eed6415fd58")  # <== paste a new product UUID here after importing
print(product.properties['link'])
print('===')

personas = client.collections.get("Personas")

for persona in personas.iterator():
    generated_description = model_pro.generate_content(["Create a product description tailored to the following person, make sure to use the name (", persona.properties["name"],") of the persona.\n\n", "# Product Description\n", product.properties["description"], "# Persona", persona.properties["description"]]) # generate a description tailored to the persona
    print(generated_description.text)
    # Add the personalized description to the `description` property in the Personalized collection
    new_uuid = personalized.data.insert(
        properties={
            "description": generated_description.text },
        references={
            "ofProduct": product.uuid, # add cross-reference to the Product collection
            "ofPersona": persona.uuid # add cross-reference to the Persona collection
        },
    )
    print("New UUID", new_uuid)
    print('===')

### Fetch the objects in the `Personalized` collection

In [None]:
personalized = client.collections.get("Personalized")

response = personalized.query.fetch_objects(
    limit=2,
    include_vector=True,
    return_references=[QueryReference(
            link_on="ofProduct", # return the title property from the Product collection
            return_properties=["title"]
        ),
        QueryReference(
            link_on="ofPersona",
            return_properties=["name"] # return the name property from the Persona collection
        )
    ]
)

for item in response.objects:
    print(item.properties)
    for ref_obj in item.references["ofProduct"].objects:
        print(ref_obj.properties)
    for ref_obj in item.references["ofPersona"].objects:
        print(ref_obj.properties)
    print(item.vector["default"])
    print("===")