AI: Improve model configuration and documentation #5123 #5232 #5322

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-11-24 14:41:13 +01:00
parent 46b3a126f0
commit a02162846b
18 changed files with 614 additions and 84 deletions

View File

@@ -0,0 +1,165 @@
## PhotoPrism — Vision Package
**Last Updated:** November 24, 2025
### Overview
`internal/ai/vision` provides the shared model registry, request builders, and parsers that power PhotoPrisms caption, label, face, NSFW, and future generate workflows. It reads `vision.yml`, normalizes models, and dispatches calls to one of three engines:
- **TensorFlow (builtin)** — default Nasnet / NSFW / Facenet models, no remote service required.
- **Ollama** — local or proxied multimodal LLMs. See [`ollama/README.md`](ollama/README.md) for tuning and schema details.
- **OpenAI** — cloud Responses API. See [`openai/README.md`](openai/README.md) for prompts, schema variants, and header requirements.
### Configuration
#### Models
The `vision.yml` file is usually kept in the `storage/config` directory (override with `PHOTOPRISM_VISION_YAML`). It defines a list of models under `Models:`. Key fields are captured below:
| Field | Default | Notes |
|-------------------------|-----------------------------|------------------------------------------------------------------------------------|
| `Type` (required) | — | `labels`, `caption`, `face`, `nsfw`, `generate`. Drives routing & scheduling. |
| `Name` | derived from type/version | Display name; lower-cased by helpers. |
| `Model` | `""` | Raw identifier override; precedence: `Service.Model``Model``Name`. |
| `Version` | `latest` (non-OpenAI) | OpenAI payloads omit version. |
| `Engine` | inferred from service/alias | Aliases set formats, file scheme, resolution. Explicit `Service` values still win. |
| `Run` | `auto` | See Run modes table below. |
| `Default` | `false` | Keep one per type for TensorFlow fallbacks. |
| `Disabled` | `false` | Registered but inactive. |
| `Resolution` | 224 (720 for Ollama/OpenAI) | Thumbnail edge in px. |
| `System` / `Prompt` | engine defaults | Override prompts per model. |
| `Format` | `""` | Response hint (`json`, `text`, `markdown`). |
| `Schema` / `SchemaFile` | engine defaults / empty | Inline vs file JSON schema (labels). |
| `TensorFlow` | nil | Local TF model info (paths, tags). |
| `Options` | nil | Sampling/settings merged with engine defaults. |
| `Service` | nil | Remote endpoint config (see below). |
#### Run Modes
| Value | When it runs | Recommended use |
|-----------------|------------------------------------------------------------------|------------------------------------------------|
| `auto` | TensorFlow defaults during index; external via metadata/schedule | Leave as-is for most setups. |
| `manual` | Only when explicitly invoked (CLI/API) | Experiments and diagnostics. |
| `on-index` | During indexing + manual | Fast local models only. |
| `newly-indexed` | Metadata worker after indexing + manual | External/Ollama/OpenAI without slowing import. |
| `on-demand` | Manual, metadata worker, and scheduled jobs | Broad coverage without index path. |
| `on-schedule` | Scheduled jobs + manual | Nightly/cron-style runs. |
| `always` | Indexing, metadata, scheduled, manual | High-priority models; watch resource use. |
| `never` | Never executes | Keep definition without running it. |
#### Model Options
| Option | Default | Description |
|-------------------|-----------------------------------------------------------------------------------------|------------------------------------------------------------------------------------|
| `Temperature` | engine default (`0.1` for Ollama; unset for OpenAI) | Controls randomness; clamped to `[0,2]`. `gpt-5*` OpenAI models are forced to `0`. |
| `TopP` | engine default (`0.9` for some Ollama label defaults; unset for OpenAI) | Nucleus sampling parameter. |
| `MaxOutputTokens` | engine default (OpenAI caption 512, labels 1024; Ollama label default 256) | Upper bound on generated tokens; adapters raise low values to defaults. |
| `ForceJson` | engine-specific (`true` for OpenAI labels; `false` for Ollama labels; captions `false`) | Forces structured output when enabled. |
| `SchemaVersion` | derived from schema name | Override when coordinating schema migrations. |
| `Stop` | engine default | Array of stop sequences (e.g., `["\\n\\n"]`). |
| `NumThread` | runtime auto | Caps CPU threads for local engines. |
| `NumCtx` | engine default | Context window length (tokens). |
#### Model Service
Used for Ollama/OpenAI (and any future HTTP engines). All credentials and identifiers support `${ENV_VAR}` expansion.
| Field | Default | Notes |
|------------------------------------|------------------------------------------|------------------------------------------------------|
| `Uri` | required for remote | Endpoint base. Empty keeps model local (TensorFlow). |
| `Method` | `POST` | Override verb if provider needs it. |
| `Key` | `""` | Bearer token; prefer env expansion. |
| `Username` / `Password` | `""` | Injected as basic auth when URI lacks userinfo. |
| `Model` | `""` | Endpoint-specific override; wins over model/name. |
| `Org` / `Project` | `""` | OpenAI headers. |
| `RequestFormat` / `ResponseFormat` | set by engine alias | Explicit values win over alias defaults. |
| `FileScheme` | set by engine alias (`data` or `base64`) | Controls image transport. |
| `Disabled` | `false` | Disable the endpoint without removing the model. |
### Field Behavior & Precedence
- Model identifier resolution order: `Service.Model``Model``Name`. `Model.GetModel()` returns `(id, name, version)` where Ollama receives `name:version` and other engines receive `name` plus a separate `Version`.
- Env expansion runs for all `Service` credentials and `Model` overrides; empty or disabled models return empty identifiers.
- Options merging: engine defaults fill missing fields; explicit values always win. Temperature is capped at `MaxTemperature`.
- Authentication: `Service.Key` sets `Authorization: Bearer <token>`; `Username`/`Password` inject HTTP basic auth into the service URI when not already present.
### Minimal Examples
#### TensorFlow (builtin defaults)
```yaml
Models:
- Type: labels
Default: true
Run: auto
- Type: nsfw
Default: true
Run: auto
- Type: face
Default: true
Run: auto
```
#### Ollama Labels
```yaml
Models:
- Type: labels
Model: qwen2.5vl:7b
Engine: ollama
Run: newly-indexed
Service:
Uri: http://ollama:11434/api/generate
```
More Ollama guidance: [`internal/ai/vision/ollama/README.md`](ollama/README.md).
#### OpenAI Captions
```yaml
Models:
- Type: caption
Model: gpt-5-mini
Engine: openai
Run: newly-indexed
Service:
Uri: https://api.openai.com/v1/responses
Org: ${OPENAI_ORG}
Project: ${OPENAI_PROJECT}
Key: ${OPENAI_API_KEY}
```
More OpenAI guidance: [`internal/ai/vision/openai/README.md`](openai/README.md).
#### Custom TensorFlow Caption (local file model)
```yaml
Models:
- Type: caption
Name: custom-caption
Engine: tensorflow
Path: storage/models/custom-caption
Resolution: 448
Run: manual
```
### CLI Quick Reference
- List models: `photoprism vision ls` (shows resolved IDs, engines, options, run mode, disabled flag).
- Run a model: `photoprism vision run -m labels --count 5` (use `--force` to bypass `Run` rules).
- Validate config: `photoprism vision ls --json` to confirm env-expanded values without triggering calls.
### When to Choose Each Engine
- **TensorFlow**: fast, offline defaults for core features (labels, faces, NSFW). Zero external deps.
- **Ollama**: private, GPU/CPU-hosted multimodal LLMs; best for richer captions/labels without cloud traffic.
- **OpenAI**: highest quality reasoning and multimodal support; requires API key and network access.
### Related Docs
- Ollama specifics: [`internal/ai/vision/ollama/README.md`](ollama/README.md)
- OpenAI specifics: [`internal/ai/vision/openai/README.md`](openai/README.md)
- REST API reference: https://docs.photoprism.dev/
- Developer guide (Vision): https://docs.photoprism.app/developer-guide/api/

View File

@@ -35,13 +35,19 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
// Add "application/json" content type header.
header.SetContentType(req, header.ContentTypeJson)
// Add an authentication header if an access token is configured.
if reqErr != nil {
return apiResponse, reqErr
}
// Add an authentication header if an access token is provided.
if key != "" {
header.SetAuthorization(req, key)
}
if reqErr != nil {
return apiResponse, reqErr
// Add custom OpenAI organization and project headers.
if apiRequest.GetResponseFormat() == ApiFormatOpenAI {
header.SetOpenAIOrg(req, apiRequest.Org)
header.SetOpenAIProject(req, apiRequest.Project)
}
// Perform API request.

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/ai/vision/ollama"
"github.com/photoprism/photoprism/pkg/http/header"
"github.com/photoprism/photoprism/pkg/http/scheme"
)
@@ -119,3 +120,44 @@ func TestPerformApiRequestOllama(t *testing.T) {
}
})
}
func TestPerformApiRequestOpenAIHeaders(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "org-123", r.Header.Get(header.OpenAIOrg))
assert.Equal(t, "proj-abc", r.Header.Get(header.OpenAIProject))
response := map[string]any{
"id": "resp_123",
"model": "gpt-5-mini",
"output": []any{
map[string]any{
"role": "assistant",
"content": []any{
map[string]any{
"type": "output_text",
"text": "A scenic mountain view.",
},
},
},
},
}
assert.NoError(t, json.NewEncoder(w).Encode(response))
}))
defer server.Close()
req := &ApiRequest{
Id: "headers",
Model: "gpt-5-mini",
Images: []string{""},
ResponseFormat: ApiFormatOpenAI,
Org: "org-123",
Project: "proj-abc",
}
resp, err := PerformApiRequest(req, server.URL, http.MethodPost, "")
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.NotNil(t, resp.Result.Caption)
assert.Equal(t, "A scenic mountain view.", resp.Result.Caption.Text)
}

View File

@@ -82,6 +82,8 @@ type ApiRequest struct {
Suffix string `form:"suffix" yaml:"Suffix,omitempty" json:"suffix"`
Format string `form:"format" yaml:"Format,omitempty" json:"format,omitempty"`
Url string `form:"url" yaml:"Url,omitempty" json:"url,omitempty"`
Org string `form:"org" yaml:"Org,omitempty" json:"org,omitempty"`
Project string `form:"project" yaml:"Project,omitempty" json:"project,omitempty"`
Options *ApiRequestOptions `form:"options" yaml:"Options,omitempty" json:"options,omitempty"`
Context *ApiRequestContext `form:"context" yaml:"Context,omitempty" json:"context,omitempty"`
Stream bool `form:"stream" yaml:"Stream,omitempty" json:"stream"`

View File

@@ -43,14 +43,11 @@ func captionInternal(images Files, mediaSrc media.Src) (result *CaptionResult, m
}
if apiRequest.Model == "" {
switch model.Service.RequestFormat {
case ApiFormatOllama:
apiRequest.Model, _, _ = model.Model()
default:
_, apiRequest.Model, apiRequest.Version = model.Model()
}
apiRequest.Model, _, apiRequest.Version = model.GetModel()
}
model.ApplyService(apiRequest)
apiRequest.System = model.GetSystemPrompt()
apiRequest.Prompt = model.GetPrompt()

View File

@@ -117,9 +117,9 @@ func (ollamaBuilder) Build(ctx context.Context, model *Model, files Files) (*Api
}
if model.Service.RequestFormat == ApiFormatOllama {
req.Model, _, _ = model.Model()
req.Model, _, _ = model.GetModel()
} else {
_, req.Model, req.Version = model.Model()
_, req.Model, req.Version = model.GetModel()
}
return req, nil

View File

@@ -87,35 +87,14 @@ func (openaiDefaults) Options(model *Model) *ApiRequestOptions {
switch model.Type {
case ModelTypeCaption:
/*
Options:
Detail: low
MaxOutputTokens: 512
Temperature: 0.1
TopP: 0.9
(Sampling values are zeroed for GPT-5 models in openaiBuilder.Build.)
*/
return &ApiRequestOptions{
Detail: openai.DefaultDetail,
MaxOutputTokens: openai.CaptionMaxTokens,
Temperature: openai.DefaultTemperature,
TopP: openai.DefaultTopP,
}
case ModelTypeLabels:
/*
Options:
Detail: low
MaxOutputTokens: 1024
Temperature: 0.1
ForceJson: true
SchemaVersion: "photoprism_vision_labels_v1"
(Sampling values are zeroed for GPT-5 models in openaiBuilder.Build.)
*/
return &ApiRequestOptions{
Detail: openai.DefaultDetail,
MaxOutputTokens: openai.LabelsMaxTokens,
Temperature: openai.DefaultTemperature,
TopP: openai.DefaultTopP,
ForceJson: true,
}
default:

View File

@@ -53,7 +53,8 @@ func DetectFaces(fileName string, minSize int, cacheCrop bool, expected int) (re
return result, err
}
_, apiRequest.Model, apiRequest.Version = model.Model()
_, apiRequest.Model, apiRequest.Version = model.GetModel()
model.ApplyService(apiRequest)
if model.System != "" {
apiRequest.System = model.System

View File

@@ -71,14 +71,11 @@ func labelsInternal(images Files, mediaSrc media.Src, labelSrc entity.Src) (resu
}
if apiRequest.Model == "" {
switch model.Service.RequestFormat {
case ApiFormatOllama:
apiRequest.Model, _, _ = model.Model()
default:
_, apiRequest.Model, apiRequest.Version = model.Model()
}
apiRequest.Model, _, apiRequest.Version = model.GetModel()
}
model.ApplyService(apiRequest)
if system := model.GetSystemPrompt(); system != "" {
apiRequest.System = system
}

View File

@@ -34,6 +34,7 @@ var (
type Model struct {
Type ModelType `yaml:"Type,omitempty" json:"type,omitempty"`
Default bool `yaml:"Default,omitempty" json:"default,omitempty"`
Model string `yaml:"Model,omitempty" json:"model,omitempty"`
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
Engine ModelEngine `yaml:"Engine,omitempty" json:"engine,omitempty"`
@@ -59,43 +60,55 @@ type Model struct {
// Models represents a set of computer vision models.
type Models []*Model
// Model returns the parsed and normalized identifier, name, and version
// strings. Nil receivers return empty values so callers can destructure the
// tuple without additional nil checks.
func (m *Model) Model() (model, name, version string) {
// GetModel returns the normalized model identifier, name, and version strings
// used in service requests. Callers can always destructure the tuple because
// nil receivers return empty values.
func (m *Model) GetModel() (model, name, version string) {
if m == nil {
return "", "", ""
}
// Return empty identifier string if no name was set.
if m.Name == "" {
return "", "", clean.TypeLowerDash(m.Version)
}
// Normalize model name.
// Normalise the configured values.
name = clean.TypeLower(m.Name)
// Split name to check if it contains the version.
s := strings.SplitN(name, ":", 2)
// Return if name contains both model name and version.
if len(s) == 2 && s[0] != "" && s[1] != "" {
return name, s[0], s[1]
}
// Normalize model version.
version = clean.TypeLowerDash(m.Version)
// Default to "latest" if no specific version was set.
// Build a base name from the highest-priority override:
// 1) Service-specific override (expanded for env vars)
// 2) Model-specific override
// 3) Declarative model name
serviceModel := m.Service.GetModel()
switch {
case serviceModel != "":
name = serviceModel
case strings.TrimSpace(m.Model) != "":
name = clean.TypeLower(m.Model)
}
// Return if no model is configured.
if name == "" {
return "", "", ""
}
// Split "name:version" strings so callers can access versioned models
// without repeating parsing logic at each call site.
if parts := strings.SplitN(name, ":", 2); len(parts) == 2 && parts[0] != "" && parts[1] != "" {
name = parts[0]
version = parts[1]
}
// Default to "latest" for non-OpenAI engines when no version was set.
if version == "" {
version = VersionLatest
}
// Create model identifier from model name and version.
model = strings.Join([]string{s[0], version}, ":")
// Return normalized model identifier, name, and version.
return model, name, version
switch m.Engine {
case openai.EngineName:
return name, name, ""
case ollama.EngineName:
return strings.Join([]string{name, version}, ":"), name, version
default:
return name, name, version
}
}
// IsDefault reports whether the model refers to one of the built-in defaults.
@@ -145,6 +158,19 @@ func (m *Model) Endpoint() (uri, method string) {
}
}
// ApplyService updates the ApiRequest with service-specific
// values when configured.
func (m *Model) ApplyService(apiRequest *ApiRequest) {
if m == nil || apiRequest == nil {
return
}
if m.Engine == openai.EngineName {
apiRequest.Org = m.Service.EndpointOrg()
apiRequest.Project = m.Service.EndpointProject()
}
}
// EndpointKey returns the access token belonging to the remote service
// endpoint, or an empty string for nil receivers.
func (m *Model) EndpointKey() (key string) {
@@ -347,6 +373,10 @@ func mergeOptionDefaults(target, defaults *ApiRequestOptions) {
target.TopP = defaults.TopP
}
if target.Temperature <= 0 && defaults.Temperature > 0 {
target.Temperature = defaults.Temperature
}
if len(target.Stop) == 0 && len(defaults.Stop) > 0 {
target.Stop = append([]string(nil), defaults.Stop...)
}
@@ -377,9 +407,7 @@ func normalizeOptions(opts *ApiRequestOptions) {
return
}
if opts.Temperature <= 0 {
opts.Temperature = DefaultTemperature
} else if opts.Temperature > MaxTemperature {
if opts.Temperature > MaxTemperature {
opts.Temperature = MaxTemperature
}
}

View File

@@ -3,6 +3,7 @@ package vision
import (
"os"
"path/filepath"
"sync"
"testing"
"github.com/stretchr/testify/assert"
@@ -25,7 +26,7 @@ func TestModelGetOptionsDefaultsOllamaLabels(t *testing.T) {
model.ApplyEngineDefaults()
m, n, v := model.Model()
m, n, v := model.GetModel()
assert.Equal(t, ollamaModel, m)
assert.Equal(t, "redule26/huihui_ai_qwen2.5-vl-7b-abliterated", n)
@@ -53,6 +54,106 @@ func TestModelGetOptionsDefaultsOllamaLabels(t *testing.T) {
}
}
func TestModel_GetModel(t *testing.T) {
tests := []struct {
name string
model *Model
wantModel string
wantName string
wantVersion string
}{
{
name: "Nil",
wantModel: "",
wantName: "",
wantVersion: "",
},
{
name: "OpenAINameOnly",
model: &Model{
Name: "gpt-5-mini",
Engine: openai.EngineName,
},
wantModel: "gpt-5-mini",
wantName: "gpt-5-mini",
wantVersion: "",
},
{
name: "NonOpenAIAddsLatest",
model: &Model{
Name: "gemma3",
Engine: ollama.EngineName,
},
wantModel: "gemma3:latest",
wantName: "gemma3",
wantVersion: "latest",
},
{
name: "ExplicitVersion",
model: &Model{
Name: "gemma3",
Version: "2",
Engine: ollama.EngineName,
},
wantModel: "gemma3:2",
wantName: "gemma3",
wantVersion: "2",
},
{
name: "NameContainsVersion",
model: &Model{
Name: "qwen2.5vl:7b",
Engine: ollama.EngineName,
},
wantModel: "qwen2.5vl:7b",
wantName: "qwen2.5vl",
wantVersion: "7b",
},
{
name: "ModelFieldFallback",
model: &Model{
Model: "CUSTOM-MODEL",
Engine: ollama.EngineName,
},
wantModel: "custom-model:latest",
wantName: "custom-model",
wantVersion: "latest",
},
{
name: "ServiceOverrideWithVersion",
model: &Model{
Name: "ignored",
Engine: ollama.EngineName,
Service: Service{Model: "mixtral:8x7b"},
},
wantModel: "mixtral:8x7b",
wantName: "mixtral",
wantVersion: "8x7b",
},
{
name: "ServiceOverrideOpenAI",
model: &Model{
Name: "gpt-4.1",
Engine: openai.EngineName,
Service: Service{Model: "gpt-5-mini"},
},
wantModel: "gpt-5-mini",
wantName: "gpt-5-mini",
wantVersion: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
model, name, version := tt.model.GetModel()
assert.Equal(t, tt.wantModel, model)
assert.Equal(t, tt.wantName, name)
assert.Equal(t, tt.wantVersion, version)
})
}
}
func TestModelGetOptionsRespectsCustomValues(t *testing.T) {
model := &Model{
Type: ModelTypeLabels,
@@ -152,6 +253,9 @@ func TestModelEndpointKeyOpenAIFallbacks(t *testing.T) {
t.Fatalf("write key file: %v", err)
}
// Reset ensureEnvOnce.
ensureEnvOnce = sync.Once{}
t.Setenv("OPENAI_API_KEY", "")
t.Setenv("OPENAI_API_KEY_FILE", path)
@@ -218,6 +322,30 @@ func TestModelGetSource(t *testing.T) {
})
}
func TestModelApplyService(t *testing.T) {
t.Run("OpenAIHeaders", func(t *testing.T) {
req := &ApiRequest{}
model := &Model{
Engine: openai.EngineName,
Service: Service{Org: "org-123", Project: "proj-abc"},
}
model.ApplyService(req)
assert.Equal(t, "org-123", req.Org)
assert.Equal(t, "proj-abc", req.Project)
})
t.Run("OtherEngineNoop", func(t *testing.T) {
req := &ApiRequest{Org: "keep", Project: "keep"}
model := &Model{Engine: ollama.EngineName, Service: Service{Org: "new", Project: "new"}}
model.ApplyService(req)
assert.Equal(t, "keep", req.Org)
assert.Equal(t, "keep", req.Project)
})
}
func TestModel_IsDefault(t *testing.T) {
nasnetCopy := *NasnetModel //nolint:govet // copy for test inspection only
nasnetCopy.Default = false

View File

@@ -47,13 +47,12 @@ func nsfwInternal(images Files, mediaSrc media.Src) (result []nsfw.Result, err e
return result, err
}
switch model.Service.RequestFormat {
case ApiFormatOllama:
apiRequest.Model, _, _ = model.Model()
default:
_, apiRequest.Model, apiRequest.Version = model.Model()
if apiRequest.Model == "" {
apiRequest.Model, _, apiRequest.Version = model.GetModel()
}
model.ApplyService(apiRequest)
if model.System != "" {
apiRequest.System = model.System
}

View File

@@ -29,7 +29,7 @@ This package provides PhotoPrisms native adapter for Ollama-compatible multim
### Architecture & Request Flow
1. **Model Selection**`Config.Model(ModelType)` returns the top-most enabled entry. When `Engine: ollama`, `ApplyEngineDefaults()` fills in the request/response format, base64 file scheme, and a 720px resolution unless overridden.
2. **Request Build**`ollamaBuilder.Build` wraps thumbnails with `NewApiRequestOllama`, which encodes them as base64 strings. `Model.Model()` resolves the exact Ollama tag (`gemma3:4b`, `qwen2.5vl:7b`, etc.).
2. **Request Build**`ollamaBuilder.Build` wraps thumbnails with `NewApiRequestOllama`, which encodes them as base64 strings. `Model.GetModel()` resolves the exact Ollama tag (`gemma3:4b`, `qwen2.5vl:7b`, etc.).
3. **Transport**`PerformApiRequest` uses a single HTTP POST (default timeout 10min). Authentication is optional; provide `Service.Key` if you proxy through an API gateway.
4. **Parsing**`ollamaParser.Parse` converts payloads into `ApiResponse`. It normalizes confidences (`LabelConfidenceDefault = 0.5` when missing), copies NSFW scores, and canonicalizes label names via `normalizeLabelResult`.
5. **Persistence**`entity.SrcOllama` is stamped on labels/captions so UI badges and audits reflect the new source.

View File

@@ -1,9 +1,11 @@
package vision
import (
"net/url"
"os"
"strings"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/http/scheme"
)
@@ -11,7 +13,12 @@ import (
type Service struct {
Uri string `yaml:"Uri,omitempty" json:"uri"`
Method string `yaml:"Method,omitempty" json:"method"`
Model string `yaml:"Model,omitempty" json:"model,omitempty"` // Optional endpoint-specific model override.
Username string `yaml:"Username,omitempty" json:"-"` // Optional basic auth user injected into Endpoint URLs.
Password string `yaml:"Password,omitempty" json:"-"`
Key string `yaml:"Key,omitempty" json:"-"`
Org string `yaml:"Org,omitempty" json:"org,omitempty"` // Optional organization header (e.g. OpenAI).
Project string `yaml:"Project,omitempty" json:"project,omitempty"` // Optional project header (e.g. OpenAI).
FileScheme string `yaml:"FileScheme,omitempty" json:"fileScheme,omitempty"`
RequestFormat ApiFormat `yaml:"RequestFormat,omitempty" json:"requestFormat,omitempty"`
ResponseFormat ApiFormat `yaml:"ResponseFormat,omitempty" json:"responseFormat,omitempty"`
@@ -20,7 +27,7 @@ type Service struct {
// Endpoint returns the remote service request method and endpoint URL, if any.
func (m *Service) Endpoint() (uri, method string) {
if m.Disabled || m.Uri == "" {
if m.Disabled || strings.TrimSpace(m.Uri) == "" {
return "", ""
}
@@ -30,7 +37,37 @@ func (m *Service) Endpoint() (uri, method string) {
method = ServiceMethod
}
return m.Uri, method
uri = strings.TrimSpace(m.Uri)
if username, password := m.BasicAuth(); username != "" || password != "" {
if parsed, err := url.Parse(uri); err == nil {
if parsed.User == nil {
switch {
case username != "" && password != "":
parsed.User = url.UserPassword(username, password)
case username != "":
parsed.User = url.User(username)
}
if parsed.User != nil {
uri = parsed.String()
}
}
}
}
return uri, method
}
// GetModel returns the model identifier override for the endpoint, if any.
func (m *Service) GetModel() string {
if m.Disabled {
return ""
}
ensureEnv()
return clean.TypeLower(os.ExpandEnv(m.Model))
}
// EndpointKey returns the access token belonging to the remote service endpoint, if any.
@@ -44,6 +81,36 @@ func (m *Service) EndpointKey() string {
return strings.TrimSpace(os.ExpandEnv(m.Key))
}
// EndpointOrg returns the organization identifier for the endpoint, if any.
func (m *Service) EndpointOrg() string {
if m.Disabled {
return ""
}
ensureEnv()
return strings.TrimSpace(os.ExpandEnv(m.Org))
}
// EndpointProject returns the project identifier for the endpoint, if any.
func (m *Service) EndpointProject() string {
if m.Disabled {
return ""
}
ensureEnv()
return strings.TrimSpace(os.ExpandEnv(m.Project))
}
// BasicAuth returns the username and password for basic authentication.
func (m *Service) BasicAuth() (username, password string) {
ensureEnv()
username = strings.TrimSpace(os.ExpandEnv(m.Username))
password = strings.TrimSpace(os.ExpandEnv(m.Password))
return username, password
}
// EndpointFileScheme returns the endpoint API file scheme type.
func (m *Service) EndpointFileScheme() scheme.Type {
if m.Disabled {

View File

@@ -0,0 +1,82 @@
package vision
import "testing"
func TestServiceEndpoint(t *testing.T) {
tests := []struct {
name string
svc Service
wantURI string
wantMethod string
}{
{
name: "Disabled",
svc: Service{Disabled: true, Uri: "https://vision.example.com"},
wantURI: "",
wantMethod: "",
},
{
name: "WithBasicAuth",
svc: Service{Uri: "https://vision.example.com/api", Username: "user", Password: "secret"},
wantURI: "https://user:secret@vision.example.com/api",
wantMethod: ServiceMethod,
},
{
name: "UsernameOnly",
svc: Service{Uri: "https://vision.example.com/", Username: "scoped"},
wantURI: "https://scoped@vision.example.com/",
wantMethod: ServiceMethod,
},
{
name: "PreserveExistingUser",
svc: Service{Uri: "https://keep:me@vision.example.com", Username: "ignored", Password: "ignored"},
wantURI: "https://keep:me@vision.example.com",
wantMethod: ServiceMethod,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
uri, method := tt.svc.Endpoint()
if uri != tt.wantURI {
t.Fatalf("uri: got %q want %q", uri, tt.wantURI)
}
if method != tt.wantMethod {
t.Fatalf("method: got %q want %q", method, tt.wantMethod)
}
})
}
}
func TestServiceCredentialsAndHeaders(t *testing.T) {
t.Setenv("VISION_USER", "alice")
t.Setenv("VISION_PASS", "hunter2")
t.Setenv("VISION_MODEL", "GEMMA3:Latest")
t.Setenv("VISION_ORG", "org-123")
t.Setenv("VISION_PROJECT", "proj-abc")
svc := Service{
Username: "${VISION_USER}",
Password: "${VISION_PASS}",
Model: "${VISION_MODEL}",
Org: "${VISION_ORG}",
Project: "${VISION_PROJECT}",
}
user, pass := svc.BasicAuth()
if user != "alice" || pass != "hunter2" {
t.Fatalf("basic auth: got %q/%q", user, pass)
}
if got := svc.GetModel(); got != "gemma3:latest" {
t.Fatalf("model override: got %q", got)
}
if got := svc.EndpointOrg(); got != "org-123" {
t.Fatalf("org: got %q", got)
}
if got := svc.EndpointProject(); got != "proj-abc" {
t.Fatalf("project: got %q", got)
}
}

View File

@@ -27,8 +27,8 @@ func visionListAction(ctx *cli.Context) error {
var rows [][]string
cols := []string{
"Type",
"Model",
"Type",
"Engine",
"Endpoint",
"Format",
@@ -52,7 +52,7 @@ func visionListAction(ctx *cli.Context) error {
modelUri, modelMethod := model.Endpoint()
tags := ""
name, _, _ := model.Model()
name, _, _ := model.GetModel()
if model.TensorFlow != nil && model.TensorFlow.Tags != nil {
tags = strings.Join(model.TensorFlow.Tags, ", ")
@@ -92,13 +92,13 @@ func visionListAction(ctx *cli.Context) error {
engine := model.EngineName()
rows[i] = []string{
model.Type,
name,
model.Type,
engine,
fmt.Sprintf("%s %s", modelMethod, modelUri),
format,
fmt.Sprintf("%d", model.Resolution),
report.Bool(len(options) == 0, "tags: "+tags, string(options)),
report.Bool(model.TensorFlow != nil, fmt.Sprintf(`{"tags":"%s"}`, tags), string(options)),
run,
report.Bool(model.Disabled, report.Disabled, report.Enabled),
}

View File

@@ -15,6 +15,8 @@ const (
Auth = "Authorization" // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
XAuthToken = "X-Auth-Token" //nolint:gosec // header name, not a secret
XSessionID = "X-Session-ID"
OpenAIOrg = "OpenAI-Organization"
OpenAIProject = "OpenAI-Project"
)
// Authentication header values.
@@ -74,6 +76,22 @@ func SetAuthorization(r *http.Request, authToken string) {
}
}
// SetOpenAIOrg adds the organization header expected by the OpenAI API if a
// non-empty value is provided.
func SetOpenAIOrg(r *http.Request, org string) {
if org = strings.TrimSpace(org); org != "" {
r.Header.Add(OpenAIOrg, org)
}
}
// SetOpenAIProject adds the project header expected by the OpenAI API if a
// non-empty value is provided.
func SetOpenAIProject(r *http.Request, project string) {
if project = strings.TrimSpace(project); project != "" {
r.Header.Add(OpenAIProject, project)
}
}
// BasicAuth checks the basic authorization header for credentials and returns them if found.
//
// Note that OAuth 2.0 defines basic authentication differently than RFC 7617, however, this

View File

@@ -64,6 +64,25 @@ func TestAuth(t *testing.T) {
})
}
func TestOpenAIHeaders(t *testing.T) {
t.Run("SetOrg", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
SetOpenAIOrg(r, " org-123 ")
assert.Equal(t, "org-123", r.Header.Get(OpenAIOrg))
SetOpenAIOrg(r, "")
assert.Equal(t, "org-123", r.Header.Get(OpenAIOrg))
})
t.Run("SetProject", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
SetOpenAIProject(r, "proj-abc")
assert.Equal(t, "proj-abc", r.Header.Get(OpenAIProject))
SetOpenAIProject(r, " ")
assert.Equal(t, "proj-abc", r.Header.Get(OpenAIProject))
})
}
func TestAuthToken(t *testing.T) {
t.Run("None", func(t *testing.T) {
gin.SetMode(gin.TestMode)