21. Statistical Divergence Measures#

21.1. Overview#

A statistical divergence quantifies discrepancies between two distinct probability distributions that can be challenging to distinguish for the following reason:

  • every event that has positive probability under one of the distributions also has positive probability under the other distribution

  • this means that there is no “smoking gun” event whose occurrence tells a statistician that one of the probability distributions surely governs the data

A statistical divergence is a function that maps two probability distributions into a nonnegative real number.

Statistical divergence functions play important roles in statistics, information theory, and what many people now call “machine learning”.

This lecture describes three divergence measures:

  • Kullback–Leibler (KL) divergence

  • Jensen–Shannon (JS) divergence

  • Chernoff entropy

These will appear in several quantecon lectures.

Let’s start by importing the necessary Python tools.

import matplotlib.pyplot as plt
import numpy as np
from numba import vectorize, jit
from math import gamma
from scipy.integrate import quad
from scipy.optimize import minimize_scalar
import pandas as pd
from IPython.display import display, Math

21.2. Primer on entropy, cross-entropy, KL divergence#

Before diving in, we’ll introduce some useful concepts in a simple setting.

We’ll temporarily assume that f and g are two probability mass functions for discrete random variables on state space I={1,2,,n} that satisfy fi0,ifi=1,gi0,igi=1.

We follow some statisticians and information theorists who define the surprise or surprisal associated with having observed a single draw x=i from distribution f as

log(1fi)

They then define the information that you can anticipate to gather from observing a single realization as the expected surprisal

H(f)=ifilog(1fi).

Claude Shannon [Shannon, 1948] called H(f) the entropy of distribution f.

Note

By maximizing H(f) with respect to {f1,f2,,fn} subject to ifi=1, we can verify that the distribution that maximizes entropy is the uniform distribution fi=1n. Entropy H(f) for the uniform distribution evidently equals log(n).

Kullback and Leibler [Kullback and Leibler, 1951] define the amount of information that a single draw of x provides for distinguishing f from g as the log likelihood ratio

logf(x)g(x)

The following two concepts are widely used to compare two distributions f and g.

Cross-Entropy:

(21.1)#H(f,g)=ifiloggi

Kullback-Leibler (KL) Divergence:

(21.2)#DKL(fg)=ifilog[figi]

These concepts are related by the following equality.

(21.3)#DKL(fg)=H(f,g)H(f)

To prove (21.3), note that

(21.4)#DKL(fg)=ifilog[figi]=ifi[logfiloggi]=ifilogfiifiloggi=H(f)+H(f,g)=H(f,g)H(f)

Remember that H(f) is the anticipated surprisal from drawing x from f.

Then the above equation tells us that the KL divergence is an anticipated “excess surprise” that comes from anticipating that x is drawn from f when it is actually drawn from g.

21.3. Two Beta distributions: running example#

We’ll use Beta distributions extensively to illustrate concepts.

The Beta distribution is particularly convenient as it’s defined on [0,1] and exhibits diverse shapes by appropriately choosing its two parameters.

The density of a Beta distribution with parameters a and b is given by

f(z;a,b)=Γ(a+b)za1(1z)b1Γ(a)Γ(b)whereΓ(p):=0xp1exdx

Let’s define parameters and density functions in Python

# Parameters in the two Beta distributions
F_a, F_b = 1, 1
G_a, G_b = 3, 1.2

@vectorize
def p(x, a, b):
    r = gamma(a + b) / (gamma(a) * gamma(b))
    return r * x** (a-1) * (1 - x) ** (b-1)

# The two density functions
f = jit(lambda x: p(x, F_a, F_b))
g = jit(lambda x: p(x, G_a, G_b))

# Plot the distributions
x_range = np.linspace(0.001, 0.999, 1000)
f_vals = [f(x) for x in x_range]
g_vals = [g(x) for x in x_range]

plt.figure(figsize=(10, 6))
plt.plot(x_range, f_vals, 'b-', linewidth=2, label=r'$f(x) \sim \text{Beta}(1,1)$')
plt.plot(x_range, g_vals, 'r-', linewidth=2, label=r'$g(x) \sim \text{Beta}(3,1.2)$')

# Fill overlap region
overlap = np.minimum(f_vals, g_vals)
plt.fill_between(x_range, 0, overlap, alpha=0.3, color='purple', label='overlap')

plt.xlabel('x')
plt.ylabel('density')
plt.legend()
plt.show()
_images/1b7a2635bc1506b0f99bf9d64e11ceeef4c913a6a8d9a42f86517e8dc1e51f2f.png

21.4. Kullback–Leibler divergence#

Our first divergence function is the Kullback–Leibler (KL) divergence.

For probability densities (or pmfs) f and g it is defined by

DKL(fg)=KL(f,g)=f(x)logf(x)g(x)dx.

We can interpret DKL(fg) as the expected excess log loss (expected excess surprisal) incurred when we use g while the data are generated by f.

It has several important properties:

  • Non-negativity (Gibbs’ inequality): DKL(fg)0 with equality if and only if f=g almost everywhere.

  • Asymmetry: DKL(fg)DKL(gf) in general (hence it is not a metric)

  • Information decomposition: DKL(fg)=H(f,g)H(f), where H(f,g) is the cross entropy and H(f) is the Shannon entropy of f.

  • Chain rule: For joint distributions f(x,y) and g(x,y), DKL(f(x,y)g(x,y))=DKL(f(x)g(x))+Ef[DKL(f(y|x)g(y|x))]

KL divergence plays a central role in statistical inference, including model selection and hypothesis testing.

Likelihood Ratio Processes describes a link between KL divergence and the expected log likelihood ratio, and the lecture A Problem that Stumped Milton Friedman connects it to the test performance of the sequential probability ratio test.

Let’s compute the KL divergence between our example distributions f and g.

def compute_KL(f, g):
    """
    Compute KL divergence KL(f, g) via numerical integration
    """
    def integrand(w):
        fw = f(w)
        gw = g(w)
        return fw * np.log(fw / gw)
    val, _ = quad(integrand, 1e-5, 1-1e-5)
    return val

# Compute KL divergences between our example distributions
kl_fg = compute_KL(f, g)
kl_gf = compute_KL(g, f)

print(f"KL(f, g) = {kl_fg:.4f}")
print(f"KL(g, f) = {kl_gf:.4f}")
KL(f, g) = 0.7590
KL(g, f) = 0.3436

The asymmetry of KL divergence has important practical implications.

DKL(fg) penalizes regions where f>0 but g is close to zero, reflecting the cost of using g to model f and vice versa.

21.5. Jensen-Shannon divergence#

Sometimes we want a symmetric measure of divergence that captures the difference between two distributions without favoring one over the other.

This often arises in applications like clustering, where we want to compare distributions without assuming one is the true model.

The Jensen-Shannon (JS) divergence symmetrizes KL divergence by comparing both distributions to their mixture:

JS(f,g)=12DKL(fm)+12DKL(gm),m=12(f+g).

where m is a mixture distribution that averages f and g

Let’s also visualize the mixture distribution m:

def m(x):
    return 0.5 * (f(x) + g(x))

m_vals = [m(x) for x in x_range]

plt.figure(figsize=(10, 6))
plt.plot(x_range, f_vals, 'b-', linewidth=2, label=r'$f(x)$')
plt.plot(x_range, g_vals, 'r-', linewidth=2, label=r'$g(x)$')
plt.plot(x_range, m_vals, 'g--', linewidth=2, label=r'$m(x) = \frac{1}{2}(f(x) + g(x))$')

plt.xlabel('x')
plt.ylabel('density')
plt.legend()
plt.show()
_images/f08ac5d6d2dfa586f05c1fe5a4408f7251bea2af51224ad9f9c676120ce1ed19.png

The JS divergence has several useful properties:

  • Symmetry: JS(f,g)=JS(g,f).

  • Boundedness: 0JS(f,g)log2.

  • Its square root JS is a metric (Jensen–Shannon distance) on the space of probability distributions.

  • JS divergence equals the mutual information between a binary random variable ZBernoulli(1/2) indicating the source and a sample X drawn from f if Z=0 or from g if Z=1.

The Jensen–Shannon divergence plays a key role in the optimization of certain generative models, as it is bounded, symmetric, and smoother than KL divergence, often providing more stable gradients for training.

Let’s compute the JS divergence between our example distributions f and g

def compute_JS(f, g):
    """Compute Jensen-Shannon divergence."""
    def m(w):
        return 0.5 * (f(w) + g(w))
    js_div = 0.5 * compute_KL(f, m) + 0.5 * compute_KL(g, m)
    return js_div

js_div = compute_JS(f, g)
print(f"Jensen-Shannon divergence JS(f,g) = {js_div:.4f}")
Jensen-Shannon divergence JS(f,g) = 0.0984

We can easily generalize to more than two distributions using the generalized Jensen-Shannon divergence with weights α=(αi)i=1n:

JSα(f1,,fn)=H(i=1nαifi)i=1nαiH(fi)

where:

  • αi0 and i=1nαi=1, and

  • H(f)=f(x)logf(x)dx is the Shannon entropy of distribution f

21.6. Chernoff entropy#

Chernoff entropy originates from early applications of the theory of large deviations, which refines central limit approximations by providing exponential decay rates for rare events.

For densities f and g the Chernoff entropy is

C(f,g)=logminϕ(0,1)fϕ(x)g1ϕ(x)dx.

Remarks:

  • The inner integral is the Chernoff coefficient.

  • At ϕ=1/2 it becomes the Bhattacharyya coefficient fg.

  • In binary hypothesis testing with T iid observations, the optimal error probability decays as eC(f,g)T.

We will see an example of the third point in the lecture Likelihood Ratio Processes, where we study the Chernoff entropy in the context of model selection.

Let’s compute the Chernoff entropy between our example distributions f and g.

def chernoff_integrand(ϕ, f, g):
    """Integral entering Chernoff entropy for a given ϕ."""
    def integrand(w):
        return f(w)**ϕ * g(w)**(1-ϕ)
    result, _ = quad(integrand, 1e-5, 1-1e-5)
    return result

def compute_chernoff_entropy(f, g):
    """Compute Chernoff entropy C(f,g)."""
    def objective(ϕ):
        return chernoff_integrand(ϕ, f, g)
    result = minimize_scalar(objective, bounds=(1e-5, 1-1e-5), method='bounded')
    min_value = result.fun
    ϕ_optimal = result.x
    chernoff_entropy = -np.log(min_value)
    return chernoff_entropy, ϕ_optimal

C_fg, ϕ_optimal = compute_chernoff_entropy(f, g)
print(f"Chernoff entropy C(f,g) = {C_fg:.4f}")
print(f"Optimal ϕ = {ϕ_optimal:.4f}")
Chernoff entropy C(f,g) = 0.1212
Optimal ϕ = 0.5969

21.7. Comparing divergence measures#

We now compare these measures across several pairs of Beta distributions

Hide code cell source

distribution_pairs = [
    # (f_params, g_params)
    ((1, 1), (0.1, 0.2)),
    ((1, 1), (0.3, 0.3)),
    ((1, 1), (0.3, 0.4)),
    ((1, 1), (0.5, 0.5)),
    ((1, 1), (0.7, 0.6)),
    ((1, 1), (0.9, 0.8)),
    ((1, 1), (1.1, 1.05)),
    ((1, 1), (1.2, 1.1)),
    ((1, 1), (1.5, 1.2)),
    ((1, 1), (2, 1.5)),
    ((1, 1), (2.5, 1.8)),
    ((1, 1), (3, 1.2)),
    ((1, 1), (4, 1)),
    ((1, 1), (5, 1))
]

# Create comparison table
results = []
for i, ((f_a, f_b), (g_a, g_b)) in enumerate(distribution_pairs):
    f = jit(lambda x, a=f_a, b=f_b: p(x, a, b))
    g = jit(lambda x, a=g_a, b=g_b: p(x, a, b))
    kl_fg = compute_KL(f, g)
    kl_gf = compute_KL(g, f)
    js_div = compute_JS(f, g)
    chernoff_ent, _ = compute_chernoff_entropy(f, g)
    results.append({
        'Pair (f, g)': f"\\text{{Beta}}({f_a},{f_b}), \\text{{Beta}}({g_a},{g_b})",
        'KL(f, g)': f"{kl_fg:.4f}",
        'KL(g, f)': f"{kl_gf:.4f}",
        'JS': f"{js_div:.4f}",
        'C': f"{chernoff_ent:.4f}"
    })

df = pd.DataFrame(results)
# Sort by JS divergence
df['JS_numeric'] = df['JS'].astype(float)
df = df.sort_values('JS_numeric').drop('JS_numeric', axis=1)

columns = ' & '.join([f'\\text{{{col}}}' for col in df.columns])
rows = ' \\\\\n'.join(
    [' & '.join([f'{val}' for val in row]) 
     for row in df.values])

latex_code = rf"""
\begin{{array}}{{lcccc}}
{columns} \\
\hline
{rows}
\end{{array}}
"""

display(Math(latex_code))
Pair (f, g)KL(f, g)KL(g, f)JSCBeta(1,1),Beta(1.1,1.05)0.00280.00260.00070.0007Beta(1,1),Beta(1.2,1.1)0.01050.00920.00240.0025Beta(1,1),Beta(0.9,0.8)0.01430.01660.00380.0039Beta(1,1),Beta(1.5,1.2)0.05890.04370.01210.0126Beta(1,1),Beta(0.7,0.6)0.06730.09240.01860.0201Beta(1,1),Beta(2,1.5)0.17810.10810.03090.0339Beta(1,1),Beta(0.5,0.5)0.14480.21900.04000.0461Beta(1,1),Beta(2.5,1.8)0.33230.17310.05020.0577Beta(1,1),Beta(0.3,0.4)0.33170.55720.08690.1203Beta(1,1),Beta(3,1.2)0.75900.34360.09840.1212Beta(1,1),Beta(0.3,0.3)0.39350.65160.10080.1456Beta(1,1),Beta(4,1)1.61340.63620.17330.2341Beta(1,1),Beta(0.1,0.2)0.98111.00360.17830.4556Beta(1,1),Beta(5,1)2.39010.80940.21620.3128

We can clearly see co-movement across the divergence measures as we vary the parameters of the Beta distributions.

Next we visualize relationships among KL, JS, and Chernoff entropy.

kl_fg_values = [float(result['KL(f, g)']) for result in results]
js_values = [float(result['JS']) for result in results]
chernoff_values = [float(result['C']) for result in results]

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

axes[0].scatter(kl_fg_values, js_values, alpha=0.7, s=60)
axes[0].set_xlabel('KL divergence KL(f, g)')
axes[0].set_ylabel('JS divergence')
axes[0].set_title('JS divergence vs KL divergence')

axes[1].scatter(js_values, chernoff_values, alpha=0.7, s=60)
axes[1].set_xlabel('JS divergence')
axes[1].set_ylabel('Chernoff entropy')
axes[1].set_title('Chernoff entropy vs JS divergence')

plt.tight_layout()
plt.show()
_images/68ccffc9034be33e5f2da0b1d581d978347c81fad7fc11f7771d5113d3985eef.png

We now generate plots illustrating how overlap visually diminishes as divergence measures increase.

param_grid = [
    ((1, 1), (1, 1)),   
    ((1, 1), (1.5, 1.2)),
    ((1, 1), (2, 1.5)),  
    ((1, 1), (3, 1.2)),  
    ((1, 1), (5, 1)),
    ((1, 1), (0.3, 0.3))
]

Hide code cell source

def plot_dist_diff(para_grid):
    """Plot overlap of selected Beta distribution pairs."""

    fig, axes = plt.subplots(3, 2, figsize=(15, 12))
    divergence_data = []
    for i, ((f_a, f_b), (g_a, g_b)) in enumerate(param_grid):
        row, col = divmod(i, 2)
        f = jit(lambda x, a=f_a, b=f_b: p(x, a, b))
        g = jit(lambda x, a=g_a, b=g_b: p(x, a, b))
        kl_fg = compute_KL(f, g)
        js_div = compute_JS(f, g)
        chernoff_ent, _ = compute_chernoff_entropy(f, g)
        divergence_data.append({
            'f_params': (f_a, f_b),
            'g_params': (g_a, g_b),
            'kl_fg': kl_fg,
            'js_div': js_div,
            'chernoff': chernoff_ent
        })
        x_range = np.linspace(0, 1, 200)
        f_vals = [f(x) for x in x_range]
        g_vals = [g(x) for x in x_range]
        axes[row, col].plot(x_range, f_vals, 'b-', 
                        linewidth=2, label=f'f ~ Beta({f_a},{f_b})')
        axes[row, col].plot(x_range, g_vals, 'r-', 
                        linewidth=2, label=f'g ~ Beta({g_a},{g_b})')
        overlap = np.minimum(f_vals, g_vals)
        axes[row, col].fill_between(x_range, 0, 
                        overlap, alpha=0.3, color='purple', label='overlap')
        axes[row, col].set_title(
            f'KL(f,g)={kl_fg:.3f}, JS={js_div:.3f}, C={chernoff_ent:.3f}', 
            fontsize=12)
        axes[row, col].legend(fontsize=12)
    plt.tight_layout()
    plt.show()
    return divergence_data

divergence_data = plot_dist_diff(param_grid)
_images/714a85f88ee03aa2a10ce643e42596617f133777f9cdd1ec4976891daa394f64.png

21.8. KL divergence and maximum-likelihood estimation#

Given a sample of n observations X={x1,x2,,xn}, the empirical distribution is

pe(x)=1ni=1nδ(xxi)

where δ(xxi) is the Dirac delta function centered at xi:

δ(xxi)={+if x=xi0if xxi
  • Discrete probability measure: Assigns probability 1n to each observed data point

  • Empirical expectation: Xpe=1ni=1nxi=μ¯

  • Support: Only on the observed data points {x1,x2,,xn}

The KL divergence from the empirical distribution pe to a parametric model pθ(x) is:

DKL(pepθ)=pe(x)logpe(x)pθ(x)dx

Using the mathematics of the Dirac delta function, it follows that

DKL(pepθ)=i=1n1nlog(1n)pθ(xi)
=1ni=1nlog1n1ni=1nlogpθ(xi)
=logn1ni=1nlogpθ(xi)

Since the log-likelihood function for parameter θ is:

(θ;X)=i=1nlogpθ(xi),

it follows that maximum likelihood chooses parameters to minimize

DKL(pepθ)

Thus, MLE is equivalent to minimizing the KL divergence from the empirical distribution to the statistical model pθ.