From 0e8b800b6b630028b1f95059f1a75e992326292c Mon Sep 17 00:00:00 2001
From: th-kim0823
Date: Sun, 10 May 2026 19:18:36 +0900
Subject: [PATCH] test(rag): integration tests for rag-v1/v2/unknown dispatch
(fb-40)
Co-Authored-By: Claude Opus 4.7 (1M context)
---
.../tests/prompt_template_dispatch.rs | 161 ++++++++++++++++++
1 file changed, 161 insertions(+)
create mode 100644 crates/kebab-rag/tests/prompt_template_dispatch.rs
diff --git a/crates/kebab-rag/tests/prompt_template_dispatch.rs b/crates/kebab-rag/tests/prompt_template_dispatch.rs
new file mode 100644
index 0000000..ba280a4
--- /dev/null
+++ b/crates/kebab-rag/tests/prompt_template_dispatch.rs
@@ -0,0 +1,161 @@
+//! p9-fb-40: integration tests for rag-v1 / rag-v2 / unknown-version dispatch.
+//!
+//! Wraps `MockLanguageModel` in a `CapturingLm` that snapshots
+//! `GenerateRequest::system` on every `generate_stream` call so the
+//! tests can assert which template constant the pipeline rendered.
+
+mod common;
+
+use std::sync::{Arc, Mutex};
+
+use common::{MockRetriever, RagEnv, id32, mk_hit};
+use kebab_core::{FinishReason, LanguageModel, Retriever, SearchMode, TokenChunk, TokenUsage};
+use kebab_llm::MockLanguageModel;
+use kebab_rag::{AskOpts, RagPipeline};
+
+const TEST_LM_ID: &str = "mock-lm";
+
+/// LM wrapper that captures the system prompt of the most-recent
+/// `generate_stream` call, so tests can assert which template was
+/// rendered. Mirrors the `CountingLm` pattern from
+/// `tests/streaming_events.rs` but stores `req.system` instead of a
+/// call counter.
+struct CapturingLm {
+ inner: MockLanguageModel,
+ captured_system: Arc>>,
+}
+
+impl CapturingLm {
+ fn new(captured: Arc>>) -> Self {
+ Self {
+ inner: MockLanguageModel {
+ model_id: TEST_LM_ID.to_string(),
+ provider: "mock".to_string(),
+ context_tokens: 32_768,
+ canned_response: "근거가 충분합니다 [#1]".to_string(),
+ canned_finish: FinishReason::Stop,
+ canned_usage: TokenUsage {
+ prompt_tokens: 10,
+ completion_tokens: 5,
+ latency_ms: 7,
+ },
+ },
+ captured_system: captured,
+ }
+ }
+}
+
+impl LanguageModel for CapturingLm {
+ fn model_ref(&self) -> kebab_core::ModelRef {
+ self.inner.model_ref()
+ }
+ fn context_tokens(&self) -> usize {
+ self.inner.context_tokens()
+ }
+ fn generate_stream(
+ &self,
+ req: kebab_core::GenerateRequest,
+ ) -> anyhow::Result> + Send>> {
+ *self.captured_system.lock().unwrap() = Some(req.system.clone());
+ self.inner.generate_stream(req)
+ }
+}
+
+/// Mirror of `streaming_events::opts_with_sink` minus the sink — every
+/// field is set explicitly because `AskOpts` does not implement `Default`.
+fn lexical_opts() -> AskOpts {
+ AskOpts {
+ k: 3,
+ explain: false,
+ mode: SearchMode::Lexical,
+ temperature: Some(0.0),
+ seed: Some(0),
+ stream_sink: None,
+ history: Vec::new(),
+ conversation_id: None,
+ turn_index: None,
+ }
+}
+
+/// Build a `RagPipeline` with the given `prompt_template_version`.
+/// Returns the pipeline, the captured-system handle, and the env (kept
+/// alive for the test body — drops the SqliteStore + tempdir together).
+fn build_pipeline_with_template(
+ version: &str,
+) -> (RagPipeline, Arc>>, RagEnv) {
+ let mut env = RagEnv::new();
+ env.config.rag.prompt_template_version = version.to_string();
+ // Drop score gate so the seeded hit (fusion_score = 0.9) always
+ // makes it through — the dispatch we want to exercise lives past
+ // the gate.
+ env.config.rag.score_gate = 0.0;
+ let captured = Arc::new(Mutex::new(None));
+ let lm: Arc = Arc::new(CapturingLm::new(captured.clone()));
+ // Seed one chunk so the [근거] block has content and the LM is
+ // actually invoked on the success path.
+ let chunk_id = id32("c");
+ let doc_id = id32("d");
+ env.seed_chunk(&chunk_id, &doc_id, "a.md", "hello world", &["H"]);
+ let hit = mk_hit(1, &chunk_id, &doc_id, "a.md", 0.9, &["H"]);
+ let retriever: Arc = Arc::new(MockRetriever::new(vec![hit]));
+ let pipeline = RagPipeline::new(env.config.clone(), retriever, lm, env.sqlite.clone());
+ (pipeline, captured, env)
+}
+
+#[test]
+fn ask_with_rag_v1_uses_v1_system_prompt() {
+ let (pipeline, captured, _env) = build_pipeline_with_template("rag-v1");
+ let _ = pipeline.ask("hello", lexical_opts());
+ let s = captured
+ .lock()
+ .unwrap()
+ .clone()
+ .expect("system prompt captured");
+ assert!(
+ s.contains("로컬 KB 위에서 동작"),
+ "shared V1/V2 prefix expected, got: {s}"
+ );
+ assert!(
+ !s.contains("학습 지식"),
+ "V1 must NOT contain V2-only 학습 지식 rule, got: {s}"
+ );
+ assert!(
+ !s.contains("확실하지 않다"),
+ "V1 must NOT contain V2-only 확실하지 않다 rule, got: {s}"
+ );
+}
+
+#[test]
+fn ask_with_rag_v2_uses_v2_system_prompt() {
+ let (pipeline, captured, _env) = build_pipeline_with_template("rag-v2");
+ let _ = pipeline.ask("hello", lexical_opts());
+ let s = captured
+ .lock()
+ .unwrap()
+ .clone()
+ .expect("system prompt captured");
+ assert!(
+ s.contains("학습 지식"),
+ "V2 must contain 학습 지식 rule, got: {s}"
+ );
+ assert!(
+ s.contains("확실하지 않다"),
+ "V2 must contain 확실하지 않다 rule, got: {s}"
+ );
+ assert!(
+ s.contains("큰따옴표"),
+ "V2 must contain 큰따옴표 rule, got: {s}"
+ );
+}
+
+#[test]
+fn ask_with_unknown_template_returns_early_error() {
+ let (pipeline, _captured, _env) = build_pipeline_with_template("rag-v99");
+ let result = pipeline.ask("hello", lexical_opts());
+ assert!(result.is_err(), "expected error on unknown version");
+ let msg = format!("{:#}", result.unwrap_err());
+ assert!(
+ msg.contains("rag-v99") && msg.contains("expected"),
+ "expected error to mention version + expected list, got: {msg}"
+ );
+}