feat(embed): candle 모델 레지스트리 + arctic-embed-l-v2.0 (CLS pooling)

e5 하드코딩(HF_MODEL/SUPPORTED_MODEL/mean/query:+passage:) → 모델 레지스트리
EmbedModelSpec{name,hf_repo,pooling,query_prefix,doc_prefix,dim,version_tag}.
e5(mean, query:/passage:) + arctic(CLS, query:/무접두어). pooling 모델별 분기
(mean=attention-mask-weighted / CLS=hidden[:,0,:]), tokenize/forward/L2 공유.
arctic pooling=CLS 는 HF 1_Pooling/config.json(pooling_mode_cls_token:true) 확인.
model_version 은 arctic 일 때 +arctic-cls 태그(embedding_version cascade 트리거);
e5 는 fastembed-e5 호환(NUMA 드롭인) 위해 plain config.version 유지.

correctness 게이트: tests/arctic_ollama_parity.rs (#[ignore], live Ollama) —
candle arctic vs Ollama snowflake-arctic-embed2 per-sentence 코사인>0.99.
수동 실측 cosine_min=0.999984 (recall@10 130 재현 보장).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-03 04:59:11 +00:00
parent 7505645008
commit cbcae69abf
3 changed files with 390 additions and 84 deletions

View File

@@ -38,6 +38,9 @@ metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
# not the library's own (non-dev) dependencies — so rayon/kebab-config/kebab-core # not the library's own (non-dev) dependencies — so rayon/kebab-config/kebab-core
# are repeated here for tests/parity.rs and tests/thread_cap.rs. # are repeated here for tests/parity.rs and tests/thread_cap.rs.
kebab-embed-local = { path = "../kebab-embed-local" } kebab-embed-local = { path = "../kebab-embed-local" }
# arctic↔Ollama parity test drives the real Ollama adapter for the reference
# vectors (tests/arctic_ollama_parity.rs, `#[ignore]` — live Ollama).
kebab-embed-ollama = { path = "../kebab-embed-ollama" }
kebab-config = { path = "../kebab-config" } kebab-config = { path = "../kebab-config" }
kebab-core = { path = "../kebab-core" } kebab-core = { path = "../kebab-core" }
rayon = "1" rayon = "1"

View File

@@ -1,31 +1,44 @@
//! `kebab-embed-candle` — [`CandleEmbedder`], a pure-Rust (candle) //! `kebab-embed-candle` — [`CandleEmbedder`], a pure-Rust (candle)
//! implementation of [`Embedder`](kebab_core::Embedder). //! implementation of [`Embedder`](kebab_core::Embedder).
//! //!
//! Runs the same `intfloat/multilingual-e5-large` model as the default //! Runs an XLM-RoBERTa-large embedding model through `candle`
//! [`FastembedEmbedder`](kebab_embed_local) but through `candle` //! (`candle-transformers`' XLM-RoBERTa) instead of onnxruntime. Two models
//! (`candle-transformers`' XLM-RoBERTa) instead of onnxruntime. Motivation: //! are wired through a small **registry** ([`MODEL_REGISTRY`]):
//! fastembed 4.9's onnxruntime hard-codes 48 intra-op threads, which corrupts
//! the heap (double-free) on dual-socket NUMA hosts. candle's CPU backend
//! sizes its threads off the global rayon pool, so a one-shot
//! [`rayon::ThreadPoolBuilder`] cap (config `num_threads` / env
//! `KEBAB_EMBED_THREADS`) keeps the worker count NUMA-safe.
//! //!
//! Output parity with the onnxruntime path was proven by the Phase 0 spike //! * `multilingual-e5-large` — the same weights the default
//! (cosine 1.000000); this crate absorbs that pipeline verbatim: //! [`FastembedEmbedder`](kebab_embed_local) uses (mean pooling,
//! `query: `/`passage: ` prefixes). candle is the NUMA-safe drop-in:
//! fastembed 4.9's onnxruntime hard-codes 48 intra-op threads, which
//! corrupts the heap (double-free) on dual-socket NUMA hosts. candle's
//! CPU backend sizes its threads off the global rayon pool, so a one-shot
//! [`rayon::ThreadPoolBuilder`] cap (config `num_threads` / env
//! `KEBAB_EMBED_THREADS`) keeps the worker count NUMA-safe.
//! * `snowflake-arctic-embed-l-v2.0` — Snowflake's arctic-embed v2.0
//! (CLS pooling, `query: ` on queries / no prefix on documents). Same
//! XLM-RoBERTa-large architecture, dim 1024, so it rides the exact same
//! tokenize → forward → L2 pipeline; only the pooling step and prefixes
//! differ (both keyed off the per-model [`EmbedModelSpec`]).
//! //!
//! 1. e5 prefix (`passage: ` for documents, `query: ` for queries — the same //! Output parity with the onnxruntime path (for e5) was proven by the
//! convention as `kebab-embed-local`'s `prefix_input`); //! Phase 0 spike (cosine 1.000000); the arctic path's pooling/prefix
//! correctness is pinned by an `#[ignore]`d cosine>0.99 cross-check against
//! Ollama's `snowflake-arctic-embed2` (see `tests/arctic_ollama_parity.rs`).
//! The shared pipeline:
//!
//! 1. instruction prefix per [`EmbedModelSpec`] (query/doc);
//! 2. tokenize (max_len 512, batch-longest padding, special tokens); //! 2. tokenize (max_len 512, batch-longest padding, special tokens);
//! 3. XLM-RoBERTa forward on `Device::Cpu`; //! 3. XLM-RoBERTa forward on the selected [`Device`];
//! 4. attention-mask-weighted mean pooling; //! 4. pooling — mean (attention-mask-weighted) or CLS (first token);
//! 5. L2 normalization. //! 5. L2 normalization.
//! //!
//! Model files (`config.json`, `tokenizer.json`, `model.safetensors`) are //! Model files (`config.json`, `tokenizer.json`, `model.safetensors`) are
//! fetched via `hf-hub` into `{config.storage.model_dir}/candle/`. //! fetched via `hf-hub` into `{config.storage.model_dir}/candle/` (hf-hub's
//! cache layout namespaces by repo, so e5 and arctic never collide).
//! //!
//! This crate is **opt-in** (`config.models.embedding.provider = "candle"`); //! This crate is **opt-in** (`config.models.embedding.provider = "candle"`);
//! the default provider stays `fastembed`. See //! the default provider stays `fastembed`. See
//! `docs/superpowers/specs/2026-06-01-embed-candle-track-spec.md`. //! `docs/superpowers/specs/2026-06-01-embed-candle-track-spec.md` and
//! `docs/superpowers/specs/2026-06-03-arctic-embedder-spec.md`.
use std::sync::Mutex; use std::sync::Mutex;
@@ -42,22 +55,95 @@ use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
/// `fastembed/` subdir so the two backends never collide. /// `fastembed/` subdir so the two backends never collide.
const CANDLE_CACHE_SUBDIR: &str = "candle"; const CANDLE_CACHE_SUBDIR: &str = "candle";
/// HuggingFace repo id for the multilingual e5 large model. Same weights the /// Token truncation length (both e5 and arctic-embed-l-v2.0 train at 512).
/// onnxruntime path uses, just the safetensors variant candle can read.
const HF_MODEL: &str = "intfloat/multilingual-e5-large";
/// The only `config.models.embedding.model` value the candle adapter accepts
/// (the e5-large weights `HF_MODEL` resolves to). Guards against silently
/// downloading e5-large while `model_id()` reports a different name.
const SUPPORTED_MODEL: &str = "multilingual-e5-large";
/// Token truncation length (e5 was trained at 512).
const MAX_LEN: usize = 512; const MAX_LEN: usize = 512;
/// Env var that overrides `config.models.embedding.num_threads`. Read once in /// Env var that overrides `config.models.embedding.num_threads`. Read once in
/// [`CandleEmbedder::new`]; `0`/unset/unparseable means "leave rayon default". /// [`CandleEmbedder::new`]; `0`/unset/unparseable means "leave rayon default".
const ENV_EMBED_THREADS: &str = "KEBAB_EMBED_THREADS"; const ENV_EMBED_THREADS: &str = "KEBAB_EMBED_THREADS";
/// Pooling strategy over the model's last hidden state. Keyed per-model by
/// [`EmbedModelSpec::pooling`] — e5 is mean, arctic is CLS.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Pooling {
/// Attention-mask-weighted mean over all tokens (e5 / sentence-transformers
/// `pooling_mode_mean_tokens`).
Mean,
/// First token (`<s>`/`[CLS]`) hidden state (arctic-embed v2.0 —
/// `1_Pooling/config.json` has `pooling_mode_cls_token: true`).
Cls,
}
/// One supported embedding model: the HF repo candle downloads, the pooling
/// strategy, and the e5-style instruction prefixes. [`MODEL_REGISTRY`] maps a
/// `config.models.embedding.model` value to one of these.
#[derive(Clone, Copy, Debug)]
pub struct EmbedModelSpec {
/// The short `config.models.embedding.model` value that selects this spec.
pub name: &'static str,
/// HuggingFace repo id candle fetches `config.json` / `tokenizer.json` /
/// `model.safetensors` from.
pub hf_repo: &'static str,
/// Pooling over the last hidden state.
pub pooling: Pooling,
/// Prefix prepended to **query** inputs before tokenization.
pub query_prefix: &'static str,
/// Prefix prepended to **document** inputs before tokenization (arctic
/// uses `""` — documents are embedded raw).
pub doc_prefix: &'static str,
/// Expected embedding dimension (model hidden size).
pub dim: usize,
/// Suffix folded into `model_version` so switching **to** this model
/// triggers the `embedding_version` cascade even if the operator forgets
/// to bump `config.version`. `None` keeps the bare `config.version` — used
/// by e5 so candle-e5 and fastembed-e5 report the *same* version and stay
/// interchangeable (the NUMA drop-in invariant — Phase 0 cosine 1.0).
pub version_tag: Option<&'static str>,
}
/// The models the candle adapter can load. Adding a model = one entry here
/// (plus, for a non-XLM-R architecture, a new forward path — both current
/// entries are XLM-RoBERTa-large so they share everything but pooling/prefix).
static MODEL_REGISTRY: &[EmbedModelSpec] = &[
EmbedModelSpec {
name: "multilingual-e5-large",
hf_repo: "intfloat/multilingual-e5-large",
pooling: Pooling::Mean,
query_prefix: "query: ",
doc_prefix: "passage: ",
dim: 1024,
version_tag: None,
},
EmbedModelSpec {
name: "snowflake-arctic-embed-l-v2.0",
hf_repo: "Snowflake/snowflake-arctic-embed-l-v2.0",
pooling: Pooling::Cls,
query_prefix: "query: ",
doc_prefix: "",
dim: 1024,
version_tag: Some("arctic-cls"),
},
];
/// Look up a model spec by `config.models.embedding.model`. Accepts either the
/// short `name` or the full `hf_repo` id (mirrors the old e5 guard, which
/// accepted both `multilingual-e5-large` and `intfloat/multilingual-e5-large`).
pub(crate) fn lookup_spec(model: &str) -> Option<&'static EmbedModelSpec> {
MODEL_REGISTRY
.iter()
.find(|s| s.name == model || s.hf_repo == model)
}
/// Comma-separated list of supported model names, for the
/// unsupported-model error message.
fn supported_models() -> String {
MODEL_REGISTRY
.iter()
.map(|s| s.name)
.collect::<Vec<_>>()
.join("`, `")
}
/// Pure-Rust candle adapter. Construct via [`CandleEmbedder::new`]; the /// Pure-Rust candle adapter. Construct via [`CandleEmbedder::new`]; the
/// constructor downloads the model on first use, so share one instance. /// constructor downloads the model on first use, so share one instance.
pub struct CandleEmbedder { pub struct CandleEmbedder {
@@ -68,6 +154,9 @@ pub struct CandleEmbedder {
model: Mutex<XLMRobertaModel>, model: Mutex<XLMRobertaModel>,
tokenizer: Tokenizer, tokenizer: Tokenizer,
device: Device, device: Device,
/// The resolved model spec (pooling + prefixes) — drives `embed` and
/// `embed_batch`.
spec: &'static EmbedModelSpec,
model_id: EmbeddingModelId, model_id: EmbeddingModelId,
version: EmbeddingVersion, version: EmbeddingVersion,
dimensions: usize, dimensions: usize,
@@ -75,7 +164,8 @@ pub struct CandleEmbedder {
} }
impl CandleEmbedder { impl CandleEmbedder {
/// Build an embedder from `Config`. Applies the NUMA thread cap, fetches /// Build an embedder from `Config`. Resolves the model spec from
/// `config.models.embedding.model`, applies the NUMA thread cap, fetches
/// the model into `{model_dir}/candle/`, and validates that the model's /// the model into `{model_dir}/candle/`, and validates that the model's
/// hidden size matches `config.models.embedding.dimensions` before /// hidden size matches `config.models.embedding.dimensions` before
/// returning. /// returning.
@@ -104,21 +194,20 @@ impl CandleEmbedder {
} }
} }
// 1b. Model guard. `HF_MODEL` is hard-coded (candle currently only wires // 1b. Model registry lookup. If the operator configured a model the
// e5-large), so if the operator configured a *different* model name // candle adapter doesn't know, fail fast (BEFORE the ~2GB
// we must NOT silently download e5-large and then label its vectors // download) — never silently download one model and then label its
// with the configured name via `model_id()` — that would mislabel // vectors with another name via `model_id()`, which would mislabel
// `embedding_version` and corrupt a mixed index. Fail fast, before // `embedding_version` and corrupt a mixed index.
// the ~2GB download.
let want = config.models.embedding.model.as_str(); let want = config.models.embedding.model.as_str();
if want != SUPPORTED_MODEL && want != HF_MODEL { let spec = lookup_spec(want).ok_or_else(|| {
anyhow::bail!( anyhow::anyhow!(
"candle provider currently supports only '{SUPPORTED_MODEL}' (or \ "candle provider supports the models `{}`, but \
the HF id '{HF_MODEL}'), but config.models.embedding.model = \ config.models.embedding.model = '{want}'. Use provider=fastembed \
'{want}'. Use provider=fastembed for other models, or set \ for other models, or pick a supported one.",
model = \"{SUPPORTED_MODEL}\"." supported_models()
); )
} })?;
// 2. Resolve `{data_dir}/models/candle/` exactly like the fastembed // 2. Resolve `{data_dir}/models/candle/` exactly like the fastembed
// adapter resolves its own subdir. // adapter resolves its own subdir.
@@ -134,14 +223,15 @@ impl CandleEmbedder {
tracing::info!( tracing::info!(
target: "kebab-embed-candle", target: "kebab-embed-candle",
cache_dir = %cache_dir.display(), cache_dir = %cache_dir.display(),
model = HF_MODEL, model = spec.hf_repo,
pooling = ?spec.pooling,
"loading candle embedding model (first run downloads ~2GB safetensors)" "loading candle embedding model (first run downloads ~2GB safetensors)"
); );
let api = hf_hub::api::sync::ApiBuilder::new() let api = hf_hub::api::sync::ApiBuilder::new()
.with_cache_dir(cache_dir.clone()) .with_cache_dir(cache_dir.clone())
.build() .build()
.context("kb-embed-candle: build hf-hub api")?; .context("kb-embed-candle: build hf-hub api")?;
let repo = api.model(HF_MODEL.to_string()); let repo = api.model(spec.hf_repo.to_string());
let config_path = repo.get("config.json").context("download config.json")?; let config_path = repo.get("config.json").context("download config.json")?;
let tokenizer_path = repo let tokenizer_path = repo
.get("tokenizer.json") .get("tokenizer.json")
@@ -180,10 +270,21 @@ impl CandleEmbedder {
})) }))
.map_err(|e| anyhow::anyhow!("kb-embed-candle: set truncation: {e}"))?; .map_err(|e| anyhow::anyhow!("kb-embed-candle: set truncation: {e}"))?;
// model_version: fold the model tag in for non-e5 models so a switch
// triggers the embedding_version cascade; e5 keeps the bare
// config.version to stay interchangeable with fastembed-e5.
let version = match spec.version_tag {
Some(tag) => {
EmbeddingVersion(format!("{}+{}", config.models.embedding.version, tag))
}
None => EmbeddingVersion(config.models.embedding.version.clone()),
};
tracing::info!( tracing::info!(
target: "kebab-embed-candle", target: "kebab-embed-candle",
dimensions = cfg.hidden_size, dimensions = cfg.hidden_size,
layers = cfg.num_hidden_layers, layers = cfg.num_hidden_layers,
model = spec.name,
"candle embedding model loaded" "candle embedding model loaded"
); );
@@ -191,16 +292,17 @@ impl CandleEmbedder {
model: Mutex::new(model), model: Mutex::new(model),
tokenizer, tokenizer,
device, device,
spec,
model_id: EmbeddingModelId(config.models.embedding.model.clone()), model_id: EmbeddingModelId(config.models.embedding.model.clone()),
version: EmbeddingVersion(config.models.embedding.version.clone()), version,
dimensions: cfg.hidden_size, dimensions: cfg.hidden_size,
batch_size: config.models.embedding.batch_size.max(1), batch_size: config.models.embedding.batch_size.max(1),
}) })
} }
/// Embed one batch of **already-prefixed** strings (the e5 `query:`/ /// Embed one batch of **already-prefixed** strings (the per-model prefix
/// `passage:` prefix is applied by the caller [`CandleEmbedder::embed`]) /// is applied by the caller [`CandleEmbedder::embed`]) through the candle
/// through the candle pipeline: tokenize → forward → masked mean pool → L2. /// pipeline: tokenize → forward → pool (mean|CLS) → L2.
fn embed_batch(&self, prefixed: &[String]) -> Result<Vec<Vec<f32>>> { fn embed_batch(&self, prefixed: &[String]) -> Result<Vec<Vec<f32>>> {
let encodings = self let encodings = self
.tokenizer .tokenizer
@@ -237,18 +339,30 @@ impl CandleEmbedder {
guard.forward(&input_ids, &attn_f32, &token_type_ids, None, None, None)? guard.forward(&input_ids, &attn_f32, &token_type_ids, None, None, None)?
}; };
// attention-mask-weighted mean pooling // Pooling — per the model spec.
let mask3 = attn_f32.unsqueeze(2)?; // (b, seq, 1) let pooled = match self.spec.pooling {
let summed = hidden.broadcast_mul(&mask3)?.sum(1)?; // (b, hidden) Pooling::Mean => {
// counts ≥ 1 always: every input is e5-prefixed AND special tokens are // attention-mask-weighted mean pooling
// added (encode_batch(_, true)), so no row has an all-zero mask. If that let mask3 = attn_f32.unsqueeze(2)?; // (b, seq, 1)
// invariant ever breaks, broadcast_div would emit NaN vectors. let summed = hidden.broadcast_mul(&mask3)?.sum(1)?; // (b, hidden)
let counts = mask3.sum(1)?; // (b, 1) // counts ≥ 1 always: every input is prefixed AND special
let mean = summed.broadcast_div(&counts)?; // tokens are added (encode_batch(_, true)), so no row has an
// all-zero mask. If that invariant ever breaks, broadcast_div
// would emit NaN vectors.
let counts = mask3.sum(1)?; // (b, 1)
summed.broadcast_div(&counts)?
}
Pooling::Cls => {
// CLS pooling: the first token's hidden state. arctic-embed
// v2.0 prepends `<s>` (the XLM-R BOS/CLS) at index 0, so
// `hidden[:, 0, :]` is the sentence embedding.
hidden.narrow(1, 0, 1)?.squeeze(1)? // (b, hidden)
}
};
// L2 normalize // L2 normalize
let norm = mean.sqr()?.sum_keepdim(1)?.sqrt()?; let norm = pooled.sqr()?.sum_keepdim(1)?.sqrt()?;
let normalized = mean.broadcast_div(&norm)?; let normalized = pooled.broadcast_div(&norm)?;
// `.contiguous()` before host copy: broadcast ops can leave a strided // `.contiguous()` before host copy: broadcast ops can leave a strided
// view, which `to_vec2` rejects on the Metal backend (CPU tolerates it). // view, which `to_vec2` rejects on the Metal backend (CPU tolerates it).
@@ -274,9 +388,9 @@ impl Embedder for CandleEmbedder {
return Ok(Vec::new()); return Ok(Vec::new());
} }
// e5 prefix per §11.3 BEFORE tokenization (same convention as // Per-model instruction prefix BEFORE tokenization (same convention as
// FastembedEmbedder so the two backends produce comparable vectors). // FastembedEmbedder for e5; arctic uses `query: `/no-prefix).
let prefixed: Vec<String> = inputs.iter().map(prefix_input).collect(); let prefixed: Vec<String> = inputs.iter().map(|i| prefix_input(self.spec, i)).collect();
let mut out: Vec<Vec<f32>> = Vec::with_capacity(prefixed.len()); let mut out: Vec<Vec<f32>> = Vec::with_capacity(prefixed.len());
for chunk in prefixed.chunks(self.batch_size) { for chunk in prefixed.chunks(self.batch_size) {
@@ -298,22 +412,22 @@ impl Embedder for CandleEmbedder {
} }
} }
/// Build the e5-prefixed string for one [`EmbeddingInput`]. Free function so /// Build the prefixed string for one [`EmbeddingInput`] using the model spec.
/// a unit test can pin the format without loading the model. Byte-identical to /// Free function so a unit test can pin the format without loading the model.
/// `kebab-embed-local`'s `prefix_input` — the two backends MUST agree here or /// For e5 this is byte-identical to `kebab-embed-local`'s `prefix_input` — the
/// their vectors diverge. /// two backends MUST agree there or their vectors diverge.
fn prefix_input(input: &EmbeddingInput<'_>) -> String { fn prefix_input(spec: &EmbedModelSpec, input: &EmbeddingInput<'_>) -> String {
match input.kind { match input.kind {
EmbeddingKind::Document => format!("passage: {}", input.text), EmbeddingKind::Document => format!("{}{}", spec.doc_prefix, input.text),
EmbeddingKind::Query => format!("query: {}", input.text), EmbeddingKind::Query => format!("{}{}", spec.query_prefix, input.text),
} }
} }
/// Select the compute device. Built with the `metal` feature (Apple Silicon /// Select the compute device. Built with the `metal` feature (Apple Silicon
/// GPU), try Metal and fall back to CPU on failure; otherwise CPU. Metal only /// GPU), try Metal and fall back to CPU on failure; otherwise CPU. Metal only
/// compiles/runs on macOS — the Linux server builds the CPU path. e5-large /// compiles/runs on macOS — the Linux server builds the CPU path. Embedding
/// vectors are model-defined, so Metal-produced and CPU-produced embeddings are /// vectors are model-defined, so Metal-produced and CPU-produced embeddings
/// cross-compatible (a Mac can ingest on GPU, the server query on CPU). /// are cross-compatible (a Mac can ingest on GPU, the server query on CPU).
fn select_device() -> Device { fn select_device() -> Device {
#[cfg(feature = "metal")] #[cfg(feature = "metal")]
{ {
@@ -367,26 +481,85 @@ pub(crate) fn check_dim(model_dim: usize, cfg_dim: usize) -> Result<()> {
mod tests { mod tests {
use super::*; use super::*;
// ── prefix_input ───────────────────────────────────────────────── fn e5_spec() -> &'static EmbedModelSpec {
// Pin the exact e5 prefix strings; these MUST match lookup_spec("multilingual-e5-large").expect("e5 in registry")
// kebab-embed-local::prefix_input or candle vs fastembed parity breaks. }
fn arctic_spec() -> &'static EmbedModelSpec {
lookup_spec("snowflake-arctic-embed-l-v2.0").expect("arctic in registry")
}
// ── registry ─────────────────────────────────────────────────────
#[test] #[test]
fn prefix_document_uses_passage() { fn registry_resolves_e5_by_name_and_hf_repo() {
assert_eq!(
lookup_spec("multilingual-e5-large").map(|s| s.name),
Some("multilingual-e5-large")
);
assert_eq!(
lookup_spec("intfloat/multilingual-e5-large").map(|s| s.name),
Some("multilingual-e5-large")
);
}
#[test]
fn registry_resolves_arctic_and_its_pooling_is_cls() {
let s = arctic_spec();
assert_eq!(s.name, "snowflake-arctic-embed-l-v2.0");
assert_eq!(s.hf_repo, "Snowflake/snowflake-arctic-embed-l-v2.0");
assert_eq!(s.pooling, Pooling::Cls);
assert_eq!(s.dim, 1024);
assert_eq!(s.version_tag, Some("arctic-cls"));
}
#[test]
fn registry_e5_is_mean_pooling_no_version_tag() {
let s = e5_spec();
assert_eq!(s.pooling, Pooling::Mean);
assert_eq!(s.version_tag, None);
}
#[test]
fn registry_rejects_unknown_model() {
assert!(lookup_spec("multilingual-e5-small").is_none());
}
// ── prefix_input ─────────────────────────────────────────────────
// e5 prefixes MUST match kebab-embed-local::prefix_input or candle vs
// fastembed parity breaks; arctic uses query-only prefixing.
#[test]
fn e5_prefix_document_uses_passage() {
let input = EmbeddingInput { let input = EmbeddingInput {
text: "hello world", text: "hello world",
kind: EmbeddingKind::Document, kind: EmbeddingKind::Document,
}; };
assert_eq!(prefix_input(&input), "passage: hello world"); assert_eq!(prefix_input(e5_spec(), &input), "passage: hello world");
} }
#[test] #[test]
fn prefix_query_uses_query() { fn e5_prefix_query_uses_query() {
let input = EmbeddingInput { let input = EmbeddingInput {
text: "hello world", text: "hello world",
kind: EmbeddingKind::Query, kind: EmbeddingKind::Query,
}; };
assert_eq!(prefix_input(&input), "query: hello world"); assert_eq!(prefix_input(e5_spec(), &input), "query: hello world");
}
#[test]
fn arctic_prefix_query_uses_query_doc_is_bare() {
let doc = EmbeddingInput {
text: "후입선출 자료구조",
kind: EmbeddingKind::Document,
};
let qry = EmbeddingInput {
text: "스택 자료구조",
kind: EmbeddingKind::Query,
};
// arctic: documents are embedded raw, queries get `query: `.
assert_eq!(prefix_input(arctic_spec(), &doc), "후입선출 자료구조");
assert_eq!(prefix_input(arctic_spec(), &qry), "query: 스택 자료구조");
} }
#[test] #[test]
@@ -399,8 +572,10 @@ mod tests {
text: "", text: "",
kind: EmbeddingKind::Query, kind: EmbeddingKind::Query,
}; };
assert_eq!(prefix_input(&doc), "passage: "); assert_eq!(prefix_input(e5_spec(), &doc), "passage: ");
assert_eq!(prefix_input(&qry), "query: "); assert_eq!(prefix_input(e5_spec(), &qry), "query: ");
assert_eq!(prefix_input(arctic_spec(), &doc), "");
assert_eq!(prefix_input(arctic_spec(), &qry), "query: ");
} }
// ── check_dim ──────────────────────────────────────────────────── // ── check_dim ────────────────────────────────────────────────────
@@ -421,9 +596,9 @@ mod tests {
} }
// ── model guard ────────────────────────────────────────────────── // ── model guard ──────────────────────────────────────────────────
// A non-e5-large model name must fail fast (BEFORE the ~2GB download), // A model name not in the registry must fail fast (BEFORE the ~2GB
// so we never download e5-large yet label its vectors with another name // download), so we never download one model yet label its vectors with
// via model_id() — which would mislabel embedding_version. // another name via model_id() — which would mislabel embedding_version.
#[test] #[test]
fn new_rejects_unsupported_model() { fn new_rejects_unsupported_model() {
@@ -437,8 +612,8 @@ mod tests {
.expect("unsupported model must error"); .expect("unsupported model must error");
let msg = format!("{err:#}"); let msg = format!("{err:#}");
assert!( assert!(
msg.contains("candle provider currently supports only"), msg.contains("candle provider supports the models"),
"expected model-guard error, got: {msg}" "expected model-registry error, got: {msg}"
); );
} }
} }

View File

@@ -0,0 +1,128 @@
//! arctic-embed-l-v2.0 correctness gate (`#[ignore]` — needs the ~2GB candle
//! model + a live Ollama serving `snowflake-arctic-embed2`).
//!
//! This is the load-bearing pooling/prefix check for the arctic integration.
//! The recall measurement that justified adopting arctic (recall@10 130/132)
//! went through Ollama's `snowflake-arctic-embed2`. The candle path
//! re-implements the model (XLM-RoBERTa-large + **CLS** pooling + `query: ` on
//! queries / no prefix on documents). If candle's pooling or prefix is wrong,
//! its vectors silently diverge from the measured route and the 130 number
//! does NOT carry over. This test pins them together: per-sentence cosine
//! between the candle vector and the Ollama vector must be **> 0.99**.
//!
//! `#[ignore]` because it depends on an external Ollama daemon (CI is
//! headless/offline). The leader MUST run it once before merge.
//!
//! ## Manual run
//!
//! 1. Confirm Ollama is reachable and has the model:
//! ```sh
//! curl -s http://192.168.0.47:11434/api/tags # should list snowflake-arctic-embed2
//! ```
//! 2. Run (downloads the ~2GB candle safetensors on first run):
//! ```sh
//! CARGO_TARGET_DIR=/build/out/cargo-target \
//! KEBAB_ARCTIC_OLLAMA_ENDPOINT=http://192.168.0.47:11434 \
//! cargo test -p kebab-embed-candle --test arctic_ollama_parity -- --ignored --nocapture
//! ```
//! The endpoint defaults to `http://192.168.0.47:11434` if the env is unset.
//!
//! Record the printed `ARCTIC_PARITY_SUMMARY cosine_min=...` in
//! `/tmp/arctic-result.md` + `tasks/HOTFIXES.md`.
use kebab_config::Config;
use kebab_core::{Embedder, EmbeddingInput, EmbeddingKind};
use kebab_embed_candle::CandleEmbedder;
use kebab_embed_ollama::OllamaEmbedder;
const DOGFOOD_CONFIG: &str = "/build/dogfood/config.toml";
const DEFAULT_OLLAMA_ENDPOINT: &str = "http://192.168.0.47:11434";
/// Mixed Korean / English + the descriptive-recall shapes arctic was adopted
/// for (synonym / abbreviation / English term). Covers both prefix paths.
const SENTENCES: &[&str] = &[
"스택 자료구조",
"후입선출 방식으로 동작하는 자료구조",
"큐는 선입선출 자료구조이다",
"Rust ownership and the borrow checker",
"소유권과 빌림 검사기는 메모리 안전성을 보장한다",
"SVM 은 support vector machine 의 약자이다",
"정렬 알고리즘의 시간 복잡도",
"The capital of France is Paris.",
];
fn cosine(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
dot / (na * nb)
}
/// Base config: prefer the canonical dogfood config (for storage/cache roots),
/// fall back to `Config::defaults()` so the test still runs on a bare clone.
fn base_config() -> Config {
Config::load(Some(std::path::Path::new(DOGFOOD_CONFIG))).unwrap_or_else(|_| Config::defaults())
}
#[test]
#[ignore = "needs ~2GB candle model + live Ollama (snowflake-arctic-embed2); run manually before merge"]
fn candle_arctic_matches_ollama_arctic() {
let endpoint = std::env::var("KEBAB_ARCTIC_OLLAMA_ENDPOINT")
.unwrap_or_else(|_| DEFAULT_OLLAMA_ENDPOINT.to_string());
// candle side: the in-process arctic model.
let mut candle_cfg = base_config();
candle_cfg.models.embedding.provider = "candle".to_string();
candle_cfg.models.embedding.model = "snowflake-arctic-embed-l-v2.0".to_string();
candle_cfg.models.embedding.dimensions = 1024;
// Ollama side: the reference route the recall numbers came from.
let mut ollama_cfg = base_config();
ollama_cfg.models.embedding.provider = "ollama".to_string();
ollama_cfg.models.embedding.model = "snowflake-arctic-embed2".to_string();
ollama_cfg.models.embedding.dimensions = 1024;
ollama_cfg.models.embedding.endpoint = Some(endpoint.clone());
let candle = CandleEmbedder::new(&candle_cfg).expect("build candle arctic embedder");
let ollama = OllamaEmbedder::new(&ollama_cfg).expect("build ollama arctic embedder");
// Exercise BOTH prefix paths so a query-side divergence can't hide.
let inputs: Vec<EmbeddingInput> = SENTENCES
.iter()
.flat_map(|s| {
[EmbeddingKind::Document, EmbeddingKind::Query]
.into_iter()
.map(move |kind| EmbeddingInput { text: s, kind })
})
.collect();
let cv = candle.embed(&inputs).expect("candle embed");
let ov = ollama
.embed(&inputs)
.expect("ollama embed (is snowflake-arctic-embed2 pulled @ the endpoint?)");
assert_eq!(cv.len(), ov.len(), "embedding counts must match");
assert_eq!(cv.len(), inputs.len(), "one vector per input");
assert_eq!(candle.dimensions(), 1024);
let mut min_cos = f32::INFINITY;
for (i, inp) in inputs.iter().enumerate() {
assert_eq!(cv[i].len(), 1024, "candle dim");
assert_eq!(ov[i].len(), 1024, "ollama dim");
let c = cosine(&cv[i], &ov[i]);
min_cos = min_cos.min(c);
let kind = match inp.kind {
EmbeddingKind::Document => "doc",
EmbeddingKind::Query => "qry",
};
let preview: String = inp.text.chars().take(36).collect();
println!("[{i:>2}] {kind} cos={c:.6} {preview}");
}
println!("ARCTIC_PARITY_SUMMARY cosine_min={min_cos:.6} endpoint={endpoint}");
assert!(
min_cos > 0.99,
"candle arctic vs Ollama arctic cosine_min={min_cos:.6} ≤ 0.99 — \
pooling/prefix mismatch; the recall=130 measurement will NOT reproduce"
);
}