Evaluate Text Generation¶
Author(s) | Renato Leite (renatoleite@), Egon Soares (egon@) |
Last updated | 10/22/2023 |
BLEU¶
Explanations¶
Original paper¶
https://dl.acm.org/doi/pdf/10.3115/1073083.1073135
$BLEU = \text{Brevity Penalty}\times(\exp(\sum_{n=1}^{N}w_n\log(\text{modified precision}(n))))$
$N = 4$ - This is the baseline used in the paper
$w_n = 1 / N$ - This is for using uniform weights
$\text{Brevity Penalty} = \begin{cases} 1 & \quad \text{if } c > r\\ e^{(1-r/c)} & \quad \text{if } c \leq r \end{cases}$
$\text{modified precision}(n) = \cfrac{\sum \text{Count Clip}(n)}{\sum \text{Count n-gram}_{candidate}}$
$\text{Count Clip}(n) = min(\text{Count n-gram}_{candidate}, max(\text{Count n-gram}_{reference}))$
Alternative explanation¶
https://cloud.google.com/translate/automl/docs/evaluate#bleu
$\text{BLEU} = \underbrace{\vphantom{\prod_i^4}\min\Big(1, \exp\big(1-\frac{reference_{length}} {candidate_{length}}\big)\Big)}_{\text{brevity penalty}} \underbrace{\Big(\prod_{i=1}^{4} precision_i\Big)^{1/4}}_{\text{n-gram overlap}}$
$\text{Brevity Penalty} = min(1, \exp(1-\cfrac{reference_{length}}{candidate_{length}}))$
$\text{n-gram overlap} = (\displaystyle\prod_{i=1}^{4} precision_i)^\frac{1}{4}$
$precision_i = \dfrac{\sum_{\text{sentence}\in\text{Candidate-Corpus}}\sum_{i\in\text{sentence}}\min(m^i_{candidate}, m^i_{reference})} {w_{total Candidate}^i = \sum_{\text{sentence'}\in\text{Candidate-Corpus}}\sum_{i'\in\text{snt'}} m^{i'}_{candidate}}$
$m_{candidate}^i$: is the count of i-gram in the candidate matching the reference
$m_{reference}^i$: is the count of i-gram in the reference
$w_{totalCandidate}^i$: is the total number of i-grams in the candidate
Brevity Penalty¶
$\text{Brevity Penalty} = \begin{cases} 1 & \quad \text{if } c \geq r\\ e^{(1-r/c)} & \quad \text{if } c < r \end{cases}$
$ c = length_{candidate}$, $r = length_{reference}$
$\text{Brevity Penalty} = min(1, \exp(1-\cfrac{reference_{length}}{candidate_{length}}))$
import math
def calculate_brevity_penalty(reference_len: int, candidate_len: int) -> float:
# Raise an error if any number is negative
if reference_len < 0 or candidate_len < 0:
raise ValueError("Length cannot be negative")
# If the candidate length is greater than the reference length, r/c < 1, exp(positive number) > 1, brevity penalty = 1
if candidate_len > reference_len:
print(f"Candidate length \t ({candidate_len}) \t is greater than the reference length \t ({reference_len}), \t so the Brevity Penalty is equal to \t 1.000")
return 1.0
# If the lengths are equal, then r/c = 1, and exp(0) = 1
if candidate_len == reference_len:
print(f"Candidate length \t ({candidate_len}) \t is equal to the reference length \t ({reference_len}), \t so the Brevity Penalty is equal to \t 1.000")
return 1.0
# If candidate is empty, brevity penalty = 0, because r/0 -> inf and exp(-inf) -> 0
if candidate_len == 0:
print(f"Candidate length \t ({candidate_len}) \t is equal to 0.0, \t\t\t\t so the Brevity Penalty is equal to \t 0.000")
return 0.0
# If the candidate length is less than the reference length, brevity penalty = exp(1-r/c)
print(f"Candidate length \t ({candidate_len}) \t is less than the reference length \t ({reference_len}),\t so the Brevity Penalty is equal to \t {math.exp(1 - reference_len / candidate_len):.3f}")
return math.exp(1 - reference_len / candidate_len)
def calculate_brevity_penalty_2(reference_len: int, candidate_len: int) -> float:
# Raise an error if any number is negative
if reference_len < 0 or candidate_len < 0:
raise ValueError("Length cannot be negative")
# Avoid a division by 0
if candidate_len == 0:
if reference_len == 0:
return 1.0
else:
return 0.0
return min(1.0, math.exp(1 - reference_len / (candidate_len)))
candidates = ["It is a guide to action which ensures that the military always obeys the commands of the party.",
"It is to insure the troops forever hearing the activity guidebook that party direct.",
""]
references = ["It is a guide to action that ensures that the military will forever heed Party commands.",
"It is the guiding principle which guarantees the military forces always being under the command of the Party.",
"It is the practical guide for the army always to heed the directions of the party."]
from itertools import product
bp1 = [calculate_brevity_penalty(len(reference), len(candidate)) for reference, candidate in product(references, candidates)]
Candidate length (95) is greater than the reference length (88), so the Brevity Penalty is equal to 1.000 Candidate length (84) is less than the reference length (88), so the Brevity Penalty is equal to 0.953 Candidate length (0) is equal to 0.0, so the Brevity Penalty is equal to 0.000 Candidate length (95) is less than the reference length (109), so the Brevity Penalty is equal to 0.863 Candidate length (84) is less than the reference length (109), so the Brevity Penalty is equal to 0.743 Candidate length (0) is equal to 0.0, so the Brevity Penalty is equal to 0.000 Candidate length (95) is greater than the reference length (82), so the Brevity Penalty is equal to 1.000 Candidate length (84) is greater than the reference length (82), so the Brevity Penalty is equal to 1.000 Candidate length (0) is equal to 0.0, so the Brevity Penalty is equal to 0.000
bp_2 = [calculate_brevity_penalty_2(len(reference), len(candidate)) for reference, candidate in product(references, candidates)]
bp1 == bp_2
True
Precision¶
$\text{modified precision}(n) = \cfrac{\sum \text{Count Clip}(n)}{\sum \text{Count n-gram}_{candidate}}$
$\text{Count Clip}(n) = min(\text{Count n-gram}_{candidate}, max(\text{Count n-gram}_{reference}))$
from collections import Counter
from fractions import Fraction
from itertools import tee
def ngrams(sequence, n):
# Creates the sliding window, of n no. of items.
# `iterables` is a tuple of iterables where each iterable is a window of n items.
iterables = tee(iter(sequence), n)
for i, sub_iterable in enumerate(iterables): # For each window,
for _ in range(i): # iterate through every order of ngrams
next(sub_iterable, None) # generate the ngrams within the window.
return zip(*iterables) # Unpack and flattens the iterables.
def count_clip(counts: Counter, max_counts: dict) -> dict:
clipped_counts = {}
for ngram, count in counts.items():
clipped_count = min(count, max_counts[ngram])
clipped_counts[ngram] = clipped_count
return clipped_counts
def calculate_modified_precision(references, candidate, n):
candidate = candidate.split()
candidate_counts = Counter(ngrams(candidate, n)) if len(candidate) >= n else Counter()
max_counts = {}
for ref in references:
reference = ref.split()
reference_counts = (
Counter(ngrams(reference, n)) if len(reference) >= n else Counter()
)
for ngram in candidate_counts:
max_counts[ngram] = max(max_counts.get(ngram, 0), reference_counts[ngram])
clipped_counts = count_clip(candidate_counts, max_counts)
numerator = sum(clipped_counts.values())
# Ensures that denominator is minimum 1 to avoid ZeroDivisionError.
denominator = max(1, sum(candidate_counts.values()))
return Fraction(numerator, denominator, _normalize=False)
print("References\n")
_ = [print(reference) for reference in references]
References It is a guide to action that ensures that the military will forever heed Party commands. It is the guiding principle which guarantees the military forces always being under the command of the Party. It is the practical guide for the army always to heed the directions of the party.
print("Candidates\n")
_ = [print(f"Candidate {i} is '{candidate}'") for i, candidate in enumerate(candidates)]
Candidates Candidate 0 is 'It is a guide to action which ensures that the military always obeys the commands of the party.' Candidate 1 is 'It is to insure the troops forever hearing the activity guidebook that party direct.' Candidate 2 is ''
[f"The {j+1}-gram modified precision for candidate {i} is {calculate_modified_precision(references, candidate, j+1)}" for i, candidate in enumerate(candidates) for j in range(4)]
['The 1-gram modified precision for candidate 0 is 16/18', 'The 2-gram modified precision for candidate 0 is 10/17', 'The 3-gram modified precision for candidate 0 is 7/16', 'The 4-gram modified precision for candidate 0 is 4/15', 'The 1-gram modified precision for candidate 1 is 7/14', 'The 2-gram modified precision for candidate 1 is 1/13', 'The 3-gram modified precision for candidate 1 is 0/12', 'The 4-gram modified precision for candidate 1 is 0/11', 'The 1-gram modified precision for candidate 2 is 0', 'The 2-gram modified precision for candidate 2 is 0', 'The 3-gram modified precision for candidate 2 is 0', 'The 4-gram modified precision for candidate 2 is 0']
n-gram overlap¶
$\text{n-gram overlap} = \exp(\sum_{n=1}^{N}w_n\log(\text{modified precision}(n)))$
def calculate_n_gram_overlap(references, candidate, weights=(0.25, 0.25, 0.25, 0.25)):
# compute modified precision for 1-4 ngrams
modified_precision_numerators = Counter()
modified_precision_denominators = Counter()
candidate_lengths, reference_lengths = 0, 0
for i, _ in enumerate(weights, start=1):
modified_precision_i = calculate_modified_precision(references, candidate, i)
modified_precision_numerators[i] += modified_precision_i.numerator
modified_precision_denominators[i] += modified_precision_i.denominator
# remove zero precision
modified_precision_n = [
Fraction(modified_precision_numerators[i], modified_precision_denominators[i],
_normalize=False)
for i, _ in enumerate(weights, start=1)
if modified_precision_numerators[i] > 0
]
weighted_precisions = (weight_i * math.log(precision_i) for weight_i, precision_i in zip(weights, modified_precision_n))
precisions_sum = math.fsum(weighted_precisions)
return math.exp(precisions_sum)
def bleu(references, candidate, weights=(0.25, 0.25, 0.25, 0.25)):
candidate_len = len(candidate.split())
references_lens = (len(reference.split()) for reference in references)
# Reference length closest to the candidate length
closest_reference_len = min(
references_lens, key=lambda reference_len: (abs(reference_len - candidate_len), reference_len)
)
brevity_penalty = calculate_brevity_penalty_2(closest_reference_len, candidate_len)
n_gram_overlap = calculate_n_gram_overlap(references, candidate, weights)
return brevity_penalty * n_gram_overlap
BLEU¶
$BLEU = \text{Brevity Penalty}\times\text{n-gram overlap}$
def bleu(references, candidate, weights=(0.25, 0.25, 0.25, 0.25)):
candidate_len = len(candidate.split())
references_lens = (len(reference.split()) for reference in references)
# Reference length closest to the candidate length
closest_reference_len = min(
references_lens, key=lambda reference_len: (abs(reference_len - candidate_len), reference_len)
)
brevity_penalty = calculate_brevity_penalty_2(closest_reference_len, candidate_len)
n_gram_overlap = calculate_n_gram_overlap(references, candidate, weights)
return brevity_penalty * n_gram_overlap
bleu(references, candidates[0])
0.4969770530031034
NLTK Implementation¶
!pip install -U nltk
Requirement already satisfied: nltk in ./venv/lib/python3.9/site-packages (3.8.1) Collecting nltk Using cached nltk-3.8.1-py3-none-any.whl (1.5 MB) Using cached nltk-3.8-py3-none-any.whl (1.5 MB) Requirement already satisfied: click in ./venv/lib/python3.9/site-packages (from nltk) (8.1.7) Requirement already satisfied: tqdm in ./venv/lib/python3.9/site-packages (from nltk) (4.66.1) Requirement already satisfied: joblib in ./venv/lib/python3.9/site-packages (from nltk) (1.3.2) Requirement already satisfied: regex>=2021.8.3 in ./venv/lib/python3.9/site-packages (from nltk) (2023.8.8)
from nltk.translate.bleu_score import sentence_bleu
nltk_bleu_score = sentence_bleu([reference.split() for reference in references], candidates[0].split())
print(nltk_bleu_score)
0.4969770530031034
ROUGE-L¶
See Theory_Evaluate_2_Summarization.ipynb