diff --git a/src/main/ai/LocalOllamaProvider.ts b/src/main/ai/LocalOllamaProvider.ts index ea43858..9e117b0 100644 --- a/src/main/ai/LocalOllamaProvider.ts +++ b/src/main/ai/LocalOllamaProvider.ts @@ -1,7 +1,8 @@ import { request } from 'undici'; import { parseAiResponse, type AiResponse } from './schema.js'; import { buildPrompt } from './prompt.js'; -import type { GenerateInput, HealthResult, InferenceProvider } from './InferenceProvider.js'; +import { buildVisionPrompt } from './visionPrompt.js'; +import type { GenerateInput, GenerateOptions, HealthResult, InferenceProvider } from './InferenceProvider.js'; import { DEFAULT_OLLAMA_ENDPOINT, DEFAULT_OLLAMA_MODEL } from '../../shared/constants.js'; export interface LocalOllamaOptions { @@ -30,29 +31,39 @@ export class LocalOllamaProvider implements InferenceProvider { this.name = `local-ollama/${this.model}`; } - async generate(input: GenerateInput): Promise { + async generate(input: GenerateInput, opts?: GenerateOptions): Promise { + const useVision = !!opts?.visionModel && (input.images?.length ?? 0) > 0; + const model = useVision ? opts!.visionModel! : this.model; + const prompt = useVision + ? buildVisionPrompt(input.text, input.todayKst, input.dueDateCandidates.map((c) => c.iso ?? c.matchedToken ?? ''), input.vocab ?? []) + : buildPrompt(input.text, input.todayKst, input.dueDateCandidates, input.vocab ?? []); + this.abortController = new AbortController(); const timer = setTimeout(() => this.abortController?.abort(), this.timeoutMs); try { + const body: Record = { + model, + prompt, + format: 'json', + stream: false, + options: { temperature: this.temperature, num_predict: this.numPredict } + }; + if (useVision) { + body.images = input.images!.map((i) => i.base64); + } const res = await request(`${this.endpoint}/api/generate`, { method: 'POST', headers: { 'content-type': 'application/json' }, - body: JSON.stringify({ - model: this.model, - prompt: buildPrompt(input.text, input.todayKst, input.dueDateCandidates, input.vocab ?? []), - format: 'json', - stream: false, - options: { temperature: this.temperature, num_predict: this.numPredict } - }), + body: JSON.stringify(body), signal: this.abortController.signal }); if (res.statusCode < 200 || res.statusCode >= 300) { throw new Error(`ollama http ${res.statusCode}`); } - const body = (await res.body.json()) as { response?: string }; - if (!body.response) throw new Error('missing response field'); + const responseBody = (await res.body.json()) as { response?: string }; + if (!responseBody.response) throw new Error('missing response field'); let parsed: unknown; - try { parsed = JSON.parse(body.response); } + try { parsed = JSON.parse(responseBody.response); } catch (err) { throw new Error(`invalid json in response: ${String(err)}`); } return parseAiResponse(parsed); } finally { diff --git a/tests/unit/LocalOllamaProvider.test.ts b/tests/unit/LocalOllamaProvider.test.ts index eea7c5c..f254bb5 100644 --- a/tests/unit/LocalOllamaProvider.test.ts +++ b/tests/unit/LocalOllamaProvider.test.ts @@ -109,4 +109,58 @@ describe('LocalOllamaProvider', () => { const provider = new LocalOllamaProvider({ model: 'gemma4:26b' }); expect(provider.name).toBe('local-ollama/gemma4:26b'); }); + + describe('vision path (v0.3.1 Cut F)', () => { + it('visionModel + images → body.images + model=visionModel + buildVisionPrompt', async () => { + let capturedBody: string = ''; + mock.get('http://x').intercept({ path: '/api/generate', method: 'POST' }).reply((opts) => { + capturedBody = opts.body as string; + return { statusCode: 200, data: JSON.stringify({ + response: JSON.stringify({ title: '비전테스트', summary: 'a\nb\nc', tags: [], due_date: null }) + }) }; + }); + const provider = new LocalOllamaProvider({ endpoint: 'http://x', model: 'gemma4:e4b' }); + await provider.generate( + { text: 'hi', todayKst: '2026-05-10', dueDateCandidates: [], images: [{ base64: 'AAAA', mime: 'image/png' }] }, + { visionModel: 'gemma3:12b-vision' } + ); + const parsed = JSON.parse(capturedBody) as { model: string; prompt: string; images?: string[] }; + expect(parsed.model).toBe('gemma3:12b-vision'); + expect(parsed.prompt).toContain('이미지'); + expect(parsed.images).toEqual(['AAAA']); + }); + + it('visionModel 있어도 images 없으면 text-only (model = this.model, no body.images)', async () => { + let capturedBody: string = ''; + mock.get('http://x').intercept({ path: '/api/generate', method: 'POST' }).reply((opts) => { + capturedBody = opts.body as string; + return { statusCode: 200, data: JSON.stringify({ + response: JSON.stringify({ title: '텍스트전용', summary: 'a\nb\nc', tags: [], due_date: null }) + }) }; + }); + const provider = new LocalOllamaProvider({ endpoint: 'http://x', model: 'gemma4:e4b' }); + await provider.generate( + { text: 'hi', todayKst: '2026-05-10', dueDateCandidates: [] }, + { visionModel: 'gemma3:12b-vision' } + ); + const parsed = JSON.parse(capturedBody) as { model: string; images?: string[] }; + expect(parsed.model).toBe('gemma4:e4b'); + expect(parsed.images).toBeUndefined(); + }); + + it('opts 미전달 → 기존 text-only (회귀)', async () => { + let capturedBody: string = ''; + mock.get('http://x').intercept({ path: '/api/generate', method: 'POST' }).reply((opts) => { + capturedBody = opts.body as string; + return { statusCode: 200, data: JSON.stringify({ + response: JSON.stringify({ title: '기본텍스트', summary: 'a\nb\nc', tags: [], due_date: null }) + }) }; + }); + const provider = new LocalOllamaProvider({ endpoint: 'http://x', model: 'gemma4:e4b' }); + await provider.generate({ text: 'hi', todayKst: '2026-05-10', dueDateCandidates: [] }); + const parsed = JSON.parse(capturedBody) as { model: string; images?: string[] }; + expect(parsed.model).toBe('gemma4:e4b'); + expect(parsed.images).toBeUndefined(); + }); + }); });