feat(p3-2): fastembed-adapter — kb-embed-local 크레이트 + FastembedEmbedder #15
2197
Cargo.lock
generated
@@ -11,6 +11,7 @@ members = [
|
||||
"crates/kb-store-sqlite",
|
||||
"crates/kb-search",
|
||||
"crates/kb-embed",
|
||||
"crates/kb-embed-local",
|
||||
"crates/kb-app",
|
||||
"crates/kb-cli",
|
||||
]
|
||||
@@ -37,3 +38,8 @@ rusqlite = { version = "0.32", features = ["bundled"] }
|
||||
globset = "0.4"
|
||||
tempfile = "3"
|
||||
proptest = "1"
|
||||
# fastembed-rs ships ONNX runtime via the `ort-download-binaries` feature
|
||||
# in its default set (which also pulls `hf-hub` for first-run model
|
||||
# downloads). Pinned to the 4.x line per task p3-2 (current 5.x release
|
||||
# remains untested for this workspace).
|
||||
fastembed = "4.9"
|
||||
|
||||
22
crates/kb-embed-local/Cargo.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "kb-embed-local"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
rust-version = { workspace = true }
|
||||
license = { workspace = true }
|
||||
repository = { workspace = true }
|
||||
description = "Local fastembed-rs adapter implementing kb_core::Embedder (multilingual-e5-small default)"
|
||||
|
||||
[dependencies]
|
||||
kb-config = { path = "../kb-config" }
|
||||
kb-embed = { path = "../kb-embed" }
|
||||
# Default features bring `ort-download-binaries` (bundled ONNX runtime)
|
||||
# and `hf-hub-native-tls` (first-run model download). No extra features
|
||||
# needed for the multilingual-e5-small path.
|
||||
fastembed = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
|
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
433
crates/kb-embed-local/src/lib.rs
Normal file
@@ -0,0 +1,433 @@
|
||||
//! `kb-embed-local` — `FastembedEmbedder`, a local ONNX-backed
|
||||
//! [`Embedder`](kb_embed::Embedder) implementation.
|
||||
//!
|
||||
//! Wraps [`fastembed::TextEmbedding`] for the default `multilingual-e5-small`
|
||||
//! (384-dim) model. Honors `config.models.embedding.batch_size` and applies
|
||||
//! the e5 prefix convention (§11.3 of the design report):
|
||||
//!
|
||||
//! * `EmbeddingKind::Document` → `"passage: "` prefix
|
||||
//! * `EmbeddingKind::Query` → `"query: "` prefix
|
||||
//!
|
||||
//! The underlying fastembed `TextEmbedding::embed` already L2-normalizes each
|
||||
//! row (see `fastembed::text_embedding::output::transformer_with_precedence`),
|
||||
//! so we do not re-normalize; the unit-norm test in `tests/` keeps that
|
||||
//! invariant pinned in case fastembed changes its default.
|
||||
//!
|
||||
//! Model files are cached under
|
||||
//! `config.storage.model_dir/fastembed/`. The `model_dir` template
|
||||
//! (default `"{data_dir}/models"`) is resolved with the same expansion
|
||||
//! rules `kb-store-sqlite` applies to `data_dir` (`${XDG_DATA_HOME:-…}`,
|
||||
//! leading `~`, `{data_dir}` substitution).
|
||||
//!
|
||||
//! See `docs/superpowers/specs/2026-04-27-kb-final-form-design.md`
|
||||
//! §7.2 (Embedder), §6.4 ([models.embedding]), §9 (versioning).
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||
use kb_embed::{Embedder, EmbeddingInput, EmbeddingKind, EmbeddingModelId, EmbeddingVersion};
|
||||
|
||||
/// Subdirectory under `config.storage.model_dir` where the fastembed
|
||||
/// adapter writes / reads ONNX + tokenizer files. Hard-coded per task
|
||||
/// spec ("Model files cached under `config.storage.model_dir/fastembed/`").
|
||||
const FASTEMBED_CACHE_SUBDIR: &str = "fastembed";
|
||||
|
||||
/// Local fastembed-rs adapter.
|
||||
///
|
||||
/// Construct via [`FastembedEmbedder::new`]. The constructor performs the
|
||||
/// (potentially network-bound) model download on first use, so prefer to
|
||||
/// share an instance across calls.
|
||||
pub struct FastembedEmbedder {
|
||||
// Mutex serializes calls into TextEmbedding's underlying ONNX session.
|
||||
// fastembed::TextEmbedding::embed is `&self` in 4.9 and ORT Session is
|
||||
// Send + Sync, so this Mutex is conservative — it serializes inference
|
||||
// where parallel ORT calls would in principle work. Acceptable here
|
||||
// because callers (kb-app indexer) batch sequentially anyway. Revisit
|
||||
// in P3-3+ if profiling shows contention.
|
||||
inner: Mutex<TextEmbedding>,
|
||||
model_id: EmbeddingModelId,
|
||||
version: EmbeddingVersion,
|
||||
dimensions: usize,
|
||||
batch_size: usize,
|
||||
}
|
||||
|
||||
impl FastembedEmbedder {
|
||||
/// Build an embedder from `Config`. Validates that
|
||||
/// `config.models.embedding.dimensions` matches the model's actual
|
||||
/// dim BEFORE returning, so a mismatch fails at construction (not on
|
||||
/// first `embed`).
|
||||
pub fn new(config: &kb_config::Config) -> Result<Self> {
|
||||
// 1. Resolve `{data_dir}/models/fastembed/` from the config
|
||||
// templates. `kb-config` does not expose a public path
|
||||
// resolver yet, so we hand-roll a tiny one mirroring
|
||||
// kb-store-sqlite's `expand_data_dir`.
|
||||
let data_dir = expand_path(&config.storage.data_dir, "");
|
||||
let model_dir = expand_path(&config.storage.model_dir, &data_dir.to_string_lossy());
|
||||
let cache_dir = model_dir.join(FASTEMBED_CACHE_SUBDIR);
|
||||
std::fs::create_dir_all(&cache_dir)
|
||||
.with_context(|| format!("create fastembed cache dir {}", cache_dir.display()))?;
|
||||
|
||||
// 2. Resolve the fastembed enum variant from
|
||||
// `config.models.embedding.model`. Currently only the default
|
||||
// `multilingual-e5-small` is wired; other model names error
|
||||
// out with a clear message rather than silently misconfiguring.
|
||||
let model_name = resolve_model(&config.models.embedding.model)?;
|
||||
|
||||
// 3. Verify dim match BEFORE loading the model — if the config
|
||||
// is wrong we want to fail without paying the ONNX
|
||||
// initialization cost.
|
||||
let model_info = TextEmbedding::get_model_info(&model_name)
|
||||
.context("fastembed: get_model_info")?;
|
||||
check_dim(model_info.dim, config.models.embedding.dimensions)?;
|
||||
|
||||
tracing::info!(
|
||||
|
claude-reviewer-01
commented
Pre-load dim 검증이 핵심입니다. Pre-load dim 검증이 핵심입니다. `TextEmbedding::get_model_info`가 ONNX session 초기화 없이 모델 메타만 정적으로 반환하는 점 (fastembed 4.9.1 기준 검증됨)을 활용해서 ~470MB 다운로드 + ONNX init를 시작하기 전에 bail합니다. 사용자가 dim 설정을 잘못 적었을 때 "커피 한 잔 마시고 와서 fail"이 아니라 즉시 fail하는 UX 차이가 큽니다.
|
||||
target: "kb-embed-local",
|
||||
cache_dir = %cache_dir.display(),
|
||||
model = %config.models.embedding.model,
|
||||
dims = model_info.dim,
|
||||
"initializing FastembedEmbedder"
|
||||
);
|
||||
|
||||
// 4. Build the underlying TextEmbedding. `show_download_progress`
|
||||
// is forced to `false` so test output stays clean; first-run
|
||||
// download progress is surfaced via the `tracing::info!`
|
||||
// pair around `TextEmbedding::try_new` instead.
|
||||
let opts = InitOptions::new(model_name.clone())
|
||||
.with_cache_dir(cache_dir.clone())
|
||||
.with_show_download_progress(false);
|
||||
tracing::info!(
|
||||
target: "kb-embed-local",
|
||||
model = %config.models.embedding.model,
|
||||
cache_dir = %cache_dir.display(),
|
||||
"loading embedding model (first run will download ~470MB)"
|
||||
);
|
||||
|
claude-reviewer-01
commented
첫 실행 다운로드 가시화. 첫 실행 다운로드 가시화. `show_download_progress: false`로 fastembed 자체 progress bar는 끄고 그 대신 `tracing::info!` 두 번 (load 시작 시점에 "~470MB will download" 명시 + 성공 후 dimensions와 함께 confirm)으로 깔끔하게 처리한 게 정답입니다. progress bar는 stdout 파괴적이라 CLI 출력 포맷을 망가뜨리는데, tracing은 user가 RUST_LOG로 통제할 수 있어 production에서 안정적입니다.
|
||||
let inner = TextEmbedding::try_new(opts)
|
||||
.context("fastembed: TextEmbedding::try_new")?;
|
||||
let dimensions = model_info.dim;
|
||||
tracing::info!(
|
||||
target: "kb-embed-local",
|
||||
model = %config.models.embedding.model,
|
||||
dimensions,
|
||||
"embedding model loaded"
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
inner: Mutex::new(inner),
|
||||
model_id: EmbeddingModelId(config.models.embedding.model.clone()),
|
||||
version: EmbeddingVersion(config.models.embedding.version.clone()),
|
||||
dimensions,
|
||||
batch_size: config.models.embedding.batch_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder for FastembedEmbedder {
|
||||
fn model_id(&self) -> EmbeddingModelId {
|
||||
self.model_id.clone()
|
||||
}
|
||||
|
||||
fn model_version(&self) -> EmbeddingVersion {
|
||||
self.version.clone()
|
||||
}
|
||||
|
||||
fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
fn embed(&self, inputs: &[EmbeddingInput<'_>]) -> Result<Vec<Vec<f32>>> {
|
||||
if inputs.is_empty() {
|
||||
|
claude-reviewer-01
commented
L2 정규화는 fastembed 4.9의 default transformer가 이미 한다는 점을 코드 코멘트와 검증 path 둘 다로 박아둔 게 좋습니다. 통합 테스트 L2 정규화는 fastembed 4.9의 default transformer가 이미 한다는 점을 코드 코멘트와 검증 path 둘 다로 박아둔 게 좋습니다. 통합 테스트 `output_vectors_are_l2_normalized`가 `‖v‖ ≈ 1.0 ± 1e-3`을 pin해서 향후 fastembed bump이 normalize를 빼면 즉시 fail합니다. "외부 라이브러리가 알아서 하리라"는 가정을 테스트로 못 박는 정확한 패턴입니다.
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Apply e5 prefix per §11.3 BEFORE tokenization. The fastembed
|
||||
// model is unaware of the document/query distinction; the prefix
|
||||
// is the only signal that lets it produce different embeddings
|
||||
// for the same surface text in different roles.
|
||||
let prefixed: Vec<String> = inputs.iter().map(prefix_input).collect();
|
||||
|
||||
// We run our own batch loop on top of fastembed's internal one
|
||||
// so that `config.models.embedding.batch_size` is honored
|
||||
// exactly. fastembed's `embed(_, Some(batch_size))` does the
|
||||
// same internally; calling once with our batch size matches
|
||||
// intent and avoids an extra per-batch allocation.
|
||||
let mut out: Vec<Vec<f32>> = Vec::with_capacity(prefixed.len());
|
||||
for chunk in prefixed.chunks(self.batch_size) {
|
||||
let chunk_vec: Vec<&str> = chunk.iter().map(String::as_str).collect();
|
||||
let guard = self
|
||||
.inner
|
||||
.lock()
|
||||
.unwrap_or_else(|p| p.into_inner());
|
||||
let batch: Vec<Vec<f32>> = guard
|
||||
.embed(chunk_vec, Some(self.batch_size))
|
||||
.context("fastembed: embed")?;
|
||||
drop(guard);
|
||||
// Defensive shape check — every returned vector must match
|
||||
// the configured `dimensions`. Mismatch here means fastembed
|
||||
// and our config drifted at runtime (extremely unlikely;
|
||||
// would have been caught at construction).
|
||||
for v in &batch {
|
||||
if v.len() != self.dimensions {
|
||||
anyhow::bail!(
|
||||
"fastembed returned vector of length {} but adapter expects {}",
|
||||
v.len(),
|
||||
self.dimensions
|
||||
);
|
||||
}
|
||||
}
|
||||
out.extend(batch);
|
||||
}
|
||||
|
||||
debug_assert_eq!(out.len(), inputs.len());
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the prefixed string for one [`EmbeddingInput`]. Free function so
|
||||
/// the unit test can pin the exact format without going through `embed`.
|
||||
fn prefix_input(input: &EmbeddingInput<'_>) -> String {
|
||||
match input.kind {
|
||||
EmbeddingKind::Document => format!("passage: {}", input.text),
|
||||
EmbeddingKind::Query => format!("query: {}", input.text),
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve a `config.models.embedding.model` string to a fastembed
|
||||
/// `EmbeddingModel` enum variant. Only `multilingual-e5-small` is wired
|
||||
/// for p3-2; additional model names should be added (and their dims
|
||||
/// pinned in tests) as needed.
|
||||
fn resolve_model(name: &str) -> Result<EmbeddingModel> {
|
||||
match name {
|
||||
"multilingual-e5-small" => Ok(EmbeddingModel::MultilingualE5Small),
|
||||
other => anyhow::bail!(
|
||||
"kb-embed-local: unsupported embedding model {other:?}; \
|
||||
this adapter currently only ships `multilingual-e5-small`. \
|
||||
Add a new arm to `resolve_model` (and a fastembed feature \
|
||||
flag if needed) to support more."
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare model dim against the configured dim. Extracted so a unit
|
||||
/// test can exercise the error branch without loading ONNX.
|
||||
pub(crate) fn check_dim(model_dim: usize, cfg_dim: usize) -> Result<()> {
|
||||
if model_dim != cfg_dim {
|
||||
anyhow::bail!(
|
||||
"dimension mismatch: model={model_dim}, config={cfg_dim}; \
|
||||
update `config.models.embedding.dimensions` to match the model \
|
||||
(or pick a different model)."
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Expand the limited template language `kb-config` uses for storage
|
||||
/// paths.
|
||||
///
|
||||
/// Supported substitutions, applied in order:
|
||||
/// 1. `{data_dir}` → `data_dir` (caller-supplied resolved string). This
|
||||
/// is a no-op when `data_dir` is empty (used by the recursive call
|
||||
/// that resolves `data_dir` itself).
|
||||
/// 2. `${XDG_DATA_HOME:-~/.local/share}` (and the bare
|
||||
/// `${XDG_DATA_HOME}`) → env var if set, else the default after
|
||||
/// `:-`.
|
||||
/// 3. Leading `~` → `$HOME`.
|
||||
///
|
||||
/// Mirrors `kb-store-sqlite::store::expand_data_dir`. Kept private to
|
||||
/// this crate; promoting it to a public `kb-config` API is a separate
|
||||
/// task (see task p3-2 risks: "don't expand kb-config's public API").
|
||||
fn expand_path(raw: &str, data_dir: &str) -> PathBuf {
|
||||
let mut s = raw.to_string();
|
||||
|
||||
if !data_dir.is_empty() {
|
||||
s = s.replace("{data_dir}", data_dir);
|
||||
}
|
||||
|
||||
// ${XDG_DATA_HOME:-~/.local/share}: respect env override, else fall
|
||||
// back to the suffix after `:-`.
|
||||
if let Some(start) = s.find("${XDG_DATA_HOME") {
|
||||
if let Some(rel_end) = s[start..].find('}') {
|
||||
let end = start + rel_end + 1; // include trailing '}'
|
||||
let inner = &s[start + 2..end - 1]; // strip ${ and }
|
||||
let replacement = match std::env::var("XDG_DATA_HOME") {
|
||||
Ok(v) if !v.is_empty() => v,
|
||||
_ => {
|
||||
if let Some((_, default)) = inner.split_once(":-") {
|
||||
default.to_string()
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
};
|
||||
s.replace_range(start..end, &replacement);
|
||||
}
|
||||
}
|
||||
|
||||
// Leading `~` → $HOME.
|
||||
if let Some(rest) = s.strip_prefix('~') {
|
||||
if let Some(home) = std::env::var_os("HOME").map(PathBuf::from) {
|
||||
return home.join(rest.trim_start_matches('/'));
|
||||
}
|
||||
}
|
||||
|
||||
PathBuf::from(s)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use kb_embed::EmbeddingInput;
|
||||
|
||||
// ── check_dim ────────────────────────────────────────────────────
|
||||
//
|
||||
// Exercises the construction-time dim mismatch branch WITHOUT
|
||||
// loading the real model. The integration test that builds a full
|
||||
// FastembedEmbedder is `#[ignore]`d (loads ~470 MB of weights).
|
||||
|
||||
#[test]
|
||||
fn check_dim_match_ok() {
|
||||
check_dim(384, 384).expect("matching dims must pass");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_dim_mismatch_errors() {
|
||||
let err = check_dim(384, 512).expect_err("mismatch must error");
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("dimension mismatch"), "msg={msg}");
|
||||
assert!(msg.contains("384"), "msg={msg}");
|
||||
assert!(msg.contains("512"), "msg={msg}");
|
||||
}
|
||||
|
||||
// ── prefix_input ─────────────────────────────────────────────────
|
||||
//
|
||||
// Pin the exact e5 prefix strings; a silent regression here
|
||||
// degrades retrieval quality without any test failing in the
|
||||
// dim/norm/snapshot suite.
|
||||
|
||||
#[test]
|
||||
fn prefix_document_uses_passage() {
|
||||
let input = EmbeddingInput {
|
||||
text: "hello world",
|
||||
kind: EmbeddingKind::Document,
|
||||
};
|
||||
assert_eq!(prefix_input(&input), "passage: hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prefix_query_uses_query() {
|
||||
let input = EmbeddingInput {
|
||||
text: "hello world",
|
||||
kind: EmbeddingKind::Query,
|
||||
};
|
||||
assert_eq!(prefix_input(&input), "query: hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prefix_handles_empty_text() {
|
||||
let doc = EmbeddingInput {
|
||||
text: "",
|
||||
kind: EmbeddingKind::Document,
|
||||
};
|
||||
let qry = EmbeddingInput {
|
||||
text: "",
|
||||
kind: EmbeddingKind::Query,
|
||||
};
|
||||
assert_eq!(prefix_input(&doc), "passage: ");
|
||||
assert_eq!(prefix_input(&qry), "query: ");
|
||||
}
|
||||
|
||||
// ── resolve_model ────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn resolve_default_model_ok() {
|
||||
// The exact enum variant is opaque, but `is_ok` plus a
|
||||
// round-trip through the fastembed metadata gives confidence
|
||||
// we hit the right arm.
|
||||
resolve_model("multilingual-e5-small").expect("default model resolves");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_unknown_model_errors() {
|
||||
let err = resolve_model("not-a-real-model").expect_err("unknown model errors");
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("unsupported embedding model"), "msg={msg}");
|
||||
}
|
||||
|
||||
// ── expand_path ──────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn expand_path_substitutes_data_dir_template() {
|
||||
let p = expand_path("{data_dir}/models", "/tmp/kbtest");
|
||||
assert_eq!(p, PathBuf::from("/tmp/kbtest/models"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_path_no_op_without_template() {
|
||||
let p = expand_path("/abs/path", "/tmp/kbtest");
|
||||
assert_eq!(p, PathBuf::from("/abs/path"));
|
||||
}
|
||||
|
||||
// ── expand_path: XDG_DATA_HOME fallback ──────────────────────────
|
||||
//
|
||||
// These two tests mutate the process-wide `XDG_DATA_HOME` env var,
|
||||
// which is unsafe under edition 2024 and racy under cargo's default
|
||||
// parallel test runner. The shared `ENV_LOCK` serializes them; each
|
||||
// test snapshots the prior value and restores it on exit.
|
||||
|
||||
use std::sync::Mutex as StdMutex;
|
||||
static ENV_LOCK: StdMutex<()> = StdMutex::new(());
|
||||
|
||||
/// RAII guard: snapshots `XDG_DATA_HOME` on construction, restores
|
||||
/// it on drop. Pair with the `ENV_LOCK` guard for serial access.
|
||||
struct XdgGuard {
|
||||
prior: Option<String>,
|
||||
}
|
||||
|
||||
impl XdgGuard {
|
||||
fn capture() -> Self {
|
||||
Self {
|
||||
prior: std::env::var("XDG_DATA_HOME").ok(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for XdgGuard {
|
||||
fn drop(&mut self) {
|
||||
// SAFETY: edition 2024 marks `set_var`/`remove_var` unsafe
|
||||
// because env mutation is not thread-safe. Callers hold
|
||||
// `ENV_LOCK` for the duration of the test, so no other
|
||||
// thread observes the mutation.
|
||||
unsafe {
|
||||
match &self.prior {
|
||||
Some(v) => std::env::set_var("XDG_DATA_HOME", v),
|
||||
None => std::env::remove_var("XDG_DATA_HOME"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_path_xdg_data_home_set() {
|
||||
let _lock = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
|
||||
let _guard = XdgGuard::capture();
|
||||
// SAFETY: lock held for the duration of this test.
|
||||
unsafe { std::env::set_var("XDG_DATA_HOME", "/custom/path") };
|
||||
|
claude-reviewer-01
commented
XDG 테스트의 직렬화 패턴 — 정적 Mutex + RAII guard로 환경변수 snapshot/restore. edition 2024가 XDG 테스트의 직렬화 패턴 — 정적 Mutex + RAII guard로 환경변수 snapshot/restore. edition 2024가 `set_var`/`remove_var`을 unsafe로 분류한 이유가 정확히 이것 (다른 스레드가 환경변수를 동시 read하면 UB)인데, `ENV_LOCK`로 cross-test 직렬화 + `XdgGuard`로 prior value 복원하는 두 단계 모두 갖췄습니다. 향후 환경변수를 건드리는 테스트가 같은 ENV_LOCK을 공유하도록 확장 가능한 구조이기도 합니다.
|
||||
|
||||
let p = expand_path("${XDG_DATA_HOME:-~/.local/share}/kb", "");
|
||||
assert_eq!(p, PathBuf::from("/custom/path/kb"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_path_xdg_data_home_unset_falls_back_to_home() {
|
||||
let _lock = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
|
||||
let _guard = XdgGuard::capture();
|
||||
// SAFETY: lock held for the duration of this test.
|
||||
unsafe { std::env::remove_var("XDG_DATA_HOME") };
|
||||
|
||||
let home = std::env::var("HOME").expect("HOME must be set in tests");
|
||||
let expected = PathBuf::from(home).join(".local/share/kb");
|
||||
let p = expand_path("${XDG_DATA_HOME:-~/.local/share}/kb", "");
|
||||
assert_eq!(p, expected);
|
||||
}
|
||||
}
|
||||
284
crates/kb-embed-local/tests/embed_model.rs
Normal file
@@ -0,0 +1,284 @@
|
||||
//! Integration tests for [`FastembedEmbedder`] that load the real ONNX
|
||||
//! model.
|
||||
//!
|
||||
//! ## Why every test in this file is `#[ignore]`
|
||||
//!
|
||||
//! The first call to `FastembedEmbedder::new` downloads ~470 MB of
|
||||
//! weights from Hugging Face into `data_dir/models/fastembed/`. Doing
|
||||
//! that on every `cargo test` invocation is wasteful, so the bare
|
||||
//! invocation skips this file entirely.
|
||||
//!
|
||||
//! Run the full suite with:
|
||||
//! ```text
|
||||
//! cargo test -p kb-embed-local -- --ignored
|
||||
//! ```
|
||||
//!
|
||||
//! All tests share a `OnceLock<FastembedEmbedder>` so the model loads
|
||||
//! exactly once per process invocation (ONNX runtime first-load latency
|
||||
//! is 1-2 s on M-series Macs per design risks list).
|
||||
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Instant;
|
||||
|
||||
use kb_embed::{Embedder, EmbeddingInput, EmbeddingKind};
|
||||
use kb_embed_local::FastembedEmbedder;
|
||||
|
||||
/// Build a `Config` whose `data_dir` lives in a per-process temp dir so
|
||||
/// the test never writes into the developer's real `~/.local/share/kb`.
|
||||
/// Returns the `Config` and the `TempDir` guard (caller keeps the guard
|
||||
/// alive for the test duration).
|
||||
fn test_config() -> (kb_config::Config, tempfile::TempDir) {
|
||||
let tmp = tempfile::tempdir().expect("create tempdir");
|
||||
let mut cfg = kb_config::Config::defaults();
|
||||
cfg.storage.data_dir = tmp.path().to_string_lossy().into_owned();
|
||||
// model_dir keeps its default `{data_dir}/models` template; the
|
||||
// adapter resolves it itself.
|
||||
(cfg, tmp)
|
||||
}
|
||||
|
||||
/// Single shared embedder for the `--ignored` lane. Held behind a
|
||||
/// `OnceLock` so we pay the ~1-2 s ONNX init + (first run only) the
|
||||
/// network download just once.
|
||||
fn shared_embedder() -> &'static FastembedEmbedder {
|
||||
static EMBEDDER: OnceLock<FastembedEmbedder> = OnceLock::new();
|
||||
EMBEDDER.get_or_init(|| {
|
||||
let (cfg, _tmp) = test_config();
|
||||
// We deliberately leak `_tmp` here: the OnceLock outlives the
|
||||
// function scope so the cache directory must persist for the
|
||||
// process. (`tempfile::TempDir`'s `Drop` would erase the cache
|
||||
// and wreck subsequent calls.) The OS will reclaim the leaked
|
||||
// path when the test process exits.
|
||||
let _ = std::mem::ManuallyDrop::new(_tmp);
|
||||
FastembedEmbedder::new(&cfg).expect("init FastembedEmbedder")
|
||||
})
|
||||
}
|
||||
|
||||
// ─── construction ─────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
#[ignore = "downloads ~470MB ONNX model on first run; CI-only"]
|
||||
fn default_config_constructs_with_dims_384() {
|
||||
let emb = shared_embedder();
|
||||
assert_eq!(emb.dimensions(), 384);
|
||||
assert_eq!(emb.model_id().0, "multilingual-e5-small");
|
||||
assert_eq!(emb.model_version().0, "v1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "downloads ~470MB ONNX model on first run; CI-only"]
|
||||
fn mismatched_dims_in_config_errors_at_construction() {
|
||||
let (mut cfg, _tmp) = test_config();
|
||||
cfg.models.embedding.dimensions = 512; // model is 384
|
||||
// `FastembedEmbedder` deliberately does not implement `Debug`
|
||||
// (its inner ONNX session has no useful debug shape), so we
|
||||
// can't use `expect_err`; match the Result manually.
|
||||
let err = match FastembedEmbedder::new(&cfg) {
|
||||
Ok(_) => panic!("dim mismatch must error"),
|
||||
Err(e) => e,
|
||||
};
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("dimension mismatch"), "msg={msg}");
|
||||
assert!(msg.contains("384"), "msg={msg}");
|
||||
assert!(msg.contains("512"), "msg={msg}");
|
||||
}
|
||||
|
||||
// ─── e5 prefix differentiation ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
#[ignore = "loads ONNX model; CI-only"]
|
||||
fn document_and_query_yield_different_vectors() {
|
||||
let emb = shared_embedder();
|
||||
let text = "The quick brown fox jumps over the lazy dog.";
|
||||
let out = emb
|
||||
.embed(&[
|
||||
EmbeddingInput {
|
||||
text,
|
||||
kind: EmbeddingKind::Document,
|
||||
},
|
||||
EmbeddingInput {
|
||||
text,
|
||||
kind: EmbeddingKind::Query,
|
||||
},
|
||||
])
|
||||
.expect("embed two inputs");
|
||||
assert_eq!(out.len(), 2);
|
||||
assert_eq!(out[0].len(), 384);
|
||||
assert_eq!(out[1].len(), 384);
|
||||
|
||||
// Both vectors are L2-normalized → cosine similarity == dot product.
|
||||
let cos: f32 = out[0]
|
||||
.iter()
|
||||
.zip(out[1].iter())
|
||||
.map(|(a, b)| a * b)
|
||||
.sum();
|
||||
// Same text, different prefix → vectors must NOT be identical.
|
||||
assert!(
|
||||
cos < 0.9999,
|
||||
"expected distinct vectors for Document vs Query, got cos={cos}"
|
||||
);
|
||||
}
|
||||
|
||||
// ─── L2 normalization ─────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
#[ignore = "loads ONNX model; CI-only"]
|
||||
fn output_vectors_are_l2_normalized() {
|
||||
let emb = shared_embedder();
|
||||
let inputs = [
|
||||
EmbeddingInput {
|
||||
text: "hello world",
|
||||
kind: EmbeddingKind::Document,
|
||||
},
|
||||
EmbeddingInput {
|
||||
text: "vector search",
|
||||
kind: EmbeddingKind::Document,
|
||||
},
|
||||
EmbeddingInput {
|
||||
text: "embedding model",
|
||||
kind: EmbeddingKind::Query,
|
||||
},
|
||||
];
|
||||
let out = emb.embed(&inputs).expect("embed");
|
||||
// Per `kb_embed::assert_unit_norm` docs: `5e-4` is the safe bound at
|
||||
// 384 dims (f32::EPSILON × √384 ≈ 2.3e-6, but ONNX kernels add
|
||||
// their own per-component noise; 1e-3 is very generous and matches
|
||||
// the spec's `± 1e-3`).
|
||||
kb_embed::assert_unit_norm(&out, 1e-3);
|
||||
kb_embed::assert_vector_shape(&out, 384);
|
||||
}
|
||||
|
||||
// ─── determinism ──────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
#[ignore = "loads ONNX model; CI-only"]
|
||||
fn identical_input_yields_identical_output() {
|
||||
let emb = shared_embedder();
|
||||
let inputs = [
|
||||
EmbeddingInput {
|
||||
text: "deterministic embedding test",
|
||||
kind: EmbeddingKind::Document,
|
||||
},
|
||||
EmbeddingInput {
|
||||
text: "second sentence for variety",
|
||||
kind: EmbeddingKind::Document,
|
||||
},
|
||||
];
|
||||
let a = emb.embed(&inputs).expect("first embed");
|
||||
let b = emb.embed(&inputs).expect("second embed");
|
||||
assert_eq!(a, b, "two calls with the same inputs must be byte-equal");
|
||||
}
|
||||
|
||||
// ─── performance ──────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
#[ignore = "performance test; downloads model and runs 64-vec batch"]
|
||||
fn batch_of_64_short_inputs_under_5s() {
|
||||
let emb = shared_embedder();
|
||||
// 64 distinct short strings → forces the full default batch_size
|
||||
// through one fastembed call.
|
||||
let texts: Vec<String> = (0..64)
|
||||
.map(|i| format!("perf-test sentence number {i}"))
|
||||
.collect();
|
||||
let inputs: Vec<EmbeddingInput<'_>> = texts
|
||||
.iter()
|
||||
.map(|t| EmbeddingInput {
|
||||
text: t.as_str(),
|
||||
kind: EmbeddingKind::Document,
|
||||
})
|
||||
.collect();
|
||||
let t0 = Instant::now();
|
||||
let out = emb.embed(&inputs).expect("embed batch of 64");
|
||||
let elapsed = t0.elapsed();
|
||||
assert_eq!(out.len(), 64);
|
||||
assert!(
|
||||
elapsed.as_secs_f32() < 5.0,
|
||||
"batch-64 took {elapsed:?}, expected < 5s"
|
||||
);
|
||||
}
|
||||
|
||||
// ─── snapshot ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Aggregate hash of vectors for the 5 fixture sentences.
|
||||
///
|
||||
/// Computed by:
|
||||
/// 1. embed each sentence as `EmbeddingKind::Document`,
|
||||
/// 2. round each `f32` component to 4 decimal places (multiply by 1e4,
|
||||
/// round, store as `i32`),
|
||||
/// 3. write the rounded i32 components into a `DefaultHasher` in row-
|
||||
/// major order,
|
||||
/// 4. read out the `u64` finish value.
|
||||
///
|
||||
/// The 4-decimal tolerance is intentional float-tolerance per task spec:
|
||||
/// exact f32 equality is too strict given ONNX kernel + hardware
|
||||
/// variation.
|
||||
///
|
||||
/// **Pinning workflow** (a snapshot test must FAIL UNTIL PINNED):
|
||||
/// 1. With `SNAPSHOT_HASH_BASELINE = 0`, run
|
||||
/// `cargo test -p kb-embed-local -- --ignored snapshot`. The test
|
||||
/// panics with a message containing the captured hash.
|
||||
/// 2. Paste the printed hex value into `SNAPSHOT_HASH_BASELINE` below.
|
||||
/// 3. Re-run the same command — the test now asserts equality and
|
||||
/// passes, confirming the pin.
|
||||
///
|
||||
/// On a genuine model upgrade, reset to `0`, re-pin, and bump
|
||||
/// `EmbeddingVersion` per design §9 in the same PR.
|
||||
const SNAPSHOT_HASH_BASELINE: u64 = 0;
|
||||
|
||||
#[test]
|
||||
#[ignore = "loads ONNX model; CI-only"]
|
||||
fn snapshot_aggregate_hash_is_stable() {
|
||||
let emb = shared_embedder();
|
||||
let fixture_path =
|
||||
std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/embed/known-sentences.json");
|
||||
let raw = std::fs::read_to_string(&fixture_path).expect("read fixture");
|
||||
let json: serde_json::Value = serde_json::from_str(&raw).expect("parse fixture json");
|
||||
let sentences: Vec<String> = json["sentences"]
|
||||
.as_array()
|
||||
.expect("`sentences` array")
|
||||
.iter()
|
||||
.map(|v| v.as_str().expect("sentence is str").to_string())
|
||||
.collect();
|
||||
assert_eq!(sentences.len(), 5, "fixture must have exactly 5 sentences");
|
||||
|
||||
let inputs: Vec<EmbeddingInput<'_>> = sentences
|
||||
.iter()
|
||||
.map(|s| EmbeddingInput {
|
||||
text: s.as_str(),
|
||||
kind: EmbeddingKind::Document,
|
||||
})
|
||||
.collect();
|
||||
let out = emb.embed(&inputs).expect("embed snapshot fixture");
|
||||
|
||||
// Round every component to 4 decimal places, hash deterministically.
|
||||
let mut hasher = DefaultHasher::new();
|
||||
for (i, v) in out.iter().enumerate() {
|
||||
assert_eq!(v.len(), 384, "row {i} dim mismatch");
|
||||
for x in v {
|
||||
let rounded: i32 = (*x * 1.0e4).round() as i32;
|
||||
rounded.hash(&mut hasher);
|
||||
}
|
||||
}
|
||||
let observed = hasher.finish();
|
||||
if SNAPSHOT_HASH_BASELINE == 0 {
|
||||
// Unpinned baseline: panic with the captured hash. A snapshot
|
||||
// test that silently passes on first run defeats its purpose,
|
||||
// so we hard-fail until a maintainer commits the pin. Both
|
||||
// hex (paste-friendly) and decimal forms are printed.
|
||||
eprintln!(
|
||||
"kb-embed-local snapshot baseline (paste into SNAPSHOT_HASH_BASELINE): \
|
||||
|
claude-reviewer-01
commented
Snapshot baseline 정책이 정확합니다. Snapshot baseline 정책이 정확합니다. `SNAPSHOT_HASH_BASELINE = 0`이 silent pass가 아니라 panic — 측정값을 출력하고 "paste back into the const" 가이드를 함께 띄웁니다. snapshot test의 본질이 "pin할 때까지 fail해야 의미가 있다"인데 그 invariant이 실제로 enforce됩니다. 일반적인 "if baseline == 0 then return" 패턴의 함정 (한 번도 진짜 검증되지 않은 채 green 유지)을 정확히 피했습니다.
|
||||
{observed:#x} ({observed})"
|
||||
);
|
||||
panic!(
|
||||
"snapshot baseline unpinned — paste {observed:#x} into \
|
||||
SNAPSHOT_HASH_BASELINE then re-run"
|
||||
);
|
||||
}
|
||||
assert_eq!(
|
||||
observed, SNAPSHOT_HASH_BASELINE,
|
||||
"snapshot drift: model output for the fixture sentences changed; \
|
||||
either fastembed weights changed (bump EmbeddingVersion per §9) \
|
||||
or there's an ONNX kernel diff."
|
||||
);
|
||||
}
|
||||
10
crates/kb-embed-local/tests/fixtures/embed/known-sentences.json
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"_comment": "Snapshot fixture for FastembedEmbedder. Five short sentences in mixed languages chosen so the multilingual-e5-small model exercises both Latin and CJK token paths. The kb-embed-local snapshot test embeds these as `EmbeddingKind::Document`, rounds each f32 component to 4 decimal places, and checks an aggregate hash against a constant baked into the test source. The 4-decimal tolerance hides the per-platform ONNX kernel jitter we observed during P3 bring-up; tighter tolerances would make the test flaky across hardware.",
|
||||
"sentences": [
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"Vector search is fast on small models.",
|
||||
"Knowledge bases benefit from hybrid retrieval.",
|
||||
"한국어 문장도 임베딩이 잘 됩니다.",
|
||||
"Embeddings should be deterministic given the same input."
|
||||
]
|
||||
}
|
||||
Allowed deps 목록에 있던
kb-core와thiserror을 둘 다 제거한 결정이 정답입니다 —kb-embed이 trait 표면을 재노출하므로kb-coredirect dep는 redundant, error path는anyhow만 사용. clippyunused-crate-dependencies가 잡았을 cruft를 spec compliance를 좁게 해석하지 않고 "실제로 쓰는 deps만 선언" 원칙으로 정리한 게 좋습니다.