Files
kebab/crates/kb-embed-local/tests/embed_model.rs
altair823 bcbe2b8531 feat(p3-2): kb-embed-local crate — fastembed adapter for multilingual-e5-small
First real Embedder implementation. Wraps fastembed-rs (ONNX runtime)
with the e5 prefix convention, batching, and {data_dir}/${XDG_DATA_HOME}
template expansion so model files land under config.storage.model_dir/
fastembed/ without polluting kb-config's public API.

Public surface:
- pub struct FastembedEmbedder
- pub fn new(config: &Config) -> Result<Self>
- impl kb_core::Embedder (via kb-embed re-export)

Behavior:
- Default model multilingual-e5-small (384 dims). model_id and
  model_version come from config.models.embedding.{model,version}.
- Pre-load dim check via TextEmbedding::get_model_info: dim mismatch
  bails before paying the ~470MB ONNX init cost.
- e5 prefix applied BEFORE tokenization: "passage: " for
  EmbeddingKind::Document, "query: " for EmbeddingKind::Query. Pinned
  by prefix_input unit tests.
- Batches inputs into chunks of config.models.embedding.batch_size,
  concatenates results in input order.
- L2 normalization is performed by fastembed 4.9's default transformer
  pipeline (verified at fastembed/src/text_embedding/output.rs:43);
  we skip re-normalization. Integration test pins ‖v‖ ≈ 1.0 ± 1e-3 so
  a future fastembed bump that drops this invariant fails loudly.
- Synchronous (no async runtime). Mutex serializes calls into the
  underlying ONNX session — conservative; ORT Session is Send+Sync but
  callers (kb-app indexer) batch sequentially anyway. Revisit if
  profiling shows contention.
- First-run model download surfaces via tracing::info before/after
  TextEmbedding::try_new — users no longer stare at a silent 30-60s
  pause during the 470MB pull.

Tests:
- 11 default-lane tests covering: check_dim match/mismatch (no model
  load), prefix_input Document/Query/empty, resolve_model
  known/unknown, expand_path substitution + no-op + XDG_DATA_HOME set
  + XDG_DATA_HOME unset (falls back to ~/.local/share with recursive
  ~ expansion). XDG tests serialize on a Mutex + RAII guard since
  edition 2024 makes set_var/remove_var unsafe.
- 7 #[ignore] integration tests covering: full construction with
  default config, dim-mismatch belt-and-braces, Document vs Query
  cosine differential, L2 unit norm, byte-equal determinism, batch-64
  performance under 5s, snapshot-hash stability over a 5-sentence
  multilingual fixture.
- Snapshot test fails LOUDLY when SNAPSHOT_HASH_BASELINE is 0 — prints
  the captured hash and panics with paste-back instructions, so first
  --ignored run forces the maintainer to pin the baseline rather than
  silently passing.
- Workspace: 222 tests pass (default lane); clippy clean.

Allowed deps respected: kb-config, kb-embed (re-exports kb-core
trait surface), fastembed = "4.9", tracing, anyhow. tokenizers and
ort enter transitively through fastembed; reqwest/hyper/hf-hub also
transitive (model download is fastembed's responsibility per spec
carve-out). No direct kb-core dep needed — re-exports cover it.

Pinned to fastembed 4.x rather than the recent 5.x to limit blast
radius; consider bump when p3-3 (lancedb-store) consumes the embedder
output shape.

Out of scope: reranker (P+), Ollama embedding endpoint, candle
adapter, image embeddings (P6).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-01 08:39:38 +00:00

285 lines
11 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! Integration tests for [`FastembedEmbedder`] that load the real ONNX
//! model.
//!
//! ## Why every test in this file is `#[ignore]`
//!
//! The first call to `FastembedEmbedder::new` downloads ~470 MB of
//! weights from Hugging Face into `data_dir/models/fastembed/`. Doing
//! that on every `cargo test` invocation is wasteful, so the bare
//! invocation skips this file entirely.
//!
//! Run the full suite with:
//! ```text
//! cargo test -p kb-embed-local -- --ignored
//! ```
//!
//! All tests share a `OnceLock<FastembedEmbedder>` so the model loads
//! exactly once per process invocation (ONNX runtime first-load latency
//! is 1-2 s on M-series Macs per design risks list).
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::OnceLock;
use std::time::Instant;
use kb_embed::{Embedder, EmbeddingInput, EmbeddingKind};
use kb_embed_local::FastembedEmbedder;
/// Build a `Config` whose `data_dir` lives in a per-process temp dir so
/// the test never writes into the developer's real `~/.local/share/kb`.
/// Returns the `Config` and the `TempDir` guard (caller keeps the guard
/// alive for the test duration).
fn test_config() -> (kb_config::Config, tempfile::TempDir) {
let tmp = tempfile::tempdir().expect("create tempdir");
let mut cfg = kb_config::Config::defaults();
cfg.storage.data_dir = tmp.path().to_string_lossy().into_owned();
// model_dir keeps its default `{data_dir}/models` template; the
// adapter resolves it itself.
(cfg, tmp)
}
/// Single shared embedder for the `--ignored` lane. Held behind a
/// `OnceLock` so we pay the ~1-2 s ONNX init + (first run only) the
/// network download just once.
fn shared_embedder() -> &'static FastembedEmbedder {
static EMBEDDER: OnceLock<FastembedEmbedder> = OnceLock::new();
EMBEDDER.get_or_init(|| {
let (cfg, _tmp) = test_config();
// We deliberately leak `_tmp` here: the OnceLock outlives the
// function scope so the cache directory must persist for the
// process. (`tempfile::TempDir`'s `Drop` would erase the cache
// and wreck subsequent calls.) The OS will reclaim the leaked
// path when the test process exits.
let _ = std::mem::ManuallyDrop::new(_tmp);
FastembedEmbedder::new(&cfg).expect("init FastembedEmbedder")
})
}
// ─── construction ─────────────────────────────────────────────────────
#[test]
#[ignore = "downloads ~470MB ONNX model on first run; CI-only"]
fn default_config_constructs_with_dims_384() {
let emb = shared_embedder();
assert_eq!(emb.dimensions(), 384);
assert_eq!(emb.model_id().0, "multilingual-e5-small");
assert_eq!(emb.model_version().0, "v1");
}
#[test]
#[ignore = "downloads ~470MB ONNX model on first run; CI-only"]
fn mismatched_dims_in_config_errors_at_construction() {
let (mut cfg, _tmp) = test_config();
cfg.models.embedding.dimensions = 512; // model is 384
// `FastembedEmbedder` deliberately does not implement `Debug`
// (its inner ONNX session has no useful debug shape), so we
// can't use `expect_err`; match the Result manually.
let err = match FastembedEmbedder::new(&cfg) {
Ok(_) => panic!("dim mismatch must error"),
Err(e) => e,
};
let msg = format!("{err}");
assert!(msg.contains("dimension mismatch"), "msg={msg}");
assert!(msg.contains("384"), "msg={msg}");
assert!(msg.contains("512"), "msg={msg}");
}
// ─── e5 prefix differentiation ────────────────────────────────────────
#[test]
#[ignore = "loads ONNX model; CI-only"]
fn document_and_query_yield_different_vectors() {
let emb = shared_embedder();
let text = "The quick brown fox jumps over the lazy dog.";
let out = emb
.embed(&[
EmbeddingInput {
text,
kind: EmbeddingKind::Document,
},
EmbeddingInput {
text,
kind: EmbeddingKind::Query,
},
])
.expect("embed two inputs");
assert_eq!(out.len(), 2);
assert_eq!(out[0].len(), 384);
assert_eq!(out[1].len(), 384);
// Both vectors are L2-normalized → cosine similarity == dot product.
let cos: f32 = out[0]
.iter()
.zip(out[1].iter())
.map(|(a, b)| a * b)
.sum();
// Same text, different prefix → vectors must NOT be identical.
assert!(
cos < 0.9999,
"expected distinct vectors for Document vs Query, got cos={cos}"
);
}
// ─── L2 normalization ─────────────────────────────────────────────────
#[test]
#[ignore = "loads ONNX model; CI-only"]
fn output_vectors_are_l2_normalized() {
let emb = shared_embedder();
let inputs = [
EmbeddingInput {
text: "hello world",
kind: EmbeddingKind::Document,
},
EmbeddingInput {
text: "vector search",
kind: EmbeddingKind::Document,
},
EmbeddingInput {
text: "embedding model",
kind: EmbeddingKind::Query,
},
];
let out = emb.embed(&inputs).expect("embed");
// Per `kb_embed::assert_unit_norm` docs: `5e-4` is the safe bound at
// 384 dims (f32::EPSILON × √384 ≈ 2.3e-6, but ONNX kernels add
// their own per-component noise; 1e-3 is very generous and matches
// the spec's `± 1e-3`).
kb_embed::assert_unit_norm(&out, 1e-3);
kb_embed::assert_vector_shape(&out, 384);
}
// ─── determinism ──────────────────────────────────────────────────────
#[test]
#[ignore = "loads ONNX model; CI-only"]
fn identical_input_yields_identical_output() {
let emb = shared_embedder();
let inputs = [
EmbeddingInput {
text: "deterministic embedding test",
kind: EmbeddingKind::Document,
},
EmbeddingInput {
text: "second sentence for variety",
kind: EmbeddingKind::Document,
},
];
let a = emb.embed(&inputs).expect("first embed");
let b = emb.embed(&inputs).expect("second embed");
assert_eq!(a, b, "two calls with the same inputs must be byte-equal");
}
// ─── performance ──────────────────────────────────────────────────────
#[test]
#[ignore = "performance test; downloads model and runs 64-vec batch"]
fn batch_of_64_short_inputs_under_5s() {
let emb = shared_embedder();
// 64 distinct short strings → forces the full default batch_size
// through one fastembed call.
let texts: Vec<String> = (0..64)
.map(|i| format!("perf-test sentence number {i}"))
.collect();
let inputs: Vec<EmbeddingInput<'_>> = texts
.iter()
.map(|t| EmbeddingInput {
text: t.as_str(),
kind: EmbeddingKind::Document,
})
.collect();
let t0 = Instant::now();
let out = emb.embed(&inputs).expect("embed batch of 64");
let elapsed = t0.elapsed();
assert_eq!(out.len(), 64);
assert!(
elapsed.as_secs_f32() < 5.0,
"batch-64 took {elapsed:?}, expected < 5s"
);
}
// ─── snapshot ─────────────────────────────────────────────────────────
/// Aggregate hash of vectors for the 5 fixture sentences.
///
/// Computed by:
/// 1. embed each sentence as `EmbeddingKind::Document`,
/// 2. round each `f32` component to 4 decimal places (multiply by 1e4,
/// round, store as `i32`),
/// 3. write the rounded i32 components into a `DefaultHasher` in row-
/// major order,
/// 4. read out the `u64` finish value.
///
/// The 4-decimal tolerance is intentional float-tolerance per task spec:
/// exact f32 equality is too strict given ONNX kernel + hardware
/// variation.
///
/// **Pinning workflow** (a snapshot test must FAIL UNTIL PINNED):
/// 1. With `SNAPSHOT_HASH_BASELINE = 0`, run
/// `cargo test -p kb-embed-local -- --ignored snapshot`. The test
/// panics with a message containing the captured hash.
/// 2. Paste the printed hex value into `SNAPSHOT_HASH_BASELINE` below.
/// 3. Re-run the same command — the test now asserts equality and
/// passes, confirming the pin.
///
/// On a genuine model upgrade, reset to `0`, re-pin, and bump
/// `EmbeddingVersion` per design §9 in the same PR.
const SNAPSHOT_HASH_BASELINE: u64 = 0;
#[test]
#[ignore = "loads ONNX model; CI-only"]
fn snapshot_aggregate_hash_is_stable() {
let emb = shared_embedder();
let fixture_path =
std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/embed/known-sentences.json");
let raw = std::fs::read_to_string(&fixture_path).expect("read fixture");
let json: serde_json::Value = serde_json::from_str(&raw).expect("parse fixture json");
let sentences: Vec<String> = json["sentences"]
.as_array()
.expect("`sentences` array")
.iter()
.map(|v| v.as_str().expect("sentence is str").to_string())
.collect();
assert_eq!(sentences.len(), 5, "fixture must have exactly 5 sentences");
let inputs: Vec<EmbeddingInput<'_>> = sentences
.iter()
.map(|s| EmbeddingInput {
text: s.as_str(),
kind: EmbeddingKind::Document,
})
.collect();
let out = emb.embed(&inputs).expect("embed snapshot fixture");
// Round every component to 4 decimal places, hash deterministically.
let mut hasher = DefaultHasher::new();
for (i, v) in out.iter().enumerate() {
assert_eq!(v.len(), 384, "row {i} dim mismatch");
for x in v {
let rounded: i32 = (*x * 1.0e4).round() as i32;
rounded.hash(&mut hasher);
}
}
let observed = hasher.finish();
if SNAPSHOT_HASH_BASELINE == 0 {
// Unpinned baseline: panic with the captured hash. A snapshot
// test that silently passes on first run defeats its purpose,
// so we hard-fail until a maintainer commits the pin. Both
// hex (paste-friendly) and decimal forms are printed.
eprintln!(
"kb-embed-local snapshot baseline (paste into SNAPSHOT_HASH_BASELINE): \
{observed:#x} ({observed})"
);
panic!(
"snapshot baseline unpinned — paste {observed:#x} into \
SNAPSHOT_HASH_BASELINE then re-run"
);
}
assert_eq!(
observed, SNAPSHOT_HASH_BASELINE,
"snapshot drift: model output for the fixture sentences changed; \
either fastembed weights changed (bump EmbeddingVersion per §9) \
or there's an ONNX kernel diff."
);
}