fix(embed-candle): address round-1 review
- commit track-spec + meta-spec/plan into branch (HIGH: dangling `amends:` ref) - inline parity evidence (cosine 1.0, max_abs_diff 2.01e-7) into HOTFIXES + release notes; drop refs to deleted IMPL_REPORT/SPIKE_REPORT (MEDIUM) - model guard: reject non-e5-large `model` before the 2GB download so model_id() can't mislabel vectors (MEDIUM) + unit test - parity test now covers BOTH query: and passage: prefixes (MEDIUM) - guard encodings.first() index; document zero-attention/pooling invariant; clarify embed_batch prefixing doc (LOW) Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -46,6 +46,11 @@ const CANDLE_CACHE_SUBDIR: &str = "candle";
|
||||
/// onnxruntime path uses, just the safetensors variant candle can read.
|
||||
const HF_MODEL: &str = "intfloat/multilingual-e5-large";
|
||||
|
||||
/// The only `config.models.embedding.model` value the candle adapter accepts
|
||||
/// (the e5-large weights `HF_MODEL` resolves to). Guards against silently
|
||||
/// downloading e5-large while `model_id()` reports a different name.
|
||||
const SUPPORTED_MODEL: &str = "multilingual-e5-large";
|
||||
|
||||
/// Token truncation length (e5 was trained at 512).
|
||||
const MAX_LEN: usize = 512;
|
||||
|
||||
@@ -99,6 +104,22 @@ impl CandleEmbedder {
|
||||
}
|
||||
}
|
||||
|
||||
// 1b. Model guard. `HF_MODEL` is hard-coded (candle currently only wires
|
||||
// e5-large), so if the operator configured a *different* model name
|
||||
// we must NOT silently download e5-large and then label its vectors
|
||||
// with the configured name via `model_id()` — that would mislabel
|
||||
// `embedding_version` and corrupt a mixed index. Fail fast, before
|
||||
// the ~2GB download.
|
||||
let want = config.models.embedding.model.as_str();
|
||||
if want != SUPPORTED_MODEL && want != HF_MODEL {
|
||||
anyhow::bail!(
|
||||
"candle provider currently supports only '{SUPPORTED_MODEL}' (or \
|
||||
the HF id '{HF_MODEL}'), but config.models.embedding.model = \
|
||||
'{want}'. Use provider=fastembed for other models, or set \
|
||||
model = \"{SUPPORTED_MODEL}\"."
|
||||
);
|
||||
}
|
||||
|
||||
// 2. Resolve `{data_dir}/models/candle/` exactly like the fastembed
|
||||
// adapter resolves its own subdir.
|
||||
let data_dir = expand_path(&config.storage.data_dir, "");
|
||||
@@ -177,8 +198,9 @@ impl CandleEmbedder {
|
||||
})
|
||||
}
|
||||
|
||||
/// Embed one batch (already prefixed) through the candle pipeline:
|
||||
/// tokenize → forward → masked mean pool → L2 normalize.
|
||||
/// Embed one batch of **already-prefixed** strings (the e5 `query:`/
|
||||
/// `passage:` prefix is applied by the caller [`CandleEmbedder::embed`])
|
||||
/// through the candle pipeline: tokenize → forward → masked mean pool → L2.
|
||||
fn embed_batch(&self, prefixed: &[String]) -> Result<Vec<Vec<f32>>> {
|
||||
let encodings = self
|
||||
.tokenizer
|
||||
@@ -186,7 +208,13 @@ impl CandleEmbedder {
|
||||
.map_err(|e| anyhow::anyhow!("kb-embed-candle: encode_batch: {e}"))?;
|
||||
|
||||
let bsz = encodings.len();
|
||||
let seq = encodings[0].get_ids().len();
|
||||
// `embed` already returns early on empty input and `.chunks()` never
|
||||
// yields an empty slice, so this is currently unreachable — but guard
|
||||
// the index so a future refactor can't turn it into a panic.
|
||||
let Some(first) = encodings.first() else {
|
||||
return Ok(Vec::new());
|
||||
};
|
||||
let seq = first.get_ids().len();
|
||||
|
||||
let mut ids = Vec::with_capacity(bsz * seq);
|
||||
let mut mask = Vec::with_capacity(bsz * seq);
|
||||
@@ -212,6 +240,9 @@ impl CandleEmbedder {
|
||||
// attention-mask-weighted mean pooling
|
||||
let mask3 = attn_f32.unsqueeze(2)?; // (b, seq, 1)
|
||||
let summed = hidden.broadcast_mul(&mask3)?.sum(1)?; // (b, hidden)
|
||||
// counts ≥ 1 always: every input is e5-prefixed AND special tokens are
|
||||
// added (encode_batch(_, true)), so no row has an all-zero mask. If that
|
||||
// invariant ever breaks, broadcast_div would emit NaN vectors.
|
||||
let counts = mask3.sum(1)?; // (b, 1)
|
||||
let mean = summed.broadcast_div(&counts)?;
|
||||
|
||||
@@ -360,4 +391,26 @@ mod tests {
|
||||
"error must mention both dims, got: {msg}"
|
||||
);
|
||||
}
|
||||
|
||||
// ── model guard ──────────────────────────────────────────────────
|
||||
// A non-e5-large model name must fail fast (BEFORE the ~2GB download),
|
||||
// so we never download e5-large yet label its vectors with another name
|
||||
// via model_id() — which would mislabel embedding_version.
|
||||
|
||||
#[test]
|
||||
fn new_rejects_unsupported_model() {
|
||||
let mut config = kebab_config::Config::defaults();
|
||||
config.models.embedding.model = "multilingual-e5-small".to_string();
|
||||
// num_threads defaults to 0, so no global rayon side effect here.
|
||||
// `.err()` (not `expect_err`) avoids requiring `CandleEmbedder: Debug`
|
||||
// — it holds a Mutex/Tokenizer and intentionally derives no Debug.
|
||||
let err = CandleEmbedder::new(&config)
|
||||
.err()
|
||||
.expect("unsupported model must error");
|
||||
let msg = format!("{err:#}");
|
||||
assert!(
|
||||
msg.contains("candle provider currently supports only"),
|
||||
"expected model-guard error, got: {msg}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,11 +49,14 @@ fn candle_matches_fastembed() {
|
||||
let candle = CandleEmbedder::new(&config).expect("build CandleEmbedder");
|
||||
let fastembed = FastembedEmbedder::new(&config).expect("build FastembedEmbedder");
|
||||
|
||||
// Cover BOTH prefix paths (`passage:` for Document, `query:` for Query) so
|
||||
// a query-side prefix/pooling divergence can't slip through (reviewer note).
|
||||
let inputs: Vec<EmbeddingInput> = SENTENCES
|
||||
.iter()
|
||||
.map(|s| EmbeddingInput {
|
||||
text: s,
|
||||
kind: EmbeddingKind::Document,
|
||||
.flat_map(|s| {
|
||||
[EmbeddingKind::Document, EmbeddingKind::Query]
|
||||
.into_iter()
|
||||
.map(move |kind| EmbeddingInput { text: s, kind })
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -61,11 +64,12 @@ fn candle_matches_fastembed() {
|
||||
let fv = fastembed.embed(&inputs).expect("fastembed embed");
|
||||
|
||||
assert_eq!(cv.len(), fv.len(), "embedding counts must match");
|
||||
assert_eq!(cv.len(), inputs.len(), "one vector per input");
|
||||
assert_eq!(candle.dimensions(), 1024);
|
||||
|
||||
let mut min_cos = f32::INFINITY;
|
||||
let mut max_abs_diff = 0f32;
|
||||
for (i, s) in SENTENCES.iter().enumerate() {
|
||||
for (i, inp) in inputs.iter().enumerate() {
|
||||
assert_eq!(cv[i].len(), 1024, "candle dim");
|
||||
assert_eq!(fv[i].len(), 1024, "fastembed dim");
|
||||
let c = cosine(&cv[i], &fv[i]);
|
||||
@@ -76,8 +80,12 @@ fn candle_matches_fastembed() {
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.fold(0f32, f32::max);
|
||||
max_abs_diff = max_abs_diff.max(diff);
|
||||
let preview: String = s.chars().take(40).collect();
|
||||
println!("[{i:>2}] cos={c:.6} max_abs_diff={diff:.6e} {preview}");
|
||||
let kind = match inp.kind {
|
||||
EmbeddingKind::Document => "doc",
|
||||
EmbeddingKind::Query => "qry",
|
||||
};
|
||||
let preview: String = inp.text.chars().take(36).collect();
|
||||
println!("[{i:>2}] {kind} cos={c:.6} max_abs_diff={diff:.6e} {preview}");
|
||||
}
|
||||
|
||||
println!("PARITY_SUMMARY cosine_min={min_cos:.6} max_abs_diff={max_abs_diff:.6e}");
|
||||
|
||||
Reference in New Issue
Block a user