fix(embed-candle): address round-1 review

- commit track-spec + meta-spec/plan into branch (HIGH: dangling `amends:` ref)
- inline parity evidence (cosine 1.0, max_abs_diff 2.01e-7) into HOTFIXES +
  release notes; drop refs to deleted IMPL_REPORT/SPIKE_REPORT (MEDIUM)
- model guard: reject non-e5-large `model` before the 2GB download so
  model_id() can't mislabel vectors (MEDIUM) + unit test
- parity test now covers BOTH query: and passage: prefixes (MEDIUM)
- guard encodings.first() index; document zero-attention/pooling invariant;
  clarify embed_batch prefixing doc (LOW)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-01 16:54:20 +00:00
parent 1011c75fff
commit 6ec4e6809f
7 changed files with 339 additions and 16 deletions

View File

@@ -46,6 +46,11 @@ const CANDLE_CACHE_SUBDIR: &str = "candle";
/// onnxruntime path uses, just the safetensors variant candle can read.
const HF_MODEL: &str = "intfloat/multilingual-e5-large";
/// The only `config.models.embedding.model` value the candle adapter accepts
/// (the e5-large weights `HF_MODEL` resolves to). Guards against silently
/// downloading e5-large while `model_id()` reports a different name.
const SUPPORTED_MODEL: &str = "multilingual-e5-large";
/// Token truncation length (e5 was trained at 512).
const MAX_LEN: usize = 512;
@@ -99,6 +104,22 @@ impl CandleEmbedder {
}
}
// 1b. Model guard. `HF_MODEL` is hard-coded (candle currently only wires
// e5-large), so if the operator configured a *different* model name
// we must NOT silently download e5-large and then label its vectors
// with the configured name via `model_id()` — that would mislabel
// `embedding_version` and corrupt a mixed index. Fail fast, before
// the ~2GB download.
let want = config.models.embedding.model.as_str();
if want != SUPPORTED_MODEL && want != HF_MODEL {
anyhow::bail!(
"candle provider currently supports only '{SUPPORTED_MODEL}' (or \
the HF id '{HF_MODEL}'), but config.models.embedding.model = \
'{want}'. Use provider=fastembed for other models, or set \
model = \"{SUPPORTED_MODEL}\"."
);
}
// 2. Resolve `{data_dir}/models/candle/` exactly like the fastembed
// adapter resolves its own subdir.
let data_dir = expand_path(&config.storage.data_dir, "");
@@ -177,8 +198,9 @@ impl CandleEmbedder {
})
}
/// Embed one batch (already prefixed) through the candle pipeline:
/// tokenize → forward → masked mean pool → L2 normalize.
/// Embed one batch of **already-prefixed** strings (the e5 `query:`/
/// `passage:` prefix is applied by the caller [`CandleEmbedder::embed`])
/// through the candle pipeline: tokenize → forward → masked mean pool → L2.
fn embed_batch(&self, prefixed: &[String]) -> Result<Vec<Vec<f32>>> {
let encodings = self
.tokenizer
@@ -186,7 +208,13 @@ impl CandleEmbedder {
.map_err(|e| anyhow::anyhow!("kb-embed-candle: encode_batch: {e}"))?;
let bsz = encodings.len();
let seq = encodings[0].get_ids().len();
// `embed` already returns early on empty input and `.chunks()` never
// yields an empty slice, so this is currently unreachable — but guard
// the index so a future refactor can't turn it into a panic.
let Some(first) = encodings.first() else {
return Ok(Vec::new());
};
let seq = first.get_ids().len();
let mut ids = Vec::with_capacity(bsz * seq);
let mut mask = Vec::with_capacity(bsz * seq);
@@ -212,6 +240,9 @@ impl CandleEmbedder {
// attention-mask-weighted mean pooling
let mask3 = attn_f32.unsqueeze(2)?; // (b, seq, 1)
let summed = hidden.broadcast_mul(&mask3)?.sum(1)?; // (b, hidden)
// counts ≥ 1 always: every input is e5-prefixed AND special tokens are
// added (encode_batch(_, true)), so no row has an all-zero mask. If that
// invariant ever breaks, broadcast_div would emit NaN vectors.
let counts = mask3.sum(1)?; // (b, 1)
let mean = summed.broadcast_div(&counts)?;
@@ -360,4 +391,26 @@ mod tests {
"error must mention both dims, got: {msg}"
);
}
// ── model guard ──────────────────────────────────────────────────
// A non-e5-large model name must fail fast (BEFORE the ~2GB download),
// so we never download e5-large yet label its vectors with another name
// via model_id() — which would mislabel embedding_version.
#[test]
fn new_rejects_unsupported_model() {
let mut config = kebab_config::Config::defaults();
config.models.embedding.model = "multilingual-e5-small".to_string();
// num_threads defaults to 0, so no global rayon side effect here.
// `.err()` (not `expect_err`) avoids requiring `CandleEmbedder: Debug`
// — it holds a Mutex/Tokenizer and intentionally derives no Debug.
let err = CandleEmbedder::new(&config)
.err()
.expect("unsupported model must error");
let msg = format!("{err:#}");
assert!(
msg.contains("candle provider currently supports only"),
"expected model-guard error, got: {msg}"
);
}
}

View File

@@ -49,11 +49,14 @@ fn candle_matches_fastembed() {
let candle = CandleEmbedder::new(&config).expect("build CandleEmbedder");
let fastembed = FastembedEmbedder::new(&config).expect("build FastembedEmbedder");
// Cover BOTH prefix paths (`passage:` for Document, `query:` for Query) so
// a query-side prefix/pooling divergence can't slip through (reviewer note).
let inputs: Vec<EmbeddingInput> = SENTENCES
.iter()
.map(|s| EmbeddingInput {
text: s,
kind: EmbeddingKind::Document,
.flat_map(|s| {
[EmbeddingKind::Document, EmbeddingKind::Query]
.into_iter()
.map(move |kind| EmbeddingInput { text: s, kind })
})
.collect();
@@ -61,11 +64,12 @@ fn candle_matches_fastembed() {
let fv = fastembed.embed(&inputs).expect("fastembed embed");
assert_eq!(cv.len(), fv.len(), "embedding counts must match");
assert_eq!(cv.len(), inputs.len(), "one vector per input");
assert_eq!(candle.dimensions(), 1024);
let mut min_cos = f32::INFINITY;
let mut max_abs_diff = 0f32;
for (i, s) in SENTENCES.iter().enumerate() {
for (i, inp) in inputs.iter().enumerate() {
assert_eq!(cv[i].len(), 1024, "candle dim");
assert_eq!(fv[i].len(), 1024, "fastembed dim");
let c = cosine(&cv[i], &fv[i]);
@@ -76,8 +80,12 @@ fn candle_matches_fastembed() {
.map(|(a, b)| (a - b).abs())
.fold(0f32, f32::max);
max_abs_diff = max_abs_diff.max(diff);
let preview: String = s.chars().take(40).collect();
println!("[{i:>2}] cos={c:.6} max_abs_diff={diff:.6e} {preview}");
let kind = match inp.kind {
EmbeddingKind::Document => "doc",
EmbeddingKind::Query => "qry",
};
let preview: String = inp.text.chars().take(36).collect();
println!("[{i:>2}] {kind} cos={c:.6} max_abs_diff={diff:.6e} {preview}");
}
println!("PARITY_SUMMARY cosine_min={min_cos:.6} max_abs_diff={max_abs_diff:.6e}");