Bayesian BM25 Documentation

API reference for transforms, fusion, scoring, and diagnostics.

Installation

Shell
pip install bayesian-bm25

To use the integrated search scorer (requires bm25s):

Shell
pip install bayesian-bm25[scorer]

Requirements: Python >= 3.10, numpy >= 1.21. The core library has no other dependencies. The [scorer] extra adds bm25s for end-to-end search functionality.

BayesianProbabilityTransform

The core class that converts raw BM25 scores into calibrated relevance probabilities. Implements sigmoid likelihood, composite prior (TF + document length), and Bayesian posterior with optional base rate.

Constructor

BayesianProbabilityTransform(
    alpha=1.0,       # sigmoid steepness
    beta=0.0,        # sigmoid midpoint
    base_rate=None,  # corpus-level relevance rate (0, 1) or None
    prior_fn=None,   # callable or None -- custom prior function
)
ParameterTypeDescription
alphafloatSigmoid steepness. Higher values make the transition sharper. Learnable via fit().
betafloatSigmoid midpoint. Scores below beta map to P < 0.5. Learnable via fit().
base_ratefloat or NoneCorpus-level base rate of relevant documents. When set, decomposes the posterior into three additive log-odds terms. Reduces ECE by 68–77%.
prior_fncallable or NoneCustom prior function (score, tf, doc_len_ratio) -> float. Replaces composite prior when set. prior_free mode overrides.

score_to_probability(score, tf, doc_len_ratio)

Convert raw BM25 scores to calibrated probabilities via the full Bayesian pipeline.

probabilities = transform.score_to_probability(
    score,          # float or ndarray — raw BM25 score(s)
    tf,             # float or ndarray — term frequency
    doc_len_ratio,  # float or ndarray — doc_len / avg_doc_len
)

Returns a float or ndarray of probabilities in $[0, 1]$. Supports vectorized operation for batch processing.

fit(scores, labels, mode, ...)

Learn $\alpha$ and $\beta$ from labeled data via batch gradient descent.

transform.fit(
    scores,                # ndarray (m,) — raw BM25 scores
    labels,                # ndarray (m,) — binary relevance labels
    mode="balanced",       # "balanced" (C1), "prior_aware" (C2), "prior_free" (C3)
    learning_rate=0.01,   # gradient descent step size
    max_iterations=1000,  # maximum iterations
    tfs=None,              # ndarray (m,) — term frequencies (required for C2)
    doc_len_ratios=None,  # ndarray (m,) — document length ratios (required for C2)
)

update(score, label, ...)

Online SGD update from a single observation with EMA-smoothed gradients and Polyak averaging.

transform.update(
    score,                 # float — raw BM25 score
    label,                 # float — binary relevance label (0 or 1)
    learning_rate=0.01,   # SGD step size
    momentum=0.9,          # EMA momentum for gradient smoothing
)

After updating, use transform.averaged_alpha and transform.averaged_beta for Polyak-averaged parameters with better stability.

wand_upper_bound(bm25_score)

Compute a safe Bayesian probability upper bound for WAND pruning.

bound = transform.wand_upper_bound(bm25_upper_bound)

Returns the tightest safe probability upper bound using $p_{max} = 0.9$ (Theorem 6.1.2). Any document's actual probability is guaranteed to be at most this value. Supports base_rate-aware bounds for tighter pruning.

Fusion Functions

log_odds_conjunction(probs, alpha, weights, gating)

The primary fusion function. Combines multiple probability signals in log-odds space with confidence scaling.

log_odds_conjunction(
    probs,          # ndarray — probabilities, last axis = signals
    alpha=None,     # float or "auto" — confidence scaling (None=auto)
    weights=None,   # ndarray — per-signal weights (sum to 1)
    gating="none",  # "none", "relu", "swish", "gelu", or "softplus"
    gating_beta=1.0,  # sharpness for swish/softplus gating
)
ParameterDescription
alpha=NoneAuto: resolves to 0.5 (unweighted) or 0.0 (weighted). Explicit alpha applies in both modes.
weightsPer-signal reliability weights in $[0,1]$, summing to 1. Log-OP formulation.
gating="relu"MAP estimation: zeroes negative logits (Paper 2, Theorem 6.5.3).
gating="swish"Bayes estimation: soft suppression of negative logits (Paper 2, Theorem 6.7.4).
gating="gelu"Gaussian noise model: logit * sigmoid(1.702 * logit) (Paper 2, Theorem 6.8.1).
gating="softplus"Evidence-preserving smooth ReLU: log(1 + exp(beta * logit)) / beta. Never zeroes evidence; suitable for small datasets (Remark 6.5.4). softplus(x) > x for all finite x, so consider lower alpha.
gating_betaSharpness for swish and softplus gating. Default 1.0. beta → ∞ approaches ReLU (Theorem 6.7.6).

balanced_log_odds_fusion(probs, dense_sims, weight, alpha)

Designed for hybrid BM25 + dense retrieval. Min-max normalizes both signals in logit space before combining, preventing heavy-tailed sparse logits from drowning the dense signal.

balanced_log_odds_fusion(
    probs,       # ndarray — Bayesian BM25 probabilities
    dense_sims,  # ndarray — cosine similarities [-1, 1]
    weight=0.5,  # float — sparse weight (dense = 1 - weight)
    alpha=0.5,   # float — confidence scaling
)

prob_and / prob_or / prob_not

prob_and(probs)   # Product rule: P = prod(P_i), in log-space for stability
prob_or(probs)    # Complement rule: P = 1 - prod(1 - P_i), in log-space
prob_not(prob)    # Complement: P = 1 - P_i, with epsilon clamping

These satisfy Boolean algebra laws including De Morgan's: NOT(A AND B) = OR(NOT A, NOT B).

cosine_to_probability(score)

Maps cosine similarity $[-1, 1]$ to probability $(0, 1)$ via $(1 + \text{score}) / 2$ with epsilon clamping (Definition 7.1.2).

LearnableLogOddsWeights

Learns per-signal reliability weights from labeled data via a Hebbian gradient. Starts from Naive Bayes uniform initialization ($w_i = 1/n$).

learner = LearnableLogOddsWeights(
    n_signals=3,      # number of input signals
    alpha=0.0,         # confidence scaling (0.0 for weighted mode)
    base_rate=None,   # corpus-level base rate bias in log-odds space
)

# Batch fit from labeled data
learner.fit(probs, labels, learning_rate=0.1, max_iterations=500)

# Online update from streaming feedback
learner.update(probs, label, learning_rate=0.05, momentum=0.9)

# Inference with Polyak-averaged weights
fused = learner(test_probs, use_averaged=True)

Key properties:

  • weights — current softmax-normalized weights
  • averaged_weights — Polyak-averaged weights for stable inference
  • Hebbian gradient: backprop-free, $O(n)$ per update

AttentionLogOddsWeights

Query-dependent signal weighting via attention mechanism. Replaces static per-signal weights with a linear projection from query features to softmax attention weights (Paper 2, Section 8).

attn = AttentionLogOddsWeights(
    n_signals=2,          # number of input signals
    n_query_features=3,  # dimensionality of query feature vector
    alpha=0.5,            # confidence scaling
    normalize=True,       # per-signal logit normalization
    seed=None,            # random seed for reproducible initialization
    base_rate=None,       # corpus-level base rate bias in log-odds space
)

# Train: probs (m, 2), labels (m,), features (m, 3)
attn.fit(probs, labels, features,
         learning_rate=0.01, max_iterations=500)

# Query-dependent fusion
fused = attn(test_probs, test_features, use_averaged=True)

When normalize=True, applies per-column min-max normalization in logit space before the weighted sum, equalizing signal scales (same scaling as balanced_log_odds_fusion).

Supports exact pruning via compute_upper_bounds() and prune() methods (Theorem 8.7.1). Given candidate probability upper bounds, the pruner safely eliminates documents whose fused probability cannot exceed the current threshold.

MultiHeadAttentionLogOddsWeights

Multi-head attention fusion that creates multiple independent attention heads, averages their log-odds, and applies sigmoid (Remark 8.6, Corollary 8.7.2).

mh = MultiHeadAttentionLogOddsWeights(
    n_heads=4,            # number of independent heads
    n_signals=2,          # number of input signals
    n_query_features=3,  # query feature dimensionality
    alpha=0.5,            # confidence scaling
    normalize=False,      # per-signal logit normalization
)

# Train all heads on the same data (different init -> diversity)
mh.fit(probs, labels, query_features)

# Inference: average log-odds across heads, then sigmoid
fused = mh(test_probs, test_features, use_averaged=True)

# Attention pruning (Theorem 8.7.1)
surviving_idx, fused_probs = mh.prune(
    candidate_probs, query_features, threshold=0.5,
    upper_bound_probs=candidate_upper_bounds,
)

VectorProbabilityTransform

Converts vector distances into calibrated relevance probabilities via a likelihood ratio framework (Paper 3, Theorem 3.1.1): $P(R|d) = \sigma(\log(f_R(d) / f_G(d)) + \text{logit}(P_{base}))$.

fit_background(distances, base_rate)

Estimate the background Gaussian ($\mu_G$, $\sigma_G$) from a corpus distance sample.

vpt = VectorProbabilityTransform.fit_background(
    corpus_distances,    # ndarray -- distances from a corpus sample
    base_rate=0.01,      # corpus-level relevance prior
)

Returns a configured VectorProbabilityTransform instance. The background distribution models how distances look for random (non-relevant) documents.

calibrate(distances, weights, method, ...)

Full calibration pipeline. Uses the input distances as both the density estimation sample and the evaluation points.

probabilities = vpt.calibrate(
    distances,            # float or ndarray -- vector distances
    weights=None,         # ndarray or None -- per-doc weights (e.g. BM25 probs)
    method="auto",        # "auto", "kde", or "gmm"
    bandwidth_factor=2.0, # multiplicative KDE bandwidth factor
    density_prior=None,   # ndarray or None -- external density prior weights
)

Auto-routing selects between KDE and GMM based on gap detection and sample size.

calibrate_with_sample(eval_distances, sample_distances, ...)

Index-aware calibration path. Density $f_R$ is estimated from a local ANN neighborhood sample, while probabilities are produced for a separate evaluation set.

probabilities = vpt.calibrate_with_sample(
    eval_distances,       # float or ndarray -- distances to calibrate
    sample_distances,     # ndarray -- local neighborhood (e.g. IVF probed cells)
    weights=None,         # ndarray or None -- weights for the sample
    method="auto",        # "auto", "kde", or "gmm"
    bandwidth_factor=2.0, # multiplicative KDE bandwidth factor
    density_prior=None,   # ndarray or None -- external density prior weights
)
ParameterDescription
eval_distancesDistances at which to produce output probabilities. Can differ from the sample.
sample_distancesDistances used to estimate $f_R$. Typically from ANN index probed cells.
weightsPer-sample weights (e.g. BM25 probabilities from the sample documents).
density_priorExternal density prior (e.g. from ivf_density_prior()).
When to use which: Use calibrate() when the same distances serve as both the density sample and the evaluation points (e.g. exact search). Use calibrate_with_sample() when the density landscape comes from one set of distances (ANN neighborhood) and you need probabilities at different points (the final candidates).

Density Priors

Standalone helper functions that provide density-based prior weights for informing the vector calibration (Strategy 4.6.2).

from bayesian_bm25 import ivf_density_prior, knn_density_prior

# IVF cell density: denser cells suggest more relevant neighborhoods
prior = ivf_density_prior(
    cell_population=150,   # population of the cell containing the document
    avg_population=100,    # average cell population across the index
)

# k-NN density: closer k-th neighbor suggests denser (more relevant) region
prior = knn_density_prior(
    kth_distance=0.5,        # distance to the k-th neighbor
    global_median_kth=0.8,   # median k-th neighbor distance across corpus
)

Calibrators

Convert raw neural model scores into calibrated probabilities for Bayesian fusion.

PlattCalibrator

Sigmoid calibration: $P = \sigma(a \cdot s + b)$ with BCE gradient descent.

from bayesian_bm25.calibration import PlattCalibrator

platt = PlattCalibrator(a=1.0, b=0.0)
platt.fit(scores, labels, learning_rate=0.01, max_iterations=1000)
calibrated = platt.calibrate(new_scores)  # or platt(new_scores)

IsotonicCalibrator

Non-parametric monotone calibration via Pool Adjacent Violators Algorithm (PAVA). numpy-only, no scipy dependency.

from bayesian_bm25.calibration import IsotonicCalibrator

iso = IsotonicCalibrator()
iso.fit(scores, labels)
calibrated = iso.calibrate(new_scores)  # or iso(new_scores)

Both calibrators produce output in $(0, 1)$ suitable for log_odds_conjunction.

TemporalBayesianTransform

Extends BayesianProbabilityTransform with exponential temporal weighting for non-stationary relevance patterns (Section 12.2 #3).

from bayesian_bm25.probability import TemporalBayesianTransform

transform = TemporalBayesianTransform(
    alpha=1.0,              # sigmoid steepness
    beta=0.0,               # sigmoid midpoint
    decay_half_life=100.0,  # observations until weight halves
)

# Batch fit with temporal weights: recent data gets more influence
transform.fit(scores, labels, timestamps=timestamps)

# Online update: auto-increments internal timestamp
transform.update(score, label)
ParameterDescription
decay_half_lifeNumber of observations after which a sample's weight decays to 50%. Must be positive. Very large values match the parent class.
timestampsPer-sample timestamps for fit(). When None, all samples weighted equally.

BlockMaxIndex

Block-max index for BMW-style upper bounds (Section 6.2, Corollary 7.4.2). Partitions documents into blocks and stores per-block maximum BM25 contributions for tighter pruning bounds than global WAND.

from bayesian_bm25.scorer import BlockMaxIndex

idx = BlockMaxIndex(block_size=128)
idx.build(score_matrix)  # shape: (n_terms, n_docs)

# Per-block BM25 upper bound
block_ub = idx.block_upper_bound(term_idx, block_id)

# Bayesian probability upper bound (tighter than global WAND)
bayesian_ub = idx.bayesian_block_upper_bound(
    term_idx, block_id, transform, p_max=0.9,
)

BayesianBM25Scorer

Drop-in scorer wrapping bm25s that returns calibrated probabilities instead of raw scores. Requires the [scorer] extra.

Constructor

scorer = BayesianBM25Scorer(
    k1=1.2,               # BM25 k1 parameter
    b=0.75,               # BM25 b parameter
    method="lucene",      # BM25 variant
    base_rate="auto",     # None, "auto", or float
    base_rate_method="percentile",  # "percentile", "mixture", or "elbow"
)

index / retrieve

# Build index from pre-tokenized corpus
scorer.index(corpus_tokens, show_progress=False)

# Retrieve top-k with probabilities
doc_ids, probabilities = scorer.retrieve(queries, k=10)

# Retrieve with per-document explanations
result = scorer.retrieve(queries, k=10, explain=True)
# result is a RetrievalResult with .doc_ids, .probabilities, .explanations

add_documents

# Incremental indexing (rebuilds full index with updated statistics)
scorer.add_documents(new_tokens)

MultiFieldScorer

Manages separate BM25 indexes per field and fuses field-level probabilities via log_odds_conjunction with configurable per-field weights.

scorer = MultiFieldScorer(
    fields=["title", "body"],
    field_weights={"title": 0.4, "body": 0.6},
    k1=1.2, b=0.75, method="lucene",
)

# Documents are dicts mapping field names to token lists
scorer.index(documents, show_progress=False)
doc_ids, probabilities = scorer.retrieve(query_tokens, k=10)

FusionDebugger

Transparent inspection of the full probability pipeline. Records every intermediate value (likelihood, prior, posterior, fusion) for debugging and document comparison.

debugger = FusionDebugger(transform)

# Trace a single document through the pipeline
trace = debugger.trace_document(
    bm25_score=8.42, tf=5, doc_len_ratio=0.60,
    cosine_score=0.74, doc_id="doc-42",
)
print(debugger.format_trace(trace))

# Compare two documents
comparison = debugger.compare(trace_a, trace_b)
print(debugger.format_comparison(comparison))

# Hierarchical fusion: AND(OR(title, body), vector, NOT(spam))
step1 = debugger.trace_fusion([0.85, 0.70],
    names=["title", "body"], method="prob_or")
step2 = debugger.trace_not(0.90, name="spam")
step3 = debugger.trace_fusion(
    [step1.fused_probability, 0.80, step2.complement],
    names=["OR(title,body)", "vector", "NOT(spam)"],
    method="prob_and",
)

Calibration Metrics

from bayesian_bm25 import (
    expected_calibration_error, brier_score,
    reliability_diagram, calibration_report,
)

ece = expected_calibration_error(probs, labels)       # lower is better
bs = brier_score(probs, labels)                        # lower is better
bins = reliability_diagram(probs, labels, n_bins=10)  # (avg_pred, avg_actual, count)

# One-call diagnostic report
report = calibration_report(probs, labels)
print(report.summary())  # formatted text with ECE, Brier, reliability table

CalibrationReport dataclass bundles ECE, Brier score, and reliability diagram data into a single diagnostic with a summary() method for formatted output.

Training Modes

Three training modes control how the gradient flows through the model (Algorithm 8.3.1 from the paper):

ModeConditionTrains OnInference Prior
"balanced"C1 (default)Sigmoid likelihoodFull composite prior
"prior_aware"C2Full Bayesian posterior (chain-rule gradients through dP/dL)Full composite prior
"prior_free"C3Sigmoid likelihoodprior = 0.5 (neutral)
transform.fit(scores, labels, mode="balanced")       # C1
transform.fit(scores, labels, mode="prior_aware",     # C2
              tfs=tfs, doc_len_ratios=ratios)
transform.fit(scores, labels, mode="prior_free")      # C3
Tip: C1 (balanced) is the recommended default. C2 (prior-aware) couples alpha/beta learning with the prior, useful when TF and document length data are available. C3 (prior-free) is best when you want a clean likelihood fit without prior influence.