diff --git a/crates/kebab-rag/tests/prompt_template_dispatch.rs b/crates/kebab-rag/tests/prompt_template_dispatch.rs new file mode 100644 index 0000000..ba280a4 --- /dev/null +++ b/crates/kebab-rag/tests/prompt_template_dispatch.rs @@ -0,0 +1,161 @@ +//! p9-fb-40: integration tests for rag-v1 / rag-v2 / unknown-version dispatch. +//! +//! Wraps `MockLanguageModel` in a `CapturingLm` that snapshots +//! `GenerateRequest::system` on every `generate_stream` call so the +//! tests can assert which template constant the pipeline rendered. + +mod common; + +use std::sync::{Arc, Mutex}; + +use common::{MockRetriever, RagEnv, id32, mk_hit}; +use kebab_core::{FinishReason, LanguageModel, Retriever, SearchMode, TokenChunk, TokenUsage}; +use kebab_llm::MockLanguageModel; +use kebab_rag::{AskOpts, RagPipeline}; + +const TEST_LM_ID: &str = "mock-lm"; + +/// LM wrapper that captures the system prompt of the most-recent +/// `generate_stream` call, so tests can assert which template was +/// rendered. Mirrors the `CountingLm` pattern from +/// `tests/streaming_events.rs` but stores `req.system` instead of a +/// call counter. +struct CapturingLm { + inner: MockLanguageModel, + captured_system: Arc>>, +} + +impl CapturingLm { + fn new(captured: Arc>>) -> Self { + Self { + inner: MockLanguageModel { + model_id: TEST_LM_ID.to_string(), + provider: "mock".to_string(), + context_tokens: 32_768, + canned_response: "근거가 충분합니다 [#1]".to_string(), + canned_finish: FinishReason::Stop, + canned_usage: TokenUsage { + prompt_tokens: 10, + completion_tokens: 5, + latency_ms: 7, + }, + }, + captured_system: captured, + } + } +} + +impl LanguageModel for CapturingLm { + fn model_ref(&self) -> kebab_core::ModelRef { + self.inner.model_ref() + } + fn context_tokens(&self) -> usize { + self.inner.context_tokens() + } + fn generate_stream( + &self, + req: kebab_core::GenerateRequest, + ) -> anyhow::Result> + Send>> { + *self.captured_system.lock().unwrap() = Some(req.system.clone()); + self.inner.generate_stream(req) + } +} + +/// Mirror of `streaming_events::opts_with_sink` minus the sink — every +/// field is set explicitly because `AskOpts` does not implement `Default`. +fn lexical_opts() -> AskOpts { + AskOpts { + k: 3, + explain: false, + mode: SearchMode::Lexical, + temperature: Some(0.0), + seed: Some(0), + stream_sink: None, + history: Vec::new(), + conversation_id: None, + turn_index: None, + } +} + +/// Build a `RagPipeline` with the given `prompt_template_version`. +/// Returns the pipeline, the captured-system handle, and the env (kept +/// alive for the test body — drops the SqliteStore + tempdir together). +fn build_pipeline_with_template( + version: &str, +) -> (RagPipeline, Arc>>, RagEnv) { + let mut env = RagEnv::new(); + env.config.rag.prompt_template_version = version.to_string(); + // Drop score gate so the seeded hit (fusion_score = 0.9) always + // makes it through — the dispatch we want to exercise lives past + // the gate. + env.config.rag.score_gate = 0.0; + let captured = Arc::new(Mutex::new(None)); + let lm: Arc = Arc::new(CapturingLm::new(captured.clone())); + // Seed one chunk so the [근거] block has content and the LM is + // actually invoked on the success path. + let chunk_id = id32("c"); + let doc_id = id32("d"); + env.seed_chunk(&chunk_id, &doc_id, "a.md", "hello world", &["H"]); + let hit = mk_hit(1, &chunk_id, &doc_id, "a.md", 0.9, &["H"]); + let retriever: Arc = Arc::new(MockRetriever::new(vec![hit])); + let pipeline = RagPipeline::new(env.config.clone(), retriever, lm, env.sqlite.clone()); + (pipeline, captured, env) +} + +#[test] +fn ask_with_rag_v1_uses_v1_system_prompt() { + let (pipeline, captured, _env) = build_pipeline_with_template("rag-v1"); + let _ = pipeline.ask("hello", lexical_opts()); + let s = captured + .lock() + .unwrap() + .clone() + .expect("system prompt captured"); + assert!( + s.contains("로컬 KB 위에서 동작"), + "shared V1/V2 prefix expected, got: {s}" + ); + assert!( + !s.contains("학습 지식"), + "V1 must NOT contain V2-only 학습 지식 rule, got: {s}" + ); + assert!( + !s.contains("확실하지 않다"), + "V1 must NOT contain V2-only 확실하지 않다 rule, got: {s}" + ); +} + +#[test] +fn ask_with_rag_v2_uses_v2_system_prompt() { + let (pipeline, captured, _env) = build_pipeline_with_template("rag-v2"); + let _ = pipeline.ask("hello", lexical_opts()); + let s = captured + .lock() + .unwrap() + .clone() + .expect("system prompt captured"); + assert!( + s.contains("학습 지식"), + "V2 must contain 학습 지식 rule, got: {s}" + ); + assert!( + s.contains("확실하지 않다"), + "V2 must contain 확실하지 않다 rule, got: {s}" + ); + assert!( + s.contains("큰따옴표"), + "V2 must contain 큰따옴표 rule, got: {s}" + ); +} + +#[test] +fn ask_with_unknown_template_returns_early_error() { + let (pipeline, _captured, _env) = build_pipeline_with_template("rag-v99"); + let result = pipeline.ask("hello", lexical_opts()); + assert!(result.is_err(), "expected error on unknown version"); + let msg = format!("{:#}", result.unwrap_err()); + assert!( + msg.contains("rag-v99") && msg.contains("expected"), + "expected error to mention version + expected list, got: {msg}" + ); +}