feat(embed): candle 임베딩 provider (NUMA-안전, opt-in) + v0.22.0

duo-socket NUMA 서버에서 fastembed(onnxruntime)가 intra-op 스레드를 48개로
하드코딩해 NUMA 힙 손상 → double-free 로 ingest 가 죽는 문제를 회피하기 위해,
같은 multilingual-e5-large 모델을 순수 Rust(candle)로 돌리는 opt-in 임베딩
provider 를 추가한다.

- 신규 crate kebab-embed-candle: CandleEmbedder (kebab_core::Embedder).
  hf-hub safetensors → XLMRobertaModel forward → mask mean-pool → L2 → e5
  prefix. candle 의존성 트리를 이 crate 에 격리 (core/config 외 kebab-* 의존 0).
- 스레드 캡: [models.embedding].num_threads + env KEBAB_EMBED_THREADS →
  글로벌 rayon 풀 1회 캡 (NUMA-안전 레버).
- kebab-app::embedder() 가 provider 분기 (fastembed/onnx/"" → 기존 경로 불변,
  candle → CandleEmbedder, 미지값 → 에러).
- Phase 0 스파이크 crate 제거 (production 흡수).
- 버전 0.21.1 → 0.22.0 (신규 config surface, pre-1.0 minor bump).

패리티: cosine_min=1.000000, max abs diff=2.01e-7 (< 1e-5) → embedding_version
유지, 재색인 0. fastembed default 동작/벡터 불변. wire schema 변경 없음.

검증(파일+exit code): clippy -D warnings EXIT=0(warning 0), test EXIT=0
(candle unit 5 + thread_cap rayon=4 + config 68), parity #[ignore] EXIT=0.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-01 14:52:25 +00:00
parent 76841af7d3
commit 8f7b6ee538
18 changed files with 825 additions and 330 deletions

View File

@@ -18,6 +18,7 @@ kebab-store-vector = { path = "../kebab-store-vector" }
kebab-search = { path = "../kebab-search" }
kebab-embed = { path = "../kebab-embed" }
kebab-embed-local = { path = "../kebab-embed-local" }
kebab-embed-candle = { path = "../kebab-embed-candle" }
kebab-llm = { path = "../kebab-llm" }
kebab-llm-local = { path = "../kebab-llm-local" }
kebab-rag = { path = "../kebab-rag" }

View File

@@ -43,6 +43,7 @@ use kebab_core::{
Answer, DocumentStore, Embedder, ExtractContext, Extractor, IndexVersion, LanguageModel,
MediaType, Retriever, SearchHit, SearchMode, SearchOpts, SearchQuery, VectorStore,
};
use kebab_embed_candle::CandleEmbedder;
use kebab_embed_local::FastembedEmbedder;
use kebab_llm_local::OllamaLanguageModel;
use kebab_parse_code::{
@@ -833,9 +834,26 @@ impl App {
if let Some(e) = self.embedder.get() {
return Ok(Some(e.clone()));
}
let emb: Arc<dyn Embedder + Send + Sync> = Arc::new(
FastembedEmbedder::new(&self.config).context("kb-app: load FastembedEmbedder")?,
);
// Provider branch (Track 1 spec §3). `embeddings_disabled()` above
// already handled `"none"`; here we route the live providers.
// `fastembed`/`onnx`/(empty) keep the default onnxruntime path
// (vectors unchanged — `embedding_version` is preserved); `candle`
// selects the pure-Rust NUMA-safe backend.
let provider = self.config.models.embedding.provider.as_str();
let emb: Arc<dyn Embedder + Send + Sync> = match provider {
"fastembed" | "onnx" | "" => Arc::new(
FastembedEmbedder::new(&self.config).context("kb-app: load FastembedEmbedder")?,
),
"candle" => Arc::new(
CandleEmbedder::new(&self.config).context("kb-app: load CandleEmbedder")?,
),
other => {
return Err(anyhow!(
"kb-app: unknown embedding provider {other:?}; expected one of \
`fastembed` (default), `candle`, or `none` (lexical-only)"
));
}
};
// `set` returns Err if another thread won the race; in that case
// the loser still returns the (now-cached) winner via `get()`.
let _ = self.embedder.set(emb.clone());

View File

@@ -155,11 +155,21 @@ impl NliCfg {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct EmbeddingModelCfg {
/// `fastembed` (default, onnxruntime) or `candle` (pure-Rust,
/// NUMA-safe). `none` disables embeddings (lexical-only). Unknown
/// values error at embedder construction.
pub provider: String,
pub model: String,
pub version: String,
pub dimensions: usize,
pub batch_size: usize,
/// Cap on the CPU worker threads the `candle` provider spins up
/// (sizes the global rayon pool; env `KEBAB_EMBED_THREADS` overrides).
/// `0` = auto (rayon default = #cores). Lever to sidestep the
/// onnxruntime 48-thread NUMA double-free; ignored by the `fastembed`
/// provider. Defaulted on load so pre-0.22 config files still parse.
#[serde(default)]
pub num_threads: u32,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
@@ -707,6 +717,7 @@ impl Config {
version: "v1".to_string(),
dimensions: 1024,
batch_size: 64,
num_threads: 0,
},
llm: LlmCfg {
provider: "ollama".to_string(),
@@ -964,6 +975,11 @@ impl Config {
self.models.embedding.batch_size = n;
}
}
"KEBAB_MODELS_EMBEDDING_NUM_THREADS" => {
if let Ok(n) = v.parse::<u32>() {
self.models.embedding.num_threads = n;
}
}
// models.llm
"KEBAB_MODELS_LLM_PROVIDER" => self.models.llm.provider = v.clone(),

View File

@@ -0,0 +1,39 @@
[package]
name = "kebab-embed-candle"
version = { workspace = true }
edition = { workspace = true }
rust-version = { workspace = true }
license = { workspace = true }
repository = { workspace = true }
description = "Pure-Rust candle adapter implementing kb_core::Embedder (multilingual-e5-large, NUMA-safe thread cap)"
[dependencies]
kebab-core = { path = "../kebab-core" }
kebab-config = { path = "../kebab-config" }
# candle stack — pinned to the workspace-locked crates.io release (0.10.x),
# same versions the Phase 0 spike compiled so build artifacts are reused.
candle-core = "0.10.2"
candle-nn = "0.10.2"
candle-transformers = "0.10.2"
tokenizers = "0.21"
hf-hub = { version = "0.4", features = ["ureq"] }
serde_json = { workspace = true }
# Thread cap: a one-shot global rayon pool sizes candle's CPU threads
# (the Phase 0 spike proved RAYON_NUM_THREADS caps candle), so a NUMA host
# can keep onnxruntime's hard-coded 48-intra-op heap corruption at bay.
rayon = "1"
anyhow = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
# Integration-test binaries can only see the library's public API + these,
# not the library's own (non-dev) dependencies — so rayon/kebab-config/kebab-core
# are repeated here for tests/parity.rs and tests/thread_cap.rs.
kebab-embed-local = { path = "../kebab-embed-local" }
kebab-config = { path = "../kebab-config" }
kebab-core = { path = "../kebab-core" }
rayon = "1"
tempfile = { workspace = true }
[lints]
workspace = true

View File

@@ -0,0 +1,363 @@
//! `kebab-embed-candle` — [`CandleEmbedder`], a pure-Rust (candle)
//! implementation of [`Embedder`](kebab_core::Embedder).
//!
//! Runs the same `intfloat/multilingual-e5-large` model as the default
//! [`FastembedEmbedder`](kebab_embed_local) but through `candle`
//! (`candle-transformers`' XLM-RoBERTa) instead of onnxruntime. Motivation:
//! fastembed 4.9's onnxruntime hard-codes 48 intra-op threads, which corrupts
//! the heap (double-free) on dual-socket NUMA hosts. candle's CPU backend
//! sizes its threads off the global rayon pool, so a one-shot
//! [`rayon::ThreadPoolBuilder`] cap (config `num_threads` / env
//! `KEBAB_EMBED_THREADS`) keeps the worker count NUMA-safe.
//!
//! Output parity with the onnxruntime path was proven by the Phase 0 spike
//! (cosine 1.000000); this crate absorbs that pipeline verbatim:
//!
//! 1. e5 prefix (`passage: ` for documents, `query: ` for queries — the same
//! convention as `kebab-embed-local`'s `prefix_input`);
//! 2. tokenize (max_len 512, batch-longest padding, special tokens);
//! 3. XLM-RoBERTa forward on `Device::Cpu`;
//! 4. attention-mask-weighted mean pooling;
//! 5. L2 normalization.
//!
//! Model files (`config.json`, `tokenizer.json`, `model.safetensors`) are
//! fetched via `hf-hub` into `{config.storage.model_dir}/candle/`.
//!
//! This crate is **opt-in** (`config.models.embedding.provider = "candle"`);
//! the default provider stays `fastembed`. See
//! `docs/superpowers/specs/2026-06-01-embed-candle-track-spec.md`.
use std::sync::Mutex;
use anyhow::{Context, Result};
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::xlm_roberta::{Config as XlmConfig, XLMRobertaModel};
use kebab_config::{Config, expand_path};
use kebab_core::{Embedder, EmbeddingInput, EmbeddingKind, EmbeddingModelId, EmbeddingVersion};
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
/// Subdirectory under `config.storage.model_dir` where the candle adapter
/// caches safetensors + tokenizer. Mirrors `kebab-embed-local`'s
/// `fastembed/` subdir so the two backends never collide.
const CANDLE_CACHE_SUBDIR: &str = "candle";
/// HuggingFace repo id for the multilingual e5 large model. Same weights the
/// onnxruntime path uses, just the safetensors variant candle can read.
const HF_MODEL: &str = "intfloat/multilingual-e5-large";
/// Token truncation length (e5 was trained at 512).
const MAX_LEN: usize = 512;
/// Env var that overrides `config.models.embedding.num_threads`. Read once in
/// [`CandleEmbedder::new`]; `0`/unset/unparseable means "leave rayon default".
const ENV_EMBED_THREADS: &str = "KEBAB_EMBED_THREADS";
/// Pure-Rust candle adapter. Construct via [`CandleEmbedder::new`]; the
/// constructor downloads the model on first use, so share one instance.
pub struct CandleEmbedder {
// candle's `forward` is `&self`, but `XLMRobertaModel` is not guaranteed
// `Sync`; the `Mutex` both supplies that bound and serializes inference
// (callers batch sequentially anyway — same rationale as
// `FastembedEmbedder`).
model: Mutex<XLMRobertaModel>,
tokenizer: Tokenizer,
device: Device,
model_id: EmbeddingModelId,
version: EmbeddingVersion,
dimensions: usize,
batch_size: usize,
}
impl CandleEmbedder {
/// Build an embedder from `Config`. Applies the NUMA thread cap, fetches
/// the model into `{model_dir}/candle/`, and validates that the model's
/// hidden size matches `config.models.embedding.dimensions` before
/// returning.
pub fn new(config: &Config) -> Result<Self> {
// 1. NUMA thread cap. env `KEBAB_EMBED_THREADS` wins over the config
// field; `0`/unset leaves rayon's default. `build_global` errors if
// the pool was already initialized — intentionally ignored so a
// second embedder (or a prior rayon user) is a no-op, not a failure.
let n_threads = std::env::var(ENV_EMBED_THREADS)
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(config.models.embedding.num_threads as usize);
if n_threads > 0 {
if apply_thread_cap(n_threads) {
tracing::info!(
target: "kebab-embed-candle",
num_threads = n_threads,
"capped global rayon pool for candle CPU backend"
);
} else {
tracing::debug!(
target: "kebab-embed-candle",
requested = n_threads,
"global rayon pool already initialized; thread cap not applied"
);
}
}
// 2. Resolve `{data_dir}/models/candle/` exactly like the fastembed
// adapter resolves its own subdir.
let data_dir = expand_path(&config.storage.data_dir, "");
let model_dir = expand_path(&config.storage.model_dir, &data_dir.to_string_lossy());
let cache_dir = model_dir.join(CANDLE_CACHE_SUBDIR);
std::fs::create_dir_all(&cache_dir)
.with_context(|| format!("create candle cache dir {}", cache_dir.display()))?;
let device = Device::Cpu;
// 3. Fetch model files via hf-hub into the candle cache.
tracing::info!(
target: "kebab-embed-candle",
cache_dir = %cache_dir.display(),
model = HF_MODEL,
"loading candle embedding model (first run downloads ~2GB safetensors)"
);
let api = hf_hub::api::sync::ApiBuilder::new()
.with_cache_dir(cache_dir.clone())
.build()
.context("kb-embed-candle: build hf-hub api")?;
let repo = api.model(HF_MODEL.to_string());
let config_path = repo.get("config.json").context("download config.json")?;
let tokenizer_path = repo
.get("tokenizer.json")
.context("download tokenizer.json")?;
let weights_path = repo
.get("model.safetensors")
.context("download model.safetensors")?;
// 4. Build the candle XLM-RoBERTa model.
let cfg_json = std::fs::read_to_string(&config_path)
.with_context(|| format!("read {}", config_path.display()))?;
let cfg: XlmConfig =
serde_json::from_str(&cfg_json).context("kb-embed-candle: parse XLM-R config")?;
// Validate dim BEFORE building the model so a misconfigured
// `dimensions` fails cheaply (matches FastembedEmbedder's contract).
check_dim(cfg.hidden_size, config.models.embedding.dimensions)?;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, &device)
.context("kb-embed-candle: mmap safetensors")?
};
let model =
XLMRobertaModel::new(&cfg, vb).context("kb-embed-candle: build XLMRobertaModel")?;
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("kb-embed-candle: load tokenizer: {e}"))?;
tokenizer
.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
..Default::default()
}))
.with_truncation(Some(TruncationParams {
max_length: MAX_LEN,
..Default::default()
}))
.map_err(|e| anyhow::anyhow!("kb-embed-candle: set truncation: {e}"))?;
tracing::info!(
target: "kebab-embed-candle",
dimensions = cfg.hidden_size,
layers = cfg.num_hidden_layers,
"candle embedding model loaded"
);
Ok(Self {
model: Mutex::new(model),
tokenizer,
device,
model_id: EmbeddingModelId(config.models.embedding.model.clone()),
version: EmbeddingVersion(config.models.embedding.version.clone()),
dimensions: cfg.hidden_size,
batch_size: config.models.embedding.batch_size.max(1),
})
}
/// Embed one batch (already prefixed) through the candle pipeline:
/// tokenize → forward → masked mean pool → L2 normalize.
fn embed_batch(&self, prefixed: &[String]) -> Result<Vec<Vec<f32>>> {
let encodings = self
.tokenizer
.encode_batch(prefixed.to_vec(), true)
.map_err(|e| anyhow::anyhow!("kb-embed-candle: encode_batch: {e}"))?;
let bsz = encodings.len();
let seq = encodings[0].get_ids().len();
let mut ids = Vec::with_capacity(bsz * seq);
let mut mask = Vec::with_capacity(bsz * seq);
for enc in &encodings {
ids.extend(enc.get_ids().iter().copied());
mask.extend(enc.get_attention_mask().iter().map(|&m| m as f32));
}
let input_ids = Tensor::from_vec(ids, (bsz, seq), &self.device)?;
let attn_f32 = Tensor::from_vec(mask, (bsz, seq), &self.device)?;
let token_type_ids = input_ids.zeros_like()?;
let hidden = {
let guard = self
.model
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
// forward: (input_ids, attention_mask, token_type_ids, past,
// encoder_hidden, encoder_mask)
guard.forward(&input_ids, &attn_f32, &token_type_ids, None, None, None)?
};
// attention-mask-weighted mean pooling
let mask3 = attn_f32.unsqueeze(2)?; // (b, seq, 1)
let summed = hidden.broadcast_mul(&mask3)?.sum(1)?; // (b, hidden)
let counts = mask3.sum(1)?; // (b, 1)
let mean = summed.broadcast_div(&counts)?;
// L2 normalize
let norm = mean.sqr()?.sum_keepdim(1)?.sqrt()?;
let normalized = mean.broadcast_div(&norm)?;
Ok(normalized.to_vec2::<f32>()?)
}
}
impl Embedder for CandleEmbedder {
fn model_id(&self) -> EmbeddingModelId {
self.model_id.clone()
}
fn model_version(&self) -> EmbeddingVersion {
self.version.clone()
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn embed(&self, inputs: &[EmbeddingInput<'_>]) -> Result<Vec<Vec<f32>>> {
if inputs.is_empty() {
return Ok(Vec::new());
}
// e5 prefix per §11.3 BEFORE tokenization (same convention as
// FastembedEmbedder so the two backends produce comparable vectors).
let prefixed: Vec<String> = inputs.iter().map(prefix_input).collect();
let mut out: Vec<Vec<f32>> = Vec::with_capacity(prefixed.len());
for chunk in prefixed.chunks(self.batch_size) {
let batch = self.embed_batch(chunk)?;
for v in &batch {
if v.len() != self.dimensions {
anyhow::bail!(
"candle returned vector of length {} but adapter expects {}",
v.len(),
self.dimensions
);
}
}
out.extend(batch);
}
debug_assert_eq!(out.len(), inputs.len());
Ok(out)
}
}
/// Build the e5-prefixed string for one [`EmbeddingInput`]. Free function so
/// a unit test can pin the format without loading the model. Byte-identical to
/// `kebab-embed-local`'s `prefix_input` — the two backends MUST agree here or
/// their vectors diverge.
fn prefix_input(input: &EmbeddingInput<'_>) -> String {
match input.kind {
EmbeddingKind::Document => format!("passage: {}", input.text),
EmbeddingKind::Query => format!("query: {}", input.text),
}
}
/// Apply a one-shot global rayon thread cap (the NUMA-safety lever). Returns
/// `true` if this call set the pool, `false` if it was already initialized
/// (cap not applied) or `n_threads == 0`. `#[doc(hidden)] pub` so the
/// thread-cap test can drive it without loading the 2GB model.
#[doc(hidden)]
pub fn apply_thread_cap(n_threads: usize) -> bool {
if n_threads == 0 {
return false;
}
rayon::ThreadPoolBuilder::new()
.num_threads(n_threads)
.build_global()
.is_ok()
}
/// Compare model hidden size against the configured dim. Extracted so a unit
/// test can exercise the error branch without loading the model.
pub(crate) fn check_dim(model_dim: usize, cfg_dim: usize) -> Result<()> {
if model_dim != cfg_dim {
anyhow::bail!(
"dimension mismatch: model={model_dim}, config={cfg_dim}; \
update `config.models.embedding.dimensions` to match the model \
(or pick a different model)."
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
// ── prefix_input ─────────────────────────────────────────────────
// Pin the exact e5 prefix strings; these MUST match
// kebab-embed-local::prefix_input or candle vs fastembed parity breaks.
#[test]
fn prefix_document_uses_passage() {
let input = EmbeddingInput {
text: "hello world",
kind: EmbeddingKind::Document,
};
assert_eq!(prefix_input(&input), "passage: hello world");
}
#[test]
fn prefix_query_uses_query() {
let input = EmbeddingInput {
text: "hello world",
kind: EmbeddingKind::Query,
};
assert_eq!(prefix_input(&input), "query: hello world");
}
#[test]
fn prefix_handles_empty_text() {
let doc = EmbeddingInput {
text: "",
kind: EmbeddingKind::Document,
};
let qry = EmbeddingInput {
text: "",
kind: EmbeddingKind::Query,
};
assert_eq!(prefix_input(&doc), "passage: ");
assert_eq!(prefix_input(&qry), "query: ");
}
// ── check_dim ────────────────────────────────────────────────────
#[test]
fn check_dim_passes_for_1024() {
check_dim(1024, 1024).expect("matching dims must pass");
}
#[test]
fn check_dim_rejects_384_vs_1024() {
let err = check_dim(384, 1024).expect_err("dim mismatch must error");
let msg = format!("{err}");
assert!(
msg.contains("384") && msg.contains("1024"),
"error must mention both dims, got: {msg}"
);
}
}

View File

@@ -0,0 +1,88 @@
//! Parity test (spec §7, `#[ignore]` — needs the ~2GB model + network).
//!
//! Confirms the candle backend reproduces the onnxruntime `FastembedEmbedder`
//! vectors closely enough that no re-index is required (spec D-reindex):
//! per-sentence cosine ≥ 0.9999, and reports the dimension-wise max absolute
//! difference (the number the re-index decision hangs on).
//!
//! Run manually:
//! CARGO_TARGET_DIR=/build/out/cargo-target/target \
//! cargo test -p kebab-embed-candle --release -- --ignored --nocapture
//!
//! Uses the canonical dogfood config so both backends resolve the same model
//! identifiers and cache roots.
use kebab_config::Config;
use kebab_core::{Embedder, EmbeddingInput, EmbeddingKind};
use kebab_embed_candle::CandleEmbedder;
use kebab_embed_local::FastembedEmbedder;
const DOGFOOD_CONFIG: &str = "/build/dogfood/config.toml";
/// Mixed Korean / English parity set (≥ 8 sentences, mirrors the Phase 0 spike).
const SENTENCES: &[&str] = &[
"The quick brown fox jumps over the lazy dog.",
"오늘 날씨가 정말 좋아서 산책을 나가고 싶다.",
"Rust is a systems programming language focused on safety and performance.",
"벡터 검색은 임베딩 사이의 코사인 유사도를 이용한다.",
"Machine learning models require large amounts of training data.",
"한국어와 영어가 섞인 문장도 멀티링구얼 모델은 잘 처리한다.",
"The capital of France is Paris, a city known for its art and culture.",
"이 프로젝트는 로컬 우선 지식 베이스와 검색 증강 생성을 목표로 한다.",
"Database indexing dramatically speeds up query performance.",
"임베딩 모델을 candle 로 옮기면 NUMA 서버에서 안전하게 돌릴 수 있다.",
];
fn cosine(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
dot / (na * nb)
}
#[test]
#[ignore = "needs ~2GB model + network; run manually for the re-index decision"]
fn candle_matches_fastembed() {
let config = Config::load(Some(std::path::Path::new(DOGFOOD_CONFIG)))
.expect("load dogfood config for parity baseline");
let candle = CandleEmbedder::new(&config).expect("build CandleEmbedder");
let fastembed = FastembedEmbedder::new(&config).expect("build FastembedEmbedder");
let inputs: Vec<EmbeddingInput> = SENTENCES
.iter()
.map(|s| EmbeddingInput {
text: s,
kind: EmbeddingKind::Document,
})
.collect();
let cv = candle.embed(&inputs).expect("candle embed");
let fv = fastembed.embed(&inputs).expect("fastembed embed");
assert_eq!(cv.len(), fv.len(), "embedding counts must match");
assert_eq!(candle.dimensions(), 1024);
let mut min_cos = f32::INFINITY;
let mut max_abs_diff = 0f32;
for (i, s) in SENTENCES.iter().enumerate() {
assert_eq!(cv[i].len(), 1024, "candle dim");
assert_eq!(fv[i].len(), 1024, "fastembed dim");
let c = cosine(&cv[i], &fv[i]);
min_cos = min_cos.min(c);
let diff = cv[i]
.iter()
.zip(&fv[i])
.map(|(a, b)| (a - b).abs())
.fold(0f32, f32::max);
max_abs_diff = max_abs_diff.max(diff);
let preview: String = s.chars().take(40).collect();
println!("[{i:>2}] cos={c:.6} max_abs_diff={diff:.6e} {preview}");
}
println!("PARITY_SUMMARY cosine_min={min_cos:.6} max_abs_diff={max_abs_diff:.6e}");
assert!(
min_cos >= 0.9999,
"candle vs fastembed cosine_min={min_cos:.6} < 0.9999 — investigate before merge"
);
}

View File

@@ -0,0 +1,32 @@
//! Thread-cap test (spec §7). Own integration binary → clean process, so the
//! one-shot global rayon pool is initialized exactly once, by us.
//!
//! Verifies that `apply_thread_cap(4)` sizes the global rayon pool to 4, which
//! is the lever that keeps candle's CPU backend NUMA-safe (vs onnxruntime's
//! hard-coded 48 intra-op threads).
use kebab_embed_candle::apply_thread_cap;
#[test]
fn thread_cap_sizes_global_rayon_pool() {
// Must run before any other rayon use in this process. As the only test in
// this binary that touches rayon, that holds.
let applied = apply_thread_cap(4);
assert!(applied, "first build_global call should succeed");
assert_eq!(
rayon::current_num_threads(),
4,
"global rayon pool must be capped at the requested 4 threads"
);
// A second cap attempt is a no-op (pool already built), not a panic.
assert!(
!apply_thread_cap(8),
"second build_global must report not-applied"
);
assert_eq!(
rayon::current_num_threads(),
4,
"thread count must stay at the first cap"
);
}

View File

@@ -1,32 +0,0 @@
# Track 1 / Phase 0 feasibility SPIKE — NOT production.
# Isolated binary that loads multilingual-e5-large via candle (pure Rust)
# and compares its output against the existing onnxruntime FastembedEmbedder.
# candle deps live ONLY here so the production crates stay untouched.
[package]
name = "spike-embed-candle"
version = "0.0.0"
edition = "2024"
publish = false
[[bin]]
name = "spike-embed-candle"
path = "src/main.rs"
[dependencies]
anyhow = "1"
serde_json = "1"
# candle stack — pinned to the current crates.io release (0.10.2).
candle-core = "0.10.2"
candle-nn = "0.10.2"
candle-transformers = "0.10.2"
# Align with workspace-locked versions so we reuse compiled artifacts.
tokenizers = "0.21"
hf-hub = { version = "0.4", features = ["ureq"] }
rayon = "1"
# Parity baseline: reuse the real production embedder + its config loader.
kebab-config = { path = "../kebab-config" }
kebab-embed = { path = "../kebab-embed" }
kebab-embed-local = { path = "../kebab-embed-local" }
# Keep the spike out of the workspace pedantic-lint gate; it is throwaway.
[lints]

View File

@@ -1,251 +0,0 @@
//! Track 1 / Phase 0 feasibility SPIKE (NOT production code).
//!
//! Proves whether candle (pure Rust) can run `intfloat/multilingual-e5-large`
//! with output parity against the existing onnxruntime `FastembedEmbedder`,
//! so the NUMA double-free in fastembed 4.9.1 can be sidestepped.
//!
//! What it checks (see SPIKE_BRIEF.md):
//! 1. numeric parity — per-sentence cosine vs FastembedEmbedder
//! 2. padding_idx — XLM-R position ids start at pad_token_id+1
//! 3. thread control — RAYON_NUM_THREADS caps candle's CPU threads
//! 4. CPU latency — batch wall-clock, rough vs onnxruntime
//!
//! Run:
//! CARGO_TARGET_DIR=/build/out/cargo-target/target \
//! HF_HOME=/build/cache/huggingface \
//! RAYON_NUM_THREADS=4 \
//! cargo run -j 4 -p spike-embed-candle --release
use std::path::PathBuf;
use std::time::Instant;
use anyhow::{Context, Result};
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::xlm_roberta::{Config as XlmConfig, XLMRobertaModel};
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
use kebab_embed::{Embedder, EmbeddingInput, EmbeddingKind};
use kebab_embed_local::FastembedEmbedder;
const HF_MODEL: &str = "intfloat/multilingual-e5-large";
const DOGFOOD_CONFIG: &str = "/build/dogfood/config.toml";
const MAX_LEN: usize = 512;
/// Mixed Korean / English parity set (≥ 8, brief §3).
const SENTENCES: &[&str] = &[
"The quick brown fox jumps over the lazy dog.",
"오늘 날씨가 정말 좋아서 산책을 나가고 싶다.",
"Rust is a systems programming language focused on safety and performance.",
"벡터 검색은 임베딩 사이의 코사인 유사도를 이용한다.",
"Machine learning models require large amounts of training data.",
"한국어와 영어가 섞인 문장도 멀티링구얼 모델은 잘 처리한다.",
"The capital of France is Paris, a city known for its art and culture.",
"이 프로젝트는 로컬 우선 지식 베이스와 검색 증강 생성을 목표로 한다.",
"Database indexing dramatically speeds up query performance.",
"임베딩 모델을 candle 로 옮기면 NUMA 서버에서 안전하게 돌릴 수 있다.",
];
fn main() -> Result<()> {
// Touch the rayon global pool early so RAYON_NUM_THREADS is honored and
// reportable before any candle compute spins it up.
let rayon_threads = rayon::current_num_threads();
let avail = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(0);
let rayon_env = std::env::var("RAYON_NUM_THREADS").unwrap_or_else(|_| "<unset>".into());
println!("== spike-embed-candle ==");
println!("available_parallelism = {avail}");
println!("RAYON_NUM_THREADS env = {rayon_env}");
println!("rayon::current_num_threads() = {rayon_threads}");
let device = Device::Cpu;
// ── 1. Fetch model files (candle reads safetensors, not the ONNX cache) ──
let cache_dir = std::env::var("HF_HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("/build/cache/huggingface"));
let api = hf_hub::api::sync::ApiBuilder::new()
.with_cache_dir(cache_dir.clone())
.build()
.context("build hf-hub api")?;
let repo = api.model(HF_MODEL.to_string());
println!("\n[load] fetching {HF_MODEL} into {} ...", cache_dir.display());
let config_path = repo.get("config.json").context("download config.json")?;
let tokenizer_path = repo.get("tokenizer.json").context("download tokenizer.json")?;
let weights_path = repo
.get("model.safetensors")
.context("download model.safetensors")?;
println!("[load] config = {}", config_path.display());
println!("[load] tokenizer = {}", tokenizer_path.display());
println!("[load] weights = {}", weights_path.display());
// ── 2. Build the candle XLM-RoBERTa model ──
let cfg_json = std::fs::read_to_string(&config_path)?;
let cfg: XlmConfig = serde_json::from_str(&cfg_json).context("parse XLM-R config")?;
println!(
"[load] config: hidden={} layers={} heads={} pad_token_id={} max_pos={} pos_emb={}",
cfg.hidden_size,
cfg.num_hidden_layers,
cfg.num_attention_heads,
cfg.pad_token_id,
cfg.max_position_embeddings,
cfg.position_embedding_type,
);
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, &device)
.context("mmap safetensors")?
};
let model = XLMRobertaModel::new(&cfg, vb).context("build XLMRobertaModel")?;
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
tokenizer
.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
..Default::default()
}))
.with_truncation(Some(TruncationParams {
max_length: MAX_LEN,
..Default::default()
}))
.map_err(|e| anyhow::anyhow!("set truncation: {e}"))?;
let pad_id = cfg.pad_token_id;
// ── 3. candle embedding path (passage prefix, masked mean pool, L2) ──
let candle_vecs = candle_embed(&model, &tokenizer, &device, pad_id, SENTENCES)?;
println!("\n[candle] embedded {} sentences, dim={}", candle_vecs.len(), candle_vecs[0].len());
// L2 norm sanity (should be ~1.0 after normalization)
let norm0 = l2(&candle_vecs[0]);
println!("[candle] ‖v0‖ = {norm0:.6}");
// ── 4. FastembedEmbedder (onnxruntime) baseline ──
println!("\n[fastembed] loading FastembedEmbedder from {DOGFOOD_CONFIG} ...");
let config = kebab_config::Config::load(Some(std::path::Path::new(DOGFOOD_CONFIG)))
.context("load dogfood config")?;
let fb_t0 = Instant::now();
let fb = FastembedEmbedder::new(&config).context("build FastembedEmbedder")?;
println!("[fastembed] model loaded in {:.2}s", fb_t0.elapsed().as_secs_f64());
let fb_inputs: Vec<EmbeddingInput> = SENTENCES
.iter()
.map(|s| EmbeddingInput { text: s, kind: EmbeddingKind::Document })
.collect();
let fb_vecs = fb.embed(&fb_inputs).context("fastembed embed")?;
// ── 5. Per-sentence parity (both L2-normalized → cosine = dot) ──
println!("\n== PARITY (candle vs fastembed, EmbeddingKind::Document / passage:) ==");
let mut cosines = Vec::with_capacity(SENTENCES.len());
for (i, s) in SENTENCES.iter().enumerate() {
let c = cosine(&candle_vecs[i], &fb_vecs[i]);
cosines.push(c);
let preview: String = s.chars().take(40).collect();
println!(" [{i:>2}] cos={c:.6} {preview}");
}
let min = cosines.iter().cloned().fold(f32::INFINITY, f32::min);
let mean = cosines.iter().sum::<f32>() / cosines.len() as f32;
println!(" --> cosine min={min:.6} mean={mean:.6}");
// ── 6. Latency: batch of 32 (replicated) through candle ──
let batch: Vec<&str> = SENTENCES.iter().cloned().cycle().take(32).collect();
// warmup
let _ = candle_embed(&model, &tokenizer, &device, pad_id, &batch[..4])?;
let t0 = Instant::now();
let _ = candle_embed(&model, &tokenizer, &device, pad_id, &batch)?;
let candle_lat = t0.elapsed();
let fb_batch: Vec<EmbeddingInput> = batch
.iter()
.map(|s| EmbeddingInput { text: s, kind: EmbeddingKind::Document })
.collect();
let t1 = Instant::now();
let _ = fb.embed(&fb_batch)?;
let fb_lat = t1.elapsed();
let peak_threads = proc_threads();
println!("\n== LATENCY (batch=32) ==");
println!(" candle : {:.3}s ({:.1} ms/sentence)", candle_lat.as_secs_f64(), candle_lat.as_secs_f64() * 1000.0 / 32.0);
println!(" fastembed : {:.3}s ({:.1} ms/sentence)", fb_lat.as_secs_f64(), fb_lat.as_secs_f64() * 1000.0 / 32.0);
println!("\n== THREAD CONTROL ==");
println!(" RAYON_NUM_THREADS env = {rayon_env}");
println!(" rayon::current_num_threads = {rayon_threads}");
println!(" available_parallelism = {avail}");
println!(" peak OS threads (/proc) = {peak_threads}");
// ── 7. Machine verdict line for the report ──
let verdict = if mean >= 0.99 { "PASS" } else if mean >= 0.95 { "MARGINAL" } else { "FAIL" };
println!("\n== SUMMARY ==");
println!("VERDICT_HINT={verdict} cosine_min={min:.6} cosine_mean={mean:.6} candle_batch32_s={:.3} fb_batch32_s={:.3} rayon_threads={rayon_threads} rayon_env={rayon_env}", candle_lat.as_secs_f64(), fb_lat.as_secs_f64());
Ok(())
}
/// candle embedding: apply e5 `passage:` prefix, tokenize (batch-padded),
/// forward through XLM-R, attention-mask-weighted mean pool, L2 normalize.
fn candle_embed(
model: &XLMRobertaModel,
tokenizer: &Tokenizer,
device: &Device,
_pad_id: u32,
sentences: &[&str],
) -> Result<Vec<Vec<f32>>> {
let prefixed: Vec<String> = sentences.iter().map(|s| format!("passage: {s}")).collect();
let encodings = tokenizer
.encode_batch(prefixed, true)
.map_err(|e| anyhow::anyhow!("encode_batch: {e}"))?;
let bsz = encodings.len();
let seq = encodings[0].get_ids().len();
let mut ids = Vec::with_capacity(bsz * seq);
let mut mask = Vec::with_capacity(bsz * seq);
for enc in &encodings {
ids.extend(enc.get_ids().iter().copied());
mask.extend(enc.get_attention_mask().iter().map(|&m| m as f32));
}
let input_ids = Tensor::from_vec(ids, (bsz, seq), device)?;
let attn_f32 = Tensor::from_vec(mask, (bsz, seq), device)?;
let token_type_ids = input_ids.zeros_like()?;
// forward: (input_ids, attention_mask, token_type_ids, past, enc_hidden, enc_mask)
let hidden = model.forward(&input_ids, &attn_f32, &token_type_ids, None, None, None)?;
// masked mean pool
let mask3 = attn_f32.unsqueeze(2)?; // (b, seq, 1)
let summed = hidden.broadcast_mul(&mask3)?.sum(1)?; // (b, hidden)
let counts = mask3.sum(1)?; // (b, 1)
let mean = summed.broadcast_div(&counts)?;
// L2 normalize
let norm = mean.sqr()?.sum_keepdim(1)?.sqrt()?;
let normalized = mean.broadcast_div(&norm)?;
Ok(normalized.to_vec2::<f32>()?)
}
fn cosine(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let na = l2(a);
let nb = l2(b);
dot / (na * nb)
}
fn l2(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
/// Peak OS thread count for this process from /proc/self/status.
fn proc_threads() -> usize {
std::fs::read_to_string("/proc/self/status")
.ok()
.and_then(|s| {
s.lines()
.find(|l| l.starts_with("Threads:"))
.and_then(|l| l.split_whitespace().nth(1).map(str::to_string))
})
.and_then(|n| n.parse().ok())
.unwrap_or(0)
}