feat(embed): candle 모델 레지스트리 + arctic-embed-l-v2.0 (CLS pooling)
e5 하드코딩(HF_MODEL/SUPPORTED_MODEL/mean/query:+passage:) → 모델 레지스트리
EmbedModelSpec{name,hf_repo,pooling,query_prefix,doc_prefix,dim,version_tag}.
e5(mean, query:/passage:) + arctic(CLS, query:/무접두어). pooling 모델별 분기
(mean=attention-mask-weighted / CLS=hidden[:,0,:]), tokenize/forward/L2 공유.
arctic pooling=CLS 는 HF 1_Pooling/config.json(pooling_mode_cls_token:true) 확인.
model_version 은 arctic 일 때 +arctic-cls 태그(embedding_version cascade 트리거);
e5 는 fastembed-e5 호환(NUMA 드롭인) 위해 plain config.version 유지.
correctness 게이트: tests/arctic_ollama_parity.rs (#[ignore], live Ollama) —
candle arctic vs Ollama snowflake-arctic-embed2 per-sentence 코사인>0.99.
수동 실측 cosine_min=0.999984 (recall@10 130 재현 보장).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -38,6 +38,9 @@ metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
|
|||||||
# not the library's own (non-dev) dependencies — so rayon/kebab-config/kebab-core
|
# not the library's own (non-dev) dependencies — so rayon/kebab-config/kebab-core
|
||||||
# are repeated here for tests/parity.rs and tests/thread_cap.rs.
|
# are repeated here for tests/parity.rs and tests/thread_cap.rs.
|
||||||
kebab-embed-local = { path = "../kebab-embed-local" }
|
kebab-embed-local = { path = "../kebab-embed-local" }
|
||||||
|
# arctic↔Ollama parity test drives the real Ollama adapter for the reference
|
||||||
|
# vectors (tests/arctic_ollama_parity.rs, `#[ignore]` — live Ollama).
|
||||||
|
kebab-embed-ollama = { path = "../kebab-embed-ollama" }
|
||||||
kebab-config = { path = "../kebab-config" }
|
kebab-config = { path = "../kebab-config" }
|
||||||
kebab-core = { path = "../kebab-core" }
|
kebab-core = { path = "../kebab-core" }
|
||||||
rayon = "1"
|
rayon = "1"
|
||||||
|
|||||||
@@ -1,31 +1,44 @@
|
|||||||
//! `kebab-embed-candle` — [`CandleEmbedder`], a pure-Rust (candle)
|
//! `kebab-embed-candle` — [`CandleEmbedder`], a pure-Rust (candle)
|
||||||
//! implementation of [`Embedder`](kebab_core::Embedder).
|
//! implementation of [`Embedder`](kebab_core::Embedder).
|
||||||
//!
|
//!
|
||||||
//! Runs the same `intfloat/multilingual-e5-large` model as the default
|
//! Runs an XLM-RoBERTa-large embedding model through `candle`
|
||||||
//! [`FastembedEmbedder`](kebab_embed_local) but through `candle`
|
//! (`candle-transformers`' XLM-RoBERTa) instead of onnxruntime. Two models
|
||||||
//! (`candle-transformers`' XLM-RoBERTa) instead of onnxruntime. Motivation:
|
//! are wired through a small **registry** ([`MODEL_REGISTRY`]):
|
||||||
//! fastembed 4.9's onnxruntime hard-codes 48 intra-op threads, which corrupts
|
|
||||||
//! the heap (double-free) on dual-socket NUMA hosts. candle's CPU backend
|
|
||||||
//! sizes its threads off the global rayon pool, so a one-shot
|
|
||||||
//! [`rayon::ThreadPoolBuilder`] cap (config `num_threads` / env
|
|
||||||
//! `KEBAB_EMBED_THREADS`) keeps the worker count NUMA-safe.
|
|
||||||
//!
|
//!
|
||||||
//! Output parity with the onnxruntime path was proven by the Phase 0 spike
|
//! * `multilingual-e5-large` — the same weights the default
|
||||||
//! (cosine 1.000000); this crate absorbs that pipeline verbatim:
|
//! [`FastembedEmbedder`](kebab_embed_local) uses (mean pooling,
|
||||||
|
//! `query: `/`passage: ` prefixes). candle is the NUMA-safe drop-in:
|
||||||
|
//! fastembed 4.9's onnxruntime hard-codes 48 intra-op threads, which
|
||||||
|
//! corrupts the heap (double-free) on dual-socket NUMA hosts. candle's
|
||||||
|
//! CPU backend sizes its threads off the global rayon pool, so a one-shot
|
||||||
|
//! [`rayon::ThreadPoolBuilder`] cap (config `num_threads` / env
|
||||||
|
//! `KEBAB_EMBED_THREADS`) keeps the worker count NUMA-safe.
|
||||||
|
//! * `snowflake-arctic-embed-l-v2.0` — Snowflake's arctic-embed v2.0
|
||||||
|
//! (CLS pooling, `query: ` on queries / no prefix on documents). Same
|
||||||
|
//! XLM-RoBERTa-large architecture, dim 1024, so it rides the exact same
|
||||||
|
//! tokenize → forward → L2 pipeline; only the pooling step and prefixes
|
||||||
|
//! differ (both keyed off the per-model [`EmbedModelSpec`]).
|
||||||
//!
|
//!
|
||||||
//! 1. e5 prefix (`passage: ` for documents, `query: ` for queries — the same
|
//! Output parity with the onnxruntime path (for e5) was proven by the
|
||||||
//! convention as `kebab-embed-local`'s `prefix_input`);
|
//! Phase 0 spike (cosine 1.000000); the arctic path's pooling/prefix
|
||||||
|
//! correctness is pinned by an `#[ignore]`d cosine>0.99 cross-check against
|
||||||
|
//! Ollama's `snowflake-arctic-embed2` (see `tests/arctic_ollama_parity.rs`).
|
||||||
|
//! The shared pipeline:
|
||||||
|
//!
|
||||||
|
//! 1. instruction prefix per [`EmbedModelSpec`] (query/doc);
|
||||||
//! 2. tokenize (max_len 512, batch-longest padding, special tokens);
|
//! 2. tokenize (max_len 512, batch-longest padding, special tokens);
|
||||||
//! 3. XLM-RoBERTa forward on `Device::Cpu`;
|
//! 3. XLM-RoBERTa forward on the selected [`Device`];
|
||||||
//! 4. attention-mask-weighted mean pooling;
|
//! 4. pooling — mean (attention-mask-weighted) or CLS (first token);
|
||||||
//! 5. L2 normalization.
|
//! 5. L2 normalization.
|
||||||
//!
|
//!
|
||||||
//! Model files (`config.json`, `tokenizer.json`, `model.safetensors`) are
|
//! Model files (`config.json`, `tokenizer.json`, `model.safetensors`) are
|
||||||
//! fetched via `hf-hub` into `{config.storage.model_dir}/candle/`.
|
//! fetched via `hf-hub` into `{config.storage.model_dir}/candle/` (hf-hub's
|
||||||
|
//! cache layout namespaces by repo, so e5 and arctic never collide).
|
||||||
//!
|
//!
|
||||||
//! This crate is **opt-in** (`config.models.embedding.provider = "candle"`);
|
//! This crate is **opt-in** (`config.models.embedding.provider = "candle"`);
|
||||||
//! the default provider stays `fastembed`. See
|
//! the default provider stays `fastembed`. See
|
||||||
//! `docs/superpowers/specs/2026-06-01-embed-candle-track-spec.md`.
|
//! `docs/superpowers/specs/2026-06-01-embed-candle-track-spec.md` and
|
||||||
|
//! `docs/superpowers/specs/2026-06-03-arctic-embedder-spec.md`.
|
||||||
|
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
|
|
||||||
@@ -42,22 +55,95 @@ use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
|
|||||||
/// `fastembed/` subdir so the two backends never collide.
|
/// `fastembed/` subdir so the two backends never collide.
|
||||||
const CANDLE_CACHE_SUBDIR: &str = "candle";
|
const CANDLE_CACHE_SUBDIR: &str = "candle";
|
||||||
|
|
||||||
/// HuggingFace repo id for the multilingual e5 large model. Same weights the
|
/// Token truncation length (both e5 and arctic-embed-l-v2.0 train at 512).
|
||||||
/// onnxruntime path uses, just the safetensors variant candle can read.
|
|
||||||
const HF_MODEL: &str = "intfloat/multilingual-e5-large";
|
|
||||||
|
|
||||||
/// The only `config.models.embedding.model` value the candle adapter accepts
|
|
||||||
/// (the e5-large weights `HF_MODEL` resolves to). Guards against silently
|
|
||||||
/// downloading e5-large while `model_id()` reports a different name.
|
|
||||||
const SUPPORTED_MODEL: &str = "multilingual-e5-large";
|
|
||||||
|
|
||||||
/// Token truncation length (e5 was trained at 512).
|
|
||||||
const MAX_LEN: usize = 512;
|
const MAX_LEN: usize = 512;
|
||||||
|
|
||||||
/// Env var that overrides `config.models.embedding.num_threads`. Read once in
|
/// Env var that overrides `config.models.embedding.num_threads`. Read once in
|
||||||
/// [`CandleEmbedder::new`]; `0`/unset/unparseable means "leave rayon default".
|
/// [`CandleEmbedder::new`]; `0`/unset/unparseable means "leave rayon default".
|
||||||
const ENV_EMBED_THREADS: &str = "KEBAB_EMBED_THREADS";
|
const ENV_EMBED_THREADS: &str = "KEBAB_EMBED_THREADS";
|
||||||
|
|
||||||
|
/// Pooling strategy over the model's last hidden state. Keyed per-model by
|
||||||
|
/// [`EmbedModelSpec::pooling`] — e5 is mean, arctic is CLS.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
|
pub enum Pooling {
|
||||||
|
/// Attention-mask-weighted mean over all tokens (e5 / sentence-transformers
|
||||||
|
/// `pooling_mode_mean_tokens`).
|
||||||
|
Mean,
|
||||||
|
/// First token (`<s>`/`[CLS]`) hidden state (arctic-embed v2.0 —
|
||||||
|
/// `1_Pooling/config.json` has `pooling_mode_cls_token: true`).
|
||||||
|
Cls,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// One supported embedding model: the HF repo candle downloads, the pooling
|
||||||
|
/// strategy, and the e5-style instruction prefixes. [`MODEL_REGISTRY`] maps a
|
||||||
|
/// `config.models.embedding.model` value to one of these.
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub struct EmbedModelSpec {
|
||||||
|
/// The short `config.models.embedding.model` value that selects this spec.
|
||||||
|
pub name: &'static str,
|
||||||
|
/// HuggingFace repo id candle fetches `config.json` / `tokenizer.json` /
|
||||||
|
/// `model.safetensors` from.
|
||||||
|
pub hf_repo: &'static str,
|
||||||
|
/// Pooling over the last hidden state.
|
||||||
|
pub pooling: Pooling,
|
||||||
|
/// Prefix prepended to **query** inputs before tokenization.
|
||||||
|
pub query_prefix: &'static str,
|
||||||
|
/// Prefix prepended to **document** inputs before tokenization (arctic
|
||||||
|
/// uses `""` — documents are embedded raw).
|
||||||
|
pub doc_prefix: &'static str,
|
||||||
|
/// Expected embedding dimension (model hidden size).
|
||||||
|
pub dim: usize,
|
||||||
|
/// Suffix folded into `model_version` so switching **to** this model
|
||||||
|
/// triggers the `embedding_version` cascade even if the operator forgets
|
||||||
|
/// to bump `config.version`. `None` keeps the bare `config.version` — used
|
||||||
|
/// by e5 so candle-e5 and fastembed-e5 report the *same* version and stay
|
||||||
|
/// interchangeable (the NUMA drop-in invariant — Phase 0 cosine 1.0).
|
||||||
|
pub version_tag: Option<&'static str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The models the candle adapter can load. Adding a model = one entry here
|
||||||
|
/// (plus, for a non-XLM-R architecture, a new forward path — both current
|
||||||
|
/// entries are XLM-RoBERTa-large so they share everything but pooling/prefix).
|
||||||
|
static MODEL_REGISTRY: &[EmbedModelSpec] = &[
|
||||||
|
EmbedModelSpec {
|
||||||
|
name: "multilingual-e5-large",
|
||||||
|
hf_repo: "intfloat/multilingual-e5-large",
|
||||||
|
pooling: Pooling::Mean,
|
||||||
|
query_prefix: "query: ",
|
||||||
|
doc_prefix: "passage: ",
|
||||||
|
dim: 1024,
|
||||||
|
version_tag: None,
|
||||||
|
},
|
||||||
|
EmbedModelSpec {
|
||||||
|
name: "snowflake-arctic-embed-l-v2.0",
|
||||||
|
hf_repo: "Snowflake/snowflake-arctic-embed-l-v2.0",
|
||||||
|
pooling: Pooling::Cls,
|
||||||
|
query_prefix: "query: ",
|
||||||
|
doc_prefix: "",
|
||||||
|
dim: 1024,
|
||||||
|
version_tag: Some("arctic-cls"),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Look up a model spec by `config.models.embedding.model`. Accepts either the
|
||||||
|
/// short `name` or the full `hf_repo` id (mirrors the old e5 guard, which
|
||||||
|
/// accepted both `multilingual-e5-large` and `intfloat/multilingual-e5-large`).
|
||||||
|
pub(crate) fn lookup_spec(model: &str) -> Option<&'static EmbedModelSpec> {
|
||||||
|
MODEL_REGISTRY
|
||||||
|
.iter()
|
||||||
|
.find(|s| s.name == model || s.hf_repo == model)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Comma-separated list of supported model names, for the
|
||||||
|
/// unsupported-model error message.
|
||||||
|
fn supported_models() -> String {
|
||||||
|
MODEL_REGISTRY
|
||||||
|
.iter()
|
||||||
|
.map(|s| s.name)
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("`, `")
|
||||||
|
}
|
||||||
|
|
||||||
/// Pure-Rust candle adapter. Construct via [`CandleEmbedder::new`]; the
|
/// Pure-Rust candle adapter. Construct via [`CandleEmbedder::new`]; the
|
||||||
/// constructor downloads the model on first use, so share one instance.
|
/// constructor downloads the model on first use, so share one instance.
|
||||||
pub struct CandleEmbedder {
|
pub struct CandleEmbedder {
|
||||||
@@ -68,6 +154,9 @@ pub struct CandleEmbedder {
|
|||||||
model: Mutex<XLMRobertaModel>,
|
model: Mutex<XLMRobertaModel>,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
device: Device,
|
device: Device,
|
||||||
|
/// The resolved model spec (pooling + prefixes) — drives `embed` and
|
||||||
|
/// `embed_batch`.
|
||||||
|
spec: &'static EmbedModelSpec,
|
||||||
model_id: EmbeddingModelId,
|
model_id: EmbeddingModelId,
|
||||||
version: EmbeddingVersion,
|
version: EmbeddingVersion,
|
||||||
dimensions: usize,
|
dimensions: usize,
|
||||||
@@ -75,7 +164,8 @@ pub struct CandleEmbedder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl CandleEmbedder {
|
impl CandleEmbedder {
|
||||||
/// Build an embedder from `Config`. Applies the NUMA thread cap, fetches
|
/// Build an embedder from `Config`. Resolves the model spec from
|
||||||
|
/// `config.models.embedding.model`, applies the NUMA thread cap, fetches
|
||||||
/// the model into `{model_dir}/candle/`, and validates that the model's
|
/// the model into `{model_dir}/candle/`, and validates that the model's
|
||||||
/// hidden size matches `config.models.embedding.dimensions` before
|
/// hidden size matches `config.models.embedding.dimensions` before
|
||||||
/// returning.
|
/// returning.
|
||||||
@@ -104,21 +194,20 @@ impl CandleEmbedder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1b. Model guard. `HF_MODEL` is hard-coded (candle currently only wires
|
// 1b. Model registry lookup. If the operator configured a model the
|
||||||
// e5-large), so if the operator configured a *different* model name
|
// candle adapter doesn't know, fail fast (BEFORE the ~2GB
|
||||||
// we must NOT silently download e5-large and then label its vectors
|
// download) — never silently download one model and then label its
|
||||||
// with the configured name via `model_id()` — that would mislabel
|
// vectors with another name via `model_id()`, which would mislabel
|
||||||
// `embedding_version` and corrupt a mixed index. Fail fast, before
|
// `embedding_version` and corrupt a mixed index.
|
||||||
// the ~2GB download.
|
|
||||||
let want = config.models.embedding.model.as_str();
|
let want = config.models.embedding.model.as_str();
|
||||||
if want != SUPPORTED_MODEL && want != HF_MODEL {
|
let spec = lookup_spec(want).ok_or_else(|| {
|
||||||
anyhow::bail!(
|
anyhow::anyhow!(
|
||||||
"candle provider currently supports only '{SUPPORTED_MODEL}' (or \
|
"candle provider supports the models `{}`, but \
|
||||||
the HF id '{HF_MODEL}'), but config.models.embedding.model = \
|
config.models.embedding.model = '{want}'. Use provider=fastembed \
|
||||||
'{want}'. Use provider=fastembed for other models, or set \
|
for other models, or pick a supported one.",
|
||||||
model = \"{SUPPORTED_MODEL}\"."
|
supported_models()
|
||||||
);
|
)
|
||||||
}
|
})?;
|
||||||
|
|
||||||
// 2. Resolve `{data_dir}/models/candle/` exactly like the fastembed
|
// 2. Resolve `{data_dir}/models/candle/` exactly like the fastembed
|
||||||
// adapter resolves its own subdir.
|
// adapter resolves its own subdir.
|
||||||
@@ -134,14 +223,15 @@ impl CandleEmbedder {
|
|||||||
tracing::info!(
|
tracing::info!(
|
||||||
target: "kebab-embed-candle",
|
target: "kebab-embed-candle",
|
||||||
cache_dir = %cache_dir.display(),
|
cache_dir = %cache_dir.display(),
|
||||||
model = HF_MODEL,
|
model = spec.hf_repo,
|
||||||
|
pooling = ?spec.pooling,
|
||||||
"loading candle embedding model (first run downloads ~2GB safetensors)"
|
"loading candle embedding model (first run downloads ~2GB safetensors)"
|
||||||
);
|
);
|
||||||
let api = hf_hub::api::sync::ApiBuilder::new()
|
let api = hf_hub::api::sync::ApiBuilder::new()
|
||||||
.with_cache_dir(cache_dir.clone())
|
.with_cache_dir(cache_dir.clone())
|
||||||
.build()
|
.build()
|
||||||
.context("kb-embed-candle: build hf-hub api")?;
|
.context("kb-embed-candle: build hf-hub api")?;
|
||||||
let repo = api.model(HF_MODEL.to_string());
|
let repo = api.model(spec.hf_repo.to_string());
|
||||||
let config_path = repo.get("config.json").context("download config.json")?;
|
let config_path = repo.get("config.json").context("download config.json")?;
|
||||||
let tokenizer_path = repo
|
let tokenizer_path = repo
|
||||||
.get("tokenizer.json")
|
.get("tokenizer.json")
|
||||||
@@ -180,10 +270,21 @@ impl CandleEmbedder {
|
|||||||
}))
|
}))
|
||||||
.map_err(|e| anyhow::anyhow!("kb-embed-candle: set truncation: {e}"))?;
|
.map_err(|e| anyhow::anyhow!("kb-embed-candle: set truncation: {e}"))?;
|
||||||
|
|
||||||
|
// model_version: fold the model tag in for non-e5 models so a switch
|
||||||
|
// triggers the embedding_version cascade; e5 keeps the bare
|
||||||
|
// config.version to stay interchangeable with fastembed-e5.
|
||||||
|
let version = match spec.version_tag {
|
||||||
|
Some(tag) => {
|
||||||
|
EmbeddingVersion(format!("{}+{}", config.models.embedding.version, tag))
|
||||||
|
}
|
||||||
|
None => EmbeddingVersion(config.models.embedding.version.clone()),
|
||||||
|
};
|
||||||
|
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
target: "kebab-embed-candle",
|
target: "kebab-embed-candle",
|
||||||
dimensions = cfg.hidden_size,
|
dimensions = cfg.hidden_size,
|
||||||
layers = cfg.num_hidden_layers,
|
layers = cfg.num_hidden_layers,
|
||||||
|
model = spec.name,
|
||||||
"candle embedding model loaded"
|
"candle embedding model loaded"
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -191,16 +292,17 @@ impl CandleEmbedder {
|
|||||||
model: Mutex::new(model),
|
model: Mutex::new(model),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
device,
|
device,
|
||||||
|
spec,
|
||||||
model_id: EmbeddingModelId(config.models.embedding.model.clone()),
|
model_id: EmbeddingModelId(config.models.embedding.model.clone()),
|
||||||
version: EmbeddingVersion(config.models.embedding.version.clone()),
|
version,
|
||||||
dimensions: cfg.hidden_size,
|
dimensions: cfg.hidden_size,
|
||||||
batch_size: config.models.embedding.batch_size.max(1),
|
batch_size: config.models.embedding.batch_size.max(1),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Embed one batch of **already-prefixed** strings (the e5 `query:`/
|
/// Embed one batch of **already-prefixed** strings (the per-model prefix
|
||||||
/// `passage:` prefix is applied by the caller [`CandleEmbedder::embed`])
|
/// is applied by the caller [`CandleEmbedder::embed`]) through the candle
|
||||||
/// through the candle pipeline: tokenize → forward → masked mean pool → L2.
|
/// pipeline: tokenize → forward → pool (mean|CLS) → L2.
|
||||||
fn embed_batch(&self, prefixed: &[String]) -> Result<Vec<Vec<f32>>> {
|
fn embed_batch(&self, prefixed: &[String]) -> Result<Vec<Vec<f32>>> {
|
||||||
let encodings = self
|
let encodings = self
|
||||||
.tokenizer
|
.tokenizer
|
||||||
@@ -237,18 +339,30 @@ impl CandleEmbedder {
|
|||||||
guard.forward(&input_ids, &attn_f32, &token_type_ids, None, None, None)?
|
guard.forward(&input_ids, &attn_f32, &token_type_ids, None, None, None)?
|
||||||
};
|
};
|
||||||
|
|
||||||
// attention-mask-weighted mean pooling
|
// Pooling — per the model spec.
|
||||||
let mask3 = attn_f32.unsqueeze(2)?; // (b, seq, 1)
|
let pooled = match self.spec.pooling {
|
||||||
let summed = hidden.broadcast_mul(&mask3)?.sum(1)?; // (b, hidden)
|
Pooling::Mean => {
|
||||||
// counts ≥ 1 always: every input is e5-prefixed AND special tokens are
|
// attention-mask-weighted mean pooling
|
||||||
// added (encode_batch(_, true)), so no row has an all-zero mask. If that
|
let mask3 = attn_f32.unsqueeze(2)?; // (b, seq, 1)
|
||||||
// invariant ever breaks, broadcast_div would emit NaN vectors.
|
let summed = hidden.broadcast_mul(&mask3)?.sum(1)?; // (b, hidden)
|
||||||
let counts = mask3.sum(1)?; // (b, 1)
|
// counts ≥ 1 always: every input is prefixed AND special
|
||||||
let mean = summed.broadcast_div(&counts)?;
|
// tokens are added (encode_batch(_, true)), so no row has an
|
||||||
|
// all-zero mask. If that invariant ever breaks, broadcast_div
|
||||||
|
// would emit NaN vectors.
|
||||||
|
let counts = mask3.sum(1)?; // (b, 1)
|
||||||
|
summed.broadcast_div(&counts)?
|
||||||
|
}
|
||||||
|
Pooling::Cls => {
|
||||||
|
// CLS pooling: the first token's hidden state. arctic-embed
|
||||||
|
// v2.0 prepends `<s>` (the XLM-R BOS/CLS) at index 0, so
|
||||||
|
// `hidden[:, 0, :]` is the sentence embedding.
|
||||||
|
hidden.narrow(1, 0, 1)?.squeeze(1)? // (b, hidden)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// L2 normalize
|
// L2 normalize
|
||||||
let norm = mean.sqr()?.sum_keepdim(1)?.sqrt()?;
|
let norm = pooled.sqr()?.sum_keepdim(1)?.sqrt()?;
|
||||||
let normalized = mean.broadcast_div(&norm)?;
|
let normalized = pooled.broadcast_div(&norm)?;
|
||||||
|
|
||||||
// `.contiguous()` before host copy: broadcast ops can leave a strided
|
// `.contiguous()` before host copy: broadcast ops can leave a strided
|
||||||
// view, which `to_vec2` rejects on the Metal backend (CPU tolerates it).
|
// view, which `to_vec2` rejects on the Metal backend (CPU tolerates it).
|
||||||
@@ -274,9 +388,9 @@ impl Embedder for CandleEmbedder {
|
|||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
|
|
||||||
// e5 prefix per §11.3 BEFORE tokenization (same convention as
|
// Per-model instruction prefix BEFORE tokenization (same convention as
|
||||||
// FastembedEmbedder so the two backends produce comparable vectors).
|
// FastembedEmbedder for e5; arctic uses `query: `/no-prefix).
|
||||||
let prefixed: Vec<String> = inputs.iter().map(prefix_input).collect();
|
let prefixed: Vec<String> = inputs.iter().map(|i| prefix_input(self.spec, i)).collect();
|
||||||
|
|
||||||
let mut out: Vec<Vec<f32>> = Vec::with_capacity(prefixed.len());
|
let mut out: Vec<Vec<f32>> = Vec::with_capacity(prefixed.len());
|
||||||
for chunk in prefixed.chunks(self.batch_size) {
|
for chunk in prefixed.chunks(self.batch_size) {
|
||||||
@@ -298,22 +412,22 @@ impl Embedder for CandleEmbedder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build the e5-prefixed string for one [`EmbeddingInput`]. Free function so
|
/// Build the prefixed string for one [`EmbeddingInput`] using the model spec.
|
||||||
/// a unit test can pin the format without loading the model. Byte-identical to
|
/// Free function so a unit test can pin the format without loading the model.
|
||||||
/// `kebab-embed-local`'s `prefix_input` — the two backends MUST agree here or
|
/// For e5 this is byte-identical to `kebab-embed-local`'s `prefix_input` — the
|
||||||
/// their vectors diverge.
|
/// two backends MUST agree there or their vectors diverge.
|
||||||
fn prefix_input(input: &EmbeddingInput<'_>) -> String {
|
fn prefix_input(spec: &EmbedModelSpec, input: &EmbeddingInput<'_>) -> String {
|
||||||
match input.kind {
|
match input.kind {
|
||||||
EmbeddingKind::Document => format!("passage: {}", input.text),
|
EmbeddingKind::Document => format!("{}{}", spec.doc_prefix, input.text),
|
||||||
EmbeddingKind::Query => format!("query: {}", input.text),
|
EmbeddingKind::Query => format!("{}{}", spec.query_prefix, input.text),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Select the compute device. Built with the `metal` feature (Apple Silicon
|
/// Select the compute device. Built with the `metal` feature (Apple Silicon
|
||||||
/// GPU), try Metal and fall back to CPU on failure; otherwise CPU. Metal only
|
/// GPU), try Metal and fall back to CPU on failure; otherwise CPU. Metal only
|
||||||
/// compiles/runs on macOS — the Linux server builds the CPU path. e5-large
|
/// compiles/runs on macOS — the Linux server builds the CPU path. Embedding
|
||||||
/// vectors are model-defined, so Metal-produced and CPU-produced embeddings are
|
/// vectors are model-defined, so Metal-produced and CPU-produced embeddings
|
||||||
/// cross-compatible (a Mac can ingest on GPU, the server query on CPU).
|
/// are cross-compatible (a Mac can ingest on GPU, the server query on CPU).
|
||||||
fn select_device() -> Device {
|
fn select_device() -> Device {
|
||||||
#[cfg(feature = "metal")]
|
#[cfg(feature = "metal")]
|
||||||
{
|
{
|
||||||
@@ -367,26 +481,85 @@ pub(crate) fn check_dim(model_dim: usize, cfg_dim: usize) -> Result<()> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
// ── prefix_input ─────────────────────────────────────────────────
|
fn e5_spec() -> &'static EmbedModelSpec {
|
||||||
// Pin the exact e5 prefix strings; these MUST match
|
lookup_spec("multilingual-e5-large").expect("e5 in registry")
|
||||||
// kebab-embed-local::prefix_input or candle vs fastembed parity breaks.
|
}
|
||||||
|
|
||||||
|
fn arctic_spec() -> &'static EmbedModelSpec {
|
||||||
|
lookup_spec("snowflake-arctic-embed-l-v2.0").expect("arctic in registry")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── registry ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn prefix_document_uses_passage() {
|
fn registry_resolves_e5_by_name_and_hf_repo() {
|
||||||
|
assert_eq!(
|
||||||
|
lookup_spec("multilingual-e5-large").map(|s| s.name),
|
||||||
|
Some("multilingual-e5-large")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
lookup_spec("intfloat/multilingual-e5-large").map(|s| s.name),
|
||||||
|
Some("multilingual-e5-large")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn registry_resolves_arctic_and_its_pooling_is_cls() {
|
||||||
|
let s = arctic_spec();
|
||||||
|
assert_eq!(s.name, "snowflake-arctic-embed-l-v2.0");
|
||||||
|
assert_eq!(s.hf_repo, "Snowflake/snowflake-arctic-embed-l-v2.0");
|
||||||
|
assert_eq!(s.pooling, Pooling::Cls);
|
||||||
|
assert_eq!(s.dim, 1024);
|
||||||
|
assert_eq!(s.version_tag, Some("arctic-cls"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn registry_e5_is_mean_pooling_no_version_tag() {
|
||||||
|
let s = e5_spec();
|
||||||
|
assert_eq!(s.pooling, Pooling::Mean);
|
||||||
|
assert_eq!(s.version_tag, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn registry_rejects_unknown_model() {
|
||||||
|
assert!(lookup_spec("multilingual-e5-small").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── prefix_input ─────────────────────────────────────────────────
|
||||||
|
// e5 prefixes MUST match kebab-embed-local::prefix_input or candle vs
|
||||||
|
// fastembed parity breaks; arctic uses query-only prefixing.
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn e5_prefix_document_uses_passage() {
|
||||||
let input = EmbeddingInput {
|
let input = EmbeddingInput {
|
||||||
text: "hello world",
|
text: "hello world",
|
||||||
kind: EmbeddingKind::Document,
|
kind: EmbeddingKind::Document,
|
||||||
};
|
};
|
||||||
assert_eq!(prefix_input(&input), "passage: hello world");
|
assert_eq!(prefix_input(e5_spec(), &input), "passage: hello world");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn prefix_query_uses_query() {
|
fn e5_prefix_query_uses_query() {
|
||||||
let input = EmbeddingInput {
|
let input = EmbeddingInput {
|
||||||
text: "hello world",
|
text: "hello world",
|
||||||
kind: EmbeddingKind::Query,
|
kind: EmbeddingKind::Query,
|
||||||
};
|
};
|
||||||
assert_eq!(prefix_input(&input), "query: hello world");
|
assert_eq!(prefix_input(e5_spec(), &input), "query: hello world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn arctic_prefix_query_uses_query_doc_is_bare() {
|
||||||
|
let doc = EmbeddingInput {
|
||||||
|
text: "후입선출 자료구조",
|
||||||
|
kind: EmbeddingKind::Document,
|
||||||
|
};
|
||||||
|
let qry = EmbeddingInput {
|
||||||
|
text: "스택 자료구조",
|
||||||
|
kind: EmbeddingKind::Query,
|
||||||
|
};
|
||||||
|
// arctic: documents are embedded raw, queries get `query: `.
|
||||||
|
assert_eq!(prefix_input(arctic_spec(), &doc), "후입선출 자료구조");
|
||||||
|
assert_eq!(prefix_input(arctic_spec(), &qry), "query: 스택 자료구조");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -399,8 +572,10 @@ mod tests {
|
|||||||
text: "",
|
text: "",
|
||||||
kind: EmbeddingKind::Query,
|
kind: EmbeddingKind::Query,
|
||||||
};
|
};
|
||||||
assert_eq!(prefix_input(&doc), "passage: ");
|
assert_eq!(prefix_input(e5_spec(), &doc), "passage: ");
|
||||||
assert_eq!(prefix_input(&qry), "query: ");
|
assert_eq!(prefix_input(e5_spec(), &qry), "query: ");
|
||||||
|
assert_eq!(prefix_input(arctic_spec(), &doc), "");
|
||||||
|
assert_eq!(prefix_input(arctic_spec(), &qry), "query: ");
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── check_dim ────────────────────────────────────────────────────
|
// ── check_dim ────────────────────────────────────────────────────
|
||||||
@@ -421,9 +596,9 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ── model guard ──────────────────────────────────────────────────
|
// ── model guard ──────────────────────────────────────────────────
|
||||||
// A non-e5-large model name must fail fast (BEFORE the ~2GB download),
|
// A model name not in the registry must fail fast (BEFORE the ~2GB
|
||||||
// so we never download e5-large yet label its vectors with another name
|
// download), so we never download one model yet label its vectors with
|
||||||
// via model_id() — which would mislabel embedding_version.
|
// another name via model_id() — which would mislabel embedding_version.
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn new_rejects_unsupported_model() {
|
fn new_rejects_unsupported_model() {
|
||||||
@@ -437,8 +612,8 @@ mod tests {
|
|||||||
.expect("unsupported model must error");
|
.expect("unsupported model must error");
|
||||||
let msg = format!("{err:#}");
|
let msg = format!("{err:#}");
|
||||||
assert!(
|
assert!(
|
||||||
msg.contains("candle provider currently supports only"),
|
msg.contains("candle provider supports the models"),
|
||||||
"expected model-guard error, got: {msg}"
|
"expected model-registry error, got: {msg}"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
128
crates/kebab-embed-candle/tests/arctic_ollama_parity.rs
Normal file
128
crates/kebab-embed-candle/tests/arctic_ollama_parity.rs
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
//! arctic-embed-l-v2.0 correctness gate (`#[ignore]` — needs the ~2GB candle
|
||||||
|
//! model + a live Ollama serving `snowflake-arctic-embed2`).
|
||||||
|
//!
|
||||||
|
//! This is the load-bearing pooling/prefix check for the arctic integration.
|
||||||
|
//! The recall measurement that justified adopting arctic (recall@10 130/132)
|
||||||
|
//! went through Ollama's `snowflake-arctic-embed2`. The candle path
|
||||||
|
//! re-implements the model (XLM-RoBERTa-large + **CLS** pooling + `query: ` on
|
||||||
|
//! queries / no prefix on documents). If candle's pooling or prefix is wrong,
|
||||||
|
//! its vectors silently diverge from the measured route and the 130 number
|
||||||
|
//! does NOT carry over. This test pins them together: per-sentence cosine
|
||||||
|
//! between the candle vector and the Ollama vector must be **> 0.99**.
|
||||||
|
//!
|
||||||
|
//! `#[ignore]` because it depends on an external Ollama daemon (CI is
|
||||||
|
//! headless/offline). The leader MUST run it once before merge.
|
||||||
|
//!
|
||||||
|
//! ## Manual run
|
||||||
|
//!
|
||||||
|
//! 1. Confirm Ollama is reachable and has the model:
|
||||||
|
//! ```sh
|
||||||
|
//! curl -s http://192.168.0.47:11434/api/tags # should list snowflake-arctic-embed2
|
||||||
|
//! ```
|
||||||
|
//! 2. Run (downloads the ~2GB candle safetensors on first run):
|
||||||
|
//! ```sh
|
||||||
|
//! CARGO_TARGET_DIR=/build/out/cargo-target \
|
||||||
|
//! KEBAB_ARCTIC_OLLAMA_ENDPOINT=http://192.168.0.47:11434 \
|
||||||
|
//! cargo test -p kebab-embed-candle --test arctic_ollama_parity -- --ignored --nocapture
|
||||||
|
//! ```
|
||||||
|
//! The endpoint defaults to `http://192.168.0.47:11434` if the env is unset.
|
||||||
|
//!
|
||||||
|
//! Record the printed `ARCTIC_PARITY_SUMMARY cosine_min=...` in
|
||||||
|
//! `/tmp/arctic-result.md` + `tasks/HOTFIXES.md`.
|
||||||
|
|
||||||
|
use kebab_config::Config;
|
||||||
|
use kebab_core::{Embedder, EmbeddingInput, EmbeddingKind};
|
||||||
|
use kebab_embed_candle::CandleEmbedder;
|
||||||
|
use kebab_embed_ollama::OllamaEmbedder;
|
||||||
|
|
||||||
|
const DOGFOOD_CONFIG: &str = "/build/dogfood/config.toml";
|
||||||
|
const DEFAULT_OLLAMA_ENDPOINT: &str = "http://192.168.0.47:11434";
|
||||||
|
|
||||||
|
/// Mixed Korean / English + the descriptive-recall shapes arctic was adopted
|
||||||
|
/// for (synonym / abbreviation / English term). Covers both prefix paths.
|
||||||
|
const SENTENCES: &[&str] = &[
|
||||||
|
"스택 자료구조",
|
||||||
|
"후입선출 방식으로 동작하는 자료구조",
|
||||||
|
"큐는 선입선출 자료구조이다",
|
||||||
|
"Rust ownership and the borrow checker",
|
||||||
|
"소유권과 빌림 검사기는 메모리 안전성을 보장한다",
|
||||||
|
"SVM 은 support vector machine 의 약자이다",
|
||||||
|
"정렬 알고리즘의 시간 복잡도",
|
||||||
|
"The capital of France is Paris.",
|
||||||
|
];
|
||||||
|
|
||||||
|
fn cosine(a: &[f32], b: &[f32]) -> f32 {
|
||||||
|
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
|
||||||
|
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||||
|
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||||
|
dot / (na * nb)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Base config: prefer the canonical dogfood config (for storage/cache roots),
|
||||||
|
/// fall back to `Config::defaults()` so the test still runs on a bare clone.
|
||||||
|
fn base_config() -> Config {
|
||||||
|
Config::load(Some(std::path::Path::new(DOGFOOD_CONFIG))).unwrap_or_else(|_| Config::defaults())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[ignore = "needs ~2GB candle model + live Ollama (snowflake-arctic-embed2); run manually before merge"]
|
||||||
|
fn candle_arctic_matches_ollama_arctic() {
|
||||||
|
let endpoint = std::env::var("KEBAB_ARCTIC_OLLAMA_ENDPOINT")
|
||||||
|
.unwrap_or_else(|_| DEFAULT_OLLAMA_ENDPOINT.to_string());
|
||||||
|
|
||||||
|
// candle side: the in-process arctic model.
|
||||||
|
let mut candle_cfg = base_config();
|
||||||
|
candle_cfg.models.embedding.provider = "candle".to_string();
|
||||||
|
candle_cfg.models.embedding.model = "snowflake-arctic-embed-l-v2.0".to_string();
|
||||||
|
candle_cfg.models.embedding.dimensions = 1024;
|
||||||
|
|
||||||
|
// Ollama side: the reference route the recall numbers came from.
|
||||||
|
let mut ollama_cfg = base_config();
|
||||||
|
ollama_cfg.models.embedding.provider = "ollama".to_string();
|
||||||
|
ollama_cfg.models.embedding.model = "snowflake-arctic-embed2".to_string();
|
||||||
|
ollama_cfg.models.embedding.dimensions = 1024;
|
||||||
|
ollama_cfg.models.embedding.endpoint = Some(endpoint.clone());
|
||||||
|
|
||||||
|
let candle = CandleEmbedder::new(&candle_cfg).expect("build candle arctic embedder");
|
||||||
|
let ollama = OllamaEmbedder::new(&ollama_cfg).expect("build ollama arctic embedder");
|
||||||
|
|
||||||
|
// Exercise BOTH prefix paths so a query-side divergence can't hide.
|
||||||
|
let inputs: Vec<EmbeddingInput> = SENTENCES
|
||||||
|
.iter()
|
||||||
|
.flat_map(|s| {
|
||||||
|
[EmbeddingKind::Document, EmbeddingKind::Query]
|
||||||
|
.into_iter()
|
||||||
|
.map(move |kind| EmbeddingInput { text: s, kind })
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let cv = candle.embed(&inputs).expect("candle embed");
|
||||||
|
let ov = ollama
|
||||||
|
.embed(&inputs)
|
||||||
|
.expect("ollama embed (is snowflake-arctic-embed2 pulled @ the endpoint?)");
|
||||||
|
|
||||||
|
assert_eq!(cv.len(), ov.len(), "embedding counts must match");
|
||||||
|
assert_eq!(cv.len(), inputs.len(), "one vector per input");
|
||||||
|
assert_eq!(candle.dimensions(), 1024);
|
||||||
|
|
||||||
|
let mut min_cos = f32::INFINITY;
|
||||||
|
for (i, inp) in inputs.iter().enumerate() {
|
||||||
|
assert_eq!(cv[i].len(), 1024, "candle dim");
|
||||||
|
assert_eq!(ov[i].len(), 1024, "ollama dim");
|
||||||
|
let c = cosine(&cv[i], &ov[i]);
|
||||||
|
min_cos = min_cos.min(c);
|
||||||
|
let kind = match inp.kind {
|
||||||
|
EmbeddingKind::Document => "doc",
|
||||||
|
EmbeddingKind::Query => "qry",
|
||||||
|
};
|
||||||
|
let preview: String = inp.text.chars().take(36).collect();
|
||||||
|
println!("[{i:>2}] {kind} cos={c:.6} {preview}");
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("ARCTIC_PARITY_SUMMARY cosine_min={min_cos:.6} endpoint={endpoint}");
|
||||||
|
assert!(
|
||||||
|
min_cos > 0.99,
|
||||||
|
"candle arctic vs Ollama arctic cosine_min={min_cos:.6} ≤ 0.99 — \
|
||||||
|
pooling/prefix mismatch; the recall=130 measurement will NOT reproduce"
|
||||||
|
);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user