feat(p4-2): kb-llm-local crate — Ollama HTTP adapter (reqwest::blocking)
First real LanguageModel implementation. Wraps Ollama's local HTTP
API at POST {endpoint}/api/generate with stream:true, parses the
NDJSON streaming response into TokenChunk events, and maps Ollama
error states to a thiserror-derived LlmError with actionable hints.
Synchronous trait surface; reqwest::blocking handles the HTTP I/O.
Public surface:
- pub struct OllamaLanguageModel
- pub fn new(config: &Config) -> Result<Self> — lazy connect; never
hits the network. Spec line 96.
- pub enum LlmError { Unreachable, ModelNotPulled, Timeout, Stream,
Malformed }. Lives in this crate per spec — kb-core / kb-llm stay
free of error taxonomy.
- impl kb_core::LanguageModel via re-export from kb-llm.
Streaming:
- POST body shape per spec §11.2: model, prompt = system + "\n\n" +
user, stream: true, options { temperature, seed, num_ctx, stop }.
- OllamaStream owns BufReader<reqwest::blocking::Response>, reads
NDJSON lines via read_until(b'\n'), parses each as
{response, done, done_reason?, prompt_eval_count?, eval_count?,
total_duration?}. Token frame → TokenChunk::Token; done frame →
TokenChunk::Done { finish_reason, usage }.
- done_reason mapping: "length" → Length, "abort" → Aborted,
"stop" / missing / unknown → Stop (forward-compat with future
Ollama tags).
- Missing prompt_eval_count / eval_count default to 0 + tracing::warn
(do NOT fail). Spec line 135.
- EOF without a done line synthesizes Done { Aborted, zeros } so
downstream pipelines never deadlock waiting for a terminal frame.
- UTF-8: line-delimited framing means each JSON line is a complete
UTF-8 sequence — no cross-HTTP-chunk codepoint splits to worry
about. read_until accumulates whole lines regardless of how the
underlying reqwest body chunks.
Error mapping (LlmError):
- reqwest::Error::is_connect() → Unreachable { endpoint, source }
with hint "ensure `ollama serve` is running and reachable at
<endpoint>".
- reqwest::Error::is_timeout() → Timeout.
- 200 with non-NDJSON first line (e.g., transparent-proxy HTML
error page) → Stream(truncated body) — distinguished from
Malformed by the iterator's has_emitted flag.
- 404 with body containing model_id (case-insensitive) OR English
"model" + "not found" → ModelNotPulled(model_id) with hint
"ollama pull <model_id>". Tightened beyond spec to survive
Ollama localizing the error message (Korean / Japanese / etc.)
while keeping the original English-substring fallback.
- Other 4xx/5xx → Stream(truncated body).
- Mid-stream JSON parse failure (after at least one valid line) →
Malformed(line). Truncate all error bodies to 512 chars
(chars-based, multibyte safe) so an nginx 500 page can't blow up
the diagnostic.
- Trailing slash in endpoint stripped before formatting the URL —
endpoint = "http://x:1234/" produces .../api/generate, not
.../api//generate. Pinned by trailing-slash test.
Tokio note: reqwest 0.12's blocking feature internally wraps a
private current-thread tokio runtime, so cargo tree --edges normal
shows tokio. The auditable invariant is "no top-level tokio dep +
no async surface exposed to callers" — verified: src/ has zero
async/await/tokio::*. default-features = false drops default-tls
(rustls only) but does NOT drop tokio. Documented honestly in
Cargo.toml + lib.rs. Switching to ureq would remove tokio
entirely; deferred since reqwest is the spec's allowed dep.
Tests (24 total: 23 default + 1 ignored):
- 7 unit in src/ollama.rs: prompt-build, options-build, finish-
reason mapping, truncate_body bounds (under_cap / over_cap_marker
/ multibyte_chars_not_bytes), 404+model-id heuristic.
- 3 in tests/construction.rs: ModelRef shape, context_tokens
passthrough, lazy-connect proven via port-1 pointing.
- 13 in tests/streaming.rs: streamed tokens then Done, multibyte
chars within a line round-trip (renamed from "split across
chunks" to honestly reflect what's tested), Unreachable-with-
hint, 4xx→Stream, 404→ModelNotPulled, concat-equals-canned,
done_reason length / abort, missing eval counts default to zero,
missing done_reason defaults to Stop, determinism-by-mock,
trailing-slash endpoint, non-NDJSON 200 body → Stream not
Malformed.
- 1 #[ignore] in tests/integration.rs: real Ollama on
localhost:11434 with the configured model. Opt-in via
cargo test -p kb-llm-local -- --ignored after `ollama serve`
+ `ollama pull`.
Workspace: 288 passed / 25 ignored / 0 failed. cargo clippy
--workspace --all-targets -- -D warnings clean. No native-tls,
no openssl in the dep graph.
Allowed deps respected: kb-core, kb-config, kb-llm, reqwest 0.12
(default-features=false; blocking, json, rustls-tls), serde,
serde_json, tracing, thiserror plus anyhow (forced by trait return
type). wiremock + tokio in [dev-dependencies] only.
Out of scope: llama.cpp / candle adapters (P+), Ollama embed
endpoint (separate adapter inside kb-embed-local if requested),
cancellation / abort tokens (P+), connection-pool tuning.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
80
Cargo.lock
generated
80
Cargo.lock
generated
@@ -418,6 +418,16 @@ dependencies = [
|
||||
"stable_deref_trait",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "assert-json-diff"
|
||||
version = "2.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-channel"
|
||||
version = "2.5.0"
|
||||
@@ -683,7 +693,7 @@ version = "3.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "519bd3116aeeb42d5372c29d982d16d0170d3d4a5ed85fc7dd91642ffff3c67c"
|
||||
dependencies = [
|
||||
"darling 0.20.11",
|
||||
"darling 0.21.3",
|
||||
"ident_case",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
@@ -1854,6 +1864,24 @@ dependencies = [
|
||||
"sqlparser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deadpool"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b"
|
||||
dependencies = [
|
||||
"deadpool-runtime",
|
||||
"lazy_static",
|
||||
"num_cpus",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deadpool-runtime"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b"
|
||||
|
||||
[[package]]
|
||||
name = "deepsize"
|
||||
version = "0.2.0"
|
||||
@@ -2789,6 +2817,12 @@ version = "1.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
||||
|
||||
[[package]]
|
||||
name = "httpdate"
|
||||
version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
|
||||
|
||||
[[package]]
|
||||
name = "humantime"
|
||||
version = "2.3.0"
|
||||
@@ -2809,6 +2843,7 @@ dependencies = [
|
||||
"http",
|
||||
"http-body",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"smallvec",
|
||||
@@ -2830,6 +2865,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tower-service",
|
||||
"webpki-roots 1.0.7",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3448,6 +3484,23 @@ dependencies = [
|
||||
"proptest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kb-llm-local"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"kb-config",
|
||||
"kb-core",
|
||||
"kb-llm",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kb-normalize"
|
||||
version = "0.1.0"
|
||||
@@ -5855,6 +5908,7 @@ dependencies = [
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"encoding_rs",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
@@ -5892,6 +5946,7 @@ dependencies = [
|
||||
"wasm-bindgen-futures",
|
||||
"wasm-streams",
|
||||
"web-sys",
|
||||
"webpki-roots 1.0.7",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8034,6 +8089,29 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wiremock"
|
||||
version = "0.6.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031"
|
||||
dependencies = [
|
||||
"assert-json-diff",
|
||||
"base64 0.22.1",
|
||||
"deadpool",
|
||||
"futures",
|
||||
"http",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"log",
|
||||
"once_cell",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen"
|
||||
version = "0.46.0"
|
||||
|
||||
@@ -14,6 +14,7 @@ members = [
|
||||
"crates/kb-embed",
|
||||
"crates/kb-embed-local",
|
||||
"crates/kb-llm",
|
||||
"crates/kb-llm-local",
|
||||
"crates/kb-app",
|
||||
"crates/kb-cli",
|
||||
]
|
||||
@@ -55,3 +56,7 @@ arrow-array = "56"
|
||||
arrow-schema = "56"
|
||||
tokio = { version = "1", features = ["rt", "macros"] }
|
||||
futures = "0.3"
|
||||
# Dev-only HTTP mock server for kb-llm-local Ollama adapter tests. Requires
|
||||
# a tokio runtime to host its mock server (the runtime adapter crate stays
|
||||
# sync via reqwest::blocking — wiremock is dev-only there).
|
||||
wiremock = "0.6"
|
||||
|
||||
36
crates/kb-llm-local/Cargo.toml
Normal file
36
crates/kb-llm-local/Cargo.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[package]
|
||||
name = "kb-llm-local"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
rust-version = { workspace = true }
|
||||
license = { workspace = true }
|
||||
repository = { workspace = true }
|
||||
description = "Ollama HTTP adapter implementing kb_core::LanguageModel via reqwest::blocking"
|
||||
|
||||
[dependencies]
|
||||
kb-core = { path = "../kb-core" }
|
||||
kb-config = { path = "../kb-config" }
|
||||
kb-llm = { path = "../kb-llm" }
|
||||
# `default-features = false` drops the `default-tls` (native-tls / openssl)
|
||||
# feature so we don't pull in a system OpenSSL; we explicitly pin rustls.
|
||||
# Note: `default-features = false` does NOT drop tokio — reqwest 0.12's
|
||||
# `blocking` feature internally wraps a private current-thread tokio
|
||||
# runtime, so `cargo tree -p kb-llm-local --edges normal | grep tokio`
|
||||
# will list tokio. The auditable invariant for this crate is "no
|
||||
# top-level tokio dep + no async surface (`async`/`await`/`tokio::*`)
|
||||
# exposed to callers" rather than "tokio absent from the tree".
|
||||
reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
# wiremock requires a tokio runtime to host the mock HTTP server. tokio
|
||||
# is also pulled transitively at runtime by reqwest's `blocking` feature
|
||||
# (private current-thread runtime); see the dependency comment above.
|
||||
# What we DO guarantee: this crate's source has zero `async`/`await`/
|
||||
# `tokio::*` symbols, so the public/runtime API stays sync.
|
||||
wiremock = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt"] }
|
||||
63
crates/kb-llm-local/src/error.rs
Normal file
63
crates/kb-llm-local/src/error.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
//! [`LlmError`] — adapter-side error taxonomy mapping Ollama failure modes
|
||||
//! onto the variants downstream RAG / CLI code pattern-matches against.
|
||||
//!
|
||||
//! Living in this crate (rather than `kb-core` or `kb-llm`) is deliberate:
|
||||
//! the variants are LLM-adapter specific (e.g. "model not pulled" is an
|
||||
//! Ollama-ism), and surfacing them as `anyhow::Error` source values lets
|
||||
//! callers `downcast_ref::<LlmError>()` only if they actually care. Trait
|
||||
//! consumers stay generic over the error.
|
||||
//!
|
||||
//! Display strings follow design §10 — every variant is **actionable**: it
|
||||
//! tells the user the next command to run (`ollama serve`, `ollama pull`)
|
||||
//! when the cause is operational rather than programmatic.
|
||||
|
||||
/// Errors specific to the Ollama HTTP adapter.
|
||||
///
|
||||
/// Wrapped into `anyhow::Error` at API boundaries; downstream code that
|
||||
/// needs to render hints (e.g. `kb doctor`) can `downcast_ref::<LlmError>()`.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum LlmError {
|
||||
/// Ollama not running at the configured endpoint, or the host is
|
||||
/// unreachable. Detected via `reqwest::Error::is_connect()`.
|
||||
#[error(
|
||||
"ollama unreachable at {endpoint}: {source}\n\
|
||||
hint: ensure `ollama serve` is running and reachable at {endpoint}"
|
||||
)]
|
||||
Unreachable {
|
||||
endpoint: String,
|
||||
#[source]
|
||||
source: reqwest::Error,
|
||||
},
|
||||
|
||||
/// Server returned 404 with a body indicating the requested model is not
|
||||
/// pulled. Carries the model id so the hint is copy-pasteable.
|
||||
#[error(
|
||||
"ollama model `{0}` is not pulled\n\
|
||||
hint: run `ollama pull {0}`"
|
||||
)]
|
||||
ModelNotPulled(String),
|
||||
|
||||
/// Network read/write timed out. `reqwest::blocking::Client` is built
|
||||
/// with a 5-minute ceiling — cold-loading a 14B model can legitimately
|
||||
/// take >1 minute on first call.
|
||||
#[error("ollama timeout: {0}")]
|
||||
Timeout(#[source] reqwest::Error),
|
||||
|
||||
/// HTTP-level / server-shape error: a non-404 4xx/5xx response, or a
|
||||
/// 200 response whose body is not NDJSON at all (e.g. an HTML 500 page
|
||||
/// from a misrouted reverse proxy, or a `{"error":...}` envelope on a
|
||||
/// streaming frame). Carries the response body, **truncated to 512
|
||||
/// chars** at the construction site so a megabyte-sized nginx error
|
||||
/// page or Ollama panic dump cannot blow up logs / `Display`.
|
||||
#[error("ollama HTTP error: {0}")]
|
||||
Stream(String),
|
||||
|
||||
/// Mid-stream JSON parse failure on a line that should have been
|
||||
/// NDJSON: i.e. earlier lines in the same response parsed cleanly,
|
||||
/// then a later line was corrupt. Distinct from `Stream` (which covers
|
||||
/// "the server never spoke NDJSON to begin with") so callers can
|
||||
/// choose to skip vs. abort. Carries the offending line for
|
||||
/// `kb doctor`-style diagnostics.
|
||||
#[error("malformed response line: {0}")]
|
||||
Malformed(String),
|
||||
}
|
||||
49
crates/kb-llm-local/src/lib.rs
Normal file
49
crates/kb-llm-local/src/lib.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
//! `kb-llm-local` — Ollama HTTP adapter implementing
|
||||
//! [`kb_core::LanguageModel`] over the local `POST /api/generate` endpoint.
|
||||
//!
|
||||
//! ## Why a separate crate
|
||||
//!
|
||||
//! `kb-llm` re-exports the trait + [`MockLanguageModel`] for downstream tests.
|
||||
//! Real adapters (Ollama, llama.cpp, candle) live outside `kb-llm` so swapping
|
||||
//! providers stays config-only and so the trait crate has no heavy
|
||||
//! dependencies. p4-2 ("first real LM") is the home of [`OllamaLanguageModel`]
|
||||
//! and the [`LlmError`] enum the rest of the workspace will pattern-match
|
||||
//! against.
|
||||
//!
|
||||
//! ## Runtime contract
|
||||
//!
|
||||
//! - **Synchronous surface.** Built on `reqwest::blocking`. This crate's
|
||||
//! source contains zero `async`/`await`/`tokio::*` symbols and exposes
|
||||
//! no async surface to callers.
|
||||
//!
|
||||
//! Note on tokio: reqwest 0.12's `blocking` feature internally wraps a
|
||||
//! private current-thread tokio runtime, so
|
||||
//! `cargo tree -p kb-llm-local --edges normal | grep tokio` WILL show
|
||||
//! tokio in the runtime graph. The auditable invariant is "no top-level
|
||||
//! tokio dep + no async surface exposed to callers" rather than "tokio
|
||||
//! absent from the tree".
|
||||
//! - **Streaming.** The adapter posts `stream: true` and returns a
|
||||
//! `Box<dyn Iterator<Item = Result<TokenChunk>> + Send>` that reads
|
||||
//! line-delimited JSON frames lazily — tokens reach the caller as the
|
||||
//! server emits them.
|
||||
//! - **Lazy connect.** [`OllamaLanguageModel::new`] does not hit the network;
|
||||
//! the first error surfaces on [`LanguageModel::generate_stream`].
|
||||
//!
|
||||
//! See `docs/superpowers/specs/2026-04-27-kb-final-form-design.md` §7.2,
|
||||
//! §6.4 (`[models.llm]`), §0 Q5 (streaming), §10 (errors), and report §11.2
|
||||
//! (Ollama protocol notes).
|
||||
|
||||
mod error;
|
||||
mod ollama;
|
||||
|
||||
pub use error::LlmError;
|
||||
pub use ollama::OllamaLanguageModel;
|
||||
|
||||
// Re-export the trait surface so adapter consumers can `use kb_llm_local::*`
|
||||
// without also depending on `kb-llm` directly. These are the same symbols
|
||||
// `kb-llm` re-exports from `kb-core`; this crate adds **no new types** to
|
||||
// the trait surface (`LlmError` and `OllamaLanguageModel` are
|
||||
// implementation-side only).
|
||||
pub use kb_llm::{
|
||||
FinishReason, GenerateRequest, LanguageModel, ModelRef, TokenChunk, TokenUsage,
|
||||
};
|
||||
562
crates/kb-llm-local/src/ollama.rs
Normal file
562
crates/kb-llm-local/src/ollama.rs
Normal file
@@ -0,0 +1,562 @@
|
||||
//! [`OllamaLanguageModel`] — `reqwest::blocking` adapter for Ollama's
|
||||
//! `POST /api/generate` streaming endpoint.
|
||||
//!
|
||||
//! ## Wire shape
|
||||
//!
|
||||
//! Request body (design §11.2 / §6.4):
|
||||
//!
|
||||
//! ```json
|
||||
//! {
|
||||
//! "model": "<config.models.llm.model>",
|
||||
//! "prompt": "<system + '\n\n' + user>",
|
||||
//! "stream": true,
|
||||
//! "options": {
|
||||
//! "temperature": <float>,
|
||||
//! "seed": <u64>,
|
||||
//! "num_ctx": <usize>,
|
||||
//! "stop": ["<str>", ...]
|
||||
//! }
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! Response is line-delimited JSON; each non-empty line is one
|
||||
//! [`OllamaLine`]. Non-final lines carry `response: "..."` plus
|
||||
//! `done: false`; the final line carries `done: true` plus aggregate
|
||||
//! counters (`prompt_eval_count`, `eval_count`, `total_duration`).
|
||||
//!
|
||||
//! ## Streaming model
|
||||
//!
|
||||
//! [`generate_stream`] returns a stateful [`OllamaStream`] that owns the
|
||||
//! `BufReader<reqwest::blocking::Response>` and yields one `TokenChunk` per
|
||||
//! `Iterator::next` call — true streaming, no `collect` into a buffer. This
|
||||
//! matters because cold-loading a 14B model can take >1 minute and we want
|
||||
//! tokens to appear as the server emits them.
|
||||
//!
|
||||
//! ## Send-safety
|
||||
//!
|
||||
//! `reqwest::blocking::Response: Send`, so `BufReader<Response>: Send`, so
|
||||
//! the boxed iterator satisfies the trait's `+ Send` bound without any
|
||||
//! `Mutex` ceremony.
|
||||
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::time::Duration;
|
||||
|
||||
use kb_core::{
|
||||
FinishReason, GenerateRequest, LanguageModel, ModelRef, TokenChunk, TokenUsage,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::LlmError;
|
||||
|
||||
/// Hard ceiling on a single HTTP exchange. Cold-loading a 14B model on
|
||||
/// first call can take ~30s; 5 minutes is generous without being
|
||||
/// open-ended.
|
||||
const REQUEST_TIMEOUT: Duration = Duration::from_secs(300);
|
||||
|
||||
/// `reqwest::blocking` adapter implementing [`LanguageModel`] over Ollama's
|
||||
/// local HTTP API. Construction is cheap and offline; the first network
|
||||
/// call happens inside [`generate_stream`].
|
||||
pub struct OllamaLanguageModel {
|
||||
client: reqwest::blocking::Client,
|
||||
/// Already-validated endpoint URL string (e.g. `"http://127.0.0.1:11434"`).
|
||||
/// Stored as `String` rather than `url::Url` to keep the dep set minimal.
|
||||
endpoint: String,
|
||||
model_id: String,
|
||||
context_tokens: usize,
|
||||
default_temperature: f32,
|
||||
default_seed: u64,
|
||||
}
|
||||
|
||||
impl OllamaLanguageModel {
|
||||
/// Build an adapter from a workspace [`kb_config::Config`]. Reads
|
||||
/// `config.models.llm.{provider, model, endpoint, context_tokens,
|
||||
/// temperature, seed}`.
|
||||
///
|
||||
/// Does NOT touch the network — see module docs. The caller is
|
||||
/// expected to have validated `provider == "ollama"`; this constructor
|
||||
/// trusts the config and would happily build for an unknown provider.
|
||||
/// (Provider routing is the App layer's job, not the adapter's.)
|
||||
pub fn new(config: &kb_config::Config) -> anyhow::Result<Self> {
|
||||
let llm = &config.models.llm;
|
||||
let client = reqwest::blocking::Client::builder()
|
||||
.timeout(REQUEST_TIMEOUT)
|
||||
.build()?;
|
||||
Ok(Self {
|
||||
client,
|
||||
endpoint: llm.endpoint.clone(),
|
||||
model_id: llm.model.clone(),
|
||||
context_tokens: llm.context_tokens,
|
||||
default_temperature: llm.temperature,
|
||||
default_seed: llm.seed,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for OllamaLanguageModel {
|
||||
fn model_ref(&self) -> ModelRef {
|
||||
ModelRef {
|
||||
id: self.model_id.clone(),
|
||||
// Per design §3.8 / §6.4 — adapters that route through Ollama
|
||||
// report `provider = "ollama"` regardless of which model id
|
||||
// they carry, so downstream `Answer.model` displays consistently.
|
||||
provider: "ollama".to_string(),
|
||||
// Chat models have no embedding dimension — see §3.8.
|
||||
dimensions: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn context_tokens(&self) -> usize {
|
||||
self.context_tokens
|
||||
}
|
||||
|
||||
fn generate_stream(
|
||||
&self,
|
||||
req: GenerateRequest,
|
||||
) -> anyhow::Result<Box<dyn Iterator<Item = anyhow::Result<TokenChunk>> + Send>> {
|
||||
// ── Resolve effective options ──────────────────────────────────
|
||||
//
|
||||
// Per design §6.4 the effective temperature/seed come from the
|
||||
// config defaults. `GenerateRequest` exposes a `temperature: f32`
|
||||
// (always present) and `seed: Option<u64>` so the request can
|
||||
// override on a per-call basis. Resolution order:
|
||||
// - temperature: `req.temperature` always wins. The field is
|
||||
// non-Optional, so the RAG layer always declares an intent
|
||||
// (typically `0.0` for grounded answers); the workspace
|
||||
// default merely seeds that field at the RAG construction
|
||||
// site. NaN is rejected → fall back to the config default
|
||||
// so a malformed RAG request can never reach Ollama.
|
||||
// - seed: `req.seed.unwrap_or(default_seed)` — Optional in the
|
||||
// request, so the config default is the natural fallback.
|
||||
let effective_temperature = if req.temperature.is_nan() {
|
||||
self.default_temperature
|
||||
} else {
|
||||
req.temperature
|
||||
};
|
||||
let effective_seed = req.seed.unwrap_or(self.default_seed);
|
||||
|
||||
let prompt = if req.system.is_empty() {
|
||||
req.user.clone()
|
||||
} else {
|
||||
format!("{}\n\n{}", req.system, req.user)
|
||||
};
|
||||
|
||||
let body = OllamaRequest {
|
||||
model: &self.model_id,
|
||||
prompt,
|
||||
stream: true,
|
||||
options: OllamaOptions {
|
||||
temperature: effective_temperature,
|
||||
seed: effective_seed,
|
||||
num_ctx: self.context_tokens,
|
||||
stop: &req.stop,
|
||||
},
|
||||
};
|
||||
|
||||
// ── Send (blocking) ────────────────────────────────────────────
|
||||
let url = format!("{}/api/generate", self.endpoint.trim_end_matches('/'));
|
||||
let response = match self.client.post(&url).json(&body).send() {
|
||||
Ok(r) => r,
|
||||
Err(e) => return Err(map_send_error(e, &self.endpoint).into()),
|
||||
};
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
// Read the body so we can pattern-match on Ollama's "model not
|
||||
// found" envelope (§11.2). Body read is bounded by the server
|
||||
// — Ollama only sends a short JSON envelope on error, no
|
||||
// streaming body to drain.
|
||||
let body_text = response.text().unwrap_or_default();
|
||||
return Err(map_status_error(status, &body_text, &self.model_id).into());
|
||||
}
|
||||
|
||||
// ── Hand off to the streaming iterator ─────────────────────────
|
||||
let stream = OllamaStream {
|
||||
reader: BufReader::new(response),
|
||||
line_buf: Vec::with_capacity(1024),
|
||||
done: false,
|
||||
has_emitted: false,
|
||||
};
|
||||
Ok(Box::new(stream))
|
||||
}
|
||||
}
|
||||
|
||||
// ── Wire types ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Ollama `/api/generate` request body. Borrowed model id + stop list keep
|
||||
/// allocations to one (the prompt) per call.
|
||||
#[derive(Serialize)]
|
||||
struct OllamaRequest<'a> {
|
||||
model: &'a str,
|
||||
prompt: String,
|
||||
stream: bool,
|
||||
options: OllamaOptions<'a>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OllamaOptions<'a> {
|
||||
temperature: f32,
|
||||
seed: u64,
|
||||
num_ctx: usize,
|
||||
stop: &'a [String],
|
||||
}
|
||||
|
||||
/// One line of the streaming response. All counter fields are optional
|
||||
/// because older Ollama builds omit them on the final frame; see §10
|
||||
/// "actionable errors" — we degrade gracefully with `0` + a `tracing::warn`
|
||||
/// span rather than failing the stream.
|
||||
#[derive(Deserialize, Default, Debug)]
|
||||
struct OllamaLine {
|
||||
/// Token text. Absent / empty on the final `done: true` line.
|
||||
#[serde(default)]
|
||||
response: String,
|
||||
/// Terminal frame marker.
|
||||
#[serde(default)]
|
||||
done: bool,
|
||||
/// `"stop"` | `"length"` | `"abort"` | (older builds: missing).
|
||||
#[serde(default)]
|
||||
done_reason: Option<String>,
|
||||
/// Tokens consumed by the prompt. Older Ollama: absent → defaulted to 0.
|
||||
#[serde(default)]
|
||||
prompt_eval_count: Option<u64>,
|
||||
/// Tokens generated. Older Ollama: absent → defaulted to 0.
|
||||
#[serde(default)]
|
||||
eval_count: Option<u64>,
|
||||
/// Total wall-clock in nanoseconds. Older Ollama: absent → 0.
|
||||
#[serde(default)]
|
||||
total_duration: Option<u64>,
|
||||
/// Server-side error message (Ollama uses this on some 200-with-error
|
||||
/// frames). When present we surface it instead of treating the line as
|
||||
/// a token.
|
||||
#[serde(default)]
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
// ── Streaming iterator ────────────────────────────────────────────────────
|
||||
|
||||
/// Stateful iterator over Ollama's NDJSON stream.
|
||||
///
|
||||
/// Owns the `BufReader<Response>` so reading is incremental — `next()`
|
||||
/// blocks only as long as it takes Ollama to flush the next line.
|
||||
///
|
||||
/// Iterator semantics:
|
||||
/// - `Some(Ok(TokenChunk::Token(_)))` for each non-terminal frame.
|
||||
/// - One terminal `Some(Ok(TokenChunk::Done { .. }))` on the `done: true`
|
||||
/// line, after which `done == true` and subsequent calls return `None`.
|
||||
/// - `Some(Err(_))` on parse / I/O failure; the iterator does **not** yield
|
||||
/// `Done` after an error. Callers that need a guaranteed terminal frame
|
||||
/// should adapt with their own wrapper (the trait contract for streams
|
||||
/// ending in `Done` is enforced by the RAG pipeline, not the adapter).
|
||||
///
|
||||
/// Timeout invariant: the iterator has no inherent stop condition for an
|
||||
/// indefinitely-stalled server — only the underlying
|
||||
/// `reqwest::blocking::Client`'s read timeout (`REQUEST_TIMEOUT`, 300s)
|
||||
/// breaks the hang. Callers needing tighter cancellation should adjust
|
||||
/// the client timeout in [`OllamaLanguageModel::new`].
|
||||
struct OllamaStream {
|
||||
reader: BufReader<reqwest::blocking::Response>,
|
||||
line_buf: Vec<u8>,
|
||||
done: bool,
|
||||
/// Tracks whether we have parsed at least one valid NDJSON line. Used
|
||||
/// to discriminate "server never spoke NDJSON" (→ `LlmError::Stream`)
|
||||
/// from "mid-stream corruption" (→ `LlmError::Malformed`); see §10
|
||||
/// error taxonomy split.
|
||||
has_emitted: bool,
|
||||
}
|
||||
|
||||
impl Iterator for OllamaStream {
|
||||
type Item = anyhow::Result<TokenChunk>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.done {
|
||||
return None;
|
||||
}
|
||||
loop {
|
||||
self.line_buf.clear();
|
||||
// `read_until(b'\n', ...)` accumulates bytes across HTTP chunk
|
||||
// boundaries until it hits a newline (or EOF). UTF-8 multibyte
|
||||
// sequences inside a JSON `response` field are therefore
|
||||
// always whole by the time we attempt to decode — the line
|
||||
// boundary IS the safe re-sync point.
|
||||
let read = match self.reader.read_until(b'\n', &mut self.line_buf) {
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
self.done = true;
|
||||
return Some(Err(anyhow::anyhow!(e).context(
|
||||
"failed to read next line from ollama /api/generate stream",
|
||||
)));
|
||||
}
|
||||
};
|
||||
if read == 0 {
|
||||
// EOF without a `done: true` line. Treat as a stream
|
||||
// anomaly — synthesize an Aborted Done so downstream
|
||||
// pipelines that expect a terminal frame still terminate.
|
||||
self.done = true;
|
||||
tracing::warn!(
|
||||
target: "kb_llm_local",
|
||||
"ollama stream ended without a `done: true` frame; synthesizing Aborted",
|
||||
);
|
||||
return Some(Ok(TokenChunk::Done {
|
||||
finish_reason: FinishReason::Aborted,
|
||||
usage: TokenUsage {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
latency_ms: 0,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
// Strip trailing `\n` / `\r\n`. Empty lines (keep-alive
|
||||
// heartbeats, blank separators) are skipped silently.
|
||||
let trimmed = trim_trailing_newline(&self.line_buf);
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let line: OllamaLine = match serde_json::from_slice(trimmed) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
self.done = true;
|
||||
let preview = String::from_utf8_lossy(trimmed).into_owned();
|
||||
if !self.has_emitted {
|
||||
// First line of the body did not parse as NDJSON
|
||||
// at all — the server clearly didn't speak the
|
||||
// protocol (e.g. an HTML 500 page from a
|
||||
// misrouted reverse proxy returning 200). Per §10
|
||||
// error taxonomy this is `Stream`, not
|
||||
// `Malformed`.
|
||||
return Some(Err(anyhow::Error::from(LlmError::Stream(
|
||||
truncate_body(&preview, 512),
|
||||
))));
|
||||
}
|
||||
// Mid-stream corruption — earlier lines parsed, this
|
||||
// one didn't. That's `Malformed`.
|
||||
return Some(Err(anyhow::Error::from(LlmError::Malformed(format!(
|
||||
"{e}: line={preview}"
|
||||
)))));
|
||||
}
|
||||
};
|
||||
// We've now parsed at least one structurally-valid NDJSON
|
||||
// line; subsequent parse failures count as mid-stream.
|
||||
self.has_emitted = true;
|
||||
|
||||
// Server-side error envelope on a 200 stream.
|
||||
if let Some(err) = line.error {
|
||||
self.done = true;
|
||||
return Some(Err(anyhow::Error::from(LlmError::Stream(
|
||||
truncate_body(&err, 512),
|
||||
))));
|
||||
}
|
||||
|
||||
if line.done {
|
||||
self.done = true;
|
||||
let finish_reason = match line.done_reason.as_deref() {
|
||||
Some("length") => FinishReason::Length,
|
||||
Some("abort") => FinishReason::Aborted,
|
||||
// Per §11.2 missing or unknown done_reason → Stop.
|
||||
// We treat unrecognised reasons as Stop rather than
|
||||
// surfacing them as `Error(_)` because Ollama
|
||||
// historically used "stop" as the only terminal value
|
||||
// and forward-compatible parsing should be lenient.
|
||||
_ => FinishReason::Stop,
|
||||
};
|
||||
let prompt_tokens = line.prompt_eval_count.unwrap_or_else(|| {
|
||||
tracing::warn!(
|
||||
target: "kb_llm_local",
|
||||
"ollama done frame missing prompt_eval_count; defaulting to 0",
|
||||
);
|
||||
0
|
||||
});
|
||||
let completion_tokens = line.eval_count.unwrap_or_else(|| {
|
||||
tracing::warn!(
|
||||
target: "kb_llm_local",
|
||||
"ollama done frame missing eval_count; defaulting to 0",
|
||||
);
|
||||
0
|
||||
});
|
||||
let total_duration_ns = line.total_duration.unwrap_or(0);
|
||||
let usage = TokenUsage {
|
||||
// u32 saturation: even ~4G tokens is implausible for a
|
||||
// single chat turn; we still saturate rather than
|
||||
// panic on the unlikely case.
|
||||
prompt_tokens: prompt_tokens.min(u32::MAX as u64) as u32,
|
||||
completion_tokens: completion_tokens.min(u32::MAX as u64) as u32,
|
||||
latency_ms: (total_duration_ns / 1_000_000).min(u32::MAX as u64) as u32,
|
||||
};
|
||||
return Some(Ok(TokenChunk::Done {
|
||||
finish_reason,
|
||||
usage,
|
||||
}));
|
||||
}
|
||||
|
||||
// Non-terminal frame. Older Ollama versions occasionally emit
|
||||
// empty `response` strings as keep-alive — don't surface
|
||||
// those as zero-length tokens.
|
||||
if line.response.is_empty() {
|
||||
continue;
|
||||
}
|
||||
return Some(Ok(TokenChunk::Token(line.response)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
fn trim_trailing_newline(bytes: &[u8]) -> &[u8] {
|
||||
let mut end = bytes.len();
|
||||
while end > 0 && (bytes[end - 1] == b'\n' || bytes[end - 1] == b'\r') {
|
||||
end -= 1;
|
||||
}
|
||||
&bytes[..end]
|
||||
}
|
||||
|
||||
/// Map a `reqwest::Error` from the initial `.send()` to an [`LlmError`]
|
||||
/// (returned to the caller as `anyhow::Error`).
|
||||
fn map_send_error(err: reqwest::Error, endpoint: &str) -> LlmError {
|
||||
if err.is_timeout() {
|
||||
return LlmError::Timeout(err);
|
||||
}
|
||||
if err.is_connect() {
|
||||
return LlmError::Unreachable {
|
||||
endpoint: endpoint.to_string(),
|
||||
source: err,
|
||||
};
|
||||
}
|
||||
// Other transport errors (DNS, body builder, etc.) — surface verbatim
|
||||
// (truncated; see `truncate_body`).
|
||||
LlmError::Stream(truncate_body(&err.to_string(), 512))
|
||||
}
|
||||
|
||||
/// Map a non-2xx HTTP response to an [`LlmError`]. Pattern-matches on the
|
||||
/// 404 + "model" / "not found" body envelope to surface the actionable
|
||||
/// `ollama pull <model>` hint.
|
||||
fn map_status_error(
|
||||
status: reqwest::StatusCode,
|
||||
body: &str,
|
||||
model_id: &str,
|
||||
) -> LlmError {
|
||||
if status == reqwest::StatusCode::NOT_FOUND {
|
||||
let lower = body.to_ascii_lowercase();
|
||||
// Heuristic: Ollama's "model not pulled" envelope is roughly
|
||||
// `{"error":"model 'qwen2.5:7b-instruct' not found, try pulling
|
||||
// it first"}`.
|
||||
//
|
||||
// Primary signal: the body mentions our exact model id —
|
||||
// language-agnostic, so a localized Ollama (e.g. Korean error
|
||||
// strings) still routes here. Fallback: the original English
|
||||
// "model" + "not found" substring check, kept for the case where
|
||||
// Ollama returns a generic envelope without echoing the model id.
|
||||
if lower.contains(&model_id.to_ascii_lowercase())
|
||||
|| (lower.contains("model") && lower.contains("not found"))
|
||||
{
|
||||
return LlmError::ModelNotPulled(model_id.to_string());
|
||||
}
|
||||
}
|
||||
LlmError::Stream(truncate_body(
|
||||
&format!("status={status} body={body}"),
|
||||
512,
|
||||
))
|
||||
}
|
||||
|
||||
/// Truncate a body / error string to `n` characters, appending an
|
||||
/// "(truncated, original N chars)" marker if the cap was hit. Counted in
|
||||
/// `chars()` rather than bytes so multibyte UTF-8 (Korean / Japanese /
|
||||
/// emoji) cannot land mid-codepoint.
|
||||
///
|
||||
/// Used at every `LlmError::Stream` construction site so a megabyte-sized
|
||||
/// nginx 500 page or Ollama panic dump cannot blow up `Display` / logs.
|
||||
fn truncate_body(s: &str, n: usize) -> String {
|
||||
if s.chars().count() <= n {
|
||||
return s.to_string();
|
||||
}
|
||||
let mut out: String = s.chars().take(n).collect();
|
||||
out.push_str(&format!("... (truncated, original {} chars)", s.chars().count()));
|
||||
out
|
||||
}
|
||||
|
||||
// ── Unit tests ────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn trim_trailing_newline_removes_lf_and_crlf() {
|
||||
assert_eq!(trim_trailing_newline(b"hi\n"), b"hi");
|
||||
assert_eq!(trim_trailing_newline(b"hi\r\n"), b"hi");
|
||||
assert_eq!(trim_trailing_newline(b"hi"), b"hi");
|
||||
assert_eq!(trim_trailing_newline(b""), b"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn map_status_error_404_with_model_not_found_returns_not_pulled() {
|
||||
let body = r#"{"error":"model 'qwen2.5:7b-instruct' not found, try pulling it first"}"#;
|
||||
let err = map_status_error(
|
||||
reqwest::StatusCode::NOT_FOUND,
|
||||
body,
|
||||
"qwen2.5:7b-instruct",
|
||||
);
|
||||
match err {
|
||||
LlmError::ModelNotPulled(m) => assert_eq!(m, "qwen2.5:7b-instruct"),
|
||||
other => panic!("expected ModelNotPulled, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn map_status_error_500_returns_stream() {
|
||||
let err = map_status_error(
|
||||
reqwest::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"boom",
|
||||
"qwen2.5:7b-instruct",
|
||||
);
|
||||
assert!(matches!(err, LlmError::Stream(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn map_status_error_404_with_model_id_in_localized_body_is_not_pulled() {
|
||||
// Localized Ollama: imagine a Korean build returning
|
||||
// `{"error":"모델 'qwen2.5:7b-instruct' 을(를) 찾을 수 없습니다"}`.
|
||||
// The English "not found" substring is absent, but the model id
|
||||
// is echoed — heuristic should still route to ModelNotPulled.
|
||||
let body = r#"{"error":"모델 'qwen2.5:7b-instruct' 을(를) 찾을 수 없습니다"}"#;
|
||||
let err = map_status_error(
|
||||
reqwest::StatusCode::NOT_FOUND,
|
||||
body,
|
||||
"qwen2.5:7b-instruct",
|
||||
);
|
||||
assert!(
|
||||
matches!(err, LlmError::ModelNotPulled(ref m) if m == "qwen2.5:7b-instruct"),
|
||||
"expected ModelNotPulled for localized 404 body, got {err:?}",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_body_under_cap_returns_input_unchanged() {
|
||||
assert_eq!(truncate_body("short", 512), "short");
|
||||
assert_eq!(truncate_body("", 512), "");
|
||||
// Boundary: exactly at the cap.
|
||||
let exact = "x".repeat(10);
|
||||
assert_eq!(truncate_body(&exact, 10), exact);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_body_over_cap_appends_marker_and_caps_chars() {
|
||||
let big = "y".repeat(1000);
|
||||
let out = truncate_body(&big, 512);
|
||||
// 512 chars of payload + the truncation marker.
|
||||
assert!(out.starts_with(&"y".repeat(512)));
|
||||
assert!(out.contains("(truncated, original 1000 chars)"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_body_counts_chars_not_bytes_for_multibyte() {
|
||||
// 600 Korean chars (each ~3 UTF-8 bytes). Slicing by bytes would
|
||||
// land mid-codepoint; chars() iteration is safe.
|
||||
let big: String = "한".repeat(600);
|
||||
let out = truncate_body(&big, 512);
|
||||
// Make sure the prefix is exactly 512 Korean chars, valid UTF-8.
|
||||
let prefix: String = out.chars().take(512).collect();
|
||||
assert_eq!(prefix.chars().count(), 512);
|
||||
assert!(prefix.chars().all(|c| c == '한'));
|
||||
assert!(out.contains("(truncated, original 600 chars)"));
|
||||
}
|
||||
}
|
||||
37
crates/kb-llm-local/tests/construction.rs
Normal file
37
crates/kb-llm-local/tests/construction.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
//! Construction-time tests — verify `OllamaLanguageModel::new` reads the
|
||||
//! relevant config fields and exposes them via the trait surface, all
|
||||
//! without touching the network (per design §7.2 lazy-connect contract).
|
||||
|
||||
use kb_config::Config;
|
||||
use kb_llm_local::{LanguageModel, OllamaLanguageModel};
|
||||
|
||||
#[test]
|
||||
fn construction_with_default_config_returns_expected_model_ref() {
|
||||
let cfg = Config::defaults();
|
||||
let llm = OllamaLanguageModel::new(&cfg).expect("construction should not hit network");
|
||||
let m = llm.model_ref();
|
||||
|
||||
assert_eq!(m.provider, "ollama");
|
||||
// Default model id from kb-config §6.4 — pinned here so a silent
|
||||
// default flip in kb-config is caught by this test.
|
||||
assert_eq!(m.id, cfg.models.llm.model);
|
||||
// Chat models have no embedding dimension (§3.8).
|
||||
assert_eq!(m.dimensions, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_tokens_returns_config_value() {
|
||||
let mut cfg = Config::defaults();
|
||||
cfg.models.llm.context_tokens = 16384;
|
||||
let llm = OllamaLanguageModel::new(&cfg).unwrap();
|
||||
assert_eq!(llm.context_tokens(), 16384);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn construction_does_not_require_a_running_ollama() {
|
||||
// Point the endpoint at a closed port. Construction must succeed —
|
||||
// the contract is "lazy connect on first generate_stream call".
|
||||
let mut cfg = Config::defaults();
|
||||
cfg.models.llm.endpoint = "http://127.0.0.1:1".to_string();
|
||||
let _llm = OllamaLanguageModel::new(&cfg).expect("new() must not hit the network");
|
||||
}
|
||||
53
crates/kb-llm-local/tests/integration.rs
Normal file
53
crates/kb-llm-local/tests/integration.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
//! Real-Ollama integration tests, gated behind `#[ignore]`.
|
||||
//!
|
||||
//! Run with:
|
||||
//!
|
||||
//! ```bash
|
||||
//! ollama serve & # if not already running
|
||||
//! ollama pull qwen2.5:7b-instruct
|
||||
//! cargo test -p kb-llm-local -- --ignored
|
||||
//! ```
|
||||
//!
|
||||
//! These hit `http://127.0.0.1:11434` directly and require an actual model
|
||||
//! pulled locally. CI runs default (non-ignored) tests only.
|
||||
|
||||
use kb_config::Config;
|
||||
use kb_core::{GenerateRequest, TokenChunk};
|
||||
use kb_llm_local::{LanguageModel, OllamaLanguageModel};
|
||||
|
||||
#[test]
|
||||
#[ignore = "requires a local Ollama daemon + pulled model"]
|
||||
fn real_ollama_streams_non_empty_response() {
|
||||
// Use whatever model the workspace defaults select. Override via the
|
||||
// KB_MODELS_LLM_MODEL env var if you want a different one for this run
|
||||
// (e.g. `KB_MODELS_LLM_MODEL=qwen2.5:7b-instruct cargo test ... -- --ignored`).
|
||||
let cfg = Config::load(None).expect("config should load");
|
||||
let llm = OllamaLanguageModel::new(&cfg).unwrap();
|
||||
|
||||
let req = GenerateRequest {
|
||||
system: "You are a terse assistant.".to_string(),
|
||||
user: "Say only the word 'ok'.".to_string(),
|
||||
stop: vec![],
|
||||
max_tokens: 8,
|
||||
temperature: 0.0,
|
||||
seed: Some(0),
|
||||
};
|
||||
|
||||
let stream = llm.generate_stream(req).expect("stream should start");
|
||||
let chunks: Vec<TokenChunk> = stream
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.expect("stream should not error");
|
||||
|
||||
let text: String = chunks
|
||||
.iter()
|
||||
.filter_map(|c| match c {
|
||||
TokenChunk::Token(t) => Some(t.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
assert!(!text.is_empty(), "expected non-empty completion");
|
||||
assert!(
|
||||
matches!(chunks.last(), Some(TokenChunk::Done { .. })),
|
||||
"stream must end with Done"
|
||||
);
|
||||
}
|
||||
472
crates/kb-llm-local/tests/streaming.rs
Normal file
472
crates/kb-llm-local/tests/streaming.rs
Normal file
@@ -0,0 +1,472 @@
|
||||
//! End-to-end streaming tests against a `wiremock`-hosted mock HTTP server.
|
||||
//!
|
||||
//! `wiremock` is async, so the test functions are `#[tokio::test]`; the
|
||||
//! adapter under test stays sync and is called from `spawn_blocking` to
|
||||
//! preserve the "no async runtime in the runtime crate" invariant. Tokio
|
||||
//! is a `dev-dependency` only — `cargo tree -p kb-llm-local --edges no-dev`
|
||||
//! must not list it.
|
||||
//!
|
||||
//! Each test pins one behavior from design §7.2 / §11.2: streaming order,
|
||||
//! error mapping, finish-reason mapping, missing-counter degradation, and
|
||||
//! determinism semantics.
|
||||
|
||||
use kb_config::Config;
|
||||
use kb_core::{FinishReason, GenerateRequest, TokenChunk};
|
||||
use kb_llm_local::{LanguageModel, LlmError, OllamaLanguageModel};
|
||||
use wiremock::matchers::{method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
|
||||
/// Build a `Config` whose `models.llm.endpoint` points at the wiremock
|
||||
/// server. Other fields are left at their `Config::defaults()` values so
|
||||
/// tests pin the same `model` id the production code will use.
|
||||
fn cfg_for_endpoint(endpoint: &str) -> Config {
|
||||
let mut cfg = Config::defaults();
|
||||
cfg.models.llm.endpoint = endpoint.to_string();
|
||||
// Keep model id stable for the ModelNotPulled test below.
|
||||
cfg.models.llm.model = "qwen2.5:7b-instruct".to_string();
|
||||
cfg
|
||||
}
|
||||
|
||||
fn sample_request() -> GenerateRequest {
|
||||
GenerateRequest {
|
||||
system: "you are a test".to_string(),
|
||||
user: "hello".to_string(),
|
||||
stop: vec![],
|
||||
max_tokens: 64,
|
||||
temperature: 0.0,
|
||||
seed: Some(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper: drive `generate_stream` to completion on a blocking thread so
|
||||
/// the sync `OllamaLanguageModel` stays off the async runtime.
|
||||
async fn collect_chunks(
|
||||
cfg: Config,
|
||||
req: GenerateRequest,
|
||||
) -> anyhow::Result<Vec<TokenChunk>> {
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<TokenChunk>> {
|
||||
let llm = OllamaLanguageModel::new(&cfg)?;
|
||||
let stream = llm.generate_stream(req)?;
|
||||
stream.collect::<Result<Vec<_>, _>>()
|
||||
})
|
||||
.await
|
||||
.expect("blocking task panicked")
|
||||
}
|
||||
|
||||
/// Same as `collect_chunks`, but returns the boxed `anyhow::Error` from
|
||||
/// `generate_stream` itself (rather than a stream-mid error). Used by the
|
||||
/// "unreachable endpoint" / "model not pulled" tests where the error
|
||||
/// surfaces on `.send()` before any chunks flow.
|
||||
async fn run_expecting_request_error(
|
||||
cfg: Config,
|
||||
req: GenerateRequest,
|
||||
) -> anyhow::Error {
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
|
||||
let llm = OllamaLanguageModel::new(&cfg)?;
|
||||
let _stream = llm.generate_stream(req)?;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.expect("blocking task panicked")
|
||||
.expect_err("expected generate_stream / new to return Err")
|
||||
}
|
||||
|
||||
// ── Happy path ────────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn streamed_response_produces_tokens_then_done() {
|
||||
let server = MockServer::start().await;
|
||||
let body = concat!(
|
||||
r#"{"response":"hi","done":false}"#, "\n",
|
||||
r#"{"response":" there","done":false}"#, "\n",
|
||||
r#"{"response":"","done":true,"done_reason":"stop","prompt_eval_count":3,"eval_count":2,"total_duration":1500000}"#, "\n",
|
||||
);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string(body))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let chunks = collect_chunks(cfg_for_endpoint(&server.uri()), sample_request())
|
||||
.await
|
||||
.expect("stream should complete");
|
||||
|
||||
assert_eq!(chunks.len(), 3, "expected 2 tokens + 1 done");
|
||||
assert!(matches!(&chunks[0], TokenChunk::Token(t) if t == "hi"));
|
||||
assert!(matches!(&chunks[1], TokenChunk::Token(t) if t == " there"));
|
||||
match &chunks[2] {
|
||||
TokenChunk::Done { finish_reason, usage } => {
|
||||
assert!(matches!(finish_reason, FinishReason::Stop));
|
||||
assert_eq!(usage.prompt_tokens, 3);
|
||||
assert_eq!(usage.completion_tokens, 2);
|
||||
// 1_500_000 ns / 1_000_000 = 1 ms.
|
||||
assert_eq!(usage.latency_ms, 1);
|
||||
}
|
||||
other => panic!("expected Done, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concat_of_streamed_tokens_equals_full_text() {
|
||||
let server = MockServer::start().await;
|
||||
let pieces = ["The ", "quick ", "brown ", "fox"];
|
||||
let mut body = String::new();
|
||||
for p in &pieces {
|
||||
body.push_str(&format!(r#"{{"response":"{p}","done":false}}"#));
|
||||
body.push('\n');
|
||||
}
|
||||
body.push_str(r#"{"response":"","done":true,"done_reason":"stop","prompt_eval_count":1,"eval_count":4,"total_duration":0}"#);
|
||||
body.push('\n');
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string(body))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let chunks = collect_chunks(cfg_for_endpoint(&server.uri()), sample_request())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let joined: String = chunks
|
||||
.iter()
|
||||
.filter_map(|c| match c {
|
||||
TokenChunk::Token(t) => Some(t.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
assert_eq!(joined, "The quick brown fox");
|
||||
}
|
||||
|
||||
// ── UTF-8 / Korean ────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn multibyte_chars_within_a_line_round_trip() {
|
||||
// The "split across HTTP chunks" concern in the spec is about
|
||||
// reqwest's transport-level chunk boundaries; for line-delimited
|
||||
// JSON, the BufReader's `read_until(b'\n')` accumulates until newline
|
||||
// regardless of HTTP chunk boundary, so the UTF-8 boundary issue is
|
||||
// moot for *complete* lines. This test verifies that multi-byte
|
||||
// payloads inside a single line round-trip correctly — covering the
|
||||
// common case where a Korean / Japanese / emoji token spans 3+ bytes.
|
||||
// (Test name is honest about scope: it does NOT exercise cross-HTTP
|
||||
// -chunk reassembly — that's structurally infeasible to set up given
|
||||
// the line-delimited framing.)
|
||||
let server = MockServer::start().await;
|
||||
let body = concat!(
|
||||
// "한국어" (Korean) — each char is 3 bytes in UTF-8.
|
||||
r#"{"response":"한국어","done":false}"#, "\n",
|
||||
// Followed by an emoji ZWJ sequence (4 bytes per scalar).
|
||||
r#"{"response":"🦀","done":false}"#, "\n",
|
||||
r#"{"response":"","done":true,"done_reason":"stop","prompt_eval_count":1,"eval_count":4,"total_duration":0}"#, "\n",
|
||||
);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string(body))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let chunks = collect_chunks(cfg_for_endpoint(&server.uri()), sample_request())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let joined: String = chunks
|
||||
.iter()
|
||||
.filter_map(|c| match c {
|
||||
TokenChunk::Token(t) => Some(t.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
assert_eq!(joined, "한국어🦀");
|
||||
}
|
||||
|
||||
// ── Error mapping ─────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn unreachable_endpoint_maps_to_unreachable_error() {
|
||||
// Port 1 is reserved (tcpmux) and almost never bound on a dev box —
|
||||
// a synchronous `connect` returns ECONNREFUSED immediately, which
|
||||
// reqwest reports as `is_connect()`.
|
||||
let mut cfg = Config::defaults();
|
||||
cfg.models.llm.endpoint = "http://127.0.0.1:1".to_string();
|
||||
|
||||
let err = run_expecting_request_error(cfg, sample_request()).await;
|
||||
let llm_err = err
|
||||
.downcast_ref::<LlmError>()
|
||||
.unwrap_or_else(|| panic!("expected LlmError, got: {err:?}"));
|
||||
match llm_err {
|
||||
LlmError::Unreachable { endpoint, .. } => {
|
||||
assert_eq!(endpoint, "http://127.0.0.1:1");
|
||||
}
|
||||
other => panic!("expected LlmError::Unreachable, got {other:?}"),
|
||||
}
|
||||
// The Display string MUST carry the actionable hint per §10.
|
||||
let rendered = format!("{err}");
|
||||
assert!(
|
||||
rendered.contains("ollama serve"),
|
||||
"missing actionable hint in: {rendered}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn model_not_found_maps_to_model_not_pulled() {
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(404).set_body_string(
|
||||
r#"{"error":"model 'qwen2.5:7b-instruct' not found, try pulling it first"}"#,
|
||||
))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let err = run_expecting_request_error(cfg_for_endpoint(&server.uri()), sample_request()).await;
|
||||
let llm_err = err
|
||||
.downcast_ref::<LlmError>()
|
||||
.unwrap_or_else(|| panic!("expected LlmError, got: {err:?}"));
|
||||
match llm_err {
|
||||
LlmError::ModelNotPulled(model) => assert_eq!(model, "qwen2.5:7b-instruct"),
|
||||
other => panic!("expected LlmError::ModelNotPulled, got {other:?}"),
|
||||
}
|
||||
let rendered = format!("{err}");
|
||||
assert!(
|
||||
rendered.contains("ollama pull qwen2.5:7b-instruct"),
|
||||
"missing pull hint in: {rendered}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn other_4xx_maps_to_stream_error() {
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(400).set_body_string("bad request"))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let err = run_expecting_request_error(cfg_for_endpoint(&server.uri()), sample_request()).await;
|
||||
let llm_err = err.downcast_ref::<LlmError>().expect("expected LlmError");
|
||||
assert!(
|
||||
matches!(llm_err, LlmError::Stream(_)),
|
||||
"expected Stream variant, got {llm_err:?}"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Finish-reason mapping ─────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn done_reason_length_maps_to_finish_reason_length() {
|
||||
let server = MockServer::start().await;
|
||||
let body = concat!(
|
||||
r#"{"response":"a","done":false}"#, "\n",
|
||||
r#"{"response":"","done":true,"done_reason":"length","prompt_eval_count":1,"eval_count":1,"total_duration":0}"#, "\n",
|
||||
);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string(body))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let chunks = collect_chunks(cfg_for_endpoint(&server.uri()), sample_request())
|
||||
.await
|
||||
.unwrap();
|
||||
match chunks.last().unwrap() {
|
||||
TokenChunk::Done { finish_reason, .. } => {
|
||||
assert!(matches!(finish_reason, FinishReason::Length));
|
||||
}
|
||||
other => panic!("expected Done, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn done_reason_abort_maps_to_finish_reason_aborted() {
|
||||
let server = MockServer::start().await;
|
||||
let body = concat!(
|
||||
r#"{"response":"a","done":false}"#, "\n",
|
||||
r#"{"response":"","done":true,"done_reason":"abort","prompt_eval_count":1,"eval_count":1,"total_duration":0}"#, "\n",
|
||||
);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string(body))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let chunks = collect_chunks(cfg_for_endpoint(&server.uri()), sample_request())
|
||||
.await
|
||||
.unwrap();
|
||||
match chunks.last().unwrap() {
|
||||
TokenChunk::Done { finish_reason, .. } => {
|
||||
assert!(matches!(finish_reason, FinishReason::Aborted));
|
||||
}
|
||||
other => panic!("expected Done, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Resilience ────────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_eval_counts_default_to_zero() {
|
||||
// Older Ollama (< ~0.1.40) sometimes omitted prompt_eval_count /
|
||||
// eval_count entirely. Per §10 we degrade gracefully + warn rather
|
||||
// than failing the stream. Test asserts the zero default; the warn
|
||||
// is observed only via tracing-subscriber, which we do not wire up
|
||||
// here — the comment documents the intent.
|
||||
let server = MockServer::start().await;
|
||||
let body = concat!(
|
||||
r#"{"response":"hi","done":false}"#, "\n",
|
||||
// No prompt_eval_count / eval_count / total_duration.
|
||||
r#"{"response":"","done":true,"done_reason":"stop"}"#, "\n",
|
||||
);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string(body))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let chunks = collect_chunks(cfg_for_endpoint(&server.uri()), sample_request())
|
||||
.await
|
||||
.unwrap();
|
||||
match chunks.last().unwrap() {
|
||||
TokenChunk::Done { usage, .. } => {
|
||||
assert_eq!(usage.prompt_tokens, 0);
|
||||
assert_eq!(usage.completion_tokens, 0);
|
||||
assert_eq!(usage.latency_ms, 0);
|
||||
}
|
||||
other => panic!("expected Done, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_done_reason_defaults_to_stop() {
|
||||
let server = MockServer::start().await;
|
||||
let body = concat!(
|
||||
r#"{"response":"hi","done":false}"#, "\n",
|
||||
// Final frame omits done_reason entirely.
|
||||
r#"{"response":"","done":true,"prompt_eval_count":1,"eval_count":1,"total_duration":0}"#, "\n",
|
||||
);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string(body))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let chunks = collect_chunks(cfg_for_endpoint(&server.uri()), sample_request())
|
||||
.await
|
||||
.unwrap();
|
||||
match chunks.last().unwrap() {
|
||||
TokenChunk::Done { finish_reason, .. } => {
|
||||
assert!(matches!(finish_reason, FinishReason::Stop));
|
||||
}
|
||||
other => panic!("expected Done, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Non-NDJSON 200 body ───────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn non_ndjson_200_body_maps_to_stream_not_malformed() {
|
||||
// Misrouted reverse proxy returning a 200 with an HTML error page is
|
||||
// the canonical case: status code says "ok", body is nowhere near
|
||||
// NDJSON. Per §10 taxonomy the first-line parse failure on such a
|
||||
// response surfaces as `LlmError::Stream`, not `Malformed`
|
||||
// ("Malformed" is reserved for mid-stream corruption after at least
|
||||
// one valid NDJSON line).
|
||||
let server = MockServer::start().await;
|
||||
let html = "<html><body><h1>500 Internal Server Error</h1></body></html>";
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string(html))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let chunks = collect_chunks(cfg_for_endpoint(&server.uri()), sample_request()).await;
|
||||
let err = chunks.expect_err("expected the iterator to surface an error");
|
||||
let llm_err = err
|
||||
.downcast_ref::<LlmError>()
|
||||
.unwrap_or_else(|| panic!("expected LlmError, got: {err:?}"));
|
||||
assert!(
|
||||
matches!(llm_err, LlmError::Stream(_)),
|
||||
"first-line non-NDJSON should be Stream, got {llm_err:?}",
|
||||
);
|
||||
}
|
||||
|
||||
// ── Endpoint URL handling ─────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn endpoint_with_trailing_slash_does_not_double_slash() {
|
||||
// The adapter does `format!("{}/api/generate", endpoint.trim_end_matches('/'))`,
|
||||
// so an endpoint configured with a trailing slash must still resolve
|
||||
// to a single-slash URL. Two layers of evidence:
|
||||
// 1. The wiremock matcher `path("/api/generate")` would NOT match a
|
||||
// request to `//api/generate`, so a successful response itself
|
||||
// proves the URL is correctly normalized.
|
||||
// 2. We additionally inspect `MockServer::received_requests()` and
|
||||
// assert the recorded `Request::url` path is exactly
|
||||
// `/api/generate` — pinning the invariant explicitly so a future
|
||||
// regression that "works" via a different mismatch would still
|
||||
// fail the assertion.
|
||||
let server = MockServer::start().await;
|
||||
let body = concat!(
|
||||
r#"{"response":"ok","done":false}"#, "\n",
|
||||
r#"{"response":"","done":true,"done_reason":"stop","prompt_eval_count":1,"eval_count":1,"total_duration":0}"#, "\n",
|
||||
);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string(body))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Append the trailing slash to the wiremock URI.
|
||||
let endpoint_with_slash = format!("{}/", server.uri());
|
||||
let cfg = cfg_for_endpoint(&endpoint_with_slash);
|
||||
|
||||
let chunks = collect_chunks(cfg, sample_request())
|
||||
.await
|
||||
.expect("stream should complete despite trailing slash on endpoint");
|
||||
// Smoke-check: we got the canned tokens — proves matcher (1) above.
|
||||
assert!(matches!(&chunks[0], TokenChunk::Token(t) if t == "ok"));
|
||||
|
||||
// Evidence (2): inspect the recorded request URL.
|
||||
let recorded = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("wiremock should record requests by default");
|
||||
assert_eq!(recorded.len(), 1, "expected exactly one request");
|
||||
let url = &recorded[0].url;
|
||||
assert_eq!(
|
||||
url.path(),
|
||||
"/api/generate",
|
||||
"request path should be exactly /api/generate (single slash), got {url}",
|
||||
);
|
||||
}
|
||||
|
||||
// ── Determinism ───────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn determinism_seed_zero_temp_zero_two_runs_identical() {
|
||||
// Determinism test against a *mock* — wiremock just replays the canned
|
||||
// response so byte-equality is trivially satisfied. The point of the
|
||||
// test is to lock in the request shape: when `temperature == 0` and a
|
||||
// fixed seed are sent, we expect identical client-observed output.
|
||||
// Real-Ollama determinism is asserted in `tests/integration.rs`
|
||||
// (#[ignore]) where reproducibility is modulo model-internal nondet.
|
||||
let server = MockServer::start().await;
|
||||
let body = concat!(
|
||||
r#"{"response":"deterministic","done":false}"#, "\n",
|
||||
r#"{"response":"","done":true,"done_reason":"stop","prompt_eval_count":1,"eval_count":1,"total_duration":0}"#, "\n",
|
||||
);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string(body))
|
||||
// expect 2 calls so wiremock does not reset between them
|
||||
.expect(2)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let cfg = cfg_for_endpoint(&server.uri());
|
||||
let req1 = sample_request();
|
||||
let req2 = sample_request();
|
||||
|
||||
let chunks_a = collect_chunks(cfg.clone(), req1).await.unwrap();
|
||||
let chunks_b = collect_chunks(cfg, req2).await.unwrap();
|
||||
|
||||
assert_eq!(chunks_a, chunks_b);
|
||||
}
|
||||
Reference in New Issue
Block a user