Files
kebab/crates/kebab-nli/src/onnx.rs
altair823 7c27633df2 chore(rag): post-PR9 refactor — H1/H2/H3/D/E + test coverage + post-refactor dogfood retest
OMC team `post-pr9-refactor` 의 architectural cleanup. architect priorities 분석 후 executor + test-engineer 가 file edits, system-architect 가 component-level review 로 *pre-cut nothing — all v0.18.1+ defer* 결론.

## Executor 작업 (H1/H2/H3/D/E)

- **H1** (kebab-nli/src/onnx.rs): `[models.nli]` config wire 활성화. `DEFAULT_MODEL_ID` const 제거 (kebab-config 의 NliCfg::defaults 가 single source). OnnxNliVerifier::new 가 config.models.nli.model 읽고 config.models.nli.provider 가 "onnx" 아니면 anyhow::bail. 3 stale "PR-9c-1 will wire this" 코멘트 제거. 2 unit test 추가 (`new_uses_config_model_id`, `new_rejects_unsupported_provider`).
- **H2** (kebab-rag/src/pipeline.rs): `truncate_for_nli(premise: &str, _hypothesis: &str)` → `truncate_for_nli(premise: &str)`. v0.18.1 placeholder doc 제거. 4 callsite (tests/multi_hop.rs) 갱신 + test rename `multi_hop_truncate_for_nli_preserves_hypothesis` → `multi_hop_truncate_for_nli_char_budget` (contract 정합).
- **H3** (kebab-rag/src/pipeline.rs:1041): `was_truncated` 가 tracing::debug! 으로 surface (observability 추가, signature 보존 — caller logging contract).
- **D** (kebab-mcp/tests/tools_call_ask_multi_hop.rs): request_timeout_secs 2 → 5 (slow CI 안정성), `mh_code` discriminator 제거. dispatch contract = `mh.is_error.unwrap_or(false)` (기존 assertion 으로 충분).
- **E** (tasks/HOTFIXES.md + pipeline.rs:1633-1638): fb-41 PR-9 closure entry 의 sibling 으로 "### PR-9 NLI refusal: terminal Synthesize hop omitted from hops trace" subsection 추가. pipeline 의 "cleanup deferred to a follow-up" → "// See tasks/HOTFIXES.md ... for follow-up" cross-link.

## Test-engineer 작업 (T1/T2/T3/T4, 9 new tests)

- **T1** (kebab-nli/src/onnx.rs::tests): sanitize_model_id 3 unit (replaces_slash / idempotent / leaves_other_chars).
- **T2** (kebab-rag/tests/multi_hop_nli_panic.rs 신규): 2 panic-path tests — facade invariant (`expect("verifier must be Some when nli_threshold > 0.0")`) 의 #[should_panic] + threshold=0 의 companion.
- **T3** (kebab-rag/tests/multi_hop_nli_stream.rs 신규): 2 StreamEvent::Final tests — refuse_nli_verification + refuse_nli_model_unavailable 의 stream_sink Final 분기 wire shape pinning.
- **T4** (kebab-app/tests/open_with_config_nli.rs 신규): 2 NLI failure path — model_dir 가 unwritable 일 때 App::open_with_config 의 Result<App> Err (with "OnnxNliVerifier" in chain) + threshold=0 일 때 graceful skip.

## System-architect 결론

3 lenses (absorption / duplication / under-engineered interface) 분석 결과 — *pre-cut nothing*. Top-3 items 모두 v0.18.1+ defer:
- Lens 1: kebab-normalize + kebab-parse-types 흡수 가능 (parse-md 만 사용, 5 parsers 우회) → v0.18.1+.
- Lens 3: Extractor + Chunker trait 의 dead polymorphism (모든 callsite 가 hardcoded) → v0.18.1+.
- Lens 1 bundled: kebab-source-fs 가 kebab-parse-code 의 9 tree-sitter grammars drag → low-risk dep-graph win, v0.18.1+ bundled.
- Defer-with-intent: LanguageModel async refactor (cloud-LLM 시), NliVerifier::score_batch + typed NliError (2nd impl 시), compute_stale → kebab-core::stale.

보고서: /build/cache/tmp/post-pr9-refactor-priorities.md, /build/cache/tmp/system-architecture-priorities.md (둘 다 repo 외 — analysis 보존).

## 검증

- cargo test -p kebab-nli -j 1 → 11/11 pass.
- cargo test -p kebab-rag -j 1 → 75/75 pass (5 NLI multi-hop + 4 신규 T2/T3 포함).
- cargo test -p kebab-app -j 1 → 23 pass + 2 ignored (T4 의 2 포함).
- cargo test -p kebab-mcp --test tools_call_ask_multi_hop -j 1 → 1 pass + 1 pre-existing flaky (HOTFIX #15, no_chunks short-circuit, executor D fix 와 무관 — line 86 의 base assertion 이 fixture 없어서 fail).
- cargo clippy --workspace --all-targets -j 1 -- -D warnings clean.
- cargo test --workspace --no-fail-fast -j 1 → 1304 passed (+11 new) + 1 pre-existing flaky 동일.
- **Post-refactor dogfood retest byte-identical** (PR-9d / post-cleanup / post-refactor 3번 모두): S7 0.0035389824770390987, S1 0.058334656059741974, S10 0.0027875436935573816, S3 nli_model_unavailable.

docs/dogfood/v0.18.0/SUMMARY.md 에 "Post-architectural-refactor retest" section 추가.

Wire 영향: 없음.
Behavior 영향: 없음 (H1 의 config wiring 가 default 와 같은 model → byte-identical).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-26 04:42:37 +00:00

409 lines
17 KiB
Rust

//! ONNX-backed `NliVerifier` adapter (mDeBERTa-v3 XNLI).
//!
//! `new` resolves the cache directory from
//! `config.storage.model_dir/nli/<sanitized-model-id>/` (matching the
//! fastembed adapter's pattern of `model_dir/fastembed/`) and stamps it
//! on `self`. The (potentially network-bound) model + tokenizer download
//! is deferred to the first `score` call via `OnceLock<Session>` /
//! `OnceLock<Tokenizer>` — keeping `new` cheap so the rag crate can
//! construct the verifier eagerly during `App` boot without paying for
//! a model load on every CLI invocation.
//!
//! Per design §2.2.2 (Lazy init), §2.2.3 (truncation = `OnlyFirst`,
//! premise truncates, hypothesis preserved). The model id flows from
//! `config.models.nli.model`; `config.models.nli.provider` selects the
//! verifier impl (only `"onnx"` is implemented in v0.18).
use std::path::PathBuf;
use std::sync::OnceLock;
use anyhow::{Context, Result, anyhow};
use kebab_config::expand_path;
use ort::session::Session;
use tokenizers::{
Tokenizer, TruncationDirection, TruncationParams, TruncationStrategy,
};
use crate::{NliScores, NliVerifier};
/// Filename inside the HF repo (NOT a path on disk). The Xenova repo
/// packages the mDeBERTa-v3-base XNLI multilingual checkpoint (the
/// default `config.models.nli.model` — see `kebab-config::NliCfg::defaults`)
/// as ONNX under this path; the tokenizer ships at `tokenizer.json`.
const HF_MODEL_FILE: &str = "onnx/model.onnx";
/// Filename inside the HF repo (NOT a path on disk).
const HF_TOKENIZER_FILE: &str = "tokenizer.json";
/// Subdirectory under `config.storage.model_dir` where the NLI adapter
/// writes / reads ONNX + tokenizer files. Mirrors the fastembed
/// adapter's `model_dir/fastembed/` layout.
const NLI_CACHE_SUBDIR: &str = "nli";
/// XNLI label order in the Xenova mDeBERTa-v3 checkpoint: the model's
/// output logits are `[entailment, neutral, contradiction]`. Pinned as
/// a constant so a future model swap (different label order) is a
/// single-site change.
const LOGITS_LEN: usize = 3;
/// Max input length passed to the tokenizer. mDeBERTa-v3 is trained
/// at 512-token context, matches the Xenova ONNX export's positional
/// embedding shape. `OnlyFirst` strategy makes the premise (which is
/// allowed to be the packed-chunks context) absorb the truncation;
/// the hypothesis (the generated answer) is preserved.
const MAX_TOKENS: usize = 512;
/// ONNX-runtime mDeBERTa-v3 XNLI verifier.
///
/// `session` + `tokenizer` are lazily populated by the first call to
/// `ensure_loaded`. `new` is eager only for cache_dir create_dir_all
/// (cheap) so that the rag crate can construct an instance during
/// `App` boot without paying for the ~280 MB model download.
pub struct OnnxNliVerifier {
model_id: String,
cache_dir: PathBuf,
session: OnceLock<Session>,
tokenizer: OnceLock<Tokenizer>,
}
impl OnnxNliVerifier {
/// Construct a verifier from the user's `Config`. Eagerly resolves
/// `cache_dir = config.storage.model_dir/nli/<sanitized-model-id>/`
/// and runs `create_dir_all` so the first `score` call can drop
/// straight into download + load without re-deriving paths.
///
/// Reads `config.models.nli.model` for the HuggingFace model id
/// and `config.models.nli.provider` to select the verifier impl —
/// only `"onnx"` is implemented in v0.18. The defaults live in
/// `kebab-config::NliCfg::defaults` so this path always receives
/// a non-empty model id.
pub fn new(config: &kebab_config::Config) -> Result<Self> {
let provider = config.models.nli.provider.as_str();
if provider != "onnx" {
anyhow::bail!(
"kebab-nli: unsupported provider {provider:?} (only 'onnx' is implemented in v0.18)"
);
}
let model_id = config.models.nli.model.clone();
// Match kebab-embed-local's two-step expansion: data_dir first,
// then model_dir with `{data_dir}` substituted in.
let data_dir = expand_path(&config.storage.data_dir, "");
let model_dir = expand_path(&config.storage.model_dir, &data_dir.to_string_lossy());
let cache_dir = model_dir
.join(NLI_CACHE_SUBDIR)
.join(sanitize_model_id(&model_id));
std::fs::create_dir_all(&cache_dir)
.with_context(|| format!("create kebab-nli cache dir {}", cache_dir.display()))?;
Ok(Self {
model_id,
cache_dir,
session: OnceLock::new(),
tokenizer: OnceLock::new(),
})
}
/// Download (if needed) + load the ONNX session and tokenizer on
/// first call; return cached refs on subsequent calls. Uses two
/// `OnceLock`s rather than one because a single `OnceLock<(_, _)>`
/// would need to construct both atomically — keeping them split
/// lets us short-circuit on the (rare) hit path where only one
/// side is missing.
///
/// `OnceLock::get_or_try_init` is still unstable (rust-lang/rust#109737)
/// so we implement the fallible init by hand: probe `get`, on miss
/// compute the value, then `set` it. The race between two threads is
/// resolved by `OnceLock::set` — the loser gets `Err`, falls through
/// to a second `get`, and reads the winner's value. Each thread that
/// races + loses does pay the cost of one redundant download (rare in
/// practice: rag boot is single-threaded today), but the cache stays
/// consistent.
fn ensure_loaded(&self) -> Result<(&Session, &Tokenizer)> {
if self.session.get().is_none() {
let s = self.load_session()?;
let _ = self.session.set(s); // loser of a race: discard local value
}
if self.tokenizer.get().is_none() {
let t = self.load_tokenizer()?;
let _ = self.tokenizer.set(t);
}
// Both OnceLocks are populated at this point; `expect` is a
// tighter post-condition than `unwrap_or_else` would be.
let session = self.session.get().expect("session populated above");
let tokenizer = self.tokenizer.get().expect("tokenizer populated above");
Ok((session, tokenizer))
}
/// Build an `hf_hub::api::sync::Api` rooted at `self.cache_dir` and
/// fetch `filename` from `self.model_id`. Logs cache hits at INFO
/// so a user reading kebab logs can see which artifact source the
/// pipeline picked.
fn fetch(&self, filename: &str) -> Result<PathBuf> {
// Round-1 review N1 fix: `Api::get` triggers download on miss,
// so we can't use it as a hit probe. `Cache::get` is fs-only —
// returns Some(path) if cached, None otherwise. No network.
let repo = hf_hub::Repo::new(self.model_id.clone(), hf_hub::RepoType::Model);
let cached = hf_hub::Cache::new(self.cache_dir.clone())
.repo(repo.clone())
.get(filename)
.is_some();
if cached {
tracing::info!(
target: "kebab-nli",
model_id = %self.model_id,
file = %filename,
"NLI artifact cache hit"
);
} else {
tracing::info!(
target: "kebab-nli",
model_id = %self.model_id,
file = %filename,
cache_dir = %self.cache_dir.display(),
"downloading NLI artifact"
);
}
let api = hf_hub::api::sync::ApiBuilder::new()
.with_cache_dir(self.cache_dir.clone())
.build()
.with_context(|| {
format!(
"kebab-nli: hf-hub ApiBuilder::build failed (cache_dir={})",
self.cache_dir.display()
)
})?;
api.model(self.model_id.clone())
.get(filename)
.with_context(|| {
format!(
"kebab-nli: hf-hub fetch failed for {filename} (model_id={}, cache_dir={})",
self.model_id,
self.cache_dir.display()
)
})
}
fn load_session(&self) -> Result<Session> {
tracing::info!(
target: "kebab-nli",
model_id = %self.model_id,
"downloading NLI model + tokenizer (first run only)"
);
let model_path = self.fetch(HF_MODEL_FILE)?;
let session = Session::builder()
.with_context(|| "kebab-nli: ort Session::builder failed")?
.commit_from_file(&model_path)
.with_context(|| {
format!(
"kebab-nli: ort Session::commit_from_file({}) failed",
model_path.display()
)
})?;
tracing::info!(
target: "kebab-nli",
model_id = %self.model_id,
model_path = %model_path.display(),
"NLI model ready"
);
Ok(session)
}
fn load_tokenizer(&self) -> Result<Tokenizer> {
let tokenizer_path = self.fetch(HF_TOKENIZER_FILE)?;
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow!("kebab-nli: Tokenizer::from_file({}) failed: {e}", tokenizer_path.display()))?;
tokenizer
.with_truncation(Some(TruncationParams {
max_length: MAX_TOKENS,
strategy: TruncationStrategy::OnlyFirst,
stride: 0,
direction: TruncationDirection::Right,
}))
.map_err(|e| anyhow!("kebab-nli: Tokenizer::with_truncation failed: {e}"))?;
Ok(tokenizer)
}
}
impl NliVerifier for OnnxNliVerifier {
fn score(&self, premise: &str, hypothesis: &str) -> Result<NliScores> {
// Defense-in-depth: spec §2.3 has the caller skip empty answers,
// but a degenerate empty hypothesis here would tokenize to a
// [CLS][SEP][SEP] triple that yields a near-uniform softmax —
// misleading both faithfulness gate and any future logging.
if hypothesis.trim().is_empty() {
anyhow::bail!("kebab-nli: empty hypothesis");
}
let (session, tokenizer) = self.ensure_loaded()?;
let enc = tokenizer
.encode((premise, hypothesis), true)
.map_err(|e| anyhow!("kebab-nli: tokenizer.encode failed: {e}"))?;
let ids: Vec<i64> = enc.get_ids().iter().map(|&u| i64::from(u)).collect();
let mask: Vec<i64> = enc
.get_attention_mask()
.iter()
.map(|&u| i64::from(u))
.collect();
let seq_len = ids.len();
// mDeBERTa-v3 ONNX export expects [batch, seq_len] for both
// input_ids and attention_mask. We always feed batch=1.
let ids_arr = ndarray::Array2::from_shape_vec((1, seq_len), ids)
.with_context(|| "kebab-nli: input_ids ndarray shape build failed")?;
let mask_arr = ndarray::Array2::from_shape_vec((1, seq_len), mask)
.with_context(|| "kebab-nli: attention_mask ndarray shape build failed")?;
let outputs = session
.run(ort::inputs! {
"input_ids" => ids_arr,
"attention_mask" => mask_arr,
}?)
.with_context(|| "kebab-nli: ort Session::run failed")?;
let logits = outputs["logits"]
.try_extract_tensor::<f32>()
.with_context(|| "kebab-nli: logits try_extract_tensor::<f32> failed")?;
// Expected shape [1, 3]. Defensive check — a model swap with a
// different head would silently produce wrong scores otherwise.
let shape = logits.shape();
if shape != [1, LOGITS_LEN] {
anyhow::bail!(
"kebab-nli: unexpected logits shape {shape:?}, expected [1, {LOGITS_LEN}]"
);
}
let l = [logits[[0, 0]], logits[[0, 1]], logits[[0, 2]]];
Ok(NliScores::from_xnli_logits(l))
}
}
/// Make a HuggingFace model id (`"owner/repo"`) into a single
/// path component safe to use as a directory name. `/` → `_` is
/// enough for current ids; if more exotic chars appear we'll
/// widen this then.
fn sanitize_model_id(s: &str) -> String {
s.replace('/', "_")
}
#[cfg(test)]
mod tests {
use super::*;
use kebab_config::Config;
use tempfile::TempDir;
/// Round-1 review N2 fix: redirect Config.storage.{data,model}_dir
/// into a tempdir so unit tests don't litter the user's XDG dirs
/// with empty `nli/` subdirs.
fn tempdir_config() -> (TempDir, Config) {
let tmp = TempDir::new().expect("tempdir");
let mut cfg = Config::defaults();
cfg.storage.data_dir = tmp.path().to_string_lossy().into_owned();
cfg.storage.model_dir = "{data_dir}/models".to_string();
(tmp, cfg)
}
#[test]
fn new_succeeds_on_default_config() {
let (_tmp, cfg) = tempdir_config();
let v = OnnxNliVerifier::new(&cfg).expect("new should succeed on default config");
// cache_dir must include the sanitized model id (no '/').
let s = v.cache_dir.to_string_lossy();
assert!(s.contains(NLI_CACHE_SUBDIR), "cache_dir lacks nli/: {s}");
assert!(
!s.contains("Xenova/mDeBERTa"),
"cache_dir must sanitize '/' in model id: {s}"
);
assert!(
s.contains("Xenova_mDeBERTa"),
"cache_dir should contain sanitized id: {s}"
);
}
/// Empty hypothesis takes the defense-in-depth early bail path —
/// reaches no model load, so this is a pure unit test (no network).
/// Replaces PR-9a's `score_returns_err_in_skeleton` (stub-only).
#[test]
fn score_empty_hypothesis_returns_err() {
let (_tmp, cfg) = tempdir_config();
let v = OnnxNliVerifier::new(&cfg).unwrap();
let err = v.score("anything", "").expect_err("empty hypothesis must error");
assert!(
err.to_string().contains("empty hypothesis"),
"unexpected error message: {err}"
);
}
/// Pins that `config.models.nli.model` flows into `OnnxNliVerifier`
/// instead of being silently overridden by a hardcoded constant.
/// `model_id` is a private field, but this test lives in the same
/// module so it can read it directly — the wiring contract is
/// "whatever the user puts in TOML / KEBAB_MODELS_NLI_MODEL is the
/// id the verifier uses".
#[test]
fn new_uses_config_model_id() {
let (_tmp, mut cfg) = tempdir_config();
cfg.models.nli.model = "custom-org/custom-nli-model".to_string();
let v = OnnxNliVerifier::new(&cfg).expect("new should succeed with custom model id");
assert_eq!(v.model_id, "custom-org/custom-nli-model");
// The custom id also flows into the on-disk cache_dir layout
// (sanitized so `/` doesn't escape the namespace).
let s = v.cache_dir.to_string_lossy();
assert!(
s.contains("custom-org_custom-nli-model"),
"cache_dir should embed sanitized custom model id: {s}"
);
}
/// Pins that a non-`"onnx"` provider value errors out at `new` —
/// the field is no longer silently ignored.
#[test]
fn new_rejects_unsupported_provider() {
let (_tmp, mut cfg) = tempdir_config();
cfg.models.nli.provider = "candle".to_string();
let result = OnnxNliVerifier::new(&cfg);
assert!(result.is_err(), "non-onnx provider must error");
let msg = result.err().unwrap().to_string();
assert!(
msg.contains("unsupported provider") && msg.contains("candle"),
"error should name the rejected provider: {msg}"
);
}
// ── sanitize_model_id pure-fn coverage ────────────────────────────────
//
// Three tests pin the behavior of the private `sanitize_model_id`
// helper. These are orthogonal to the H1 executor tests above
// (which cover config-wiring); these cover the transformation
// contract of the sanitizer itself.
#[test]
fn sanitize_model_id_replaces_slash_with_underscore() {
let input = "Xenova/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7";
let expected = "Xenova_mDeBERTa-v3-base-xnli-multilingual-nli-2mil7";
assert_eq!(sanitize_model_id(input), expected);
}
#[test]
fn sanitize_model_id_is_idempotent_on_already_sanitized() {
// Input with no '/' must come back byte-for-byte unchanged.
let input = "Xenova_mDeBERTa-v3-base-xnli-multilingual-nli-2mil7";
assert_eq!(sanitize_model_id(input), input);
}
#[test]
fn sanitize_model_id_leaves_other_chars_untouched() {
// Hyphens, digits, dots, and underscores must all pass through
// unchanged — only '/' is replaced with '_'.
let input = "org_name/model-name_v2.3-alpha";
let got = sanitize_model_id(input);
assert_eq!(got, "org_name_model-name_v2.3-alpha");
assert!(!got.contains('/'), "no slash must remain after sanitize");
assert!(got.contains('-'), "hyphens must be preserved");
assert!(got.contains('.'), "dots must be preserved");
assert!(got.contains('_'), "underscores must be preserved");
}
}