feat(embed): kebab-embed-ollama 신규 크레이트 — Ollama /api/embed Embedder

arctic-embed-l-v2.0 의 폴백 백엔드(측정에 쓴 경로 그대로). reqwest::blocking
POST {endpoint}/api/embed {model,input:[...]} → embeddings. batch 48 +
fail-soft 재시도 3, 결과 L2 정규화(Ollama raw 반환 → 일관성), dim 검증.
query/doc prefix 는 모델 태그로 추론(arctic-embed→query:/무접두어, e5→query:/passage:).
model_version=ollama:{model}. endpoint=models.embedding.endpoint ?? models.llm.endpoint.
wiremock 테스트 3종(L2 정규화/dim mismatch/empty no-op) + 단위 5.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-03 04:59:11 +00:00
parent e2ae9a4589
commit 7505645008
3 changed files with 439 additions and 0 deletions

View File

@@ -0,0 +1,30 @@
[package]
name = "kebab-embed-ollama"
version = { workspace = true }
edition = { workspace = true }
rust-version = { workspace = true }
license = { workspace = true }
repository = { workspace = true }
description = "Ollama HTTP adapter implementing kebab_core::Embedder (POST /api/embed, L2-normalized, batched + fail-soft)"
[dependencies]
kebab-core = { path = "../kebab-core" }
kebab-config = { path = "../kebab-config" }
# `default-features = false` drops native-tls (system OpenSSL); we pin rustls.
# reqwest 0.12's `blocking` feature wraps a private current-thread tokio
# runtime — this crate exposes NO async surface (no `async`/`await`/`tokio::*`
# symbols), matching the kebab-llm-local invariant.
reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tracing = { workspace = true }
anyhow = { workspace = true }
[dev-dependencies]
# wiremock hosts the mock /api/embed server (needs a tokio runtime); tokio is
# also pulled transitively at runtime by reqwest's `blocking` feature.
wiremock = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt"] }
[lints]
workspace = true

View File

@@ -0,0 +1,310 @@
//! `kebab-embed-ollama` — [`OllamaEmbedder`], a `reqwest::blocking` adapter
//! implementing [`Embedder`](kebab_core::Embedder) over Ollama's
//! `POST /api/embed` endpoint.
//!
//! ## Why this exists
//!
//! The candle backend ([`kebab-embed-candle`]) runs arctic-embed-l-v2.0
//! in-process (pure Rust, NUMA-safe). This crate is the **fallback** path:
//! it offloads embedding to a local/remote Ollama daemon (`snowflake-arctic-embed2`),
//! which is exactly the route the recall measurements used — so it reproduces
//! the measured numbers (recall@10 130/132) byte-for-route. Opt-in via
//! `config.models.embedding.provider = "ollama"`.
//!
//! ## Wire shape
//!
//! Request (`POST {endpoint}/api/embed`):
//!
//! ```json
//! { "model": "snowflake-arctic-embed2", "input": ["query: 스택", "후입선출 ..."] }
//! ```
//!
//! Response:
//!
//! ```json
//! { "model": "...", "embeddings": [[0.01, ...], [0.02, ...]] }
//! ```
//!
//! ## Pipeline
//!
//! 1. instruction prefix per model ([`prefixes_for`] — arctic: `query: ` on
//! queries, no prefix on documents; e5: `query: `/`passage: `);
//! 2. batch into `BATCH` (48) inputs per request;
//! 3. `POST /api/embed`, with fail-soft retry (`MAX_RETRIES`);
//! 4. **L2 normalize** each returned vector — Ollama returns raw (un-normalized)
//! embeddings, so we normalize for cosine consistency with the candle path;
//! 5. dim check against `config.models.embedding.dimensions`.
//!
//! ## Send-safety
//!
//! `reqwest::blocking::Client: Send + Sync`; the adapter holds only the client,
//! an endpoint string, and small config scalars, so it is trivially `Send + Sync`
//! as the [`Embedder`] trait requires.
use std::time::Duration;
use anyhow::{Context, Result};
use kebab_core::{Embedder, EmbeddingInput, EmbeddingKind, EmbeddingModelId, EmbeddingVersion};
use serde::{Deserialize, Serialize};
/// Inputs per `/api/embed` request. Ollama handles arbitrary batch sizes, but
/// a cap keeps a single HTTP body bounded and lets a partial failure retry a
/// smaller unit.
const BATCH: usize = 48;
/// Fail-soft retry attempts per batch before the error propagates. Cold model
/// load on the Ollama side can transiently 500/timeout; a couple of retries
/// smooth that over without masking a hard misconfiguration.
const MAX_RETRIES: u32 = 3;
/// Default per-request HTTP timeout (seconds). Cold-loading an embedding model
/// on first call can take tens of seconds; this matches the generous default
/// used by the LLM adapter.
const REQUEST_TIMEOUT_SECS: u64 = 300;
/// Resolve the (query_prefix, doc_prefix) for an Ollama embedding model tag.
///
/// Mirrors `kebab-embed-candle`'s `MODEL_REGISTRY`, but keyed on the **Ollama
/// model tag** (which differs from the HF id — e.g. `snowflake-arctic-embed2`
/// vs `Snowflake/snowflake-arctic-embed-l-v2.0`). Kept here rather than shared
/// so this crate does not depend on the candle backend.
///
/// An unrecognized model gets no prefix (`("", "")`): many embedding models
/// are not instruction-tuned, so embedding the raw text is the correct default
/// — and a misspelled known model surfaces as a recall regression, not a silent
/// wrong-prefix, because the dim check still passes either way.
fn prefixes_for(model: &str) -> (&'static str, &'static str) {
let m = model.to_ascii_lowercase();
if m.contains("arctic-embed") {
// arctic-embed v2.0: `query: ` on queries, documents embedded raw.
("query: ", "")
} else if m.contains("e5") {
// multilingual-e5: `query: ` / `passage: `.
("query: ", "passage: ")
} else {
("", "")
}
}
/// `reqwest::blocking` adapter implementing [`Embedder`] over Ollama's
/// `/api/embed`. Construction is offline; the first network call happens in
/// [`Embedder::embed`].
pub struct OllamaEmbedder {
client: reqwest::blocking::Client,
/// Validated endpoint base (e.g. `"http://127.0.0.1:11434"`).
endpoint: String,
/// Ollama model tag (e.g. `"snowflake-arctic-embed2"`).
model: String,
query_prefix: &'static str,
doc_prefix: &'static str,
model_id: EmbeddingModelId,
version: EmbeddingVersion,
dimensions: usize,
}
impl OllamaEmbedder {
/// Build from a workspace [`kebab_config::Config`]. Reads
/// `config.models.embedding.{model, dimensions}` and resolves the endpoint
/// as `models.embedding.endpoint` → fallback `models.llm.endpoint`.
///
/// Does NOT touch the network. The caller (app layer) is expected to have
/// validated `provider == "ollama"`.
pub fn new(config: &kebab_config::Config) -> Result<Self> {
let emb = &config.models.embedding;
let endpoint = emb
.endpoint
.clone()
.filter(|e| !e.is_empty())
.unwrap_or_else(|| config.models.llm.endpoint.clone());
if endpoint.is_empty() {
anyhow::bail!(
"ollama embedding provider needs an endpoint: set \
`models.embedding.endpoint` (or `models.llm.endpoint`)"
);
}
let client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
.build()
.context("kb-embed-ollama: build reqwest client")?;
let (query_prefix, doc_prefix) = prefixes_for(&emb.model);
Ok(Self {
client,
endpoint,
model: emb.model.clone(),
query_prefix,
doc_prefix,
model_id: EmbeddingModelId(emb.model.clone()),
// model_version = `ollama:{model}` so a provider/model switch
// triggers the embedding_version cascade and never collides with
// the candle path's version string for the same model.
version: EmbeddingVersion(format!("ollama:{}", emb.model)),
dimensions: emb.dimensions,
})
}
/// Embed one already-prefixed batch via `/api/embed`, with fail-soft retry.
fn embed_batch(&self, prefixed: &[String]) -> Result<Vec<Vec<f32>>> {
let url = format!("{}/api/embed", self.endpoint.trim_end_matches('/'));
let body = EmbedRequest {
model: &self.model,
input: prefixed,
};
let mut last_err: Option<anyhow::Error> = None;
for attempt in 1..=MAX_RETRIES {
match self.try_once(&url, &body) {
Ok(resp) => return self.finalize(resp, prefixed.len()),
Err(e) => {
tracing::warn!(
target: "kebab-embed-ollama",
attempt,
max = MAX_RETRIES,
error = %e,
"ollama /api/embed attempt failed; retrying"
);
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| {
anyhow::anyhow!("kb-embed-ollama: all {MAX_RETRIES} attempts failed")
}))
}
/// One HTTP round-trip. Network / non-2xx / decode errors all map to
/// `Err` so the retry loop can decide.
fn try_once(&self, url: &str, body: &EmbedRequest<'_>) -> Result<EmbedResponse> {
let resp = self
.client
.post(url)
.json(body)
.send()
.with_context(|| format!("kb-embed-ollama: POST {url}"))?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().unwrap_or_default();
anyhow::bail!("kb-embed-ollama: /api/embed returned {status}: {text}");
}
resp.json::<EmbedResponse>()
.context("kb-embed-ollama: decode /api/embed response")
}
/// Validate count + dim, then L2-normalize each vector.
fn finalize(&self, resp: EmbedResponse, expected: usize) -> Result<Vec<Vec<f32>>> {
if resp.embeddings.len() != expected {
anyhow::bail!(
"kb-embed-ollama: expected {expected} embeddings, got {}",
resp.embeddings.len()
);
}
let mut out = Vec::with_capacity(resp.embeddings.len());
for v in resp.embeddings {
if v.len() != self.dimensions {
anyhow::bail!(
"kb-embed-ollama: model returned dim {} but config expects {} \
(check models.embedding.dimensions vs the Ollama model)",
v.len(),
self.dimensions
);
}
out.push(l2_normalize(v));
}
Ok(out)
}
}
impl Embedder for OllamaEmbedder {
fn model_id(&self) -> EmbeddingModelId {
self.model_id.clone()
}
fn model_version(&self) -> EmbeddingVersion {
self.version.clone()
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn embed(&self, inputs: &[EmbeddingInput<'_>]) -> Result<Vec<Vec<f32>>> {
if inputs.is_empty() {
return Ok(Vec::new());
}
let prefixed: Vec<String> = inputs.iter().map(|i| self.prefix(i)).collect();
let mut out = Vec::with_capacity(prefixed.len());
for chunk in prefixed.chunks(BATCH) {
out.extend(self.embed_batch(chunk)?);
}
debug_assert_eq!(out.len(), inputs.len());
Ok(out)
}
}
impl OllamaEmbedder {
/// Prefix one input per the resolved model prefixes.
fn prefix(&self, input: &EmbeddingInput<'_>) -> String {
match input.kind {
EmbeddingKind::Document => format!("{}{}", self.doc_prefix, input.text),
EmbeddingKind::Query => format!("{}{}", self.query_prefix, input.text),
}
}
}
/// L2-normalize a vector in place-ish (consumes + returns). A zero vector is
/// returned unchanged (norm 0 → no division) so a degenerate embedding can
/// never produce NaNs.
fn l2_normalize(mut v: Vec<f32>) -> Vec<f32> {
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut v {
*x /= norm;
}
}
v
}
// ── Wire types ──────────────────────────────────────────────────────────────
#[derive(Serialize)]
struct EmbedRequest<'a> {
model: &'a str,
input: &'a [String],
}
#[derive(Deserialize)]
struct EmbedResponse {
embeddings: Vec<Vec<f32>>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prefixes_for_arctic_is_query_only() {
assert_eq!(prefixes_for("snowflake-arctic-embed2"), ("query: ", ""));
assert_eq!(prefixes_for("snowflake-arctic-embed2:latest"), ("query: ", ""));
}
#[test]
fn prefixes_for_e5_is_query_passage() {
assert_eq!(prefixes_for("multilingual-e5-large"), ("query: ", "passage: "));
}
#[test]
fn prefixes_for_unknown_is_bare() {
assert_eq!(prefixes_for("nomic-embed-text"), ("", ""));
}
#[test]
fn l2_normalize_unit_length() {
let v = l2_normalize(vec![3.0, 4.0]);
let norm = (v[0] * v[0] + v[1] * v[1]).sqrt();
assert!((norm - 1.0).abs() < 1e-6, "norm = {norm}");
}
#[test]
fn l2_normalize_zero_vector_is_unchanged() {
assert_eq!(l2_normalize(vec![0.0, 0.0, 0.0]), vec![0.0, 0.0, 0.0]);
}
}

View File

@@ -0,0 +1,99 @@
//! `/api/embed` behavior against a `wiremock`-hosted mock server.
//!
//! `wiremock` is async, so the tests are `#[tokio::test]`; the sync
//! [`OllamaEmbedder`] is driven from `spawn_blocking` to keep `reqwest::blocking`
//! off the async runtime (same pattern as `kebab-llm-local`'s streaming tests).
//! tokio is a `dev-dependency` only.
use kebab_config::Config;
use kebab_core::{Embedder, EmbeddingInput, EmbeddingKind};
use kebab_embed_ollama::OllamaEmbedder;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
/// Config pointing at the mock server, with a small dim so the mock body is
/// tiny. `model` is an arctic tag so prefix resolution is exercised.
fn cfg_for(endpoint: &str, dim: usize) -> Config {
let mut cfg = Config::defaults();
cfg.models.embedding.provider = "ollama".to_string();
cfg.models.embedding.model = "snowflake-arctic-embed2".to_string();
cfg.models.embedding.dimensions = dim;
cfg.models.embedding.endpoint = Some(endpoint.to_string());
cfg
}
async fn embed_blocking(
cfg: Config,
inputs: Vec<(String, EmbeddingKind)>,
) -> anyhow::Result<Vec<Vec<f32>>> {
tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<Vec<f32>>> {
let emb = OllamaEmbedder::new(&cfg)?;
let refs: Vec<EmbeddingInput<'_>> = inputs
.iter()
.map(|(t, k)| EmbeddingInput { text: t, kind: *k })
.collect();
emb.embed(&refs)
})
.await
.expect("blocking task panicked")
}
#[tokio::test]
async fn embed_returns_l2_normalized_vectors() {
let server = MockServer::start().await;
// Two raw (un-normalized) vectors of dim 2; the adapter must L2-normalize.
let body = r#"{"model":"snowflake-arctic-embed2","embeddings":[[3.0,4.0],[0.0,5.0]]}"#;
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(ResponseTemplate::new(200).set_body_string(body))
.mount(&server)
.await;
let out = embed_blocking(
cfg_for(&server.uri(), 2),
vec![
("스택 자료구조".to_string(), EmbeddingKind::Query),
("후입선출".to_string(), EmbeddingKind::Document),
],
)
.await
.expect("embed should succeed");
assert_eq!(out.len(), 2);
for v in &out {
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "expected unit norm, got {norm}");
}
// [3,4] → [0.6, 0.8].
assert!((out[0][0] - 0.6).abs() < 1e-5 && (out[0][1] - 0.8).abs() < 1e-5);
}
#[tokio::test]
async fn embed_rejects_dim_mismatch() {
let server = MockServer::start().await;
// Server returns dim 3, config expects dim 2 → hard error.
let body = r#"{"model":"snowflake-arctic-embed2","embeddings":[[1.0,2.0,3.0]]}"#;
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(ResponseTemplate::new(200).set_body_string(body))
.mount(&server)
.await;
let err = embed_blocking(
cfg_for(&server.uri(), 2),
vec![("q".to_string(), EmbeddingKind::Query)],
)
.await
.expect_err("dim mismatch must error");
let msg = format!("{err:#}");
assert!(msg.contains("dim"), "expected dim error, got: {msg}");
}
#[tokio::test]
async fn embed_empty_input_is_noop() {
// No mock needed — empty input must never hit the network.
let out = embed_blocking(cfg_for("http://127.0.0.1:1", 2), vec![])
.await
.expect("empty embed should be Ok(empty)");
assert!(out.is_empty());
}