Files
kebab/crates/kebab-rag/tests/streaming_events.rs
altair823 cf35f36f88 feat(rag): fb-41 PR-2 — RagPipeline::ask_multi_hop skeleton (fixed depth=2)
PR-2 of fb-41 multi-hop RAG. Decompose + retrieve + synthesize 3-stage
pipeline가 `opts.multi_hop=true` 일 때 dispatch. Dynamic decide loop
는 PR-3.

- `AskOpts.multi_hop: bool` 필드 추가 + `impl Default for AskOpts`
  도입 (HOTFIXES 2026-05-07 의 known limitation 해소). 9 explicit
  init site 모두 `multi_hop: false` 추가 — Default 도입으로 향후
  `..Default::default()` 점진 migrate 가능.
- `RagPipeline::ask` 의 entry 에 dispatcher 한 줄
  (`if opts.multi_hop { return self.ask_multi_hop(...) }`).
- `RagPipeline::ask_multi_hop` 신규 method. 1) decompose LLM call
  → JSON array of strings parse, 2) 각 sub-query 로 retrieve +
  chunk_id dedup pool, 3) score gate / no-chunks 가드, 4)
  pack_context (single-pass 와 helper 공유), 5) synthesize LLM
  call w/ MULTI_HOP_SYNTHESIZE_SYSTEM_PROMPT, 6) citation extract
  + Answer build. `prompt_template_version` = "rag-multi-hop-v1"
  로 stamp — eval `compare` 가 single-pass vs multi-hop 분리.
- Prompt const 신규: MULTI_HOP_DECOMPOSE_SYSTEM_PROMPT +
  MULTI_HOP_DECOMPOSE_USER_TEMPLATE + MULTI_HOP_SYNTHESIZE_SYSTEM_PROMPT
  + PROMPT_TEMPLATE_VERSION_MULTI_HOP + MULTI_HOP_MAX_SUB_QUERIES_DEFAULT.
- `kebab_core::RefusalReason::MultiHopDecomposeFailed` variant 신규.
  Cascade: kebab-store-sqlite `refusal_reason_label` + kebab-tui `ask
  refusal render` exhaustive match 갱신.
- `parse_decompose_response` + `strip_markdown_json_fence` helper —
  markdown code fence (```json / ```) strip + JSON array of strings
  parse + trim + drop empty + cap at MULTI_HOP_MAX_SUB_QUERIES_DEFAULT.
  None 반환 시 caller 가 `MultiHopDecomposeFailed` refusal.

Tests (55 passing total, 8 신규):
- 6 unit (parse_decompose_response 의 bare array / fence variants /
  garbage / cap / trim 회귀 핀).
- 2 integration: `ask_multi_hop_dispatches_and_decompose_garbage_refuses`
  (decompose garbage → MultiHopDecomposeFailed + 정확히 1 LLM call) +
  `ask_with_multi_hop_false_keeps_single_pass_path` (회귀 핀, 기존
  caller 자동 backwards-compat).

Happy-path multi-hop (decompose 성공 → synthesize) 의 integration
test 는 ScriptedLm helper 가 PR-3 의 decide loop 와 함께 도입될
때 같이 추가. 현 `MockLanguageModel` 는 canned single response 라
2-LLM-call sequence 핀 불가.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 06:45:32 +00:00

219 lines
7.5 KiB
Rust

//! p9-fb-33: pipeline-level streaming behavior — order invariants,
//! cancel propagation, refusal flagging.
mod common;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::sync::mpsc;
use common::{MockRetriever, RagEnv, id32, mk_hit};
use kebab_core::{
FinishReason, LanguageModel, RefusalReason, Retriever, SearchMode, TokenChunk, TokenUsage,
};
use kebab_llm::MockLanguageModel;
use kebab_rag::{AskOpts, RagPipeline, StreamEvent};
const TEST_LM_ID: &str = "mock-lm";
/// Minimal LM mirroring `tests/pipeline.rs::CountingLm` so the
/// streaming-events suite stays self-contained.
struct CountingLm {
inner: MockLanguageModel,
calls: std::sync::atomic::AtomicUsize,
}
impl CountingLm {
fn new(canned: &str) -> Self {
Self {
inner: MockLanguageModel {
model_id: TEST_LM_ID.to_string(),
provider: "mock".to_string(),
context_tokens: 32_768,
canned_response: canned.to_string(),
canned_finish: FinishReason::Stop,
canned_usage: TokenUsage {
prompt_tokens: 10,
completion_tokens: 5,
latency_ms: 7,
},
},
calls: std::sync::atomic::AtomicUsize::new(0),
}
}
}
impl LanguageModel for CountingLm {
fn model_ref(&self) -> kebab_core::ModelRef {
self.inner.model_ref()
}
fn context_tokens(&self) -> usize {
self.inner.context_tokens()
}
fn generate_stream(
&self,
req: kebab_core::GenerateRequest,
) -> anyhow::Result<Box<dyn Iterator<Item = anyhow::Result<TokenChunk>> + Send>> {
self.calls.fetch_add(1, Ordering::SeqCst);
self.inner.generate_stream(req)
}
}
fn opts_with_sink(tx: mpsc::Sender<StreamEvent>) -> AskOpts {
AskOpts {
k: 3,
explain: false,
mode: SearchMode::Lexical,
temperature: Some(0.0),
seed: Some(0),
stream_sink: Some(tx),
history: Vec::new(),
conversation_id: None,
turn_index: None,
multi_hop: false,
}
}
/// Build a pipeline with one seeded chunk + canned LM response so
/// retrieval lands a single hit and the LM emits at least one token.
fn env_with_one_hit(canned: &str) -> (RagEnv, RagPipeline) {
let env = RagEnv::new();
let cid = id32("c1");
let did = id32("d1");
env.seed_chunk(&cid, &did, "notes/a.md", "apples are red.", &["Intro"]);
let hits = vec![mk_hit(1, &cid, &did, "notes/a.md", 0.85, &["Intro"])];
let retriever: Arc<dyn Retriever> = Arc::new(MockRetriever::new(hits));
let lm: Arc<dyn LanguageModel> = Arc::new(CountingLm::new(canned));
let pipeline = RagPipeline::new(env.config.clone(), retriever, lm, env.sqlite.clone());
(env, pipeline)
}
#[test]
fn ask_emits_retrieval_then_tokens_then_final() {
let (_env, pipeline) = env_with_one_hit("apples are red. [#1]");
let (tx, rx) = mpsc::channel::<StreamEvent>();
let _ans = pipeline.ask("apples", opts_with_sink(tx)).unwrap();
let events: Vec<StreamEvent> = rx.iter().collect();
// First event must be RetrievalDone.
assert!(
matches!(events.first(), Some(StreamEvent::RetrievalDone { .. })),
"first event must be RetrievalDone, got {:?}",
events.first()
);
// Last event must be Final.
assert!(
matches!(events.last(), Some(StreamEvent::Final { .. })),
"last event must be Final, got {:?}",
events.last()
);
// Everything in between is Token.
for ev in &events[1..events.len() - 1] {
assert!(
matches!(ev, StreamEvent::Token { .. }),
"middle events must be Token, got {ev:?}"
);
}
}
#[test]
fn ask_records_llm_stream_aborted_when_receiver_drops() {
let (env, pipeline) = env_with_one_hit("apples are red. [#1]");
let (tx, rx) = mpsc::channel::<StreamEvent>();
// Drop the receiver immediately so the first Token send fails.
drop(rx);
let ans = pipeline.ask("apples", opts_with_sink(tx)).unwrap();
assert!(!ans.grounded);
assert_eq!(ans.refusal_reason, Some(RefusalReason::LlmStreamAborted));
// Persistence still happens on cancel — the row is the audit trail.
assert_eq!(env.count_answers(), 1, "answers row written on cancel");
}
/// p9-fb-33 (PR #124 round 1, item 5): pin the "no Final on cancel"
/// invariant. Uses a barrier-gated LM so the test can observe the
/// `RetrievalDone` event before any `Token`/`Final` lands in the
/// channel — then drops `rx` to force SendError on the next `Token`.
/// The pipeline's cancel branch must avoid emitting `Final` and
/// record `RefusalReason::LlmStreamAborted`.
struct BlockingLm {
inner: MockLanguageModel,
/// Pipeline thread waits on this before yielding any token.
/// Test thread releases it after observing `RetrievalDone`.
gate: Arc<std::sync::Barrier>,
}
impl LanguageModel for BlockingLm {
fn model_ref(&self) -> kebab_core::ModelRef {
self.inner.model_ref()
}
fn context_tokens(&self) -> usize {
self.inner.context_tokens()
}
fn generate_stream(
&self,
req: kebab_core::GenerateRequest,
) -> anyhow::Result<Box<dyn Iterator<Item = anyhow::Result<TokenChunk>> + Send>> {
// Block until the test signals — guarantees `RetrievalDone`
// arrives at the receiver before any `Token` is queued.
self.gate.wait();
self.inner.generate_stream(req)
}
}
#[test]
fn ask_emits_no_final_when_cancelled_mid_stream() {
use std::sync::Barrier;
let env = RagEnv::new();
let cid = id32("c1");
let did = id32("d1");
env.seed_chunk(&cid, &did, "notes/a.md", "apples are red.", &["Intro"]);
let hits = vec![mk_hit(1, &cid, &did, "notes/a.md", 0.85, &["Intro"])];
let retriever: Arc<dyn Retriever> = Arc::new(MockRetriever::new(hits));
let gate = Arc::new(Barrier::new(2));
let lm: Arc<dyn LanguageModel> = Arc::new(BlockingLm {
inner: MockLanguageModel {
model_id: TEST_LM_ID.to_string(),
provider: "mock".to_string(),
context_tokens: 32_768,
canned_response: "apples are red. [#1]".to_string(),
canned_finish: FinishReason::Stop,
canned_usage: TokenUsage {
prompt_tokens: 10,
completion_tokens: 5,
latency_ms: 7,
},
},
gate: Arc::clone(&gate),
});
let pipeline = RagPipeline::new(env.config.clone(), retriever, lm, env.sqlite.clone());
let (tx, rx) = mpsc::channel::<StreamEvent>();
let opts = opts_with_sink(tx);
let handle = std::thread::spawn(move || pipeline.ask("apples", opts));
// Receive RetrievalDone first — pipeline emits this before
// calling generate_stream (where the LM blocks on the gate).
let first = rx.recv().expect("RetrievalDone must arrive");
assert!(
matches!(first, StreamEvent::RetrievalDone { .. }),
"first event must be RetrievalDone, got {first:?}",
);
// Drop rx now, BEFORE releasing the gate. Once the LM unblocks
// and the pipeline tries to send the first Token, it'll get
// SendError → cancel branch.
drop(rx);
gate.wait();
let ans = handle.join().expect("ask thread").unwrap();
// Cancel was observed: no Final emitted, refusal recorded.
assert!(!ans.grounded);
assert_eq!(ans.refusal_reason, Some(RefusalReason::LlmStreamAborted));
assert_eq!(env.count_answers(), 1, "answers row written on cancel");
}