diff --git a/src/main/services/VisionDetect.ts b/src/main/services/VisionDetect.ts new file mode 100644 index 0000000..12adb91 --- /dev/null +++ b/src/main/services/VisionDetect.ts @@ -0,0 +1,44 @@ +import type { SettingsService } from './SettingsService.js'; + +const VISION_FAMILIES = new Set(['gemma3', 'llava', 'llama3.2-vision', 'minicpm-v', 'pixtral']); +const VISION_NAME_HINTS = ['vision', 'vl', 'multimodal', 'gemma3']; + +export interface OllamaModel { + name: string; + details?: { family?: string; families?: string[] }; +} + +export function isVisionCapable(model: OllamaModel): boolean { + if (model.details?.family && VISION_FAMILIES.has(model.details.family)) return true; + if (model.details?.families?.some((f) => VISION_FAMILIES.has(f))) return true; + const lower = model.name.toLowerCase(); + return VISION_NAME_HINTS.some((h) => lower.includes(h)); +} + +export interface RefreshDeps { + settings: SettingsService; + endpoint: string; + now?: () => Date; + fetchImpl?: typeof fetch; +} + +export async function refreshVisionCache( + deps: RefreshDeps +): Promise<{ ok: true; models: string[] } | { ok: false; reason: string }> { + if (!(await deps.settings.isAiEnabled())) { + return { ok: false, reason: 'ai_disabled' }; + } + const fetchFn = deps.fetchImpl ?? fetch; + let body: { models?: OllamaModel[] }; + try { + const r = await fetchFn(`${deps.endpoint}/api/tags`); + if (!r.ok) return { ok: false, reason: `tags http ${r.status}` }; + body = (await r.json()) as { models?: OllamaModel[] }; + } catch (e) { + return { ok: false, reason: `unreachable: ${(e as Error).message}` }; + } + const capable = (body.models ?? []).filter(isVisionCapable).map((m) => m.name); + const now = deps.now ? deps.now() : new Date(); + await deps.settings.setVisionCapableCache(capable, now); + return { ok: true, models: capable }; +} diff --git a/tests/unit/VisionDetect.test.ts b/tests/unit/VisionDetect.test.ts new file mode 100644 index 0000000..c2bcf1c --- /dev/null +++ b/tests/unit/VisionDetect.test.ts @@ -0,0 +1,110 @@ +import { describe, it, expect, vi } from 'vitest'; +import { isVisionCapable, refreshVisionCache } from '@main/services/VisionDetect.js'; +import type { OllamaModel } from '@main/services/VisionDetect.js'; + +// --------------------------------------------------------------------------- +// isVisionCapable +// --------------------------------------------------------------------------- +describe('isVisionCapable', () => { + it('returns true when details.family is in VISION_FAMILIES', () => { + const model: OllamaModel = { name: 'some-model', details: { family: 'llava' } }; + expect(isVisionCapable(model)).toBe(true); + }); + + it('returns true when details.families contains a vision family', () => { + const model: OllamaModel = { name: 'some-model', details: { families: ['text', 'minicpm-v'] } }; + expect(isVisionCapable(model)).toBe(true); + }); + + it('returns true when name contains a vision hint (case-insensitive)', () => { + const model: OllamaModel = { name: 'My-Vision-Model:latest' }; + expect(isVisionCapable(model)).toBe(true); + }); + + it('returns true when name contains "vl" hint', () => { + const model: OllamaModel = { name: 'qwen2-vl:7b' }; + expect(isVisionCapable(model)).toBe(true); + }); + + it('returns false for a plain text model with no vision signals', () => { + const model: OllamaModel = { name: 'gemma2:9b', details: { family: 'gemma', families: ['gemma'] } }; + expect(isVisionCapable(model)).toBe(false); + }); +}); + +// --------------------------------------------------------------------------- +// refreshVisionCache +// --------------------------------------------------------------------------- +describe('refreshVisionCache', () => { + function makeSettings(overrides: Partial<{ + isAiEnabled: boolean; + setCalled: { models: string[]; at: Date } | null; + }> = {}) { + const setCalled: { models: string[]; at: Date } | null = null; + const settings = { + isAiEnabled: vi.fn().mockResolvedValue(overrides.isAiEnabled ?? true), + setVisionCapableCache: vi.fn().mockImplementation(async () => undefined), + }; + return settings; + } + + it('returns ok:false with reason "ai_disabled" when AI is off', async () => { + const settings = makeSettings({ isAiEnabled: false }); + const result = await refreshVisionCache({ + settings: settings as never, + endpoint: 'http://localhost:11434', + }); + expect(result).toEqual({ ok: false, reason: 'ai_disabled' }); + expect(settings.setVisionCapableCache).not.toHaveBeenCalled(); + }); + + it('returns ok:false with http reason on non-ok response', async () => { + const settings = makeSettings(); + const fetchImpl = vi.fn().mockResolvedValue({ ok: false, status: 503 }); + const result = await refreshVisionCache({ + settings: settings as never, + endpoint: 'http://localhost:11434', + fetchImpl: fetchImpl as never, + }); + expect(result).toEqual({ ok: false, reason: 'tags http 503' }); + }); + + it('returns ok:false with unreachable reason on fetch throw', async () => { + const settings = makeSettings(); + const fetchImpl = vi.fn().mockRejectedValue(new Error('ECONNREFUSED')); + const result = await refreshVisionCache({ + settings: settings as never, + endpoint: 'http://localhost:11434', + fetchImpl: fetchImpl as never, + }); + expect(result.ok).toBe(false); + if (!result.ok) expect(result.reason).toMatch(/unreachable/); + }); + + it('filters vision-capable models, persists cache, returns ok:true + models', async () => { + const settings = makeSettings(); + const fixedNow = new Date('2026-05-09T00:00:00.000Z'); + const responseBody = { + models: [ + { name: 'llava:13b', details: { family: 'llava' } }, + { name: 'gemma2:9b', details: { family: 'gemma' } }, + { name: 'qwen2-vl:7b' }, + ], + }; + const fetchImpl = vi.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve(responseBody), + }); + const result = await refreshVisionCache({ + settings: settings as never, + endpoint: 'http://localhost:11434', + fetchImpl: fetchImpl as never, + now: () => fixedNow, + }); + expect(result).toEqual({ ok: true, models: ['llava:13b', 'qwen2-vl:7b'] }); + expect(settings.setVisionCapableCache).toHaveBeenCalledWith( + ['llava:13b', 'qwen2-vl:7b'], + fixedNow + ); + }); +});