diff --git a/crates/kebab-core/src/traits.rs b/crates/kebab-core/src/traits.rs index 2c48411..bb4d0c3 100644 --- a/crates/kebab-core/src/traits.rs +++ b/crates/kebab-core/src/traits.rs @@ -98,6 +98,11 @@ pub enum FinishReason { Stop, Length, Aborted, + /// p9-fb-33: caller-side cancel. The pipeline breaks the LM loop + /// when a `Token` send into `AskOpts.stream_sink` returns + /// `SendError` (receiver dropped). The persisted answer is + /// flagged with `RefusalReason::LlmStreamAborted`. + Cancelled, Error(String), } diff --git a/crates/kebab-rag/src/lib.rs b/crates/kebab-rag/src/lib.rs index a883dae..e527dc0 100644 --- a/crates/kebab-rag/src/lib.rs +++ b/crates/kebab-rag/src/lib.rs @@ -22,4 +22,4 @@ pub use kebab_core::{Answer, AnswerCitation, AnswerRetrievalSummary, RefusalReas mod pipeline; -pub use pipeline::{AskOpts, RagPipeline}; +pub use pipeline::{AskOpts, RagPipeline, StreamEvent}; diff --git a/crates/kebab-rag/src/pipeline.rs b/crates/kebab-rag/src/pipeline.rs index 3f1024b..21e7844 100644 --- a/crates/kebab-rag/src/pipeline.rs +++ b/crates/kebab-rag/src/pipeline.rs @@ -83,6 +83,12 @@ type PackedContext = (String, Vec, usize); /// `RefusalReason::LlmStreamAborted`). #[derive(Clone, Debug, serde::Serialize)] #[serde(tag = "kind", rename_all = "snake_case")] +// `Final.answer` carries a full `Answer` (~320B) and is the largest +// variant; `Token` is the hot path. Size mismatch is unavoidable +// without boxing the wire-shape, which would force every consumer +// (TUI / CLI / future MCP) to deref. The sink is short-lived (one +// per ask) so the per-event overhead is not material. +#[allow(clippy::large_enum_variant)] pub enum StreamEvent { RetrievalDone { hits: Vec, @@ -231,6 +237,16 @@ impl RagPipeline { for h in &mut hits { h.stale = compute_stale(h.indexed_at, now, stale_threshold_days); } + // p9-fb-33: emit retrieval_done as soon as the hit list is + // ready (post stale-stamp so consumers see the same `stale` + // values the App-level wire path emits). Cancel is best-effort + // here — if the caller already dropped the receiver we just + // skip and let the LLM-loop SendError handle it consistently. + if let Some(sink) = &opts.stream_sink { + let _ = sink.send(StreamEvent::RetrievalDone { + hits: hits.clone(), + }); + } let chunks_returned = u32::try_from(hits.len()).unwrap_or(u32::MAX); let top_score = hits.first().map(|h| h.retrieval.fusion_score).unwrap_or(0.0); @@ -329,16 +345,28 @@ impl RagPipeline { .llm .generate_stream(req) .context("kb-rag: llm.generate_stream")?; + let mut cancelled = false; for item in stream { let chunk = item.context("kb-rag: stream item")?; match chunk { TokenChunk::Token(t) => { acc.push_str(&t); if let Some(sink) = &opts.stream_sink { - // SendError silently dropped — caller cancelled but the - // pipeline still drives generation to completion so the - // `answers` row gets a faithful record. - let _ = sink.send(t); + // p9-fb-33: SendError → caller dropped the + // receiver (probably a closed stdout downstream). + // Stop generation, mark the answer cancelled so + // the persistence path records refusal_reason = + // LlmStreamAborted. + if sink + .send(StreamEvent::Token { + delta: t, + turn_index: opts.turn_index, + }) + .is_err() + { + cancelled = true; + break; + } } } TokenChunk::Done { @@ -351,6 +379,9 @@ impl RagPipeline { } } } + if cancelled { + finish_reason = FinishReason::Cancelled; + } // ── 6. Citation extract ──────────────────────────────────────────── let extracted: Vec = extract_markers(&acc); @@ -375,15 +406,20 @@ impl RagPipeline { }); let trimmed_answer = acc.trim(); let matched_refusal_phrase = refusal_phrase.is_match(&acc); - let grounded = !trimmed_answer.is_empty() + let grounded_unaware = !trimmed_answer.is_empty() && unknown_markers.is_empty() && !extracted.is_empty(); - let refusal_reason = if grounded { - None + // p9-fb-33: cancel takes priority over LlmSelfJudge — the + // caller bailed mid-stream, so the recorded reason should + // reflect that, not "model didn't cite". + let (grounded, refusal_reason) = if matches!(finish_reason, FinishReason::Cancelled) { + (false, Some(RefusalReason::LlmStreamAborted)) + } else if grounded_unaware { + (true, None) } else { // Spec §7: empty answer, unknown markers, silent ungrounded, // and explicit "근거가 부족" all collapse to LlmSelfJudge. - Some(RefusalReason::LlmSelfJudge) + (false, Some(RefusalReason::LlmSelfJudge)) }; // ── 8. Build Answer ──────────────────────────────────────────────── @@ -461,6 +497,17 @@ impl RagPipeline { "kb-rag: ask done" ); + // p9-fb-33: emit final on the success path. On cancel we + // skip Final — the receiver is gone and persistence still + // records the partial answer below. + if !cancelled + && let Some(sink) = &opts.stream_sink + { + let _ = sink.send(StreamEvent::Final { + answer: answer.clone(), + }); + } + // ── 9. Persist ───────────────────────────────────────────────────── let packed_chunks_json = if opts.explain { // Snapshot the packed entries as a portable list of objects so diff --git a/crates/kebab-rag/tests/pipeline.rs b/crates/kebab-rag/tests/pipeline.rs index 9dc9bd4..875e9d6 100644 --- a/crates/kebab-rag/tests/pipeline.rs +++ b/crates/kebab-rag/tests/pipeline.rs @@ -14,7 +14,7 @@ use kebab_core::{ FinishReason, LanguageModel, Retriever, SearchMode, TokenChunk, TokenUsage, }; use kebab_llm::MockLanguageModel; -use kebab_rag::{AskOpts, RagPipeline, RefusalReason}; +use kebab_rag::{AskOpts, RagPipeline, RefusalReason, StreamEvent}; /// LM ID used everywhere — kept short so snapshots stay stable. const TEST_LM_ID: &str = "mock-lm"; @@ -270,18 +270,32 @@ fn streaming_forwards_tokens_to_sink() { let lm: Arc = Arc::new(CountingLm::new(canned)); let pipeline = RagPipeline::new(env.config.clone(), retriever, lm, env.sqlite.clone()); - let (tx, rx) = std::sync::mpsc::channel::(); + let (tx, rx) = std::sync::mpsc::channel::(); let mut opts = default_opts(); opts.stream_sink = Some(tx); let _ = pipeline.ask("q", opts).unwrap(); - let collected: String = rx.into_iter().collect::>().join(""); + // p9-fb-33: extract Token deltas from the staged event stream. + let collected: String = rx + .into_iter() + .filter_map(|ev| match ev { + StreamEvent::Token { delta, .. } => Some(delta), + _ => None, + }) + .collect::>() + .join(""); assert_eq!(collected, canned); } -// ── 10. dropped receiver does NOT abort generation ──────────────────────── +// ── 10. dropped receiver aborts generation, records LlmStreamAborted ────── +// +// p9-fb-33: cancel semantics changed. Pre-fb-33 the pipeline drove +// the LM loop to completion and silently dropped sends. Now a +// SendError breaks the loop and stamps `RefusalReason::LlmStreamAborted` +// onto the persisted row — the partial answer (whatever was buffered +// before the cancel) still gets written for audit. #[test] -fn dropped_receiver_does_not_abort_generation() { +fn dropped_receiver_aborts_with_llm_stream_aborted() { let env = RagEnv::new(); let cid = id32("c1"); let did = id32("d1"); @@ -292,13 +306,17 @@ fn dropped_receiver_does_not_abort_generation() { let lm: Arc = Arc::new(CountingLm::new(canned)); let pipeline = RagPipeline::new(env.config.clone(), retriever, lm, env.sqlite.clone()); - let (tx, rx) = std::sync::mpsc::channel::(); - drop(rx); // receiver gone — every send fails silently + let (tx, rx) = std::sync::mpsc::channel::(); + drop(rx); // receiver gone — first Token send fails, loop breaks let mut opts = default_opts(); opts.stream_sink = Some(tx); let answer = pipeline.ask("q", opts).unwrap(); - assert_eq!(answer.answer, canned, "generation completes despite dead sink"); - assert!(answer.grounded); + assert!(!answer.grounded, "cancel takes priority over grounded"); + assert_eq!( + answer.refusal_reason, + Some(RefusalReason::LlmStreamAborted), + "cancel records LlmStreamAborted", + ); assert_eq!(env.count_answers(), 1, "answers row still persisted"); } diff --git a/crates/kebab-rag/tests/streaming_events.rs b/crates/kebab-rag/tests/streaming_events.rs new file mode 100644 index 0000000..05be1b0 --- /dev/null +++ b/crates/kebab-rag/tests/streaming_events.rs @@ -0,0 +1,131 @@ +//! p9-fb-33: pipeline-level streaming behavior — order invariants, +//! cancel propagation, refusal flagging. + +mod common; + +use std::sync::Arc; +use std::sync::atomic::Ordering; +use std::sync::mpsc; + +use common::{MockRetriever, RagEnv, id32, mk_hit}; +use kebab_core::{ + FinishReason, LanguageModel, RefusalReason, Retriever, SearchMode, TokenChunk, TokenUsage, +}; +use kebab_llm::MockLanguageModel; +use kebab_rag::{AskOpts, RagPipeline, StreamEvent}; + +const TEST_LM_ID: &str = "mock-lm"; + +/// Minimal LM mirroring `tests/pipeline.rs::CountingLm` so the +/// streaming-events suite stays self-contained. +struct CountingLm { + inner: MockLanguageModel, + calls: std::sync::atomic::AtomicUsize, +} + +impl CountingLm { + fn new(canned: &str) -> Self { + Self { + inner: MockLanguageModel { + model_id: TEST_LM_ID.to_string(), + provider: "mock".to_string(), + context_tokens: 32_768, + canned_response: canned.to_string(), + canned_finish: FinishReason::Stop, + canned_usage: TokenUsage { + prompt_tokens: 10, + completion_tokens: 5, + latency_ms: 7, + }, + }, + calls: std::sync::atomic::AtomicUsize::new(0), + } + } +} + +impl LanguageModel for CountingLm { + fn model_ref(&self) -> kebab_core::ModelRef { + self.inner.model_ref() + } + fn context_tokens(&self) -> usize { + self.inner.context_tokens() + } + fn generate_stream( + &self, + req: kebab_core::GenerateRequest, + ) -> anyhow::Result> + Send>> { + self.calls.fetch_add(1, Ordering::SeqCst); + self.inner.generate_stream(req) + } +} + +fn opts_with_sink(tx: mpsc::Sender) -> AskOpts { + AskOpts { + k: 3, + explain: false, + mode: SearchMode::Lexical, + temperature: Some(0.0), + seed: Some(0), + stream_sink: Some(tx), + history: Vec::new(), + conversation_id: None, + turn_index: None, + } +} + +/// Build a pipeline with one seeded chunk + canned LM response so +/// retrieval lands a single hit and the LM emits at least one token. +fn env_with_one_hit(canned: &str) -> (RagEnv, RagPipeline) { + let env = RagEnv::new(); + let cid = id32("c1"); + let did = id32("d1"); + env.seed_chunk(&cid, &did, "notes/a.md", "apples are red.", &["Intro"]); + let hits = vec![mk_hit(1, &cid, &did, "notes/a.md", 0.85, &["Intro"])]; + let retriever: Arc = Arc::new(MockRetriever::new(hits)); + let lm: Arc = Arc::new(CountingLm::new(canned)); + let pipeline = RagPipeline::new(env.config.clone(), retriever, lm, env.sqlite.clone()); + (env, pipeline) +} + +#[test] +fn ask_emits_retrieval_then_tokens_then_final() { + let (_env, pipeline) = env_with_one_hit("apples are red. [#1]"); + let (tx, rx) = mpsc::channel::(); + let _ans = pipeline.ask("apples", opts_with_sink(tx)).unwrap(); + let events: Vec = rx.iter().collect(); + + // First event must be RetrievalDone. + assert!( + matches!(events.first(), Some(StreamEvent::RetrievalDone { .. })), + "first event must be RetrievalDone, got {:?}", + events.first() + ); + + // Last event must be Final. + assert!( + matches!(events.last(), Some(StreamEvent::Final { .. })), + "last event must be Final, got {:?}", + events.last() + ); + + // Everything in between is Token. + for ev in &events[1..events.len() - 1] { + assert!( + matches!(ev, StreamEvent::Token { .. }), + "middle events must be Token, got {ev:?}" + ); + } +} + +#[test] +fn ask_records_llm_stream_aborted_when_receiver_drops() { + let (env, pipeline) = env_with_one_hit("apples are red. [#1]"); + let (tx, rx) = mpsc::channel::(); + // Drop the receiver immediately so the first Token send fails. + drop(rx); + let ans = pipeline.ask("apples", opts_with_sink(tx)).unwrap(); + assert!(!ans.grounded); + assert_eq!(ans.refusal_reason, Some(RefusalReason::LlmStreamAborted)); + // Persistence still happens on cancel — the row is the audit trail. + assert_eq!(env.count_answers(), 1, "answers row written on cancel"); +}