mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 08:44:04 +01:00
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
165
internal/ai/vision/README.md
Normal file
165
internal/ai/vision/README.md
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
## PhotoPrism — Vision Package
|
||||||
|
|
||||||
|
**Last Updated:** November 24, 2025
|
||||||
|
|
||||||
|
### Overview
|
||||||
|
|
||||||
|
`internal/ai/vision` provides the shared model registry, request builders, and parsers that power PhotoPrism’s caption, label, face, NSFW, and future generate workflows. It reads `vision.yml`, normalizes models, and dispatches calls to one of three engines:
|
||||||
|
|
||||||
|
- **TensorFlow (built‑in)** — default Nasnet / NSFW / Facenet models, no remote service required.
|
||||||
|
- **Ollama** — local or proxied multimodal LLMs. See [`ollama/README.md`](ollama/README.md) for tuning and schema details.
|
||||||
|
- **OpenAI** — cloud Responses API. See [`openai/README.md`](openai/README.md) for prompts, schema variants, and header requirements.
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
#### Models
|
||||||
|
|
||||||
|
The `vision.yml` file is usually kept in the `storage/config` directory (override with `PHOTOPRISM_VISION_YAML`). It defines a list of models under `Models:`. Key fields are captured below:
|
||||||
|
|
||||||
|
| Field | Default | Notes |
|
||||||
|
|-------------------------|-----------------------------|------------------------------------------------------------------------------------|
|
||||||
|
| `Type` (required) | — | `labels`, `caption`, `face`, `nsfw`, `generate`. Drives routing & scheduling. |
|
||||||
|
| `Name` | derived from type/version | Display name; lower-cased by helpers. |
|
||||||
|
| `Model` | `""` | Raw identifier override; precedence: `Service.Model` → `Model` → `Name`. |
|
||||||
|
| `Version` | `latest` (non-OpenAI) | OpenAI payloads omit version. |
|
||||||
|
| `Engine` | inferred from service/alias | Aliases set formats, file scheme, resolution. Explicit `Service` values still win. |
|
||||||
|
| `Run` | `auto` | See Run modes table below. |
|
||||||
|
| `Default` | `false` | Keep one per type for TensorFlow fallbacks. |
|
||||||
|
| `Disabled` | `false` | Registered but inactive. |
|
||||||
|
| `Resolution` | 224 (720 for Ollama/OpenAI) | Thumbnail edge in px. |
|
||||||
|
| `System` / `Prompt` | engine defaults | Override prompts per model. |
|
||||||
|
| `Format` | `""` | Response hint (`json`, `text`, `markdown`). |
|
||||||
|
| `Schema` / `SchemaFile` | engine defaults / empty | Inline vs file JSON schema (labels). |
|
||||||
|
| `TensorFlow` | nil | Local TF model info (paths, tags). |
|
||||||
|
| `Options` | nil | Sampling/settings merged with engine defaults. |
|
||||||
|
| `Service` | nil | Remote endpoint config (see below). |
|
||||||
|
|
||||||
|
#### Run Modes
|
||||||
|
|
||||||
|
| Value | When it runs | Recommended use |
|
||||||
|
|-----------------|------------------------------------------------------------------|------------------------------------------------|
|
||||||
|
| `auto` | TensorFlow defaults during index; external via metadata/schedule | Leave as-is for most setups. |
|
||||||
|
| `manual` | Only when explicitly invoked (CLI/API) | Experiments and diagnostics. |
|
||||||
|
| `on-index` | During indexing + manual | Fast local models only. |
|
||||||
|
| `newly-indexed` | Metadata worker after indexing + manual | External/Ollama/OpenAI without slowing import. |
|
||||||
|
| `on-demand` | Manual, metadata worker, and scheduled jobs | Broad coverage without index path. |
|
||||||
|
| `on-schedule` | Scheduled jobs + manual | Nightly/cron-style runs. |
|
||||||
|
| `always` | Indexing, metadata, scheduled, manual | High-priority models; watch resource use. |
|
||||||
|
| `never` | Never executes | Keep definition without running it. |
|
||||||
|
|
||||||
|
#### Model Options
|
||||||
|
|
||||||
|
| Option | Default | Description |
|
||||||
|
|-------------------|-----------------------------------------------------------------------------------------|------------------------------------------------------------------------------------|
|
||||||
|
| `Temperature` | engine default (`0.1` for Ollama; unset for OpenAI) | Controls randomness; clamped to `[0,2]`. `gpt-5*` OpenAI models are forced to `0`. |
|
||||||
|
| `TopP` | engine default (`0.9` for some Ollama label defaults; unset for OpenAI) | Nucleus sampling parameter. |
|
||||||
|
| `MaxOutputTokens` | engine default (OpenAI caption 512, labels 1024; Ollama label default 256) | Upper bound on generated tokens; adapters raise low values to defaults. |
|
||||||
|
| `ForceJson` | engine-specific (`true` for OpenAI labels; `false` for Ollama labels; captions `false`) | Forces structured output when enabled. |
|
||||||
|
| `SchemaVersion` | derived from schema name | Override when coordinating schema migrations. |
|
||||||
|
| `Stop` | engine default | Array of stop sequences (e.g., `["\\n\\n"]`). |
|
||||||
|
| `NumThread` | runtime auto | Caps CPU threads for local engines. |
|
||||||
|
| `NumCtx` | engine default | Context window length (tokens). |
|
||||||
|
|
||||||
|
#### Model Service
|
||||||
|
|
||||||
|
Used for Ollama/OpenAI (and any future HTTP engines). All credentials and identifiers support `${ENV_VAR}` expansion.
|
||||||
|
|
||||||
|
| Field | Default | Notes |
|
||||||
|
|------------------------------------|------------------------------------------|------------------------------------------------------|
|
||||||
|
| `Uri` | required for remote | Endpoint base. Empty keeps model local (TensorFlow). |
|
||||||
|
| `Method` | `POST` | Override verb if provider needs it. |
|
||||||
|
| `Key` | `""` | Bearer token; prefer env expansion. |
|
||||||
|
| `Username` / `Password` | `""` | Injected as basic auth when URI lacks userinfo. |
|
||||||
|
| `Model` | `""` | Endpoint-specific override; wins over model/name. |
|
||||||
|
| `Org` / `Project` | `""` | OpenAI headers. |
|
||||||
|
| `RequestFormat` / `ResponseFormat` | set by engine alias | Explicit values win over alias defaults. |
|
||||||
|
| `FileScheme` | set by engine alias (`data` or `base64`) | Controls image transport. |
|
||||||
|
| `Disabled` | `false` | Disable the endpoint without removing the model. |
|
||||||
|
|
||||||
|
### Field Behavior & Precedence
|
||||||
|
|
||||||
|
- Model identifier resolution order: `Service.Model` → `Model` → `Name`. `Model.GetModel()` returns `(id, name, version)` where Ollama receives `name:version` and other engines receive `name` plus a separate `Version`.
|
||||||
|
- Env expansion runs for all `Service` credentials and `Model` overrides; empty or disabled models return empty identifiers.
|
||||||
|
- Options merging: engine defaults fill missing fields; explicit values always win. Temperature is capped at `MaxTemperature`.
|
||||||
|
- Authentication: `Service.Key` sets `Authorization: Bearer <token>`; `Username`/`Password` inject HTTP basic auth into the service URI when not already present.
|
||||||
|
|
||||||
|
### Minimal Examples
|
||||||
|
|
||||||
|
#### TensorFlow (built‑in defaults)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
Models:
|
||||||
|
- Type: labels
|
||||||
|
Default: true
|
||||||
|
Run: auto
|
||||||
|
|
||||||
|
- Type: nsfw
|
||||||
|
Default: true
|
||||||
|
Run: auto
|
||||||
|
|
||||||
|
- Type: face
|
||||||
|
Default: true
|
||||||
|
Run: auto
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Ollama Labels
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
Models:
|
||||||
|
- Type: labels
|
||||||
|
Model: qwen2.5vl:7b
|
||||||
|
Engine: ollama
|
||||||
|
Run: newly-indexed
|
||||||
|
Service:
|
||||||
|
Uri: http://ollama:11434/api/generate
|
||||||
|
```
|
||||||
|
|
||||||
|
More Ollama guidance: [`internal/ai/vision/ollama/README.md`](ollama/README.md).
|
||||||
|
|
||||||
|
#### OpenAI Captions
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
Models:
|
||||||
|
- Type: caption
|
||||||
|
Model: gpt-5-mini
|
||||||
|
Engine: openai
|
||||||
|
Run: newly-indexed
|
||||||
|
Service:
|
||||||
|
Uri: https://api.openai.com/v1/responses
|
||||||
|
Org: ${OPENAI_ORG}
|
||||||
|
Project: ${OPENAI_PROJECT}
|
||||||
|
Key: ${OPENAI_API_KEY}
|
||||||
|
```
|
||||||
|
|
||||||
|
More OpenAI guidance: [`internal/ai/vision/openai/README.md`](openai/README.md).
|
||||||
|
|
||||||
|
#### Custom TensorFlow Caption (local file model)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
Models:
|
||||||
|
- Type: caption
|
||||||
|
Name: custom-caption
|
||||||
|
Engine: tensorflow
|
||||||
|
Path: storage/models/custom-caption
|
||||||
|
Resolution: 448
|
||||||
|
Run: manual
|
||||||
|
```
|
||||||
|
|
||||||
|
### CLI Quick Reference
|
||||||
|
|
||||||
|
- List models: `photoprism vision ls` (shows resolved IDs, engines, options, run mode, disabled flag).
|
||||||
|
- Run a model: `photoprism vision run -m labels --count 5` (use `--force` to bypass `Run` rules).
|
||||||
|
- Validate config: `photoprism vision ls --json` to confirm env-expanded values without triggering calls.
|
||||||
|
|
||||||
|
### When to Choose Each Engine
|
||||||
|
|
||||||
|
- **TensorFlow**: fast, offline defaults for core features (labels, faces, NSFW). Zero external deps.
|
||||||
|
- **Ollama**: private, GPU/CPU-hosted multimodal LLMs; best for richer captions/labels without cloud traffic.
|
||||||
|
- **OpenAI**: highest quality reasoning and multimodal support; requires API key and network access.
|
||||||
|
|
||||||
|
### Related Docs
|
||||||
|
|
||||||
|
- Ollama specifics: [`internal/ai/vision/ollama/README.md`](ollama/README.md)
|
||||||
|
- OpenAI specifics: [`internal/ai/vision/openai/README.md`](openai/README.md)
|
||||||
|
- REST API reference: https://docs.photoprism.dev/
|
||||||
|
- Developer guide (Vision): https://docs.photoprism.app/developer-guide/api/
|
||||||
@@ -35,13 +35,19 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
|
|||||||
// Add "application/json" content type header.
|
// Add "application/json" content type header.
|
||||||
header.SetContentType(req, header.ContentTypeJson)
|
header.SetContentType(req, header.ContentTypeJson)
|
||||||
|
|
||||||
// Add an authentication header if an access token is configured.
|
if reqErr != nil {
|
||||||
|
return apiResponse, reqErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add an authentication header if an access token is provided.
|
||||||
if key != "" {
|
if key != "" {
|
||||||
header.SetAuthorization(req, key)
|
header.SetAuthorization(req, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
if reqErr != nil {
|
// Add custom OpenAI organization and project headers.
|
||||||
return apiResponse, reqErr
|
if apiRequest.GetResponseFormat() == ApiFormatOpenAI {
|
||||||
|
header.SetOpenAIOrg(req, apiRequest.Org)
|
||||||
|
header.SetOpenAIProject(req, apiRequest.Project)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform API request.
|
// Perform API request.
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/photoprism/photoprism/internal/ai/vision/ollama"
|
"github.com/photoprism/photoprism/internal/ai/vision/ollama"
|
||||||
|
"github.com/photoprism/photoprism/pkg/http/header"
|
||||||
"github.com/photoprism/photoprism/pkg/http/scheme"
|
"github.com/photoprism/photoprism/pkg/http/scheme"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -119,3 +120,44 @@ func TestPerformApiRequestOllama(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPerformApiRequestOpenAIHeaders(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, "org-123", r.Header.Get(header.OpenAIOrg))
|
||||||
|
assert.Equal(t, "proj-abc", r.Header.Get(header.OpenAIProject))
|
||||||
|
|
||||||
|
response := map[string]any{
|
||||||
|
"id": "resp_123",
|
||||||
|
"model": "gpt-5-mini",
|
||||||
|
"output": []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "output_text",
|
||||||
|
"text": "A scenic mountain view.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, json.NewEncoder(w).Encode(response))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := &ApiRequest{
|
||||||
|
Id: "headers",
|
||||||
|
Model: "gpt-5-mini",
|
||||||
|
Images: []string{"data:image/jpeg;base64,AA=="},
|
||||||
|
ResponseFormat: ApiFormatOpenAI,
|
||||||
|
Org: "org-123",
|
||||||
|
Project: "proj-abc",
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := PerformApiRequest(req, server.URL, http.MethodPost, "")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, resp)
|
||||||
|
assert.NotNil(t, resp.Result.Caption)
|
||||||
|
assert.Equal(t, "A scenic mountain view.", resp.Result.Caption.Text)
|
||||||
|
}
|
||||||
|
|||||||
@@ -82,6 +82,8 @@ type ApiRequest struct {
|
|||||||
Suffix string `form:"suffix" yaml:"Suffix,omitempty" json:"suffix"`
|
Suffix string `form:"suffix" yaml:"Suffix,omitempty" json:"suffix"`
|
||||||
Format string `form:"format" yaml:"Format,omitempty" json:"format,omitempty"`
|
Format string `form:"format" yaml:"Format,omitempty" json:"format,omitempty"`
|
||||||
Url string `form:"url" yaml:"Url,omitempty" json:"url,omitempty"`
|
Url string `form:"url" yaml:"Url,omitempty" json:"url,omitempty"`
|
||||||
|
Org string `form:"org" yaml:"Org,omitempty" json:"org,omitempty"`
|
||||||
|
Project string `form:"project" yaml:"Project,omitempty" json:"project,omitempty"`
|
||||||
Options *ApiRequestOptions `form:"options" yaml:"Options,omitempty" json:"options,omitempty"`
|
Options *ApiRequestOptions `form:"options" yaml:"Options,omitempty" json:"options,omitempty"`
|
||||||
Context *ApiRequestContext `form:"context" yaml:"Context,omitempty" json:"context,omitempty"`
|
Context *ApiRequestContext `form:"context" yaml:"Context,omitempty" json:"context,omitempty"`
|
||||||
Stream bool `form:"stream" yaml:"Stream,omitempty" json:"stream"`
|
Stream bool `form:"stream" yaml:"Stream,omitempty" json:"stream"`
|
||||||
|
|||||||
@@ -43,14 +43,11 @@ func captionInternal(images Files, mediaSrc media.Src) (result *CaptionResult, m
|
|||||||
}
|
}
|
||||||
|
|
||||||
if apiRequest.Model == "" {
|
if apiRequest.Model == "" {
|
||||||
switch model.Service.RequestFormat {
|
apiRequest.Model, _, apiRequest.Version = model.GetModel()
|
||||||
case ApiFormatOllama:
|
|
||||||
apiRequest.Model, _, _ = model.Model()
|
|
||||||
default:
|
|
||||||
_, apiRequest.Model, apiRequest.Version = model.Model()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model.ApplyService(apiRequest)
|
||||||
|
|
||||||
apiRequest.System = model.GetSystemPrompt()
|
apiRequest.System = model.GetSystemPrompt()
|
||||||
apiRequest.Prompt = model.GetPrompt()
|
apiRequest.Prompt = model.GetPrompt()
|
||||||
|
|
||||||
|
|||||||
@@ -117,9 +117,9 @@ func (ollamaBuilder) Build(ctx context.Context, model *Model, files Files) (*Api
|
|||||||
}
|
}
|
||||||
|
|
||||||
if model.Service.RequestFormat == ApiFormatOllama {
|
if model.Service.RequestFormat == ApiFormatOllama {
|
||||||
req.Model, _, _ = model.Model()
|
req.Model, _, _ = model.GetModel()
|
||||||
} else {
|
} else {
|
||||||
_, req.Model, req.Version = model.Model()
|
_, req.Model, req.Version = model.GetModel()
|
||||||
}
|
}
|
||||||
|
|
||||||
return req, nil
|
return req, nil
|
||||||
|
|||||||
@@ -87,35 +87,14 @@ func (openaiDefaults) Options(model *Model) *ApiRequestOptions {
|
|||||||
|
|
||||||
switch model.Type {
|
switch model.Type {
|
||||||
case ModelTypeCaption:
|
case ModelTypeCaption:
|
||||||
/*
|
|
||||||
Options:
|
|
||||||
Detail: low
|
|
||||||
MaxOutputTokens: 512
|
|
||||||
Temperature: 0.1
|
|
||||||
TopP: 0.9
|
|
||||||
(Sampling values are zeroed for GPT-5 models in openaiBuilder.Build.)
|
|
||||||
*/
|
|
||||||
return &ApiRequestOptions{
|
return &ApiRequestOptions{
|
||||||
Detail: openai.DefaultDetail,
|
Detail: openai.DefaultDetail,
|
||||||
MaxOutputTokens: openai.CaptionMaxTokens,
|
MaxOutputTokens: openai.CaptionMaxTokens,
|
||||||
Temperature: openai.DefaultTemperature,
|
|
||||||
TopP: openai.DefaultTopP,
|
|
||||||
}
|
}
|
||||||
case ModelTypeLabels:
|
case ModelTypeLabels:
|
||||||
/*
|
|
||||||
Options:
|
|
||||||
Detail: low
|
|
||||||
MaxOutputTokens: 1024
|
|
||||||
Temperature: 0.1
|
|
||||||
ForceJson: true
|
|
||||||
SchemaVersion: "photoprism_vision_labels_v1"
|
|
||||||
(Sampling values are zeroed for GPT-5 models in openaiBuilder.Build.)
|
|
||||||
*/
|
|
||||||
return &ApiRequestOptions{
|
return &ApiRequestOptions{
|
||||||
Detail: openai.DefaultDetail,
|
Detail: openai.DefaultDetail,
|
||||||
MaxOutputTokens: openai.LabelsMaxTokens,
|
MaxOutputTokens: openai.LabelsMaxTokens,
|
||||||
Temperature: openai.DefaultTemperature,
|
|
||||||
TopP: openai.DefaultTopP,
|
|
||||||
ForceJson: true,
|
ForceJson: true,
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -53,7 +53,8 @@ func DetectFaces(fileName string, minSize int, cacheCrop bool, expected int) (re
|
|||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, apiRequest.Model, apiRequest.Version = model.Model()
|
_, apiRequest.Model, apiRequest.Version = model.GetModel()
|
||||||
|
model.ApplyService(apiRequest)
|
||||||
|
|
||||||
if model.System != "" {
|
if model.System != "" {
|
||||||
apiRequest.System = model.System
|
apiRequest.System = model.System
|
||||||
|
|||||||
@@ -71,14 +71,11 @@ func labelsInternal(images Files, mediaSrc media.Src, labelSrc entity.Src) (resu
|
|||||||
}
|
}
|
||||||
|
|
||||||
if apiRequest.Model == "" {
|
if apiRequest.Model == "" {
|
||||||
switch model.Service.RequestFormat {
|
apiRequest.Model, _, apiRequest.Version = model.GetModel()
|
||||||
case ApiFormatOllama:
|
|
||||||
apiRequest.Model, _, _ = model.Model()
|
|
||||||
default:
|
|
||||||
_, apiRequest.Model, apiRequest.Version = model.Model()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model.ApplyService(apiRequest)
|
||||||
|
|
||||||
if system := model.GetSystemPrompt(); system != "" {
|
if system := model.GetSystemPrompt(); system != "" {
|
||||||
apiRequest.System = system
|
apiRequest.System = system
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ var (
|
|||||||
type Model struct {
|
type Model struct {
|
||||||
Type ModelType `yaml:"Type,omitempty" json:"type,omitempty"`
|
Type ModelType `yaml:"Type,omitempty" json:"type,omitempty"`
|
||||||
Default bool `yaml:"Default,omitempty" json:"default,omitempty"`
|
Default bool `yaml:"Default,omitempty" json:"default,omitempty"`
|
||||||
|
Model string `yaml:"Model,omitempty" json:"model,omitempty"`
|
||||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||||
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
|
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
|
||||||
Engine ModelEngine `yaml:"Engine,omitempty" json:"engine,omitempty"`
|
Engine ModelEngine `yaml:"Engine,omitempty" json:"engine,omitempty"`
|
||||||
@@ -59,43 +60,55 @@ type Model struct {
|
|||||||
// Models represents a set of computer vision models.
|
// Models represents a set of computer vision models.
|
||||||
type Models []*Model
|
type Models []*Model
|
||||||
|
|
||||||
// Model returns the parsed and normalized identifier, name, and version
|
// GetModel returns the normalized model identifier, name, and version strings
|
||||||
// strings. Nil receivers return empty values so callers can destructure the
|
// used in service requests. Callers can always destructure the tuple because
|
||||||
// tuple without additional nil checks.
|
// nil receivers return empty values.
|
||||||
func (m *Model) Model() (model, name, version string) {
|
func (m *Model) GetModel() (model, name, version string) {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return "", "", ""
|
return "", "", ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return empty identifier string if no name was set.
|
// Normalise the configured values.
|
||||||
if m.Name == "" {
|
|
||||||
return "", "", clean.TypeLowerDash(m.Version)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize model name.
|
|
||||||
name = clean.TypeLower(m.Name)
|
name = clean.TypeLower(m.Name)
|
||||||
|
|
||||||
// Split name to check if it contains the version.
|
|
||||||
s := strings.SplitN(name, ":", 2)
|
|
||||||
|
|
||||||
// Return if name contains both model name and version.
|
|
||||||
if len(s) == 2 && s[0] != "" && s[1] != "" {
|
|
||||||
return name, s[0], s[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize model version.
|
|
||||||
version = clean.TypeLowerDash(m.Version)
|
version = clean.TypeLowerDash(m.Version)
|
||||||
|
|
||||||
// Default to "latest" if no specific version was set.
|
// Build a base name from the highest-priority override:
|
||||||
|
// 1) Service-specific override (expanded for env vars)
|
||||||
|
// 2) Model-specific override
|
||||||
|
// 3) Declarative model name
|
||||||
|
serviceModel := m.Service.GetModel()
|
||||||
|
switch {
|
||||||
|
case serviceModel != "":
|
||||||
|
name = serviceModel
|
||||||
|
case strings.TrimSpace(m.Model) != "":
|
||||||
|
name = clean.TypeLower(m.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return if no model is configured.
|
||||||
|
if name == "" {
|
||||||
|
return "", "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split "name:version" strings so callers can access versioned models
|
||||||
|
// without repeating parsing logic at each call site.
|
||||||
|
if parts := strings.SplitN(name, ":", 2); len(parts) == 2 && parts[0] != "" && parts[1] != "" {
|
||||||
|
name = parts[0]
|
||||||
|
version = parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to "latest" for non-OpenAI engines when no version was set.
|
||||||
if version == "" {
|
if version == "" {
|
||||||
version = VersionLatest
|
version = VersionLatest
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create model identifier from model name and version.
|
switch m.Engine {
|
||||||
model = strings.Join([]string{s[0], version}, ":")
|
case openai.EngineName:
|
||||||
|
return name, name, ""
|
||||||
// Return normalized model identifier, name, and version.
|
case ollama.EngineName:
|
||||||
return model, name, version
|
return strings.Join([]string{name, version}, ":"), name, version
|
||||||
|
default:
|
||||||
|
return name, name, version
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsDefault reports whether the model refers to one of the built-in defaults.
|
// IsDefault reports whether the model refers to one of the built-in defaults.
|
||||||
@@ -145,6 +158,19 @@ func (m *Model) Endpoint() (uri, method string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ApplyService updates the ApiRequest with service-specific
|
||||||
|
// values when configured.
|
||||||
|
func (m *Model) ApplyService(apiRequest *ApiRequest) {
|
||||||
|
if m == nil || apiRequest == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Engine == openai.EngineName {
|
||||||
|
apiRequest.Org = m.Service.EndpointOrg()
|
||||||
|
apiRequest.Project = m.Service.EndpointProject()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// EndpointKey returns the access token belonging to the remote service
|
// EndpointKey returns the access token belonging to the remote service
|
||||||
// endpoint, or an empty string for nil receivers.
|
// endpoint, or an empty string for nil receivers.
|
||||||
func (m *Model) EndpointKey() (key string) {
|
func (m *Model) EndpointKey() (key string) {
|
||||||
@@ -347,6 +373,10 @@ func mergeOptionDefaults(target, defaults *ApiRequestOptions) {
|
|||||||
target.TopP = defaults.TopP
|
target.TopP = defaults.TopP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if target.Temperature <= 0 && defaults.Temperature > 0 {
|
||||||
|
target.Temperature = defaults.Temperature
|
||||||
|
}
|
||||||
|
|
||||||
if len(target.Stop) == 0 && len(defaults.Stop) > 0 {
|
if len(target.Stop) == 0 && len(defaults.Stop) > 0 {
|
||||||
target.Stop = append([]string(nil), defaults.Stop...)
|
target.Stop = append([]string(nil), defaults.Stop...)
|
||||||
}
|
}
|
||||||
@@ -377,9 +407,7 @@ func normalizeOptions(opts *ApiRequestOptions) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.Temperature <= 0 {
|
if opts.Temperature > MaxTemperature {
|
||||||
opts.Temperature = DefaultTemperature
|
|
||||||
} else if opts.Temperature > MaxTemperature {
|
|
||||||
opts.Temperature = MaxTemperature
|
opts.Temperature = MaxTemperature
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package vision
|
|||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -25,7 +26,7 @@ func TestModelGetOptionsDefaultsOllamaLabels(t *testing.T) {
|
|||||||
|
|
||||||
model.ApplyEngineDefaults()
|
model.ApplyEngineDefaults()
|
||||||
|
|
||||||
m, n, v := model.Model()
|
m, n, v := model.GetModel()
|
||||||
|
|
||||||
assert.Equal(t, ollamaModel, m)
|
assert.Equal(t, ollamaModel, m)
|
||||||
assert.Equal(t, "redule26/huihui_ai_qwen2.5-vl-7b-abliterated", n)
|
assert.Equal(t, "redule26/huihui_ai_qwen2.5-vl-7b-abliterated", n)
|
||||||
@@ -53,6 +54,106 @@ func TestModelGetOptionsDefaultsOllamaLabels(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestModel_GetModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model *Model
|
||||||
|
wantModel string
|
||||||
|
wantName string
|
||||||
|
wantVersion string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Nil",
|
||||||
|
wantModel: "",
|
||||||
|
wantName: "",
|
||||||
|
wantVersion: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OpenAINameOnly",
|
||||||
|
model: &Model{
|
||||||
|
Name: "gpt-5-mini",
|
||||||
|
Engine: openai.EngineName,
|
||||||
|
},
|
||||||
|
wantModel: "gpt-5-mini",
|
||||||
|
wantName: "gpt-5-mini",
|
||||||
|
wantVersion: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NonOpenAIAddsLatest",
|
||||||
|
model: &Model{
|
||||||
|
Name: "gemma3",
|
||||||
|
Engine: ollama.EngineName,
|
||||||
|
},
|
||||||
|
wantModel: "gemma3:latest",
|
||||||
|
wantName: "gemma3",
|
||||||
|
wantVersion: "latest",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ExplicitVersion",
|
||||||
|
model: &Model{
|
||||||
|
Name: "gemma3",
|
||||||
|
Version: "2",
|
||||||
|
Engine: ollama.EngineName,
|
||||||
|
},
|
||||||
|
wantModel: "gemma3:2",
|
||||||
|
wantName: "gemma3",
|
||||||
|
wantVersion: "2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NameContainsVersion",
|
||||||
|
model: &Model{
|
||||||
|
Name: "qwen2.5vl:7b",
|
||||||
|
Engine: ollama.EngineName,
|
||||||
|
},
|
||||||
|
wantModel: "qwen2.5vl:7b",
|
||||||
|
wantName: "qwen2.5vl",
|
||||||
|
wantVersion: "7b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ModelFieldFallback",
|
||||||
|
model: &Model{
|
||||||
|
Model: "CUSTOM-MODEL",
|
||||||
|
Engine: ollama.EngineName,
|
||||||
|
},
|
||||||
|
wantModel: "custom-model:latest",
|
||||||
|
wantName: "custom-model",
|
||||||
|
wantVersion: "latest",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ServiceOverrideWithVersion",
|
||||||
|
model: &Model{
|
||||||
|
Name: "ignored",
|
||||||
|
Engine: ollama.EngineName,
|
||||||
|
Service: Service{Model: "mixtral:8x7b"},
|
||||||
|
},
|
||||||
|
wantModel: "mixtral:8x7b",
|
||||||
|
wantName: "mixtral",
|
||||||
|
wantVersion: "8x7b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ServiceOverrideOpenAI",
|
||||||
|
model: &Model{
|
||||||
|
Name: "gpt-4.1",
|
||||||
|
Engine: openai.EngineName,
|
||||||
|
Service: Service{Model: "gpt-5-mini"},
|
||||||
|
},
|
||||||
|
wantModel: "gpt-5-mini",
|
||||||
|
wantName: "gpt-5-mini",
|
||||||
|
wantVersion: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
model, name, version := tt.model.GetModel()
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantModel, model)
|
||||||
|
assert.Equal(t, tt.wantName, name)
|
||||||
|
assert.Equal(t, tt.wantVersion, version)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestModelGetOptionsRespectsCustomValues(t *testing.T) {
|
func TestModelGetOptionsRespectsCustomValues(t *testing.T) {
|
||||||
model := &Model{
|
model := &Model{
|
||||||
Type: ModelTypeLabels,
|
Type: ModelTypeLabels,
|
||||||
@@ -152,6 +253,9 @@ func TestModelEndpointKeyOpenAIFallbacks(t *testing.T) {
|
|||||||
t.Fatalf("write key file: %v", err)
|
t.Fatalf("write key file: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset ensureEnvOnce.
|
||||||
|
ensureEnvOnce = sync.Once{}
|
||||||
|
|
||||||
t.Setenv("OPENAI_API_KEY", "")
|
t.Setenv("OPENAI_API_KEY", "")
|
||||||
t.Setenv("OPENAI_API_KEY_FILE", path)
|
t.Setenv("OPENAI_API_KEY_FILE", path)
|
||||||
|
|
||||||
@@ -218,6 +322,30 @@ func TestModelGetSource(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestModelApplyService(t *testing.T) {
|
||||||
|
t.Run("OpenAIHeaders", func(t *testing.T) {
|
||||||
|
req := &ApiRequest{}
|
||||||
|
model := &Model{
|
||||||
|
Engine: openai.EngineName,
|
||||||
|
Service: Service{Org: "org-123", Project: "proj-abc"},
|
||||||
|
}
|
||||||
|
|
||||||
|
model.ApplyService(req)
|
||||||
|
|
||||||
|
assert.Equal(t, "org-123", req.Org)
|
||||||
|
assert.Equal(t, "proj-abc", req.Project)
|
||||||
|
})
|
||||||
|
t.Run("OtherEngineNoop", func(t *testing.T) {
|
||||||
|
req := &ApiRequest{Org: "keep", Project: "keep"}
|
||||||
|
model := &Model{Engine: ollama.EngineName, Service: Service{Org: "new", Project: "new"}}
|
||||||
|
|
||||||
|
model.ApplyService(req)
|
||||||
|
|
||||||
|
assert.Equal(t, "keep", req.Org)
|
||||||
|
assert.Equal(t, "keep", req.Project)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestModel_IsDefault(t *testing.T) {
|
func TestModel_IsDefault(t *testing.T) {
|
||||||
nasnetCopy := *NasnetModel //nolint:govet // copy for test inspection only
|
nasnetCopy := *NasnetModel //nolint:govet // copy for test inspection only
|
||||||
nasnetCopy.Default = false
|
nasnetCopy.Default = false
|
||||||
|
|||||||
@@ -47,13 +47,12 @@ func nsfwInternal(images Files, mediaSrc media.Src) (result []nsfw.Result, err e
|
|||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch model.Service.RequestFormat {
|
if apiRequest.Model == "" {
|
||||||
case ApiFormatOllama:
|
apiRequest.Model, _, apiRequest.Version = model.GetModel()
|
||||||
apiRequest.Model, _, _ = model.Model()
|
|
||||||
default:
|
|
||||||
_, apiRequest.Model, apiRequest.Version = model.Model()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model.ApplyService(apiRequest)
|
||||||
|
|
||||||
if model.System != "" {
|
if model.System != "" {
|
||||||
apiRequest.System = model.System
|
apiRequest.System = model.System
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ This package provides PhotoPrism’s native adapter for Ollama-compatible multim
|
|||||||
### Architecture & Request Flow
|
### Architecture & Request Flow
|
||||||
|
|
||||||
1. **Model Selection** — `Config.Model(ModelType)` returns the top-most enabled entry. When `Engine: ollama`, `ApplyEngineDefaults()` fills in the request/response format, base64 file scheme, and a 720 px resolution unless overridden.
|
1. **Model Selection** — `Config.Model(ModelType)` returns the top-most enabled entry. When `Engine: ollama`, `ApplyEngineDefaults()` fills in the request/response format, base64 file scheme, and a 720 px resolution unless overridden.
|
||||||
2. **Request Build** — `ollamaBuilder.Build` wraps thumbnails with `NewApiRequestOllama`, which encodes them as base64 strings. `Model.Model()` resolves the exact Ollama tag (`gemma3:4b`, `qwen2.5vl:7b`, etc.).
|
2. **Request Build** — `ollamaBuilder.Build` wraps thumbnails with `NewApiRequestOllama`, which encodes them as base64 strings. `Model.GetModel()` resolves the exact Ollama tag (`gemma3:4b`, `qwen2.5vl:7b`, etc.).
|
||||||
3. **Transport** — `PerformApiRequest` uses a single HTTP POST (default timeout 10 min). Authentication is optional; provide `Service.Key` if you proxy through an API gateway.
|
3. **Transport** — `PerformApiRequest` uses a single HTTP POST (default timeout 10 min). Authentication is optional; provide `Service.Key` if you proxy through an API gateway.
|
||||||
4. **Parsing** — `ollamaParser.Parse` converts payloads into `ApiResponse`. It normalizes confidences (`LabelConfidenceDefault = 0.5` when missing), copies NSFW scores, and canonicalizes label names via `normalizeLabelResult`.
|
4. **Parsing** — `ollamaParser.Parse` converts payloads into `ApiResponse`. It normalizes confidences (`LabelConfidenceDefault = 0.5` when missing), copies NSFW scores, and canonicalizes label names via `normalizeLabelResult`.
|
||||||
5. **Persistence** — `entity.SrcOllama` is stamped on labels/captions so UI badges and audits reflect the new source.
|
5. **Persistence** — `entity.SrcOllama` is stamped on labels/captions so UI badges and audits reflect the new source.
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package vision
|
package vision
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/photoprism/photoprism/pkg/clean"
|
||||||
"github.com/photoprism/photoprism/pkg/http/scheme"
|
"github.com/photoprism/photoprism/pkg/http/scheme"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -11,7 +13,12 @@ import (
|
|||||||
type Service struct {
|
type Service struct {
|
||||||
Uri string `yaml:"Uri,omitempty" json:"uri"`
|
Uri string `yaml:"Uri,omitempty" json:"uri"`
|
||||||
Method string `yaml:"Method,omitempty" json:"method"`
|
Method string `yaml:"Method,omitempty" json:"method"`
|
||||||
|
Model string `yaml:"Model,omitempty" json:"model,omitempty"` // Optional endpoint-specific model override.
|
||||||
|
Username string `yaml:"Username,omitempty" json:"-"` // Optional basic auth user injected into Endpoint URLs.
|
||||||
|
Password string `yaml:"Password,omitempty" json:"-"`
|
||||||
Key string `yaml:"Key,omitempty" json:"-"`
|
Key string `yaml:"Key,omitempty" json:"-"`
|
||||||
|
Org string `yaml:"Org,omitempty" json:"org,omitempty"` // Optional organization header (e.g. OpenAI).
|
||||||
|
Project string `yaml:"Project,omitempty" json:"project,omitempty"` // Optional project header (e.g. OpenAI).
|
||||||
FileScheme string `yaml:"FileScheme,omitempty" json:"fileScheme,omitempty"`
|
FileScheme string `yaml:"FileScheme,omitempty" json:"fileScheme,omitempty"`
|
||||||
RequestFormat ApiFormat `yaml:"RequestFormat,omitempty" json:"requestFormat,omitempty"`
|
RequestFormat ApiFormat `yaml:"RequestFormat,omitempty" json:"requestFormat,omitempty"`
|
||||||
ResponseFormat ApiFormat `yaml:"ResponseFormat,omitempty" json:"responseFormat,omitempty"`
|
ResponseFormat ApiFormat `yaml:"ResponseFormat,omitempty" json:"responseFormat,omitempty"`
|
||||||
@@ -20,7 +27,7 @@ type Service struct {
|
|||||||
|
|
||||||
// Endpoint returns the remote service request method and endpoint URL, if any.
|
// Endpoint returns the remote service request method and endpoint URL, if any.
|
||||||
func (m *Service) Endpoint() (uri, method string) {
|
func (m *Service) Endpoint() (uri, method string) {
|
||||||
if m.Disabled || m.Uri == "" {
|
if m.Disabled || strings.TrimSpace(m.Uri) == "" {
|
||||||
return "", ""
|
return "", ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -30,7 +37,37 @@ func (m *Service) Endpoint() (uri, method string) {
|
|||||||
method = ServiceMethod
|
method = ServiceMethod
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.Uri, method
|
uri = strings.TrimSpace(m.Uri)
|
||||||
|
|
||||||
|
if username, password := m.BasicAuth(); username != "" || password != "" {
|
||||||
|
if parsed, err := url.Parse(uri); err == nil {
|
||||||
|
if parsed.User == nil {
|
||||||
|
switch {
|
||||||
|
case username != "" && password != "":
|
||||||
|
parsed.User = url.UserPassword(username, password)
|
||||||
|
case username != "":
|
||||||
|
parsed.User = url.User(username)
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsed.User != nil {
|
||||||
|
uri = parsed.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return uri, method
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModel returns the model identifier override for the endpoint, if any.
|
||||||
|
func (m *Service) GetModel() string {
|
||||||
|
if m.Disabled {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureEnv()
|
||||||
|
|
||||||
|
return clean.TypeLower(os.ExpandEnv(m.Model))
|
||||||
}
|
}
|
||||||
|
|
||||||
// EndpointKey returns the access token belonging to the remote service endpoint, if any.
|
// EndpointKey returns the access token belonging to the remote service endpoint, if any.
|
||||||
@@ -44,6 +81,36 @@ func (m *Service) EndpointKey() string {
|
|||||||
return strings.TrimSpace(os.ExpandEnv(m.Key))
|
return strings.TrimSpace(os.ExpandEnv(m.Key))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EndpointOrg returns the organization identifier for the endpoint, if any.
|
||||||
|
func (m *Service) EndpointOrg() string {
|
||||||
|
if m.Disabled {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureEnv()
|
||||||
|
|
||||||
|
return strings.TrimSpace(os.ExpandEnv(m.Org))
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndpointProject returns the project identifier for the endpoint, if any.
|
||||||
|
func (m *Service) EndpointProject() string {
|
||||||
|
if m.Disabled {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureEnv()
|
||||||
|
|
||||||
|
return strings.TrimSpace(os.ExpandEnv(m.Project))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BasicAuth returns the username and password for basic authentication.
|
||||||
|
func (m *Service) BasicAuth() (username, password string) {
|
||||||
|
ensureEnv()
|
||||||
|
username = strings.TrimSpace(os.ExpandEnv(m.Username))
|
||||||
|
password = strings.TrimSpace(os.ExpandEnv(m.Password))
|
||||||
|
return username, password
|
||||||
|
}
|
||||||
|
|
||||||
// EndpointFileScheme returns the endpoint API file scheme type.
|
// EndpointFileScheme returns the endpoint API file scheme type.
|
||||||
func (m *Service) EndpointFileScheme() scheme.Type {
|
func (m *Service) EndpointFileScheme() scheme.Type {
|
||||||
if m.Disabled {
|
if m.Disabled {
|
||||||
|
|||||||
82
internal/ai/vision/service_test.go
Normal file
82
internal/ai/vision/service_test.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package vision
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestServiceEndpoint(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
svc Service
|
||||||
|
wantURI string
|
||||||
|
wantMethod string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Disabled",
|
||||||
|
svc: Service{Disabled: true, Uri: "https://vision.example.com"},
|
||||||
|
wantURI: "",
|
||||||
|
wantMethod: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "WithBasicAuth",
|
||||||
|
svc: Service{Uri: "https://vision.example.com/api", Username: "user", Password: "secret"},
|
||||||
|
wantURI: "https://user:secret@vision.example.com/api",
|
||||||
|
wantMethod: ServiceMethod,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UsernameOnly",
|
||||||
|
svc: Service{Uri: "https://vision.example.com/", Username: "scoped"},
|
||||||
|
wantURI: "https://scoped@vision.example.com/",
|
||||||
|
wantMethod: ServiceMethod,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PreserveExistingUser",
|
||||||
|
svc: Service{Uri: "https://keep:me@vision.example.com", Username: "ignored", Password: "ignored"},
|
||||||
|
wantURI: "https://keep:me@vision.example.com",
|
||||||
|
wantMethod: ServiceMethod,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
uri, method := tt.svc.Endpoint()
|
||||||
|
if uri != tt.wantURI {
|
||||||
|
t.Fatalf("uri: got %q want %q", uri, tt.wantURI)
|
||||||
|
}
|
||||||
|
if method != tt.wantMethod {
|
||||||
|
t.Fatalf("method: got %q want %q", method, tt.wantMethod)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServiceCredentialsAndHeaders(t *testing.T) {
|
||||||
|
t.Setenv("VISION_USER", "alice")
|
||||||
|
t.Setenv("VISION_PASS", "hunter2")
|
||||||
|
t.Setenv("VISION_MODEL", "GEMMA3:Latest")
|
||||||
|
t.Setenv("VISION_ORG", "org-123")
|
||||||
|
t.Setenv("VISION_PROJECT", "proj-abc")
|
||||||
|
|
||||||
|
svc := Service{
|
||||||
|
Username: "${VISION_USER}",
|
||||||
|
Password: "${VISION_PASS}",
|
||||||
|
Model: "${VISION_MODEL}",
|
||||||
|
Org: "${VISION_ORG}",
|
||||||
|
Project: "${VISION_PROJECT}",
|
||||||
|
}
|
||||||
|
|
||||||
|
user, pass := svc.BasicAuth()
|
||||||
|
if user != "alice" || pass != "hunter2" {
|
||||||
|
t.Fatalf("basic auth: got %q/%q", user, pass)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := svc.GetModel(); got != "gemma3:latest" {
|
||||||
|
t.Fatalf("model override: got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := svc.EndpointOrg(); got != "org-123" {
|
||||||
|
t.Fatalf("org: got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := svc.EndpointProject(); got != "proj-abc" {
|
||||||
|
t.Fatalf("project: got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -27,8 +27,8 @@ func visionListAction(ctx *cli.Context) error {
|
|||||||
var rows [][]string
|
var rows [][]string
|
||||||
|
|
||||||
cols := []string{
|
cols := []string{
|
||||||
"Type",
|
|
||||||
"Model",
|
"Model",
|
||||||
|
"Type",
|
||||||
"Engine",
|
"Engine",
|
||||||
"Endpoint",
|
"Endpoint",
|
||||||
"Format",
|
"Format",
|
||||||
@@ -52,7 +52,7 @@ func visionListAction(ctx *cli.Context) error {
|
|||||||
modelUri, modelMethod := model.Endpoint()
|
modelUri, modelMethod := model.Endpoint()
|
||||||
tags := ""
|
tags := ""
|
||||||
|
|
||||||
name, _, _ := model.Model()
|
name, _, _ := model.GetModel()
|
||||||
|
|
||||||
if model.TensorFlow != nil && model.TensorFlow.Tags != nil {
|
if model.TensorFlow != nil && model.TensorFlow.Tags != nil {
|
||||||
tags = strings.Join(model.TensorFlow.Tags, ", ")
|
tags = strings.Join(model.TensorFlow.Tags, ", ")
|
||||||
@@ -92,13 +92,13 @@ func visionListAction(ctx *cli.Context) error {
|
|||||||
engine := model.EngineName()
|
engine := model.EngineName()
|
||||||
|
|
||||||
rows[i] = []string{
|
rows[i] = []string{
|
||||||
model.Type,
|
|
||||||
name,
|
name,
|
||||||
|
model.Type,
|
||||||
engine,
|
engine,
|
||||||
fmt.Sprintf("%s %s", modelMethod, modelUri),
|
fmt.Sprintf("%s %s", modelMethod, modelUri),
|
||||||
format,
|
format,
|
||||||
fmt.Sprintf("%d", model.Resolution),
|
fmt.Sprintf("%d", model.Resolution),
|
||||||
report.Bool(len(options) == 0, "tags: "+tags, string(options)),
|
report.Bool(model.TensorFlow != nil, fmt.Sprintf(`{"tags":"%s"}`, tags), string(options)),
|
||||||
run,
|
run,
|
||||||
report.Bool(model.Disabled, report.Disabled, report.Enabled),
|
report.Bool(model.Disabled, report.Disabled, report.Enabled),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,9 +12,11 @@ import (
|
|||||||
|
|
||||||
// Authentication header names.
|
// Authentication header names.
|
||||||
const (
|
const (
|
||||||
Auth = "Authorization" // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
|
Auth = "Authorization" // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
|
||||||
XAuthToken = "X-Auth-Token" //nolint:gosec // header name, not a secret
|
XAuthToken = "X-Auth-Token" //nolint:gosec // header name, not a secret
|
||||||
XSessionID = "X-Session-ID"
|
XSessionID = "X-Session-ID"
|
||||||
|
OpenAIOrg = "OpenAI-Organization"
|
||||||
|
OpenAIProject = "OpenAI-Project"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Authentication header values.
|
// Authentication header values.
|
||||||
@@ -74,6 +76,22 @@ func SetAuthorization(r *http.Request, authToken string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetOpenAIOrg adds the organization header expected by the OpenAI API if a
|
||||||
|
// non-empty value is provided.
|
||||||
|
func SetOpenAIOrg(r *http.Request, org string) {
|
||||||
|
if org = strings.TrimSpace(org); org != "" {
|
||||||
|
r.Header.Add(OpenAIOrg, org)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetOpenAIProject adds the project header expected by the OpenAI API if a
|
||||||
|
// non-empty value is provided.
|
||||||
|
func SetOpenAIProject(r *http.Request, project string) {
|
||||||
|
if project = strings.TrimSpace(project); project != "" {
|
||||||
|
r.Header.Add(OpenAIProject, project)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// BasicAuth checks the basic authorization header for credentials and returns them if found.
|
// BasicAuth checks the basic authorization header for credentials and returns them if found.
|
||||||
//
|
//
|
||||||
// Note that OAuth 2.0 defines basic authentication differently than RFC 7617, however, this
|
// Note that OAuth 2.0 defines basic authentication differently than RFC 7617, however, this
|
||||||
|
|||||||
@@ -64,6 +64,25 @@ func TestAuth(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIHeaders(t *testing.T) {
|
||||||
|
t.Run("SetOrg", func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
SetOpenAIOrg(r, " org-123 ")
|
||||||
|
assert.Equal(t, "org-123", r.Header.Get(OpenAIOrg))
|
||||||
|
|
||||||
|
SetOpenAIOrg(r, "")
|
||||||
|
assert.Equal(t, "org-123", r.Header.Get(OpenAIOrg))
|
||||||
|
})
|
||||||
|
t.Run("SetProject", func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
SetOpenAIProject(r, "proj-abc")
|
||||||
|
assert.Equal(t, "proj-abc", r.Header.Get(OpenAIProject))
|
||||||
|
|
||||||
|
SetOpenAIProject(r, " ")
|
||||||
|
assert.Equal(t, "proj-abc", r.Header.Get(OpenAIProject))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestAuthToken(t *testing.T) {
|
func TestAuthToken(t *testing.T) {
|
||||||
t.Run("None", func(t *testing.T) {
|
t.Run("None", func(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|||||||
Reference in New Issue
Block a user