feat(nli): fb-41 PR-9b — OnnxNliVerifier 의 ONNX inference + model download

- OnnxNliVerifier fields: model_id, cache_dir (XDG model_dir/nli/<sanitized>), session/tokenizer OnceLock.
- new(): eager cache_dir stamp만 — actual model download + Session::commit_from_file 는 첫 score 호출 시 ensure_loaded() 가 lazy 수행.
- score(): ensure_loaded → tokenizer.encode(pair, OnlyFirst truncation max_length=512) → ndarray Array2<i64> → ort::Session::run → logits[1,3] → NliScores::from_xnli_logits.
- empty hypothesis edge: defense-in-depth bail (spec §2.3 의 caller-side skip 외 추가).
- sanitize_model_id helper: "/" → "_".
- 5 #[ignore] integration tests (EN self-entailment, EN unrelated, KR entailment, long premise truncation, empty hypothesis err) — manual smoke 가 PR description 첨부.

Cargo.toml: `download-binaries` feature 를 kebab-nli 의 ort dep 에 활성화 (PR-9b prep commit 의 후속). 단독 `cargo test -p kebab-nli` 의 per-crate feature 유니온은 fastembed 없이 ort/download-binaries 가 OFF 되어 ort-sys link 가 실패 — kebab-nli 측에서 명시적으로 켜 줘야 standalone build 가 ONNX 런타임 link 됨. workspace 전체 빌드에서는 fastembed 의 동일 opt-in 과 union 되어 부작용 없음.

Verification:
- cargo test -p kebab-nli -j 1 — PR-9a 의 6 unit pass (`score_returns_err_in_skeleton` → `score_empty_hypothesis_returns_err` 로 stub→실 path 갱신, 갯수 유지).
- cargo clippy -p kebab-nli --all-targets -- -D warnings clean.
- cargo build --workspace -j 1 — 회귀 0.
- Manual --ignored smoke 결과 PR body 첨부.

Wire 영향: 없음 (crate-internal).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-25 21:56:22 +00:00
parent 93436f9eca
commit b807fd5aa5
3 changed files with 434 additions and 22 deletions

View File

@@ -14,7 +14,13 @@ description = "fb-41: NLI-based post-synthesis verification (XNLI mDeBERTa-v3)
kebab-config = { path = "../kebab-config" }
anyhow = { workspace = true }
serde = { workspace = true }
ort = { workspace = true }
# ort: extend the workspace pin with `download-binaries` so kebab-nli
# can link the ONNX runtime when fastembed is NOT in the build graph
# (e.g. `cargo test -p kebab-nli` alone, where the per-crate feature
# union excludes kebab-embed-local + fastembed). In workspace-wide
# builds the feature gets union'd with fastembed's identical opt-in
# so no extra runtime gets pulled.
ort = { workspace = true, features = ["download-binaries"] }
tokenizers = { workspace = true }
hf-hub = { workspace = true }
ndarray = { workspace = true }

View File

@@ -1,37 +1,291 @@
//! ONNX-backed `NliVerifier` adapter (mDeBERTa-v3 XNLI).
//!
//! PR-9a: scaffolding only. `new` succeeds against the default `Config`
//! and `score` returns an explicit `"PR-9a stub"` error so any caller that
//! wires this up before PR-9b lands gets a loud failure instead of silent
//! all-zero scores. PR-9b will add ort `Session` + `Tokenizer` lazy init
//! and real inference.
//! PR-9b: real implementation. `new` resolves the cache directory from
//! `config.storage.model_dir/nli/<sanitized-model-id>/` (matching the
//! fastembed adapter's pattern of `model_dir/fastembed/`) and stamps it
//! on `self`. The (potentially network-bound) model + tokenizer download
//! is deferred to the first `score` call via `OnceLock<Session>` /
//! `OnceLock<Tokenizer>` — keeping `new` cheap so the rag crate can
//! construct the verifier eagerly during `App` boot without paying for
//! a model load on every CLI invocation.
//!
//! Per design §2.2.2 (Lazy init), §2.2.3 (truncation = `OnlyFirst`,
//! premise truncates, hypothesis preserved). PR-9c-1 will wire the
//! `[models.nli]` config section; until then the model id is hard-coded
//! to the Xenova mDeBERTa-v3 XNLI multilingual checkpoint.
use std::path::PathBuf;
use std::sync::OnceLock;
use anyhow::{Context, Result, anyhow};
use kebab_config::expand_path;
use ort::session::Session;
use tokenizers::{
Tokenizer, TruncationDirection, TruncationParams, TruncationStrategy,
};
use crate::{NliScores, NliVerifier};
/// Default HuggingFace model id for the XNLI verifier. PR-9c-1 will
/// replace this constant with a `config.models.nli.model` lookup once
/// the `NliCfg` section lands. The Xenova repo packages the
/// mDeBERTa-v3-base XNLI multilingual checkpoint as ONNX under the
/// `onnx/model.onnx` path; the tokenizer ships at `tokenizer.json`.
const DEFAULT_MODEL_ID: &str = "Xenova/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7";
/// Filename inside the HF repo (NOT a path on disk).
const HF_MODEL_FILE: &str = "onnx/model.onnx";
/// Filename inside the HF repo (NOT a path on disk).
const HF_TOKENIZER_FILE: &str = "tokenizer.json";
/// Subdirectory under `config.storage.model_dir` where the NLI adapter
/// writes / reads ONNX + tokenizer files. Mirrors the fastembed
/// adapter's `model_dir/fastembed/` layout.
const NLI_CACHE_SUBDIR: &str = "nli";
/// XNLI label order in the Xenova mDeBERTa-v3 checkpoint: the model's
/// output logits are `[entailment, neutral, contradiction]`. Pinned as
/// a constant so a future model swap (different label order) is a
/// single-site change.
const LOGITS_LEN: usize = 3;
/// Max input length passed to the tokenizer. mDeBERTa-v3 is trained
/// at 512-token context, matches the Xenova ONNX export's positional
/// embedding shape. `OnlyFirst` strategy makes the premise (which is
/// allowed to be the packed-chunks context) absorb the truncation;
/// the hypothesis (the generated answer) is preserved.
const MAX_TOKENS: usize = 512;
/// ONNX-runtime mDeBERTa-v3 XNLI verifier.
///
/// PR-9a scaffolding holds no state — fields land in PR-9b
/// (`model_id`, `cache_dir`, `session: OnceLock<ort::Session>`,
/// `tokenizer: OnceLock<tokenizers::Tokenizer>`).
/// `session` + `tokenizer` are lazily populated by the first call to
/// `ensure_loaded`. `new` is eager only for cache_dir create_dir_all
/// (cheap) so that the rag crate can construct an instance during
/// `App` boot without paying for the ~280 MB model download.
pub struct OnnxNliVerifier {
_private: (),
model_id: String,
cache_dir: PathBuf,
session: OnceLock<Session>,
tokenizer: OnceLock<Tokenizer>,
}
impl OnnxNliVerifier {
/// Construct a verifier from the user's `Config`. PR-9a always returns
/// `Ok` because the real model + tokenizer download is deferred to
/// PR-9b's first `score` call.
pub fn new(_config: &kebab_config::Config) -> anyhow::Result<Self> {
Ok(Self { _private: () })
/// Construct a verifier from the user's `Config`. Eagerly resolves
/// `cache_dir = config.storage.model_dir/nli/<sanitized-model-id>/`
/// and runs `create_dir_all` so the first `score` call can drop
/// straight into download + load without re-deriving paths.
///
/// PR-9c-1 will swap `DEFAULT_MODEL_ID` for `config.models.nli.model`.
pub fn new(config: &kebab_config::Config) -> Result<Self> {
let model_id = DEFAULT_MODEL_ID.to_string();
// Match kebab-embed-local's two-step expansion: data_dir first,
// then model_dir with `{data_dir}` substituted in.
let data_dir = expand_path(&config.storage.data_dir, "");
let model_dir = expand_path(&config.storage.model_dir, &data_dir.to_string_lossy());
let cache_dir = model_dir
.join(NLI_CACHE_SUBDIR)
.join(sanitize_model_id(&model_id));
std::fs::create_dir_all(&cache_dir)
.with_context(|| format!("create kebab-nli cache dir {}", cache_dir.display()))?;
Ok(Self {
model_id,
cache_dir,
session: OnceLock::new(),
tokenizer: OnceLock::new(),
})
}
/// Download (if needed) + load the ONNX session and tokenizer on
/// first call; return cached refs on subsequent calls. Uses two
/// `OnceLock`s rather than one because a single `OnceLock<(_, _)>`
/// would need to construct both atomically — keeping them split
/// lets us short-circuit on the (rare) hit path where only one
/// side is missing.
///
/// `OnceLock::get_or_try_init` is still unstable (rust-lang/rust#109737)
/// so we implement the fallible init by hand: probe `get`, on miss
/// compute the value, then `set` it. The race between two threads is
/// resolved by `OnceLock::set` — the loser gets `Err`, falls through
/// to a second `get`, and reads the winner's value. Each thread that
/// races + loses does pay the cost of one redundant download (rare in
/// practice: rag boot is single-threaded today), but the cache stays
/// consistent.
fn ensure_loaded(&self) -> Result<(&Session, &Tokenizer)> {
if self.session.get().is_none() {
let s = self.load_session()?;
let _ = self.session.set(s); // loser of a race: discard local value
}
if self.tokenizer.get().is_none() {
let t = self.load_tokenizer()?;
let _ = self.tokenizer.set(t);
}
// Both OnceLocks are populated at this point; `expect` is a
// tighter post-condition than `unwrap_or_else` would be.
let session = self.session.get().expect("session populated above");
let tokenizer = self.tokenizer.get().expect("tokenizer populated above");
Ok((session, tokenizer))
}
/// Build an `hf_hub::api::sync::Api` rooted at `self.cache_dir` and
/// fetch `filename` from `self.model_id`. Logs cache hits at INFO
/// so a user reading kebab logs can see which artifact source the
/// pipeline picked.
fn fetch(&self, filename: &str) -> Result<PathBuf> {
let api = hf_hub::api::sync::ApiBuilder::new()
.with_cache_dir(self.cache_dir.clone())
.build()
.with_context(|| {
format!(
"kebab-nli: hf-hub ApiBuilder::build failed (cache_dir={})",
self.cache_dir.display()
)
})?;
let repo = api.model(self.model_id.clone());
// `ApiRepo::get` returns the local path if cached, otherwise
// downloads. We can't tell after the fact whether the file
// was already cached without an extra `Cache::repo::get`
// probe, so do that probe first to emit the right log line.
let cache_path = api
.repo(hf_hub::Repo::new(
self.model_id.clone(),
hf_hub::RepoType::Model,
))
.get(filename)
.ok();
if cache_path.is_some() {
tracing::info!(
target: "kebab-nli",
model_id = %self.model_id,
file = %filename,
"NLI artifact cache hit"
);
} else {
tracing::info!(
target: "kebab-nli",
model_id = %self.model_id,
file = %filename,
cache_dir = %self.cache_dir.display(),
"downloading NLI artifact"
);
}
repo.get(filename).with_context(|| {
format!(
"kebab-nli: hf-hub fetch failed for {filename} (model_id={}, cache_dir={})",
self.model_id,
self.cache_dir.display()
)
})
}
fn load_session(&self) -> Result<Session> {
tracing::info!(
target: "kebab-nli",
model_id = %self.model_id,
"downloading NLI model + tokenizer (first run only)"
);
let model_path = self.fetch(HF_MODEL_FILE)?;
let session = Session::builder()
.with_context(|| "kebab-nli: ort Session::builder failed")?
.commit_from_file(&model_path)
.with_context(|| {
format!(
"kebab-nli: ort Session::commit_from_file({}) failed",
model_path.display()
)
})?;
tracing::info!(
target: "kebab-nli",
model_id = %self.model_id,
model_path = %model_path.display(),
"NLI model ready"
);
Ok(session)
}
fn load_tokenizer(&self) -> Result<Tokenizer> {
let tokenizer_path = self.fetch(HF_TOKENIZER_FILE)?;
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow!("kebab-nli: Tokenizer::from_file({}) failed: {e}", tokenizer_path.display()))?;
tokenizer
.with_truncation(Some(TruncationParams {
max_length: MAX_TOKENS,
strategy: TruncationStrategy::OnlyFirst,
stride: 0,
direction: TruncationDirection::Right,
}))
.map_err(|e| anyhow!("kebab-nli: Tokenizer::with_truncation failed: {e}"))?;
Ok(tokenizer)
}
}
impl NliVerifier for OnnxNliVerifier {
fn score(&self, _premise: &str, _hypothesis: &str) -> anyhow::Result<NliScores> {
anyhow::bail!("PR-9a stub — ONNX inference lands in PR-9b")
fn score(&self, premise: &str, hypothesis: &str) -> Result<NliScores> {
// Defense-in-depth: spec §2.3 has the caller skip empty answers,
// but a degenerate empty hypothesis here would tokenize to a
// [CLS][SEP][SEP] triple that yields a near-uniform softmax —
// misleading both faithfulness gate and any future logging.
if hypothesis.trim().is_empty() {
anyhow::bail!("kebab-nli: empty hypothesis");
}
let (session, tokenizer) = self.ensure_loaded()?;
let enc = tokenizer
.encode((premise, hypothesis), true)
.map_err(|e| anyhow!("kebab-nli: tokenizer.encode failed: {e}"))?;
let ids: Vec<i64> = enc.get_ids().iter().map(|&u| u as i64).collect();
let mask: Vec<i64> = enc
.get_attention_mask()
.iter()
.map(|&u| u as i64)
.collect();
let seq_len = ids.len();
// mDeBERTa-v3 ONNX export expects [batch, seq_len] for both
// input_ids and attention_mask. We always feed batch=1.
let ids_arr = ndarray::Array2::from_shape_vec((1, seq_len), ids)
.with_context(|| "kebab-nli: input_ids ndarray shape build failed")?;
let mask_arr = ndarray::Array2::from_shape_vec((1, seq_len), mask)
.with_context(|| "kebab-nli: attention_mask ndarray shape build failed")?;
let outputs = session
.run(ort::inputs! {
"input_ids" => ids_arr,
"attention_mask" => mask_arr,
}?)
.with_context(|| "kebab-nli: ort Session::run failed")?;
let logits = outputs["logits"]
.try_extract_tensor::<f32>()
.with_context(|| "kebab-nli: logits try_extract_tensor::<f32> failed")?;
// Expected shape [1, 3]. Defensive check — a model swap with a
// different head would silently produce wrong scores otherwise.
let shape = logits.shape();
if shape != [1, LOGITS_LEN] {
anyhow::bail!(
"kebab-nli: unexpected logits shape {:?}, expected [1, {LOGITS_LEN}]",
shape
);
}
let l = [logits[[0, 0]], logits[[0, 1]], logits[[0, 2]]];
Ok(NliScores::from_xnli_logits(l))
}
}
/// Make a HuggingFace model id (`"owner/repo"`) into a single
/// path component safe to use as a directory name. `/` → `_` is
/// enough for current ids; if more exotic chars appear we'll
/// widen this then.
fn sanitize_model_id(s: &str) -> String {
s.replace('/', "_")
}
#[cfg(test)]
mod tests {
use super::*;
@@ -41,17 +295,29 @@ mod tests {
fn new_succeeds_on_default_config() {
let cfg = Config::defaults();
let v = OnnxNliVerifier::new(&cfg).expect("new should succeed on default config");
// Silence unused-binding lint without weakening the assertion.
let _ = &v;
// cache_dir must include the sanitized model id (no '/').
let s = v.cache_dir.to_string_lossy();
assert!(s.contains(NLI_CACHE_SUBDIR), "cache_dir lacks nli/: {s}");
assert!(
!s.contains("Xenova/mDeBERTa"),
"cache_dir must sanitize '/' in model id: {s}"
);
assert!(
s.contains("Xenova_mDeBERTa"),
"cache_dir should contain sanitized id: {s}"
);
}
/// Empty hypothesis takes the defense-in-depth early bail path —
/// reaches no model load, so this is a pure unit test (no network).
/// Replaces PR-9a's `score_returns_err_in_skeleton` (stub-only).
#[test]
fn score_returns_err_in_skeleton() {
fn score_empty_hypothesis_returns_err() {
let cfg = Config::defaults();
let v = OnnxNliVerifier::new(&cfg).unwrap();
let err = v.score("a", "b").expect_err("PR-9a stub must error");
let err = v.score("anything", "").expect_err("empty hypothesis must error");
assert!(
err.to_string().contains("PR-9a stub"),
err.to_string().contains("empty hypothesis"),
"unexpected error message: {err}"
);
}

View File

@@ -0,0 +1,140 @@
//! Integration tests for `OnnxNliVerifier` against the real
//! mDeBERTa-v3 XNLI model. Every test is `#[ignore]` — plain
//! `cargo test -p kebab-nli` skips them; run explicitly with
//! `cargo test -p kebab-nli --test inference -- --ignored` to
//! exercise the (slow + network-bound on first run) inference path.
//!
//! First test in the file triggers the ~280 MB ONNX + ~16 MB
//! tokenizer download into `config.storage.model_dir/nli/...`;
//! subsequent tests hit the OnceLock cache for free.
use kebab_config::Config;
use kebab_nli::{NliVerifier, OnnxNliVerifier};
/// Test 1: an English statement entails itself with high confidence.
/// Smoke evidence captured for the PR description's `## 검증` section.
#[test]
#[ignore]
fn en_self_entailment_high_score() {
let cfg = Config::defaults();
let v = OnnxNliVerifier::new(&cfg).expect("verifier construction");
let premise = "Caffeine is a stimulant.";
let hypothesis = "Caffeine is a stimulant.";
let s = v.score(premise, hypothesis).expect("score should succeed");
eprintln!(
"[test1 en_self_entailment_high_score] premise={premise:?} hypothesis={hypothesis:?} \
scores: entailment={:.4}, neutral={:.4}, contradiction={:.4}",
s.entailment, s.neutral, s.contradiction
);
assert!(
s.entailment > 0.8,
"expected entailment > 0.8, got {:.4} (full scores: {:?})",
s.entailment,
s
);
}
/// Test 2: an unrelated chemistry fact does NOT entail the premise.
/// Entailment should be low — neutral / contradiction wins.
#[test]
#[ignore]
fn en_unrelated_low_entailment() {
let cfg = Config::defaults();
let v = OnnxNliVerifier::new(&cfg).expect("verifier construction");
let premise = "Caffeine is a stimulant.";
let hypothesis = "The chemical formula of caffeine is C8H10N4O2.";
let s = v.score(premise, hypothesis).expect("score should succeed");
eprintln!(
"[test2 en_unrelated_low_entailment] \
scores: entailment={:.4}, neutral={:.4}, contradiction={:.4}",
s.entailment, s.neutral, s.contradiction
);
assert!(
s.entailment < 0.3,
"expected entailment < 0.3, got {:.4} (full scores: {:?})",
s.entailment,
s
);
}
/// Test 3: Korean entailment. The threshold is intentionally generous
/// (> 0.5) because cross-lingual XNLI is noisier than English-only.
#[test]
#[ignore]
fn ko_entailment_high_score() {
let cfg = Config::defaults();
let v = OnnxNliVerifier::new(&cfg).expect("verifier construction");
let premise = "사과는 빨갛다.";
let hypothesis = "사과는 색이 있다.";
let s = v.score(premise, hypothesis).expect("score should succeed");
eprintln!(
"[test3 ko_entailment_high_score] \
scores: entailment={:.4}, neutral={:.4}, contradiction={:.4}",
s.entailment, s.neutral, s.contradiction
);
assert!(
s.entailment > 0.5,
"expected entailment > 0.5, got {:.4} (full scores: {:?})",
s.entailment,
s
);
}
/// Test 4: a > 24 000-char premise must not panic. mDeBERTa-v3 is
/// trained at 512 tokens; the `OnlyFirst` truncation strategy keeps
/// the premise side from blowing the positional embedding cap.
#[test]
#[ignore]
fn long_premise_truncates_without_panic() {
let cfg = Config::defaults();
let v = OnnxNliVerifier::new(&cfg).expect("verifier construction");
let premise = "foo bar baz ".repeat(2000); // ~24 000 chars
let hypothesis = "foo";
let s = v
.score(&premise, hypothesis)
.expect("score should succeed on long premise");
eprintln!(
"[test4 long_premise_truncates_without_panic] premise_len={} \
scores: entailment={:.4}, neutral={:.4}, contradiction={:.4}",
premise.len(),
s.entailment,
s.neutral,
s.contradiction
);
// No NaN / infinity in any channel.
for (name, x) in [
("entailment", s.entailment),
("neutral", s.neutral),
("contradiction", s.contradiction),
] {
assert!(
x.is_finite(),
"channel {name} non-finite: {x} (full scores: {:?})",
s
);
}
// Softmax invariant — the three channels sum to ~1.
let sum = s.entailment + s.neutral + s.contradiction;
assert!(
(sum - 1.0).abs() < 1e-3,
"softmax channels must sum to ~1, got {sum:.6}"
);
}
/// Test 5: an empty hypothesis triggers the defense-in-depth bail
/// path BEFORE the tokenizer runs. Hits no network — fast, even on
/// a fresh machine.
#[test]
#[ignore]
fn empty_hypothesis_returns_err() {
let cfg = Config::defaults();
let v = OnnxNliVerifier::new(&cfg).expect("verifier construction");
let err = v
.score("anything", "")
.expect_err("empty hypothesis must error");
let msg = err.to_string();
assert!(
msg.contains("empty hypothesis"),
"expected 'empty hypothesis' in error, got: {msg}"
);
}