feat(rag): pipeline emits StreamEvent + cancel on SendError (fb-33)

RetrievalDone after retrieve+stale-stamp, Token per LM chunk
(SendError → break, FinishReason::Cancelled, RefusalReason::
LlmStreamAborted), Final on success. answers row still persists
on cancel for audit. Adds FinishReason::Cancelled, re-exports
StreamEvent from kebab_rag, migrates two pre-fb-33 sink tests
in tests/pipeline.rs to the new StreamEvent type (the
"dropped receiver does not abort" test inverts to record cancel).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
th-kim0823
2026-05-09 14:49:55 +09:00
parent 31475f0312
commit 307fd8d527
5 changed files with 219 additions and 18 deletions

View File

@@ -98,6 +98,11 @@ pub enum FinishReason {
Stop,
Length,
Aborted,
/// p9-fb-33: caller-side cancel. The pipeline breaks the LM loop
/// when a `Token` send into `AskOpts.stream_sink` returns
/// `SendError` (receiver dropped). The persisted answer is
/// flagged with `RefusalReason::LlmStreamAborted`.
Cancelled,
Error(String),
}

View File

@@ -22,4 +22,4 @@ pub use kebab_core::{Answer, AnswerCitation, AnswerRetrievalSummary, RefusalReas
mod pipeline;
pub use pipeline::{AskOpts, RagPipeline};
pub use pipeline::{AskOpts, RagPipeline, StreamEvent};

View File

@@ -83,6 +83,12 @@ type PackedContext = (String, Vec<PackedCitation>, usize);
/// `RefusalReason::LlmStreamAborted`).
#[derive(Clone, Debug, serde::Serialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
// `Final.answer` carries a full `Answer` (~320B) and is the largest
// variant; `Token` is the hot path. Size mismatch is unavoidable
// without boxing the wire-shape, which would force every consumer
// (TUI / CLI / future MCP) to deref. The sink is short-lived (one
// per ask) so the per-event overhead is not material.
#[allow(clippy::large_enum_variant)]
pub enum StreamEvent {
RetrievalDone {
hits: Vec<SearchHit>,
@@ -231,6 +237,16 @@ impl RagPipeline {
for h in &mut hits {
h.stale = compute_stale(h.indexed_at, now, stale_threshold_days);
}
// p9-fb-33: emit retrieval_done as soon as the hit list is
// ready (post stale-stamp so consumers see the same `stale`
// values the App-level wire path emits). Cancel is best-effort
// here — if the caller already dropped the receiver we just
// skip and let the LLM-loop SendError handle it consistently.
if let Some(sink) = &opts.stream_sink {
let _ = sink.send(StreamEvent::RetrievalDone {
hits: hits.clone(),
});
}
let chunks_returned = u32::try_from(hits.len()).unwrap_or(u32::MAX);
let top_score = hits.first().map(|h| h.retrieval.fusion_score).unwrap_or(0.0);
@@ -329,16 +345,28 @@ impl RagPipeline {
.llm
.generate_stream(req)
.context("kb-rag: llm.generate_stream")?;
let mut cancelled = false;
for item in stream {
let chunk = item.context("kb-rag: stream item")?;
match chunk {
TokenChunk::Token(t) => {
acc.push_str(&t);
if let Some(sink) = &opts.stream_sink {
// SendError silently dropped — caller cancelled but the
// pipeline still drives generation to completion so the
// `answers` row gets a faithful record.
let _ = sink.send(t);
// p9-fb-33: SendError → caller dropped the
// receiver (probably a closed stdout downstream).
// Stop generation, mark the answer cancelled so
// the persistence path records refusal_reason =
// LlmStreamAborted.
if sink
.send(StreamEvent::Token {
delta: t,
turn_index: opts.turn_index,
})
.is_err()
{
cancelled = true;
break;
}
}
}
TokenChunk::Done {
@@ -351,6 +379,9 @@ impl RagPipeline {
}
}
}
if cancelled {
finish_reason = FinishReason::Cancelled;
}
// ── 6. Citation extract ────────────────────────────────────────────
let extracted: Vec<u32> = extract_markers(&acc);
@@ -375,15 +406,20 @@ impl RagPipeline {
});
let trimmed_answer = acc.trim();
let matched_refusal_phrase = refusal_phrase.is_match(&acc);
let grounded = !trimmed_answer.is_empty()
let grounded_unaware = !trimmed_answer.is_empty()
&& unknown_markers.is_empty()
&& !extracted.is_empty();
let refusal_reason = if grounded {
None
// p9-fb-33: cancel takes priority over LlmSelfJudge — the
// caller bailed mid-stream, so the recorded reason should
// reflect that, not "model didn't cite".
let (grounded, refusal_reason) = if matches!(finish_reason, FinishReason::Cancelled) {
(false, Some(RefusalReason::LlmStreamAborted))
} else if grounded_unaware {
(true, None)
} else {
// Spec §7: empty answer, unknown markers, silent ungrounded,
// and explicit "근거가 부족" all collapse to LlmSelfJudge.
Some(RefusalReason::LlmSelfJudge)
(false, Some(RefusalReason::LlmSelfJudge))
};
// ── 8. Build Answer ────────────────────────────────────────────────
@@ -461,6 +497,17 @@ impl RagPipeline {
"kb-rag: ask done"
);
// p9-fb-33: emit final on the success path. On cancel we
// skip Final — the receiver is gone and persistence still
// records the partial answer below.
if !cancelled
&& let Some(sink) = &opts.stream_sink
{
let _ = sink.send(StreamEvent::Final {
answer: answer.clone(),
});
}
// ── 9. Persist ─────────────────────────────────────────────────────
let packed_chunks_json = if opts.explain {
// Snapshot the packed entries as a portable list of objects so

View File

@@ -14,7 +14,7 @@ use kebab_core::{
FinishReason, LanguageModel, Retriever, SearchMode, TokenChunk, TokenUsage,
};
use kebab_llm::MockLanguageModel;
use kebab_rag::{AskOpts, RagPipeline, RefusalReason};
use kebab_rag::{AskOpts, RagPipeline, RefusalReason, StreamEvent};
/// LM ID used everywhere — kept short so snapshots stay stable.
const TEST_LM_ID: &str = "mock-lm";
@@ -270,18 +270,32 @@ fn streaming_forwards_tokens_to_sink() {
let lm: Arc<dyn LanguageModel> = Arc::new(CountingLm::new(canned));
let pipeline = RagPipeline::new(env.config.clone(), retriever, lm, env.sqlite.clone());
let (tx, rx) = std::sync::mpsc::channel::<String>();
let (tx, rx) = std::sync::mpsc::channel::<StreamEvent>();
let mut opts = default_opts();
opts.stream_sink = Some(tx);
let _ = pipeline.ask("q", opts).unwrap();
let collected: String = rx.into_iter().collect::<Vec<_>>().join("");
// p9-fb-33: extract Token deltas from the staged event stream.
let collected: String = rx
.into_iter()
.filter_map(|ev| match ev {
StreamEvent::Token { delta, .. } => Some(delta),
_ => None,
})
.collect::<Vec<_>>()
.join("");
assert_eq!(collected, canned);
}
// ── 10. dropped receiver does NOT abort generation ────────────────────────
// ── 10. dropped receiver aborts generation, records LlmStreamAborted ──────
//
// p9-fb-33: cancel semantics changed. Pre-fb-33 the pipeline drove
// the LM loop to completion and silently dropped sends. Now a
// SendError breaks the loop and stamps `RefusalReason::LlmStreamAborted`
// onto the persisted row — the partial answer (whatever was buffered
// before the cancel) still gets written for audit.
#[test]
fn dropped_receiver_does_not_abort_generation() {
fn dropped_receiver_aborts_with_llm_stream_aborted() {
let env = RagEnv::new();
let cid = id32("c1");
let did = id32("d1");
@@ -292,13 +306,17 @@ fn dropped_receiver_does_not_abort_generation() {
let lm: Arc<dyn LanguageModel> = Arc::new(CountingLm::new(canned));
let pipeline = RagPipeline::new(env.config.clone(), retriever, lm, env.sqlite.clone());
let (tx, rx) = std::sync::mpsc::channel::<String>();
drop(rx); // receiver gone — every send fails silently
let (tx, rx) = std::sync::mpsc::channel::<StreamEvent>();
drop(rx); // receiver gone — first Token send fails, loop breaks
let mut opts = default_opts();
opts.stream_sink = Some(tx);
let answer = pipeline.ask("q", opts).unwrap();
assert_eq!(answer.answer, canned, "generation completes despite dead sink");
assert!(answer.grounded);
assert!(!answer.grounded, "cancel takes priority over grounded");
assert_eq!(
answer.refusal_reason,
Some(RefusalReason::LlmStreamAborted),
"cancel records LlmStreamAborted",
);
assert_eq!(env.count_answers(), 1, "answers row still persisted");
}

View File

@@ -0,0 +1,131 @@
//! p9-fb-33: pipeline-level streaming behavior — order invariants,
//! cancel propagation, refusal flagging.
mod common;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::sync::mpsc;
use common::{MockRetriever, RagEnv, id32, mk_hit};
use kebab_core::{
FinishReason, LanguageModel, RefusalReason, Retriever, SearchMode, TokenChunk, TokenUsage,
};
use kebab_llm::MockLanguageModel;
use kebab_rag::{AskOpts, RagPipeline, StreamEvent};
const TEST_LM_ID: &str = "mock-lm";
/// Minimal LM mirroring `tests/pipeline.rs::CountingLm` so the
/// streaming-events suite stays self-contained.
struct CountingLm {
inner: MockLanguageModel,
calls: std::sync::atomic::AtomicUsize,
}
impl CountingLm {
fn new(canned: &str) -> Self {
Self {
inner: MockLanguageModel {
model_id: TEST_LM_ID.to_string(),
provider: "mock".to_string(),
context_tokens: 32_768,
canned_response: canned.to_string(),
canned_finish: FinishReason::Stop,
canned_usage: TokenUsage {
prompt_tokens: 10,
completion_tokens: 5,
latency_ms: 7,
},
},
calls: std::sync::atomic::AtomicUsize::new(0),
}
}
}
impl LanguageModel for CountingLm {
fn model_ref(&self) -> kebab_core::ModelRef {
self.inner.model_ref()
}
fn context_tokens(&self) -> usize {
self.inner.context_tokens()
}
fn generate_stream(
&self,
req: kebab_core::GenerateRequest,
) -> anyhow::Result<Box<dyn Iterator<Item = anyhow::Result<TokenChunk>> + Send>> {
self.calls.fetch_add(1, Ordering::SeqCst);
self.inner.generate_stream(req)
}
}
fn opts_with_sink(tx: mpsc::Sender<StreamEvent>) -> AskOpts {
AskOpts {
k: 3,
explain: false,
mode: SearchMode::Lexical,
temperature: Some(0.0),
seed: Some(0),
stream_sink: Some(tx),
history: Vec::new(),
conversation_id: None,
turn_index: None,
}
}
/// Build a pipeline with one seeded chunk + canned LM response so
/// retrieval lands a single hit and the LM emits at least one token.
fn env_with_one_hit(canned: &str) -> (RagEnv, RagPipeline) {
let env = RagEnv::new();
let cid = id32("c1");
let did = id32("d1");
env.seed_chunk(&cid, &did, "notes/a.md", "apples are red.", &["Intro"]);
let hits = vec![mk_hit(1, &cid, &did, "notes/a.md", 0.85, &["Intro"])];
let retriever: Arc<dyn Retriever> = Arc::new(MockRetriever::new(hits));
let lm: Arc<dyn LanguageModel> = Arc::new(CountingLm::new(canned));
let pipeline = RagPipeline::new(env.config.clone(), retriever, lm, env.sqlite.clone());
(env, pipeline)
}
#[test]
fn ask_emits_retrieval_then_tokens_then_final() {
let (_env, pipeline) = env_with_one_hit("apples are red. [#1]");
let (tx, rx) = mpsc::channel::<StreamEvent>();
let _ans = pipeline.ask("apples", opts_with_sink(tx)).unwrap();
let events: Vec<StreamEvent> = rx.iter().collect();
// First event must be RetrievalDone.
assert!(
matches!(events.first(), Some(StreamEvent::RetrievalDone { .. })),
"first event must be RetrievalDone, got {:?}",
events.first()
);
// Last event must be Final.
assert!(
matches!(events.last(), Some(StreamEvent::Final { .. })),
"last event must be Final, got {:?}",
events.last()
);
// Everything in between is Token.
for ev in &events[1..events.len() - 1] {
assert!(
matches!(ev, StreamEvent::Token { .. }),
"middle events must be Token, got {ev:?}"
);
}
}
#[test]
fn ask_records_llm_stream_aborted_when_receiver_drops() {
let (env, pipeline) = env_with_one_hit("apples are red. [#1]");
let (tx, rx) = mpsc::channel::<StreamEvent>();
// Drop the receiver immediately so the first Token send fails.
drop(rx);
let ans = pipeline.ask("apples", opts_with_sink(tx)).unwrap();
assert!(!ans.grounded);
assert_eq!(ans.refusal_reason, Some(RefusalReason::LlmStreamAborted));
// Persistence still happens on cancel — the row is the audit trail.
assert_eq!(env.count_answers(), 1, "answers row written on cancel");
}