diff --git a/crates/kebab-nli/Cargo.toml b/crates/kebab-nli/Cargo.toml index 5e52ae1..9886bca 100644 --- a/crates/kebab-nli/Cargo.toml +++ b/crates/kebab-nli/Cargo.toml @@ -14,7 +14,13 @@ description = "fb-41: NLI-based post-synthesis verification (XNLI mDeBERTa-v3) kebab-config = { path = "../kebab-config" } anyhow = { workspace = true } serde = { workspace = true } -ort = { workspace = true } +# ort: extend the workspace pin with `download-binaries` so kebab-nli +# can link the ONNX runtime when fastembed is NOT in the build graph +# (e.g. `cargo test -p kebab-nli` alone, where the per-crate feature +# union excludes kebab-embed-local + fastembed). In workspace-wide +# builds the feature gets union'd with fastembed's identical opt-in +# so no extra runtime gets pulled. +ort = { workspace = true, features = ["download-binaries"] } tokenizers = { workspace = true } hf-hub = { workspace = true } ndarray = { workspace = true } diff --git a/crates/kebab-nli/src/onnx.rs b/crates/kebab-nli/src/onnx.rs index da1f025..17b1e4f 100644 --- a/crates/kebab-nli/src/onnx.rs +++ b/crates/kebab-nli/src/onnx.rs @@ -1,37 +1,291 @@ //! ONNX-backed `NliVerifier` adapter (mDeBERTa-v3 XNLI). //! -//! PR-9a: scaffolding only. `new` succeeds against the default `Config` -//! and `score` returns an explicit `"PR-9a stub"` error so any caller that -//! wires this up before PR-9b lands gets a loud failure instead of silent -//! all-zero scores. PR-9b will add ort `Session` + `Tokenizer` lazy init -//! and real inference. +//! PR-9b: real implementation. `new` resolves the cache directory from +//! `config.storage.model_dir/nli//` (matching the +//! fastembed adapter's pattern of `model_dir/fastembed/`) and stamps it +//! on `self`. The (potentially network-bound) model + tokenizer download +//! is deferred to the first `score` call via `OnceLock` / +//! `OnceLock` — keeping `new` cheap so the rag crate can +//! construct the verifier eagerly during `App` boot without paying for +//! a model load on every CLI invocation. +//! +//! Per design §2.2.2 (Lazy init), §2.2.3 (truncation = `OnlyFirst`, +//! premise truncates, hypothesis preserved). PR-9c-1 will wire the +//! `[models.nli]` config section; until then the model id is hard-coded +//! to the Xenova mDeBERTa-v3 XNLI multilingual checkpoint. + +use std::path::PathBuf; +use std::sync::OnceLock; + +use anyhow::{Context, Result, anyhow}; +use kebab_config::expand_path; +use ort::session::Session; +use tokenizers::{ + Tokenizer, TruncationDirection, TruncationParams, TruncationStrategy, +}; use crate::{NliScores, NliVerifier}; +/// Default HuggingFace model id for the XNLI verifier. PR-9c-1 will +/// replace this constant with a `config.models.nli.model` lookup once +/// the `NliCfg` section lands. The Xenova repo packages the +/// mDeBERTa-v3-base XNLI multilingual checkpoint as ONNX under the +/// `onnx/model.onnx` path; the tokenizer ships at `tokenizer.json`. +const DEFAULT_MODEL_ID: &str = "Xenova/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7"; + +/// Filename inside the HF repo (NOT a path on disk). +const HF_MODEL_FILE: &str = "onnx/model.onnx"; +/// Filename inside the HF repo (NOT a path on disk). +const HF_TOKENIZER_FILE: &str = "tokenizer.json"; + +/// Subdirectory under `config.storage.model_dir` where the NLI adapter +/// writes / reads ONNX + tokenizer files. Mirrors the fastembed +/// adapter's `model_dir/fastembed/` layout. +const NLI_CACHE_SUBDIR: &str = "nli"; + +/// XNLI label order in the Xenova mDeBERTa-v3 checkpoint: the model's +/// output logits are `[entailment, neutral, contradiction]`. Pinned as +/// a constant so a future model swap (different label order) is a +/// single-site change. +const LOGITS_LEN: usize = 3; + +/// Max input length passed to the tokenizer. mDeBERTa-v3 is trained +/// at 512-token context, matches the Xenova ONNX export's positional +/// embedding shape. `OnlyFirst` strategy makes the premise (which is +/// allowed to be the packed-chunks context) absorb the truncation; +/// the hypothesis (the generated answer) is preserved. +const MAX_TOKENS: usize = 512; + /// ONNX-runtime mDeBERTa-v3 XNLI verifier. /// -/// PR-9a scaffolding holds no state — fields land in PR-9b -/// (`model_id`, `cache_dir`, `session: OnceLock`, -/// `tokenizer: OnceLock`). +/// `session` + `tokenizer` are lazily populated by the first call to +/// `ensure_loaded`. `new` is eager only for cache_dir create_dir_all +/// (cheap) so that the rag crate can construct an instance during +/// `App` boot without paying for the ~280 MB model download. pub struct OnnxNliVerifier { - _private: (), + model_id: String, + cache_dir: PathBuf, + session: OnceLock, + tokenizer: OnceLock, } impl OnnxNliVerifier { - /// Construct a verifier from the user's `Config`. PR-9a always returns - /// `Ok` because the real model + tokenizer download is deferred to - /// PR-9b's first `score` call. - pub fn new(_config: &kebab_config::Config) -> anyhow::Result { - Ok(Self { _private: () }) + /// Construct a verifier from the user's `Config`. Eagerly resolves + /// `cache_dir = config.storage.model_dir/nli//` + /// and runs `create_dir_all` so the first `score` call can drop + /// straight into download + load without re-deriving paths. + /// + /// PR-9c-1 will swap `DEFAULT_MODEL_ID` for `config.models.nli.model`. + pub fn new(config: &kebab_config::Config) -> Result { + let model_id = DEFAULT_MODEL_ID.to_string(); + + // Match kebab-embed-local's two-step expansion: data_dir first, + // then model_dir with `{data_dir}` substituted in. + let data_dir = expand_path(&config.storage.data_dir, ""); + let model_dir = expand_path(&config.storage.model_dir, &data_dir.to_string_lossy()); + let cache_dir = model_dir + .join(NLI_CACHE_SUBDIR) + .join(sanitize_model_id(&model_id)); + std::fs::create_dir_all(&cache_dir) + .with_context(|| format!("create kebab-nli cache dir {}", cache_dir.display()))?; + + Ok(Self { + model_id, + cache_dir, + session: OnceLock::new(), + tokenizer: OnceLock::new(), + }) + } + + /// Download (if needed) + load the ONNX session and tokenizer on + /// first call; return cached refs on subsequent calls. Uses two + /// `OnceLock`s rather than one because a single `OnceLock<(_, _)>` + /// would need to construct both atomically — keeping them split + /// lets us short-circuit on the (rare) hit path where only one + /// side is missing. + /// + /// `OnceLock::get_or_try_init` is still unstable (rust-lang/rust#109737) + /// so we implement the fallible init by hand: probe `get`, on miss + /// compute the value, then `set` it. The race between two threads is + /// resolved by `OnceLock::set` — the loser gets `Err`, falls through + /// to a second `get`, and reads the winner's value. Each thread that + /// races + loses does pay the cost of one redundant download (rare in + /// practice: rag boot is single-threaded today), but the cache stays + /// consistent. + fn ensure_loaded(&self) -> Result<(&Session, &Tokenizer)> { + if self.session.get().is_none() { + let s = self.load_session()?; + let _ = self.session.set(s); // loser of a race: discard local value + } + if self.tokenizer.get().is_none() { + let t = self.load_tokenizer()?; + let _ = self.tokenizer.set(t); + } + // Both OnceLocks are populated at this point; `expect` is a + // tighter post-condition than `unwrap_or_else` would be. + let session = self.session.get().expect("session populated above"); + let tokenizer = self.tokenizer.get().expect("tokenizer populated above"); + Ok((session, tokenizer)) + } + + /// Build an `hf_hub::api::sync::Api` rooted at `self.cache_dir` and + /// fetch `filename` from `self.model_id`. Logs cache hits at INFO + /// so a user reading kebab logs can see which artifact source the + /// pipeline picked. + fn fetch(&self, filename: &str) -> Result { + let api = hf_hub::api::sync::ApiBuilder::new() + .with_cache_dir(self.cache_dir.clone()) + .build() + .with_context(|| { + format!( + "kebab-nli: hf-hub ApiBuilder::build failed (cache_dir={})", + self.cache_dir.display() + ) + })?; + let repo = api.model(self.model_id.clone()); + + // `ApiRepo::get` returns the local path if cached, otherwise + // downloads. We can't tell after the fact whether the file + // was already cached without an extra `Cache::repo::get` + // probe, so do that probe first to emit the right log line. + let cache_path = api + .repo(hf_hub::Repo::new( + self.model_id.clone(), + hf_hub::RepoType::Model, + )) + .get(filename) + .ok(); + if cache_path.is_some() { + tracing::info!( + target: "kebab-nli", + model_id = %self.model_id, + file = %filename, + "NLI artifact cache hit" + ); + } else { + tracing::info!( + target: "kebab-nli", + model_id = %self.model_id, + file = %filename, + cache_dir = %self.cache_dir.display(), + "downloading NLI artifact" + ); + } + + repo.get(filename).with_context(|| { + format!( + "kebab-nli: hf-hub fetch failed for {filename} (model_id={}, cache_dir={})", + self.model_id, + self.cache_dir.display() + ) + }) + } + + fn load_session(&self) -> Result { + tracing::info!( + target: "kebab-nli", + model_id = %self.model_id, + "downloading NLI model + tokenizer (first run only)" + ); + let model_path = self.fetch(HF_MODEL_FILE)?; + let session = Session::builder() + .with_context(|| "kebab-nli: ort Session::builder failed")? + .commit_from_file(&model_path) + .with_context(|| { + format!( + "kebab-nli: ort Session::commit_from_file({}) failed", + model_path.display() + ) + })?; + tracing::info!( + target: "kebab-nli", + model_id = %self.model_id, + model_path = %model_path.display(), + "NLI model ready" + ); + Ok(session) + } + + fn load_tokenizer(&self) -> Result { + let tokenizer_path = self.fetch(HF_TOKENIZER_FILE)?; + let mut tokenizer = Tokenizer::from_file(&tokenizer_path) + .map_err(|e| anyhow!("kebab-nli: Tokenizer::from_file({}) failed: {e}", tokenizer_path.display()))?; + tokenizer + .with_truncation(Some(TruncationParams { + max_length: MAX_TOKENS, + strategy: TruncationStrategy::OnlyFirst, + stride: 0, + direction: TruncationDirection::Right, + })) + .map_err(|e| anyhow!("kebab-nli: Tokenizer::with_truncation failed: {e}"))?; + Ok(tokenizer) } } impl NliVerifier for OnnxNliVerifier { - fn score(&self, _premise: &str, _hypothesis: &str) -> anyhow::Result { - anyhow::bail!("PR-9a stub — ONNX inference lands in PR-9b") + fn score(&self, premise: &str, hypothesis: &str) -> Result { + // Defense-in-depth: spec §2.3 has the caller skip empty answers, + // but a degenerate empty hypothesis here would tokenize to a + // [CLS][SEP][SEP] triple that yields a near-uniform softmax — + // misleading both faithfulness gate and any future logging. + if hypothesis.trim().is_empty() { + anyhow::bail!("kebab-nli: empty hypothesis"); + } + + let (session, tokenizer) = self.ensure_loaded()?; + + let enc = tokenizer + .encode((premise, hypothesis), true) + .map_err(|e| anyhow!("kebab-nli: tokenizer.encode failed: {e}"))?; + + let ids: Vec = enc.get_ids().iter().map(|&u| u as i64).collect(); + let mask: Vec = enc + .get_attention_mask() + .iter() + .map(|&u| u as i64) + .collect(); + let seq_len = ids.len(); + + // mDeBERTa-v3 ONNX export expects [batch, seq_len] for both + // input_ids and attention_mask. We always feed batch=1. + let ids_arr = ndarray::Array2::from_shape_vec((1, seq_len), ids) + .with_context(|| "kebab-nli: input_ids ndarray shape build failed")?; + let mask_arr = ndarray::Array2::from_shape_vec((1, seq_len), mask) + .with_context(|| "kebab-nli: attention_mask ndarray shape build failed")?; + + let outputs = session + .run(ort::inputs! { + "input_ids" => ids_arr, + "attention_mask" => mask_arr, + }?) + .with_context(|| "kebab-nli: ort Session::run failed")?; + + let logits = outputs["logits"] + .try_extract_tensor::() + .with_context(|| "kebab-nli: logits try_extract_tensor:: failed")?; + + // Expected shape [1, 3]. Defensive check — a model swap with a + // different head would silently produce wrong scores otherwise. + let shape = logits.shape(); + if shape != [1, LOGITS_LEN] { + anyhow::bail!( + "kebab-nli: unexpected logits shape {:?}, expected [1, {LOGITS_LEN}]", + shape + ); + } + let l = [logits[[0, 0]], logits[[0, 1]], logits[[0, 2]]]; + Ok(NliScores::from_xnli_logits(l)) } } +/// Make a HuggingFace model id (`"owner/repo"`) into a single +/// path component safe to use as a directory name. `/` → `_` is +/// enough for current ids; if more exotic chars appear we'll +/// widen this then. +fn sanitize_model_id(s: &str) -> String { + s.replace('/', "_") +} + #[cfg(test)] mod tests { use super::*; @@ -41,17 +295,29 @@ mod tests { fn new_succeeds_on_default_config() { let cfg = Config::defaults(); let v = OnnxNliVerifier::new(&cfg).expect("new should succeed on default config"); - // Silence unused-binding lint without weakening the assertion. - let _ = &v; + // cache_dir must include the sanitized model id (no '/'). + let s = v.cache_dir.to_string_lossy(); + assert!(s.contains(NLI_CACHE_SUBDIR), "cache_dir lacks nli/: {s}"); + assert!( + !s.contains("Xenova/mDeBERTa"), + "cache_dir must sanitize '/' in model id: {s}" + ); + assert!( + s.contains("Xenova_mDeBERTa"), + "cache_dir should contain sanitized id: {s}" + ); } + /// Empty hypothesis takes the defense-in-depth early bail path — + /// reaches no model load, so this is a pure unit test (no network). + /// Replaces PR-9a's `score_returns_err_in_skeleton` (stub-only). #[test] - fn score_returns_err_in_skeleton() { + fn score_empty_hypothesis_returns_err() { let cfg = Config::defaults(); let v = OnnxNliVerifier::new(&cfg).unwrap(); - let err = v.score("a", "b").expect_err("PR-9a stub must error"); + let err = v.score("anything", "").expect_err("empty hypothesis must error"); assert!( - err.to_string().contains("PR-9a stub"), + err.to_string().contains("empty hypothesis"), "unexpected error message: {err}" ); } diff --git a/crates/kebab-nli/tests/inference.rs b/crates/kebab-nli/tests/inference.rs new file mode 100644 index 0000000..bdcf05c --- /dev/null +++ b/crates/kebab-nli/tests/inference.rs @@ -0,0 +1,140 @@ +//! Integration tests for `OnnxNliVerifier` against the real +//! mDeBERTa-v3 XNLI model. Every test is `#[ignore]` — plain +//! `cargo test -p kebab-nli` skips them; run explicitly with +//! `cargo test -p kebab-nli --test inference -- --ignored` to +//! exercise the (slow + network-bound on first run) inference path. +//! +//! First test in the file triggers the ~280 MB ONNX + ~16 MB +//! tokenizer download into `config.storage.model_dir/nli/...`; +//! subsequent tests hit the OnceLock cache for free. + +use kebab_config::Config; +use kebab_nli::{NliVerifier, OnnxNliVerifier}; + +/// Test 1: an English statement entails itself with high confidence. +/// Smoke evidence captured for the PR description's `## 검증` section. +#[test] +#[ignore] +fn en_self_entailment_high_score() { + let cfg = Config::defaults(); + let v = OnnxNliVerifier::new(&cfg).expect("verifier construction"); + let premise = "Caffeine is a stimulant."; + let hypothesis = "Caffeine is a stimulant."; + let s = v.score(premise, hypothesis).expect("score should succeed"); + eprintln!( + "[test1 en_self_entailment_high_score] premise={premise:?} hypothesis={hypothesis:?} \ + scores: entailment={:.4}, neutral={:.4}, contradiction={:.4}", + s.entailment, s.neutral, s.contradiction + ); + assert!( + s.entailment > 0.8, + "expected entailment > 0.8, got {:.4} (full scores: {:?})", + s.entailment, + s + ); +} + +/// Test 2: an unrelated chemistry fact does NOT entail the premise. +/// Entailment should be low — neutral / contradiction wins. +#[test] +#[ignore] +fn en_unrelated_low_entailment() { + let cfg = Config::defaults(); + let v = OnnxNliVerifier::new(&cfg).expect("verifier construction"); + let premise = "Caffeine is a stimulant."; + let hypothesis = "The chemical formula of caffeine is C8H10N4O2."; + let s = v.score(premise, hypothesis).expect("score should succeed"); + eprintln!( + "[test2 en_unrelated_low_entailment] \ + scores: entailment={:.4}, neutral={:.4}, contradiction={:.4}", + s.entailment, s.neutral, s.contradiction + ); + assert!( + s.entailment < 0.3, + "expected entailment < 0.3, got {:.4} (full scores: {:?})", + s.entailment, + s + ); +} + +/// Test 3: Korean entailment. The threshold is intentionally generous +/// (> 0.5) because cross-lingual XNLI is noisier than English-only. +#[test] +#[ignore] +fn ko_entailment_high_score() { + let cfg = Config::defaults(); + let v = OnnxNliVerifier::new(&cfg).expect("verifier construction"); + let premise = "사과는 빨갛다."; + let hypothesis = "사과는 색이 있다."; + let s = v.score(premise, hypothesis).expect("score should succeed"); + eprintln!( + "[test3 ko_entailment_high_score] \ + scores: entailment={:.4}, neutral={:.4}, contradiction={:.4}", + s.entailment, s.neutral, s.contradiction + ); + assert!( + s.entailment > 0.5, + "expected entailment > 0.5, got {:.4} (full scores: {:?})", + s.entailment, + s + ); +} + +/// Test 4: a > 24 000-char premise must not panic. mDeBERTa-v3 is +/// trained at 512 tokens; the `OnlyFirst` truncation strategy keeps +/// the premise side from blowing the positional embedding cap. +#[test] +#[ignore] +fn long_premise_truncates_without_panic() { + let cfg = Config::defaults(); + let v = OnnxNliVerifier::new(&cfg).expect("verifier construction"); + let premise = "foo bar baz ".repeat(2000); // ~24 000 chars + let hypothesis = "foo"; + let s = v + .score(&premise, hypothesis) + .expect("score should succeed on long premise"); + eprintln!( + "[test4 long_premise_truncates_without_panic] premise_len={} \ + scores: entailment={:.4}, neutral={:.4}, contradiction={:.4}", + premise.len(), + s.entailment, + s.neutral, + s.contradiction + ); + // No NaN / infinity in any channel. + for (name, x) in [ + ("entailment", s.entailment), + ("neutral", s.neutral), + ("contradiction", s.contradiction), + ] { + assert!( + x.is_finite(), + "channel {name} non-finite: {x} (full scores: {:?})", + s + ); + } + // Softmax invariant — the three channels sum to ~1. + let sum = s.entailment + s.neutral + s.contradiction; + assert!( + (sum - 1.0).abs() < 1e-3, + "softmax channels must sum to ~1, got {sum:.6}" + ); +} + +/// Test 5: an empty hypothesis triggers the defense-in-depth bail +/// path BEFORE the tokenizer runs. Hits no network — fast, even on +/// a fresh machine. +#[test] +#[ignore] +fn empty_hypothesis_returns_err() { + let cfg = Config::defaults(); + let v = OnnxNliVerifier::new(&cfg).expect("verifier construction"); + let err = v + .score("anything", "") + .expect_err("empty hypothesis must error"); + let msg = err.to_string(); + assert!( + msg.contains("empty hypothesis"), + "expected 'empty hypothesis' in error, got: {msg}" + ); +}