From cbcae69abf0e4ba4578d95fb7dd801d3ddea3d2f Mon Sep 17 00:00:00 2001 From: altair823 Date: Wed, 3 Jun 2026 04:59:11 +0000 Subject: [PATCH] =?UTF-8?q?feat(embed):=20candle=20=EB=AA=A8=EB=8D=B8=20?= =?UTF-8?q?=EB=A0=88=EC=A7=80=EC=8A=A4=ED=8A=B8=EB=A6=AC=20+=20arctic-embe?= =?UTF-8?q?d-l-v2.0=20(CLS=20pooling)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit e5 하드코딩(HF_MODEL/SUPPORTED_MODEL/mean/query:+passage:) → 모델 레지스트리 EmbedModelSpec{name,hf_repo,pooling,query_prefix,doc_prefix,dim,version_tag}. e5(mean, query:/passage:) + arctic(CLS, query:/무접두어). pooling 모델별 분기 (mean=attention-mask-weighted / CLS=hidden[:,0,:]), tokenize/forward/L2 공유. arctic pooling=CLS 는 HF 1_Pooling/config.json(pooling_mode_cls_token:true) 확인. model_version 은 arctic 일 때 +arctic-cls 태그(embedding_version cascade 트리거); e5 는 fastembed-e5 호환(NUMA 드롭인) 위해 plain config.version 유지. correctness 게이트: tests/arctic_ollama_parity.rs (#[ignore], live Ollama) — candle arctic vs Ollama snowflake-arctic-embed2 per-sentence 코사인>0.99. 수동 실측 cosine_min=0.999984 (recall@10 130 재현 보장). Co-Authored-By: Claude Opus 4.8 (1M context) --- crates/kebab-embed-candle/Cargo.toml | 3 + crates/kebab-embed-candle/src/lib.rs | 343 +++++++++++++----- .../tests/arctic_ollama_parity.rs | 128 +++++++ 3 files changed, 390 insertions(+), 84 deletions(-) create mode 100644 crates/kebab-embed-candle/tests/arctic_ollama_parity.rs diff --git a/crates/kebab-embed-candle/Cargo.toml b/crates/kebab-embed-candle/Cargo.toml index 40a17d9..f9d2c67 100644 --- a/crates/kebab-embed-candle/Cargo.toml +++ b/crates/kebab-embed-candle/Cargo.toml @@ -38,6 +38,9 @@ metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"] # not the library's own (non-dev) dependencies — so rayon/kebab-config/kebab-core # are repeated here for tests/parity.rs and tests/thread_cap.rs. kebab-embed-local = { path = "../kebab-embed-local" } +# arctic↔Ollama parity test drives the real Ollama adapter for the reference +# vectors (tests/arctic_ollama_parity.rs, `#[ignore]` — live Ollama). +kebab-embed-ollama = { path = "../kebab-embed-ollama" } kebab-config = { path = "../kebab-config" } kebab-core = { path = "../kebab-core" } rayon = "1" diff --git a/crates/kebab-embed-candle/src/lib.rs b/crates/kebab-embed-candle/src/lib.rs index 0649e14..f45d529 100644 --- a/crates/kebab-embed-candle/src/lib.rs +++ b/crates/kebab-embed-candle/src/lib.rs @@ -1,31 +1,44 @@ //! `kebab-embed-candle` — [`CandleEmbedder`], a pure-Rust (candle) //! implementation of [`Embedder`](kebab_core::Embedder). //! -//! Runs the same `intfloat/multilingual-e5-large` model as the default -//! [`FastembedEmbedder`](kebab_embed_local) but through `candle` -//! (`candle-transformers`' XLM-RoBERTa) instead of onnxruntime. Motivation: -//! fastembed 4.9's onnxruntime hard-codes 48 intra-op threads, which corrupts -//! the heap (double-free) on dual-socket NUMA hosts. candle's CPU backend -//! sizes its threads off the global rayon pool, so a one-shot -//! [`rayon::ThreadPoolBuilder`] cap (config `num_threads` / env -//! `KEBAB_EMBED_THREADS`) keeps the worker count NUMA-safe. +//! Runs an XLM-RoBERTa-large embedding model through `candle` +//! (`candle-transformers`' XLM-RoBERTa) instead of onnxruntime. Two models +//! are wired through a small **registry** ([`MODEL_REGISTRY`]): //! -//! Output parity with the onnxruntime path was proven by the Phase 0 spike -//! (cosine 1.000000); this crate absorbs that pipeline verbatim: +//! * `multilingual-e5-large` — the same weights the default +//! [`FastembedEmbedder`](kebab_embed_local) uses (mean pooling, +//! `query: `/`passage: ` prefixes). candle is the NUMA-safe drop-in: +//! fastembed 4.9's onnxruntime hard-codes 48 intra-op threads, which +//! corrupts the heap (double-free) on dual-socket NUMA hosts. candle's +//! CPU backend sizes its threads off the global rayon pool, so a one-shot +//! [`rayon::ThreadPoolBuilder`] cap (config `num_threads` / env +//! `KEBAB_EMBED_THREADS`) keeps the worker count NUMA-safe. +//! * `snowflake-arctic-embed-l-v2.0` — Snowflake's arctic-embed v2.0 +//! (CLS pooling, `query: ` on queries / no prefix on documents). Same +//! XLM-RoBERTa-large architecture, dim 1024, so it rides the exact same +//! tokenize → forward → L2 pipeline; only the pooling step and prefixes +//! differ (both keyed off the per-model [`EmbedModelSpec`]). //! -//! 1. e5 prefix (`passage: ` for documents, `query: ` for queries — the same -//! convention as `kebab-embed-local`'s `prefix_input`); +//! Output parity with the onnxruntime path (for e5) was proven by the +//! Phase 0 spike (cosine 1.000000); the arctic path's pooling/prefix +//! correctness is pinned by an `#[ignore]`d cosine>0.99 cross-check against +//! Ollama's `snowflake-arctic-embed2` (see `tests/arctic_ollama_parity.rs`). +//! The shared pipeline: +//! +//! 1. instruction prefix per [`EmbedModelSpec`] (query/doc); //! 2. tokenize (max_len 512, batch-longest padding, special tokens); -//! 3. XLM-RoBERTa forward on `Device::Cpu`; -//! 4. attention-mask-weighted mean pooling; +//! 3. XLM-RoBERTa forward on the selected [`Device`]; +//! 4. pooling — mean (attention-mask-weighted) or CLS (first token); //! 5. L2 normalization. //! //! Model files (`config.json`, `tokenizer.json`, `model.safetensors`) are -//! fetched via `hf-hub` into `{config.storage.model_dir}/candle/`. +//! fetched via `hf-hub` into `{config.storage.model_dir}/candle/` (hf-hub's +//! cache layout namespaces by repo, so e5 and arctic never collide). //! //! This crate is **opt-in** (`config.models.embedding.provider = "candle"`); //! the default provider stays `fastembed`. See -//! `docs/superpowers/specs/2026-06-01-embed-candle-track-spec.md`. +//! `docs/superpowers/specs/2026-06-01-embed-candle-track-spec.md` and +//! `docs/superpowers/specs/2026-06-03-arctic-embedder-spec.md`. use std::sync::Mutex; @@ -42,22 +55,95 @@ use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; /// `fastembed/` subdir so the two backends never collide. const CANDLE_CACHE_SUBDIR: &str = "candle"; -/// HuggingFace repo id for the multilingual e5 large model. Same weights the -/// onnxruntime path uses, just the safetensors variant candle can read. -const HF_MODEL: &str = "intfloat/multilingual-e5-large"; - -/// The only `config.models.embedding.model` value the candle adapter accepts -/// (the e5-large weights `HF_MODEL` resolves to). Guards against silently -/// downloading e5-large while `model_id()` reports a different name. -const SUPPORTED_MODEL: &str = "multilingual-e5-large"; - -/// Token truncation length (e5 was trained at 512). +/// Token truncation length (both e5 and arctic-embed-l-v2.0 train at 512). const MAX_LEN: usize = 512; /// Env var that overrides `config.models.embedding.num_threads`. Read once in /// [`CandleEmbedder::new`]; `0`/unset/unparseable means "leave rayon default". const ENV_EMBED_THREADS: &str = "KEBAB_EMBED_THREADS"; +/// Pooling strategy over the model's last hidden state. Keyed per-model by +/// [`EmbedModelSpec::pooling`] — e5 is mean, arctic is CLS. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Pooling { + /// Attention-mask-weighted mean over all tokens (e5 / sentence-transformers + /// `pooling_mode_mean_tokens`). + Mean, + /// First token (``/`[CLS]`) hidden state (arctic-embed v2.0 — + /// `1_Pooling/config.json` has `pooling_mode_cls_token: true`). + Cls, +} + +/// One supported embedding model: the HF repo candle downloads, the pooling +/// strategy, and the e5-style instruction prefixes. [`MODEL_REGISTRY`] maps a +/// `config.models.embedding.model` value to one of these. +#[derive(Clone, Copy, Debug)] +pub struct EmbedModelSpec { + /// The short `config.models.embedding.model` value that selects this spec. + pub name: &'static str, + /// HuggingFace repo id candle fetches `config.json` / `tokenizer.json` / + /// `model.safetensors` from. + pub hf_repo: &'static str, + /// Pooling over the last hidden state. + pub pooling: Pooling, + /// Prefix prepended to **query** inputs before tokenization. + pub query_prefix: &'static str, + /// Prefix prepended to **document** inputs before tokenization (arctic + /// uses `""` — documents are embedded raw). + pub doc_prefix: &'static str, + /// Expected embedding dimension (model hidden size). + pub dim: usize, + /// Suffix folded into `model_version` so switching **to** this model + /// triggers the `embedding_version` cascade even if the operator forgets + /// to bump `config.version`. `None` keeps the bare `config.version` — used + /// by e5 so candle-e5 and fastembed-e5 report the *same* version and stay + /// interchangeable (the NUMA drop-in invariant — Phase 0 cosine 1.0). + pub version_tag: Option<&'static str>, +} + +/// The models the candle adapter can load. Adding a model = one entry here +/// (plus, for a non-XLM-R architecture, a new forward path — both current +/// entries are XLM-RoBERTa-large so they share everything but pooling/prefix). +static MODEL_REGISTRY: &[EmbedModelSpec] = &[ + EmbedModelSpec { + name: "multilingual-e5-large", + hf_repo: "intfloat/multilingual-e5-large", + pooling: Pooling::Mean, + query_prefix: "query: ", + doc_prefix: "passage: ", + dim: 1024, + version_tag: None, + }, + EmbedModelSpec { + name: "snowflake-arctic-embed-l-v2.0", + hf_repo: "Snowflake/snowflake-arctic-embed-l-v2.0", + pooling: Pooling::Cls, + query_prefix: "query: ", + doc_prefix: "", + dim: 1024, + version_tag: Some("arctic-cls"), + }, +]; + +/// Look up a model spec by `config.models.embedding.model`. Accepts either the +/// short `name` or the full `hf_repo` id (mirrors the old e5 guard, which +/// accepted both `multilingual-e5-large` and `intfloat/multilingual-e5-large`). +pub(crate) fn lookup_spec(model: &str) -> Option<&'static EmbedModelSpec> { + MODEL_REGISTRY + .iter() + .find(|s| s.name == model || s.hf_repo == model) +} + +/// Comma-separated list of supported model names, for the +/// unsupported-model error message. +fn supported_models() -> String { + MODEL_REGISTRY + .iter() + .map(|s| s.name) + .collect::>() + .join("`, `") +} + /// Pure-Rust candle adapter. Construct via [`CandleEmbedder::new`]; the /// constructor downloads the model on first use, so share one instance. pub struct CandleEmbedder { @@ -68,6 +154,9 @@ pub struct CandleEmbedder { model: Mutex, tokenizer: Tokenizer, device: Device, + /// The resolved model spec (pooling + prefixes) — drives `embed` and + /// `embed_batch`. + spec: &'static EmbedModelSpec, model_id: EmbeddingModelId, version: EmbeddingVersion, dimensions: usize, @@ -75,7 +164,8 @@ pub struct CandleEmbedder { } impl CandleEmbedder { - /// Build an embedder from `Config`. Applies the NUMA thread cap, fetches + /// Build an embedder from `Config`. Resolves the model spec from + /// `config.models.embedding.model`, applies the NUMA thread cap, fetches /// the model into `{model_dir}/candle/`, and validates that the model's /// hidden size matches `config.models.embedding.dimensions` before /// returning. @@ -104,21 +194,20 @@ impl CandleEmbedder { } } - // 1b. Model guard. `HF_MODEL` is hard-coded (candle currently only wires - // e5-large), so if the operator configured a *different* model name - // we must NOT silently download e5-large and then label its vectors - // with the configured name via `model_id()` — that would mislabel - // `embedding_version` and corrupt a mixed index. Fail fast, before - // the ~2GB download. + // 1b. Model registry lookup. If the operator configured a model the + // candle adapter doesn't know, fail fast (BEFORE the ~2GB + // download) — never silently download one model and then label its + // vectors with another name via `model_id()`, which would mislabel + // `embedding_version` and corrupt a mixed index. let want = config.models.embedding.model.as_str(); - if want != SUPPORTED_MODEL && want != HF_MODEL { - anyhow::bail!( - "candle provider currently supports only '{SUPPORTED_MODEL}' (or \ - the HF id '{HF_MODEL}'), but config.models.embedding.model = \ - '{want}'. Use provider=fastembed for other models, or set \ - model = \"{SUPPORTED_MODEL}\"." - ); - } + let spec = lookup_spec(want).ok_or_else(|| { + anyhow::anyhow!( + "candle provider supports the models `{}`, but \ + config.models.embedding.model = '{want}'. Use provider=fastembed \ + for other models, or pick a supported one.", + supported_models() + ) + })?; // 2. Resolve `{data_dir}/models/candle/` exactly like the fastembed // adapter resolves its own subdir. @@ -134,14 +223,15 @@ impl CandleEmbedder { tracing::info!( target: "kebab-embed-candle", cache_dir = %cache_dir.display(), - model = HF_MODEL, + model = spec.hf_repo, + pooling = ?spec.pooling, "loading candle embedding model (first run downloads ~2GB safetensors)" ); let api = hf_hub::api::sync::ApiBuilder::new() .with_cache_dir(cache_dir.clone()) .build() .context("kb-embed-candle: build hf-hub api")?; - let repo = api.model(HF_MODEL.to_string()); + let repo = api.model(spec.hf_repo.to_string()); let config_path = repo.get("config.json").context("download config.json")?; let tokenizer_path = repo .get("tokenizer.json") @@ -180,10 +270,21 @@ impl CandleEmbedder { })) .map_err(|e| anyhow::anyhow!("kb-embed-candle: set truncation: {e}"))?; + // model_version: fold the model tag in for non-e5 models so a switch + // triggers the embedding_version cascade; e5 keeps the bare + // config.version to stay interchangeable with fastembed-e5. + let version = match spec.version_tag { + Some(tag) => { + EmbeddingVersion(format!("{}+{}", config.models.embedding.version, tag)) + } + None => EmbeddingVersion(config.models.embedding.version.clone()), + }; + tracing::info!( target: "kebab-embed-candle", dimensions = cfg.hidden_size, layers = cfg.num_hidden_layers, + model = spec.name, "candle embedding model loaded" ); @@ -191,16 +292,17 @@ impl CandleEmbedder { model: Mutex::new(model), tokenizer, device, + spec, model_id: EmbeddingModelId(config.models.embedding.model.clone()), - version: EmbeddingVersion(config.models.embedding.version.clone()), + version, dimensions: cfg.hidden_size, batch_size: config.models.embedding.batch_size.max(1), }) } - /// Embed one batch of **already-prefixed** strings (the e5 `query:`/ - /// `passage:` prefix is applied by the caller [`CandleEmbedder::embed`]) - /// through the candle pipeline: tokenize → forward → masked mean pool → L2. + /// Embed one batch of **already-prefixed** strings (the per-model prefix + /// is applied by the caller [`CandleEmbedder::embed`]) through the candle + /// pipeline: tokenize → forward → pool (mean|CLS) → L2. fn embed_batch(&self, prefixed: &[String]) -> Result>> { let encodings = self .tokenizer @@ -237,18 +339,30 @@ impl CandleEmbedder { guard.forward(&input_ids, &attn_f32, &token_type_ids, None, None, None)? }; - // attention-mask-weighted mean pooling - let mask3 = attn_f32.unsqueeze(2)?; // (b, seq, 1) - let summed = hidden.broadcast_mul(&mask3)?.sum(1)?; // (b, hidden) - // counts ≥ 1 always: every input is e5-prefixed AND special tokens are - // added (encode_batch(_, true)), so no row has an all-zero mask. If that - // invariant ever breaks, broadcast_div would emit NaN vectors. - let counts = mask3.sum(1)?; // (b, 1) - let mean = summed.broadcast_div(&counts)?; + // Pooling — per the model spec. + let pooled = match self.spec.pooling { + Pooling::Mean => { + // attention-mask-weighted mean pooling + let mask3 = attn_f32.unsqueeze(2)?; // (b, seq, 1) + let summed = hidden.broadcast_mul(&mask3)?.sum(1)?; // (b, hidden) + // counts ≥ 1 always: every input is prefixed AND special + // tokens are added (encode_batch(_, true)), so no row has an + // all-zero mask. If that invariant ever breaks, broadcast_div + // would emit NaN vectors. + let counts = mask3.sum(1)?; // (b, 1) + summed.broadcast_div(&counts)? + } + Pooling::Cls => { + // CLS pooling: the first token's hidden state. arctic-embed + // v2.0 prepends `` (the XLM-R BOS/CLS) at index 0, so + // `hidden[:, 0, :]` is the sentence embedding. + hidden.narrow(1, 0, 1)?.squeeze(1)? // (b, hidden) + } + }; // L2 normalize - let norm = mean.sqr()?.sum_keepdim(1)?.sqrt()?; - let normalized = mean.broadcast_div(&norm)?; + let norm = pooled.sqr()?.sum_keepdim(1)?.sqrt()?; + let normalized = pooled.broadcast_div(&norm)?; // `.contiguous()` before host copy: broadcast ops can leave a strided // view, which `to_vec2` rejects on the Metal backend (CPU tolerates it). @@ -274,9 +388,9 @@ impl Embedder for CandleEmbedder { return Ok(Vec::new()); } - // e5 prefix per §11.3 BEFORE tokenization (same convention as - // FastembedEmbedder so the two backends produce comparable vectors). - let prefixed: Vec = inputs.iter().map(prefix_input).collect(); + // Per-model instruction prefix BEFORE tokenization (same convention as + // FastembedEmbedder for e5; arctic uses `query: `/no-prefix). + let prefixed: Vec = inputs.iter().map(|i| prefix_input(self.spec, i)).collect(); let mut out: Vec> = Vec::with_capacity(prefixed.len()); for chunk in prefixed.chunks(self.batch_size) { @@ -298,22 +412,22 @@ impl Embedder for CandleEmbedder { } } -/// Build the e5-prefixed string for one [`EmbeddingInput`]. Free function so -/// a unit test can pin the format without loading the model. Byte-identical to -/// `kebab-embed-local`'s `prefix_input` — the two backends MUST agree here or -/// their vectors diverge. -fn prefix_input(input: &EmbeddingInput<'_>) -> String { +/// Build the prefixed string for one [`EmbeddingInput`] using the model spec. +/// Free function so a unit test can pin the format without loading the model. +/// For e5 this is byte-identical to `kebab-embed-local`'s `prefix_input` — the +/// two backends MUST agree there or their vectors diverge. +fn prefix_input(spec: &EmbedModelSpec, input: &EmbeddingInput<'_>) -> String { match input.kind { - EmbeddingKind::Document => format!("passage: {}", input.text), - EmbeddingKind::Query => format!("query: {}", input.text), + EmbeddingKind::Document => format!("{}{}", spec.doc_prefix, input.text), + EmbeddingKind::Query => format!("{}{}", spec.query_prefix, input.text), } } /// Select the compute device. Built with the `metal` feature (Apple Silicon /// GPU), try Metal and fall back to CPU on failure; otherwise CPU. Metal only -/// compiles/runs on macOS — the Linux server builds the CPU path. e5-large -/// vectors are model-defined, so Metal-produced and CPU-produced embeddings are -/// cross-compatible (a Mac can ingest on GPU, the server query on CPU). +/// compiles/runs on macOS — the Linux server builds the CPU path. Embedding +/// vectors are model-defined, so Metal-produced and CPU-produced embeddings +/// are cross-compatible (a Mac can ingest on GPU, the server query on CPU). fn select_device() -> Device { #[cfg(feature = "metal")] { @@ -367,26 +481,85 @@ pub(crate) fn check_dim(model_dim: usize, cfg_dim: usize) -> Result<()> { mod tests { use super::*; - // ── prefix_input ───────────────────────────────────────────────── - // Pin the exact e5 prefix strings; these MUST match - // kebab-embed-local::prefix_input or candle vs fastembed parity breaks. + fn e5_spec() -> &'static EmbedModelSpec { + lookup_spec("multilingual-e5-large").expect("e5 in registry") + } + + fn arctic_spec() -> &'static EmbedModelSpec { + lookup_spec("snowflake-arctic-embed-l-v2.0").expect("arctic in registry") + } + + // ── registry ───────────────────────────────────────────────────── #[test] - fn prefix_document_uses_passage() { + fn registry_resolves_e5_by_name_and_hf_repo() { + assert_eq!( + lookup_spec("multilingual-e5-large").map(|s| s.name), + Some("multilingual-e5-large") + ); + assert_eq!( + lookup_spec("intfloat/multilingual-e5-large").map(|s| s.name), + Some("multilingual-e5-large") + ); + } + + #[test] + fn registry_resolves_arctic_and_its_pooling_is_cls() { + let s = arctic_spec(); + assert_eq!(s.name, "snowflake-arctic-embed-l-v2.0"); + assert_eq!(s.hf_repo, "Snowflake/snowflake-arctic-embed-l-v2.0"); + assert_eq!(s.pooling, Pooling::Cls); + assert_eq!(s.dim, 1024); + assert_eq!(s.version_tag, Some("arctic-cls")); + } + + #[test] + fn registry_e5_is_mean_pooling_no_version_tag() { + let s = e5_spec(); + assert_eq!(s.pooling, Pooling::Mean); + assert_eq!(s.version_tag, None); + } + + #[test] + fn registry_rejects_unknown_model() { + assert!(lookup_spec("multilingual-e5-small").is_none()); + } + + // ── prefix_input ───────────────────────────────────────────────── + // e5 prefixes MUST match kebab-embed-local::prefix_input or candle vs + // fastembed parity breaks; arctic uses query-only prefixing. + + #[test] + fn e5_prefix_document_uses_passage() { let input = EmbeddingInput { text: "hello world", kind: EmbeddingKind::Document, }; - assert_eq!(prefix_input(&input), "passage: hello world"); + assert_eq!(prefix_input(e5_spec(), &input), "passage: hello world"); } #[test] - fn prefix_query_uses_query() { + fn e5_prefix_query_uses_query() { let input = EmbeddingInput { text: "hello world", kind: EmbeddingKind::Query, }; - assert_eq!(prefix_input(&input), "query: hello world"); + assert_eq!(prefix_input(e5_spec(), &input), "query: hello world"); + } + + #[test] + fn arctic_prefix_query_uses_query_doc_is_bare() { + let doc = EmbeddingInput { + text: "후입선출 자료구조", + kind: EmbeddingKind::Document, + }; + let qry = EmbeddingInput { + text: "스택 자료구조", + kind: EmbeddingKind::Query, + }; + // arctic: documents are embedded raw, queries get `query: `. + assert_eq!(prefix_input(arctic_spec(), &doc), "후입선출 자료구조"); + assert_eq!(prefix_input(arctic_spec(), &qry), "query: 스택 자료구조"); } #[test] @@ -399,8 +572,10 @@ mod tests { text: "", kind: EmbeddingKind::Query, }; - assert_eq!(prefix_input(&doc), "passage: "); - assert_eq!(prefix_input(&qry), "query: "); + assert_eq!(prefix_input(e5_spec(), &doc), "passage: "); + assert_eq!(prefix_input(e5_spec(), &qry), "query: "); + assert_eq!(prefix_input(arctic_spec(), &doc), ""); + assert_eq!(prefix_input(arctic_spec(), &qry), "query: "); } // ── check_dim ──────────────────────────────────────────────────── @@ -421,9 +596,9 @@ mod tests { } // ── model guard ────────────────────────────────────────────────── - // A non-e5-large model name must fail fast (BEFORE the ~2GB download), - // so we never download e5-large yet label its vectors with another name - // via model_id() — which would mislabel embedding_version. + // A model name not in the registry must fail fast (BEFORE the ~2GB + // download), so we never download one model yet label its vectors with + // another name via model_id() — which would mislabel embedding_version. #[test] fn new_rejects_unsupported_model() { @@ -437,8 +612,8 @@ mod tests { .expect("unsupported model must error"); let msg = format!("{err:#}"); assert!( - msg.contains("candle provider currently supports only"), - "expected model-guard error, got: {msg}" + msg.contains("candle provider supports the models"), + "expected model-registry error, got: {msg}" ); } } diff --git a/crates/kebab-embed-candle/tests/arctic_ollama_parity.rs b/crates/kebab-embed-candle/tests/arctic_ollama_parity.rs new file mode 100644 index 0000000..ccc3504 --- /dev/null +++ b/crates/kebab-embed-candle/tests/arctic_ollama_parity.rs @@ -0,0 +1,128 @@ +//! arctic-embed-l-v2.0 correctness gate (`#[ignore]` — needs the ~2GB candle +//! model + a live Ollama serving `snowflake-arctic-embed2`). +//! +//! This is the load-bearing pooling/prefix check for the arctic integration. +//! The recall measurement that justified adopting arctic (recall@10 130/132) +//! went through Ollama's `snowflake-arctic-embed2`. The candle path +//! re-implements the model (XLM-RoBERTa-large + **CLS** pooling + `query: ` on +//! queries / no prefix on documents). If candle's pooling or prefix is wrong, +//! its vectors silently diverge from the measured route and the 130 number +//! does NOT carry over. This test pins them together: per-sentence cosine +//! between the candle vector and the Ollama vector must be **> 0.99**. +//! +//! `#[ignore]` because it depends on an external Ollama daemon (CI is +//! headless/offline). The leader MUST run it once before merge. +//! +//! ## Manual run +//! +//! 1. Confirm Ollama is reachable and has the model: +//! ```sh +//! curl -s http://192.168.0.47:11434/api/tags # should list snowflake-arctic-embed2 +//! ``` +//! 2. Run (downloads the ~2GB candle safetensors on first run): +//! ```sh +//! CARGO_TARGET_DIR=/build/out/cargo-target \ +//! KEBAB_ARCTIC_OLLAMA_ENDPOINT=http://192.168.0.47:11434 \ +//! cargo test -p kebab-embed-candle --test arctic_ollama_parity -- --ignored --nocapture +//! ``` +//! The endpoint defaults to `http://192.168.0.47:11434` if the env is unset. +//! +//! Record the printed `ARCTIC_PARITY_SUMMARY cosine_min=...` in +//! `/tmp/arctic-result.md` + `tasks/HOTFIXES.md`. + +use kebab_config::Config; +use kebab_core::{Embedder, EmbeddingInput, EmbeddingKind}; +use kebab_embed_candle::CandleEmbedder; +use kebab_embed_ollama::OllamaEmbedder; + +const DOGFOOD_CONFIG: &str = "/build/dogfood/config.toml"; +const DEFAULT_OLLAMA_ENDPOINT: &str = "http://192.168.0.47:11434"; + +/// Mixed Korean / English + the descriptive-recall shapes arctic was adopted +/// for (synonym / abbreviation / English term). Covers both prefix paths. +const SENTENCES: &[&str] = &[ + "스택 자료구조", + "후입선출 방식으로 동작하는 자료구조", + "큐는 선입선출 자료구조이다", + "Rust ownership and the borrow checker", + "소유권과 빌림 검사기는 메모리 안전성을 보장한다", + "SVM 은 support vector machine 의 약자이다", + "정렬 알고리즘의 시간 복잡도", + "The capital of France is Paris.", +]; + +fn cosine(a: &[f32], b: &[f32]) -> f32 { + let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum(); + let na: f32 = a.iter().map(|x| x * x).sum::().sqrt(); + let nb: f32 = b.iter().map(|x| x * x).sum::().sqrt(); + dot / (na * nb) +} + +/// Base config: prefer the canonical dogfood config (for storage/cache roots), +/// fall back to `Config::defaults()` so the test still runs on a bare clone. +fn base_config() -> Config { + Config::load(Some(std::path::Path::new(DOGFOOD_CONFIG))).unwrap_or_else(|_| Config::defaults()) +} + +#[test] +#[ignore = "needs ~2GB candle model + live Ollama (snowflake-arctic-embed2); run manually before merge"] +fn candle_arctic_matches_ollama_arctic() { + let endpoint = std::env::var("KEBAB_ARCTIC_OLLAMA_ENDPOINT") + .unwrap_or_else(|_| DEFAULT_OLLAMA_ENDPOINT.to_string()); + + // candle side: the in-process arctic model. + let mut candle_cfg = base_config(); + candle_cfg.models.embedding.provider = "candle".to_string(); + candle_cfg.models.embedding.model = "snowflake-arctic-embed-l-v2.0".to_string(); + candle_cfg.models.embedding.dimensions = 1024; + + // Ollama side: the reference route the recall numbers came from. + let mut ollama_cfg = base_config(); + ollama_cfg.models.embedding.provider = "ollama".to_string(); + ollama_cfg.models.embedding.model = "snowflake-arctic-embed2".to_string(); + ollama_cfg.models.embedding.dimensions = 1024; + ollama_cfg.models.embedding.endpoint = Some(endpoint.clone()); + + let candle = CandleEmbedder::new(&candle_cfg).expect("build candle arctic embedder"); + let ollama = OllamaEmbedder::new(&ollama_cfg).expect("build ollama arctic embedder"); + + // Exercise BOTH prefix paths so a query-side divergence can't hide. + let inputs: Vec = SENTENCES + .iter() + .flat_map(|s| { + [EmbeddingKind::Document, EmbeddingKind::Query] + .into_iter() + .map(move |kind| EmbeddingInput { text: s, kind }) + }) + .collect(); + + let cv = candle.embed(&inputs).expect("candle embed"); + let ov = ollama + .embed(&inputs) + .expect("ollama embed (is snowflake-arctic-embed2 pulled @ the endpoint?)"); + + assert_eq!(cv.len(), ov.len(), "embedding counts must match"); + assert_eq!(cv.len(), inputs.len(), "one vector per input"); + assert_eq!(candle.dimensions(), 1024); + + let mut min_cos = f32::INFINITY; + for (i, inp) in inputs.iter().enumerate() { + assert_eq!(cv[i].len(), 1024, "candle dim"); + assert_eq!(ov[i].len(), 1024, "ollama dim"); + let c = cosine(&cv[i], &ov[i]); + min_cos = min_cos.min(c); + let kind = match inp.kind { + EmbeddingKind::Document => "doc", + EmbeddingKind::Query => "qry", + }; + let preview: String = inp.text.chars().take(36).collect(); + println!("[{i:>2}] {kind} cos={c:.6} {preview}"); + } + + println!("ARCTIC_PARITY_SUMMARY cosine_min={min_cos:.6} endpoint={endpoint}"); + assert!( + min_cos > 0.99, + "candle arctic vs Ollama arctic cosine_min={min_cos:.6} ≤ 0.99 — \ + pooling/prefix mismatch; the recall=130 measurement will NOT reproduce" + ); +}