feat(rag): pipeline emits StreamEvent + cancel on SendError (fb-33)
RetrievalDone after retrieve+stale-stamp, Token per LM chunk (SendError → break, FinishReason::Cancelled, RefusalReason:: LlmStreamAborted), Final on success. answers row still persists on cancel for audit. Adds FinishReason::Cancelled, re-exports StreamEvent from kebab_rag, migrates two pre-fb-33 sink tests in tests/pipeline.rs to the new StreamEvent type (the "dropped receiver does not abort" test inverts to record cancel). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -98,6 +98,11 @@ pub enum FinishReason {
|
||||
Stop,
|
||||
Length,
|
||||
Aborted,
|
||||
/// p9-fb-33: caller-side cancel. The pipeline breaks the LM loop
|
||||
/// when a `Token` send into `AskOpts.stream_sink` returns
|
||||
/// `SendError` (receiver dropped). The persisted answer is
|
||||
/// flagged with `RefusalReason::LlmStreamAborted`.
|
||||
Cancelled,
|
||||
Error(String),
|
||||
}
|
||||
|
||||
|
||||
@@ -22,4 +22,4 @@ pub use kebab_core::{Answer, AnswerCitation, AnswerRetrievalSummary, RefusalReas
|
||||
|
||||
mod pipeline;
|
||||
|
||||
pub use pipeline::{AskOpts, RagPipeline};
|
||||
pub use pipeline::{AskOpts, RagPipeline, StreamEvent};
|
||||
|
||||
@@ -83,6 +83,12 @@ type PackedContext = (String, Vec<PackedCitation>, usize);
|
||||
/// `RefusalReason::LlmStreamAborted`).
|
||||
#[derive(Clone, Debug, serde::Serialize)]
|
||||
#[serde(tag = "kind", rename_all = "snake_case")]
|
||||
// `Final.answer` carries a full `Answer` (~320B) and is the largest
|
||||
// variant; `Token` is the hot path. Size mismatch is unavoidable
|
||||
// without boxing the wire-shape, which would force every consumer
|
||||
// (TUI / CLI / future MCP) to deref. The sink is short-lived (one
|
||||
// per ask) so the per-event overhead is not material.
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum StreamEvent {
|
||||
RetrievalDone {
|
||||
hits: Vec<SearchHit>,
|
||||
@@ -231,6 +237,16 @@ impl RagPipeline {
|
||||
for h in &mut hits {
|
||||
h.stale = compute_stale(h.indexed_at, now, stale_threshold_days);
|
||||
}
|
||||
// p9-fb-33: emit retrieval_done as soon as the hit list is
|
||||
// ready (post stale-stamp so consumers see the same `stale`
|
||||
// values the App-level wire path emits). Cancel is best-effort
|
||||
// here — if the caller already dropped the receiver we just
|
||||
// skip and let the LLM-loop SendError handle it consistently.
|
||||
if let Some(sink) = &opts.stream_sink {
|
||||
let _ = sink.send(StreamEvent::RetrievalDone {
|
||||
hits: hits.clone(),
|
||||
});
|
||||
}
|
||||
let chunks_returned = u32::try_from(hits.len()).unwrap_or(u32::MAX);
|
||||
let top_score = hits.first().map(|h| h.retrieval.fusion_score).unwrap_or(0.0);
|
||||
|
||||
@@ -329,16 +345,28 @@ impl RagPipeline {
|
||||
.llm
|
||||
.generate_stream(req)
|
||||
.context("kb-rag: llm.generate_stream")?;
|
||||
let mut cancelled = false;
|
||||
for item in stream {
|
||||
let chunk = item.context("kb-rag: stream item")?;
|
||||
match chunk {
|
||||
TokenChunk::Token(t) => {
|
||||
acc.push_str(&t);
|
||||
if let Some(sink) = &opts.stream_sink {
|
||||
// SendError silently dropped — caller cancelled but the
|
||||
// pipeline still drives generation to completion so the
|
||||
// `answers` row gets a faithful record.
|
||||
let _ = sink.send(t);
|
||||
// p9-fb-33: SendError → caller dropped the
|
||||
// receiver (probably a closed stdout downstream).
|
||||
// Stop generation, mark the answer cancelled so
|
||||
// the persistence path records refusal_reason =
|
||||
// LlmStreamAborted.
|
||||
if sink
|
||||
.send(StreamEvent::Token {
|
||||
delta: t,
|
||||
turn_index: opts.turn_index,
|
||||
})
|
||||
.is_err()
|
||||
{
|
||||
cancelled = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
TokenChunk::Done {
|
||||
@@ -351,6 +379,9 @@ impl RagPipeline {
|
||||
}
|
||||
}
|
||||
}
|
||||
if cancelled {
|
||||
finish_reason = FinishReason::Cancelled;
|
||||
}
|
||||
|
||||
// ── 6. Citation extract ────────────────────────────────────────────
|
||||
let extracted: Vec<u32> = extract_markers(&acc);
|
||||
@@ -375,15 +406,20 @@ impl RagPipeline {
|
||||
});
|
||||
let trimmed_answer = acc.trim();
|
||||
let matched_refusal_phrase = refusal_phrase.is_match(&acc);
|
||||
let grounded = !trimmed_answer.is_empty()
|
||||
let grounded_unaware = !trimmed_answer.is_empty()
|
||||
&& unknown_markers.is_empty()
|
||||
&& !extracted.is_empty();
|
||||
let refusal_reason = if grounded {
|
||||
None
|
||||
// p9-fb-33: cancel takes priority over LlmSelfJudge — the
|
||||
// caller bailed mid-stream, so the recorded reason should
|
||||
// reflect that, not "model didn't cite".
|
||||
let (grounded, refusal_reason) = if matches!(finish_reason, FinishReason::Cancelled) {
|
||||
(false, Some(RefusalReason::LlmStreamAborted))
|
||||
} else if grounded_unaware {
|
||||
(true, None)
|
||||
} else {
|
||||
// Spec §7: empty answer, unknown markers, silent ungrounded,
|
||||
// and explicit "근거가 부족" all collapse to LlmSelfJudge.
|
||||
Some(RefusalReason::LlmSelfJudge)
|
||||
(false, Some(RefusalReason::LlmSelfJudge))
|
||||
};
|
||||
|
||||
// ── 8. Build Answer ────────────────────────────────────────────────
|
||||
@@ -461,6 +497,17 @@ impl RagPipeline {
|
||||
"kb-rag: ask done"
|
||||
);
|
||||
|
||||
// p9-fb-33: emit final on the success path. On cancel we
|
||||
// skip Final — the receiver is gone and persistence still
|
||||
// records the partial answer below.
|
||||
if !cancelled
|
||||
&& let Some(sink) = &opts.stream_sink
|
||||
{
|
||||
let _ = sink.send(StreamEvent::Final {
|
||||
answer: answer.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
// ── 9. Persist ─────────────────────────────────────────────────────
|
||||
let packed_chunks_json = if opts.explain {
|
||||
// Snapshot the packed entries as a portable list of objects so
|
||||
|
||||
@@ -14,7 +14,7 @@ use kebab_core::{
|
||||
FinishReason, LanguageModel, Retriever, SearchMode, TokenChunk, TokenUsage,
|
||||
};
|
||||
use kebab_llm::MockLanguageModel;
|
||||
use kebab_rag::{AskOpts, RagPipeline, RefusalReason};
|
||||
use kebab_rag::{AskOpts, RagPipeline, RefusalReason, StreamEvent};
|
||||
|
||||
/// LM ID used everywhere — kept short so snapshots stay stable.
|
||||
const TEST_LM_ID: &str = "mock-lm";
|
||||
@@ -270,18 +270,32 @@ fn streaming_forwards_tokens_to_sink() {
|
||||
let lm: Arc<dyn LanguageModel> = Arc::new(CountingLm::new(canned));
|
||||
let pipeline = RagPipeline::new(env.config.clone(), retriever, lm, env.sqlite.clone());
|
||||
|
||||
let (tx, rx) = std::sync::mpsc::channel::<String>();
|
||||
let (tx, rx) = std::sync::mpsc::channel::<StreamEvent>();
|
||||
let mut opts = default_opts();
|
||||
opts.stream_sink = Some(tx);
|
||||
let _ = pipeline.ask("q", opts).unwrap();
|
||||
let collected: String = rx.into_iter().collect::<Vec<_>>().join("");
|
||||
// p9-fb-33: extract Token deltas from the staged event stream.
|
||||
let collected: String = rx
|
||||
.into_iter()
|
||||
.filter_map(|ev| match ev {
|
||||
StreamEvent::Token { delta, .. } => Some(delta),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
assert_eq!(collected, canned);
|
||||
}
|
||||
|
||||
// ── 10. dropped receiver does NOT abort generation ────────────────────────
|
||||
// ── 10. dropped receiver aborts generation, records LlmStreamAborted ──────
|
||||
//
|
||||
// p9-fb-33: cancel semantics changed. Pre-fb-33 the pipeline drove
|
||||
// the LM loop to completion and silently dropped sends. Now a
|
||||
// SendError breaks the loop and stamps `RefusalReason::LlmStreamAborted`
|
||||
// onto the persisted row — the partial answer (whatever was buffered
|
||||
// before the cancel) still gets written for audit.
|
||||
|
||||
#[test]
|
||||
fn dropped_receiver_does_not_abort_generation() {
|
||||
fn dropped_receiver_aborts_with_llm_stream_aborted() {
|
||||
let env = RagEnv::new();
|
||||
let cid = id32("c1");
|
||||
let did = id32("d1");
|
||||
@@ -292,13 +306,17 @@ fn dropped_receiver_does_not_abort_generation() {
|
||||
let lm: Arc<dyn LanguageModel> = Arc::new(CountingLm::new(canned));
|
||||
let pipeline = RagPipeline::new(env.config.clone(), retriever, lm, env.sqlite.clone());
|
||||
|
||||
let (tx, rx) = std::sync::mpsc::channel::<String>();
|
||||
drop(rx); // receiver gone — every send fails silently
|
||||
let (tx, rx) = std::sync::mpsc::channel::<StreamEvent>();
|
||||
drop(rx); // receiver gone — first Token send fails, loop breaks
|
||||
let mut opts = default_opts();
|
||||
opts.stream_sink = Some(tx);
|
||||
let answer = pipeline.ask("q", opts).unwrap();
|
||||
assert_eq!(answer.answer, canned, "generation completes despite dead sink");
|
||||
assert!(answer.grounded);
|
||||
assert!(!answer.grounded, "cancel takes priority over grounded");
|
||||
assert_eq!(
|
||||
answer.refusal_reason,
|
||||
Some(RefusalReason::LlmStreamAborted),
|
||||
"cancel records LlmStreamAborted",
|
||||
);
|
||||
assert_eq!(env.count_answers(), 1, "answers row still persisted");
|
||||
}
|
||||
|
||||
|
||||
131
crates/kebab-rag/tests/streaming_events.rs
Normal file
131
crates/kebab-rag/tests/streaming_events.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
//! p9-fb-33: pipeline-level streaming behavior — order invariants,
|
||||
//! cancel propagation, refusal flagging.
|
||||
|
||||
mod common;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::mpsc;
|
||||
|
||||
use common::{MockRetriever, RagEnv, id32, mk_hit};
|
||||
use kebab_core::{
|
||||
FinishReason, LanguageModel, RefusalReason, Retriever, SearchMode, TokenChunk, TokenUsage,
|
||||
};
|
||||
use kebab_llm::MockLanguageModel;
|
||||
use kebab_rag::{AskOpts, RagPipeline, StreamEvent};
|
||||
|
||||
const TEST_LM_ID: &str = "mock-lm";
|
||||
|
||||
/// Minimal LM mirroring `tests/pipeline.rs::CountingLm` so the
|
||||
/// streaming-events suite stays self-contained.
|
||||
struct CountingLm {
|
||||
inner: MockLanguageModel,
|
||||
calls: std::sync::atomic::AtomicUsize,
|
||||
}
|
||||
|
||||
impl CountingLm {
|
||||
fn new(canned: &str) -> Self {
|
||||
Self {
|
||||
inner: MockLanguageModel {
|
||||
model_id: TEST_LM_ID.to_string(),
|
||||
provider: "mock".to_string(),
|
||||
context_tokens: 32_768,
|
||||
canned_response: canned.to_string(),
|
||||
canned_finish: FinishReason::Stop,
|
||||
canned_usage: TokenUsage {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
latency_ms: 7,
|
||||
},
|
||||
},
|
||||
calls: std::sync::atomic::AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for CountingLm {
|
||||
fn model_ref(&self) -> kebab_core::ModelRef {
|
||||
self.inner.model_ref()
|
||||
}
|
||||
fn context_tokens(&self) -> usize {
|
||||
self.inner.context_tokens()
|
||||
}
|
||||
fn generate_stream(
|
||||
&self,
|
||||
req: kebab_core::GenerateRequest,
|
||||
) -> anyhow::Result<Box<dyn Iterator<Item = anyhow::Result<TokenChunk>> + Send>> {
|
||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
self.inner.generate_stream(req)
|
||||
}
|
||||
}
|
||||
|
||||
fn opts_with_sink(tx: mpsc::Sender<StreamEvent>) -> AskOpts {
|
||||
AskOpts {
|
||||
k: 3,
|
||||
explain: false,
|
||||
mode: SearchMode::Lexical,
|
||||
temperature: Some(0.0),
|
||||
seed: Some(0),
|
||||
stream_sink: Some(tx),
|
||||
history: Vec::new(),
|
||||
conversation_id: None,
|
||||
turn_index: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a pipeline with one seeded chunk + canned LM response so
|
||||
/// retrieval lands a single hit and the LM emits at least one token.
|
||||
fn env_with_one_hit(canned: &str) -> (RagEnv, RagPipeline) {
|
||||
let env = RagEnv::new();
|
||||
let cid = id32("c1");
|
||||
let did = id32("d1");
|
||||
env.seed_chunk(&cid, &did, "notes/a.md", "apples are red.", &["Intro"]);
|
||||
let hits = vec![mk_hit(1, &cid, &did, "notes/a.md", 0.85, &["Intro"])];
|
||||
let retriever: Arc<dyn Retriever> = Arc::new(MockRetriever::new(hits));
|
||||
let lm: Arc<dyn LanguageModel> = Arc::new(CountingLm::new(canned));
|
||||
let pipeline = RagPipeline::new(env.config.clone(), retriever, lm, env.sqlite.clone());
|
||||
(env, pipeline)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ask_emits_retrieval_then_tokens_then_final() {
|
||||
let (_env, pipeline) = env_with_one_hit("apples are red. [#1]");
|
||||
let (tx, rx) = mpsc::channel::<StreamEvent>();
|
||||
let _ans = pipeline.ask("apples", opts_with_sink(tx)).unwrap();
|
||||
let events: Vec<StreamEvent> = rx.iter().collect();
|
||||
|
||||
// First event must be RetrievalDone.
|
||||
assert!(
|
||||
matches!(events.first(), Some(StreamEvent::RetrievalDone { .. })),
|
||||
"first event must be RetrievalDone, got {:?}",
|
||||
events.first()
|
||||
);
|
||||
|
||||
// Last event must be Final.
|
||||
assert!(
|
||||
matches!(events.last(), Some(StreamEvent::Final { .. })),
|
||||
"last event must be Final, got {:?}",
|
||||
events.last()
|
||||
);
|
||||
|
||||
// Everything in between is Token.
|
||||
for ev in &events[1..events.len() - 1] {
|
||||
assert!(
|
||||
matches!(ev, StreamEvent::Token { .. }),
|
||||
"middle events must be Token, got {ev:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ask_records_llm_stream_aborted_when_receiver_drops() {
|
||||
let (env, pipeline) = env_with_one_hit("apples are red. [#1]");
|
||||
let (tx, rx) = mpsc::channel::<StreamEvent>();
|
||||
// Drop the receiver immediately so the first Token send fails.
|
||||
drop(rx);
|
||||
let ans = pipeline.ask("apples", opts_with_sink(tx)).unwrap();
|
||||
assert!(!ans.grounded);
|
||||
assert_eq!(ans.refusal_reason, Some(RefusalReason::LlmStreamAborted));
|
||||
// Persistence still happens on cancel — the row is the audit trail.
|
||||
assert_eq!(env.count_answers(), 1, "answers row written on cancel");
|
||||
}
|
||||
Reference in New Issue
Block a user