AI: Configure vision model execution and scheduling #5232 #5233 #5234

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-09-30 15:51:48 +02:00
parent 0c9f82a641
commit d782a43c2b
46 changed files with 1094 additions and 420 deletions

View File

@@ -1,6 +1,6 @@
# PhotoPrism® Repository Guidelines
**Last Updated:** September 29, 2025
**Last Updated:** September 30, 2025
## Purpose
@@ -273,6 +273,7 @@ If anything in this file conflicts with the `Makefile` or the Developer Guide, t
- Respect precedence: `options.yml` overrides CLI/env values, which override defaults. When adding a new option, update `internal/config/options.go` (yaml/flag tags), register it in `internal/config/flags.go`, expose a getter, surface it in `*config.Report()`, and write generated values back to `options.yml` by setting `c.options.OptionsYaml` before persisting. Use `CliTestContext` in `internal/config/test.go` to exercise new flags.
- When touching configuration in Go code, use the public accessors on `*config.Config` (e.g. `Config.JWKSUrl()`, `Config.SetJWKSUrl()`, `Config.ClusterUUID()`) instead of mutating `Config.Options()` directly; reserve raw option tweaks for test fixtures only.
- Vision worker scheduling is controlled via `VisionSchedule` / `VisionFilter` and the `Run` property set in `vision.yml`. Utilities like `vision.FilterModels` and `entity.Photo.ShouldGenerateLabels/Caption` help decide when work is required before loading media files.
- Logging: use the shared logger (`event.Log`) via the package-level `log` variable (see `internal/auth/jwt/logger.go`) instead of direct `fmt.Print*` or ad-hoc loggers.
- Cluster registry tests (`internal/service/cluster/registry`) currently rely on a full test config because they persist `entity.Client` rows. They run migrations and seed the SQLite DB, so they are intentionally slow. If you refactor them, consider sharing a single `config.TestConfig()` across subtests or building a lightweight schema harness; do not swap to the minimal config helper unless the tests stop touching the database.
- Favor explicit CLI flags: check `c.cliCtx.IsSet("<flag>")` before overriding user-supplied values, and follow the `ClusterUUID` pattern (`options.yml` → CLI/env → generated UUIDv4 persisted).

View File

@@ -60,12 +60,12 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
format := apiRequest.GetResponseFormat()
if provider, ok := ProviderFor(format); ok && provider.Parser != nil {
if engine, ok := EngineFor(format); ok && engine.Parser != nil {
if clientResp.StatusCode >= 300 {
log.Debugf("vision: %s (status code %d)", body, clientResp.StatusCode)
}
parsed, parseErr := provider.Parser.Parse(context.Background(), apiRequest, body, clientResp.StatusCode)
parsed, parseErr := engine.Parser.Parse(context.Background(), apiRequest, body, clientResp.StatusCode)
if parseErr != nil {
return nil, parseErr
}

View File

@@ -4,14 +4,27 @@ import (
"context"
"errors"
"github.com/photoprism/photoprism/internal/ai/vision/ollama"
"github.com/photoprism/photoprism/internal/ai/vision/openai"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/pkg/media"
)
var captionFunc = captionInternal
// SetCaptionFunc overrides the caption generator. Intended for tests.
func SetCaptionFunc(fn func(Files, media.Src) (*CaptionResult, *Model, error)) {
if fn == nil {
captionFunc = captionInternal
return
}
captionFunc = fn
}
// Caption returns generated captions for the specified images.
func Caption(images Files, mediaSrc media.Src) (result *CaptionResult, model *Model, err error) {
func Caption(images Files, mediaSrc media.Src) (*CaptionResult, *Model, error) {
return captionFunc(images, mediaSrc)
}
func captionInternal(images Files, mediaSrc media.Src) (result *CaptionResult, model *Model, err error) {
// Return if there is no configuration or no image classification models are configured.
if Config == nil {
return result, model, errors.New("vision service is not configured")
@@ -21,8 +34,8 @@ func Caption(images Files, mediaSrc media.Src) (result *CaptionResult, model *Mo
var apiRequest *ApiRequest
var apiResponse *ApiResponse
if provider, ok := ProviderFor(model.EndpointRequestFormat()); ok && provider.Builder != nil {
if apiRequest, err = provider.Builder.Build(context.Background(), model, images); err != nil {
if engine, ok := EngineFor(model.EndpointRequestFormat()); ok && engine.Builder != nil {
if apiRequest, err = engine.Builder.Build(context.Background(), model, images); err != nil {
return result, model, err
}
} else if apiRequest, err = NewApiRequest(model.EndpointRequestFormat(), images, model.EndpointFileScheme()); err != nil {
@@ -50,15 +63,8 @@ func Caption(images Files, mediaSrc media.Src) (result *CaptionResult, model *Mo
}
// Set image as the default caption source.
if apiResponse.Result.Caption.Text != "" && apiResponse.Result.Caption.Source == "" {
switch model.Provider {
case ollama.ProviderName:
apiResponse.Result.Caption.Source = entity.SrcOllama
case openai.ProviderName:
apiResponse.Result.Caption.Source = entity.SrcOpenAI
default:
apiResponse.Result.Caption.Source = entity.SrcImage
}
if apiResponse.Result.Caption.Source == "" {
apiResponse.Result.Caption.Source = model.GetSource()
}
result = apiResponse.Result.Caption

View File

@@ -47,7 +47,7 @@ func NewConfig() *ConfigValues {
}
for _, model := range cfg.Models {
model.ApplyProviderDefaults()
model.ApplyEngineDefaults()
}
return cfg
@@ -95,7 +95,7 @@ func (c *ConfigValues) Load(fileName string) error {
}
for _, model := range c.Models {
model.ApplyProviderDefaults()
model.ApplyEngineDefaults()
}
if c.Thresholds.Confidence <= 0 || c.Thresholds.Confidence > 100 {
@@ -124,7 +124,9 @@ func (c *ConfigValues) Save(fileName string) error {
return nil
}
// Model returns the first enabled model with the matching type from the configuration.
// Model returns the first enabled model with the matching type.
// It returns nil if no matching model is available or every model of that
// type is disabled, allowing callers to chain nil-safe Model methods.
func (c *ConfigValues) Model(t ModelType) *Model {
for i := len(c.Models) - 1; i >= 0; i-- {
m := c.Models[i]
@@ -136,7 +138,9 @@ func (c *ConfigValues) Model(t ModelType) *Model {
return nil
}
// ShouldRun checks when the specified model type should run.
// ShouldRun reports whether the configured model for the given type is
// allowed to run in the specified context. It returns false when no
// suitable model exists or when execution is explicitly disabled.
func (c *ConfigValues) ShouldRun(t ModelType, when RunType) bool {
m := c.Model(t)

View File

@@ -37,7 +37,7 @@ func TestConfigModelPrefersLastEnabled(t *testing.T) {
customModel := &Model{
Type: ModelTypeLabels,
Name: "ollama-labels",
Provider: "ollama",
Engine: "ollama",
Disabled: false,
}
@@ -74,7 +74,7 @@ func TestConfigValues_IsDefaultAndIsCustom(t *testing.T) {
}
})
t.Run("CustomOverridesDefault", func(t *testing.T) {
custom := &Model{Type: ModelTypeLabels, Name: "custom", Provider: "ollama"}
custom := &Model{Type: ModelTypeLabels, Name: "custom", Engine: "ollama"}
cfg := &ConfigValues{Models: Models{defaultModel, custom}}
if cfg.IsDefault(ModelTypeLabels) {
t.Fatalf("expected custom model to disable default detection")
@@ -84,7 +84,7 @@ func TestConfigValues_IsDefaultAndIsCustom(t *testing.T) {
}
})
t.Run("DisabledCustomFallsBackToDefault", func(t *testing.T) {
custom := &Model{Type: ModelTypeLabels, Name: "custom", Provider: "ollama", Disabled: true}
custom := &Model{Type: ModelTypeLabels, Name: "custom", Engine: "ollama", Disabled: true}
cfg := &ConfigValues{Models: Models{defaultModel, custom}}
if !cfg.IsDefault(ModelTypeLabels) {
t.Fatalf("expected disabled custom model to fall back to default")

View File

@@ -0,0 +1,121 @@
package vision
import (
"context"
"strings"
"sync"
"github.com/photoprism/photoprism/internal/ai/vision/openai"
"github.com/photoprism/photoprism/pkg/service/http/scheme"
)
// ModelEngine represents the canonical identifier for a computer vision service engine.
type ModelEngine = string
const (
// EngineVision represents the default PhotoPrism vision service endpoints.
EngineVision ModelEngine = "vision"
// EngineTensorFlow represents on-device TensorFlow models.
EngineTensorFlow ModelEngine = "tensorflow"
// EngineLocal is used when no explicit engine can be determined.
EngineLocal ModelEngine = "local"
)
// RequestBuilder builds an API request for an engine based on the model configuration and input files.
type RequestBuilder interface {
Build(ctx context.Context, model *Model, files Files) (*ApiRequest, error)
}
// ResponseParser parses a raw engine response into the generic ApiResponse structure.
type ResponseParser interface {
Parse(ctx context.Context, req *ApiRequest, raw []byte, status int) (*ApiResponse, error)
}
// EngineDefaults supplies engine-specific prompt and schema defaults when they are not configured explicitly.
type EngineDefaults interface {
SystemPrompt(model *Model) string
UserPrompt(model *Model) string
SchemaTemplate(model *Model) string
Options(model *Model) *ApiRequestOptions
}
// Engine groups the callbacks required to integrate a third-party vision service.
type Engine struct {
Builder RequestBuilder
Parser ResponseParser
Defaults EngineDefaults
}
var (
engineRegistry = make(map[ApiFormat]Engine)
engineAliasIndex = make(map[string]EngineInfo)
engineMu sync.RWMutex
)
// init wires up the built-in aliases so configuration files can reference the
// human-friendly engine names without duplicating adapter metadata.
func init() {
RegisterEngineAlias(EngineVision, EngineInfo{
RequestFormat: ApiFormatVision,
ResponseFormat: ApiFormatVision,
FileScheme: string(scheme.Data),
Resolution: DefaultResolution,
})
RegisterEngineAlias(openai.EngineName, EngineInfo{
RequestFormat: ApiFormatOpenAI,
ResponseFormat: ApiFormatOpenAI,
FileScheme: string(scheme.Data),
Resolution: openai.Resolution,
})
}
// RegisterEngine adds/overrides an engine implementation for a specific API format.
func RegisterEngine(format ApiFormat, engine Engine) {
engineMu.Lock()
defer engineMu.Unlock()
engineRegistry[format] = engine
}
// EngineInfo describes metadata that can be associated with an engine alias.
type EngineInfo struct {
RequestFormat ApiFormat
ResponseFormat ApiFormat
FileScheme string
Resolution int
}
// RegisterEngineAlias maps a logical engine name (e.g., "ollama") to a
// request/response format pair.
func RegisterEngineAlias(name string, info EngineInfo) {
name = strings.TrimSpace(strings.ToLower(name))
if name == "" || info.RequestFormat == "" {
return
}
if info.ResponseFormat == "" {
info.ResponseFormat = info.RequestFormat
}
engineMu.Lock()
engineAliasIndex[name] = info
engineMu.Unlock()
}
// EngineInfoFor returns the metadata associated with a logical engine name.
func EngineInfoFor(name string) (EngineInfo, bool) {
name = strings.TrimSpace(strings.ToLower(name))
engineMu.RLock()
info, ok := engineAliasIndex[name]
engineMu.RUnlock()
return info, ok
}
// EngineFor returns the registered engine implementation for the given API
// format, if any.
func EngineFor(format ApiFormat) (Engine, bool) {
engineMu.RLock()
defer engineMu.RUnlock()
engine, ok := engineRegistry[format]
return engine, ok
}

View File

@@ -18,20 +18,23 @@ type ollamaBuilder struct{}
type ollamaParser struct{}
func init() {
RegisterProvider(ApiFormatOllama, Provider{
RegisterEngine(ApiFormatOllama, Engine{
Builder: ollamaBuilder{},
Parser: ollamaParser{},
Defaults: ollamaDefaults{},
})
RegisterProviderAlias(ollama.ProviderName, ProviderInfo{
// Register the human-friendly engine name so configuration can simply use
// `Engine: "ollama"` and inherit adapter defaults.
RegisterEngineAlias(ollama.EngineName, EngineInfo{
RequestFormat: ApiFormatOllama,
ResponseFormat: ApiFormatOllama,
FileScheme: string(scheme.Base64),
Resolution: ollama.Resolution,
})
CaptionModel.Provider = ollama.ProviderName
CaptionModel.ApplyProviderDefaults()
CaptionModel.Engine = ollama.EngineName
CaptionModel.ApplyEngineDefaults()
}
func (ollamaDefaults) SystemPrompt(model *Model) string {

View File

@@ -0,0 +1,18 @@
package vision
import (
"github.com/photoprism/photoprism/internal/ai/vision/openai"
"github.com/photoprism/photoprism/pkg/service/http/scheme"
)
// init registers the OpenAI engine alias so models can set Engine: "openai"
// and inherit sensible defaults (request/response formats, file scheme, and
// preferred thumbnail resolution).
func init() {
RegisterEngineAlias(openai.EngineName, EngineInfo{
RequestFormat: ApiFormatOpenAI,
ResponseFormat: ApiFormatOpenAI,
FileScheme: string(scheme.Base64),
Resolution: openai.Resolution,
})
}

View File

@@ -0,0 +1,78 @@
package vision
import "testing"
func TestRegisterEngineAlias(t *testing.T) {
const alias = "unit-test"
engineMu.Lock()
prev, had := engineAliasIndex[alias]
if had {
delete(engineAliasIndex, alias)
}
engineMu.Unlock()
t.Cleanup(func() {
engineMu.Lock()
if had {
engineAliasIndex[alias] = prev
} else {
delete(engineAliasIndex, alias)
}
engineMu.Unlock()
})
RegisterEngineAlias(" Unit-Test ", EngineInfo{RequestFormat: ApiFormat("custom"), ResponseFormat: "", FileScheme: "data", Resolution: 512})
info, ok := EngineInfoFor(alias)
if !ok {
t.Fatalf("expected engine alias %q to be registered", alias)
}
if info.RequestFormat != ApiFormat("custom") {
t.Errorf("unexpected request format: %s", info.RequestFormat)
}
if info.ResponseFormat != ApiFormat("custom") {
t.Errorf("expected response format default to request, got %s", info.ResponseFormat)
}
if info.FileScheme != "data" {
t.Errorf("unexpected file scheme: %s", info.FileScheme)
}
if info.Resolution != 512 {
t.Errorf("unexpected resolution: %d", info.Resolution)
}
}
func TestRegisterEngine(t *testing.T) {
format := ApiFormat("unit-format")
engine := Engine{}
engineMu.Lock()
prev, had := engineRegistry[format]
if had {
delete(engineRegistry, format)
}
engineMu.Unlock()
t.Cleanup(func() {
engineMu.Lock()
if had {
engineRegistry[format] = prev
} else {
delete(engineRegistry, format)
}
engineMu.Unlock()
})
RegisterEngine(format, engine)
got, ok := EngineFor(format)
if !ok {
t.Fatalf("expected engine for %s", format)
}
if got != engine {
t.Errorf("unexpected engine value: %#v", got)
}
}

View File

@@ -13,10 +13,26 @@ import (
"github.com/photoprism/photoprism/pkg/media"
)
var labelsFunc = labelsInternal
// SetLabelsFunc overrides the labels generator. Intended for tests.
func SetLabelsFunc(fn func(Files, media.Src, string) (classify.Labels, error)) {
if fn == nil {
labelsFunc = labelsInternal
return
}
labelsFunc = fn
}
// Labels finds matching labels for the specified image.
// Caller must pass the appropriate metadata source string (e.g., entity.SrcOllama, entity.SrcOpenAI)
// so that downstream indexing can record where the labels originated.
func Labels(images Files, mediaSrc media.Src, labelSrc string) (result classify.Labels, err error) {
func Labels(images Files, mediaSrc media.Src, labelSrc string) (classify.Labels, error) {
return labelsFunc(images, mediaSrc, labelSrc)
}
func labelsInternal(images Files, mediaSrc media.Src, labelSrc string) (result classify.Labels, err error) {
// Return if no thumbnail filenames were given.
if len(images) == 0 {
return result, errors.New("at least one image required")
@@ -42,8 +58,8 @@ func Labels(images Files, mediaSrc media.Src, labelSrc string) (result classify.
var apiRequest *ApiRequest
var apiResponse *ApiResponse
if provider, ok := ProviderFor(model.EndpointRequestFormat()); ok && provider.Builder != nil {
if apiRequest, err = provider.Builder.Build(context.Background(), model, images); err != nil {
if engine, ok := EngineFor(model.EndpointRequestFormat()); ok && engine.Builder != nil {
if apiRequest, err = engine.Builder.Build(context.Background(), model, images); err != nil {
return result, err
}
} else if apiRequest, err = NewApiRequest(model.EndpointRequestFormat(), images, model.EndpointFileScheme()); err != nil {

View File

@@ -13,6 +13,7 @@ import (
"github.com/photoprism/photoprism/internal/ai/vision/ollama"
"github.com/photoprism/photoprism/internal/ai/vision/openai"
visionschema "github.com/photoprism/photoprism/internal/ai/vision/schema"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/service/http/scheme"
@@ -35,7 +36,7 @@ type Model struct {
Default bool `yaml:"Default,omitempty" json:"default,omitempty"`
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
Provider ModelProvider `yaml:"Provider,omitempty" json:"provider,omitempty"`
Engine ModelEngine `yaml:"Engine,omitempty" json:"engine,omitempty"`
Run RunType `yaml:"Run,omitempty" json:"Run,omitempty"` // "auto", "never", "manual", "always", "newly-indexed", "on-schedule"
System string `yaml:"System,omitempty" json:"system,omitempty"`
Prompt string `yaml:"Prompt,omitempty" json:"prompt,omitempty"`
@@ -58,8 +59,14 @@ type Model struct {
// Models represents a set of computer vision models.
type Models []*Model
// Model returns the parsed and normalized model identifier, name, and version strings.
// Model returns the parsed and normalized identifier, name, and version
// strings. Nil receivers return empty values so callers can destructure the
// tuple without additional nil checks.
func (m *Model) Model() (model, name, version string) {
if m == nil {
return "", "", ""
}
// Return empty identifier string if no name was set.
if m.Name == "" {
return "", "", clean.TypeLowerDash(m.Version)
@@ -91,8 +98,13 @@ func (m *Model) Model() (model, name, version string) {
return model, name, version
}
// IsDefault checks if this is a built-in default model.
// IsDefault reports whether the model refers to one of the built-in defaults.
// Nil receivers return false.
func (m *Model) IsDefault() bool {
if m == nil {
return false
}
if m.Default {
return true
}
@@ -115,8 +127,13 @@ func (m *Model) IsDefault() bool {
return false
}
// Endpoint returns the remote service request method and endpoint URL, if any.
// Endpoint returns the remote service request method and endpoint URL. Nil
// receivers return empty strings.
func (m *Model) Endpoint() (uri, method string) {
if m == nil {
return uri, method
}
if uri, method = m.Service.Endpoint(); uri != "" && method != "" {
return uri, method
} else if ServiceUri == "" {
@@ -128,8 +145,13 @@ func (m *Model) Endpoint() (uri, method string) {
}
}
// EndpointKey returns the access token belonging to the remote service endpoint, if any.
// EndpointKey returns the access token belonging to the remote service
// endpoint, or an empty string for nil receivers.
func (m *Model) EndpointKey() (key string) {
if m == nil {
return ""
}
if key = m.Service.EndpointKey(); key != "" {
return key
} else {
@@ -137,8 +159,13 @@ func (m *Model) EndpointKey() (key string) {
}
}
// EndpointFileScheme returns the endpoint API request file scheme type.
// EndpointFileScheme returns the endpoint API request file scheme type. Nil
// receivers fall back to the global default scheme.
func (m *Model) EndpointFileScheme() (fileScheme scheme.Type) {
if m == nil {
return ""
}
if fileScheme = m.Service.EndpointFileScheme(); fileScheme != "" {
return fileScheme
}
@@ -146,8 +173,13 @@ func (m *Model) EndpointFileScheme() (fileScheme scheme.Type) {
return ServiceFileScheme
}
// EndpointRequestFormat returns the endpoint API request format.
// EndpointRequestFormat returns the endpoint API request format. Nil receivers
// fall back to the global default format.
func (m *Model) EndpointRequestFormat() (format ApiFormat) {
if m == nil {
return ""
}
if format = m.Service.EndpointRequestFormat(); format != "" {
return format
}
@@ -155,8 +187,13 @@ func (m *Model) EndpointRequestFormat() (format ApiFormat) {
return ServiceRequestFormat
}
// EndpointResponseFormat returns the endpoint API response format.
// EndpointResponseFormat returns the endpoint API response format. Nil
// receivers fall back to the global default format.
func (m *Model) EndpointResponseFormat() (format ApiFormat) {
if m == nil {
return ""
}
if format = m.Service.EndpointResponseFormat(); format != "" {
return format
}
@@ -164,13 +201,18 @@ func (m *Model) EndpointResponseFormat() (format ApiFormat) {
return ServiceResponseFormat
}
// GetPrompt returns the configured model prompt, or the default prompt if none is specified.
// GetPrompt returns the configured model prompt, using engine defaults when
// none is specified. Nil receivers return an empty string.
func (m *Model) GetPrompt() string {
if m == nil {
return ""
}
if m.Prompt != "" {
return m.Prompt
}
if defaults := m.defaultsProvider(); defaults != nil {
if defaults := m.engineDefaults(); defaults != nil {
if prompt := defaults.UserPrompt(m); prompt != "" {
return prompt
}
@@ -186,13 +228,19 @@ func (m *Model) GetPrompt() string {
}
}
// GetSystemPrompt returns the configured system model prompt, or the default system prompt if none is specified.
// GetSystemPrompt returns the configured system prompt, falling back to
// engine defaults when none is specified. Nil receivers return an empty
// string.
func (m *Model) GetSystemPrompt() string {
if m == nil {
return ""
}
if m.System != "" {
return m.System
}
if defaults := m.defaultsProvider(); defaults != nil {
if defaults := m.engineDefaults(); defaults != nil {
if system := defaults.SystemPrompt(m); system != "" {
return system
}
@@ -206,8 +254,13 @@ func (m *Model) GetSystemPrompt() string {
}
}
// GetFormat returns the configured response format or a sensible default.
// GetFormat returns the configured response format or a sensible default. Nil
// receivers return an empty string.
func (m *Model) GetFormat() string {
if m == nil {
return ""
}
if f := strings.TrimSpace(strings.ToLower(m.Format)); f != "" {
return f
}
@@ -219,28 +272,56 @@ func (m *Model) GetFormat() string {
return ""
}
// GetOptions returns the API request options.
// GetSource returns the default entity src based on the model configuration.
func (m *Model) GetSource() string {
if m == nil {
return entity.SrcAuto
}
switch m.EngineName() {
case ollama.EngineName:
return entity.SrcOllama
case openai.EngineName:
return entity.SrcOpenAI
}
switch m.EndpointRequestFormat() {
case ApiFormatOllama:
return entity.SrcOllama
case ApiFormatOpenAI:
return entity.SrcOpenAI
}
return entity.SrcImage
}
// GetOptions returns the API request options, applying engine defaults on
// demand. Nil receivers return nil.
func (m *Model) GetOptions() *ApiRequestOptions {
var providerDefaults *ApiRequestOptions
if defaults := m.defaultsProvider(); defaults != nil {
providerDefaults = cloneOptions(defaults.Options(m))
if m == nil {
return nil
}
var engineDefaults *ApiRequestOptions
if defaults := m.engineDefaults(); defaults != nil {
engineDefaults = cloneOptions(defaults.Options(m))
}
if m.Options == nil {
switch m.Type {
case ModelTypeLabels, ModelTypeCaption, ModelTypeGenerate:
if providerDefaults == nil {
providerDefaults = &ApiRequestOptions{}
if engineDefaults == nil {
engineDefaults = &ApiRequestOptions{}
}
normalizeOptions(providerDefaults)
m.Options = providerDefaults
normalizeOptions(engineDefaults)
m.Options = engineDefaults
return m.Options
default:
return nil
}
}
mergeOptionDefaults(m.Options, providerDefaults)
mergeOptionDefaults(m.Options, engineDefaults)
normalizeOptions(m.Options)
return m.Options
@@ -286,14 +367,15 @@ func cloneOptions(opts *ApiRequestOptions) *ApiRequestOptions {
return &clone
}
// ProviderName returns the configured provider or infers a sensible default based on the model settings.
func (m *Model) ProviderName() string {
// EngineName returns the normalized engine identifier or infers one from the
// request configuration. Nil receivers return an empty string.
func (m *Model) EngineName() string {
if m == nil {
return ""
}
if provider := strings.TrimSpace(strings.ToLower(m.Provider)); provider != "" {
return provider
if engine := strings.TrimSpace(strings.ToLower(m.Engine)); engine != "" {
return engine
}
uri, method := m.Endpoint()
@@ -301,36 +383,36 @@ func (m *Model) ProviderName() string {
format := m.EndpointRequestFormat()
switch format {
case ApiFormatOllama:
return ollama.ProviderName
return ollama.EngineName
case ApiFormatOpenAI:
return openai.ProviderName
return openai.EngineName
case ApiFormatVision, "":
return ProviderVision
return EngineVision
default:
return strings.ToLower(string(format))
}
}
if m.TensorFlow != nil {
return ProviderTensorFlow
return EngineTensorFlow
}
return ProviderLocal
return EngineLocal
}
// ApplyProviderDefaults normalizes the provider name and applies registered provider defaults
// for request/response formats and file schemes when these are not explicitly configured.
func (m *Model) ApplyProviderDefaults() {
// ApplyEngineDefaults normalizes the engine name and applies registered engine
// defaults (formats, schemes, resolution) when these are not explicitly configured.
func (m *Model) ApplyEngineDefaults() {
if m == nil {
return
}
provider := strings.TrimSpace(strings.ToLower(m.Provider))
if provider == "" {
engine := strings.TrimSpace(strings.ToLower(m.Engine))
if engine == "" {
return
}
if info, ok := ProviderInfoFor(provider); ok {
if info, ok := EngineInfoFor(engine); ok {
if m.Service.RequestFormat == "" {
m.Service.RequestFormat = info.RequestFormat
}
@@ -342,13 +424,22 @@ func (m *Model) ApplyProviderDefaults() {
if info.FileScheme != "" && m.Service.FileScheme == "" {
m.Service.FileScheme = info.FileScheme
}
if info.Resolution > 0 && m.Resolution <= 0 {
m.Resolution = info.Resolution
}
}
m.Provider = provider
m.Engine = engine
}
// SchemaTemplate returns the model-specific JSON schema template, if any.
// SchemaTemplate returns the model-specific JSON schema template, if any. Nil
// receivers return an empty string.
func (m *Model) SchemaTemplate() string {
if m == nil {
return ""
}
m.schemaOnce.Do(func() {
var schemaText string
@@ -385,7 +476,7 @@ func (m *Model) SchemaTemplate() string {
m.schema = strings.TrimSpace(schemaText)
if m.schema == "" && m.Type == ModelTypeLabels {
if defaults := m.defaultsProvider(); defaults != nil {
if defaults := m.engineDefaults(); defaults != nil {
m.schema = strings.TrimSpace(defaults.SchemaTemplate(m))
}
}
@@ -398,21 +489,30 @@ func (m *Model) SchemaTemplate() string {
return m.schema
}
func (m *Model) defaultsProvider() DefaultsProvider {
if provider, ok := ProviderFor(m.EndpointRequestFormat()); ok {
return provider.Defaults
func (m *Model) engineDefaults() EngineDefaults {
if m == nil {
return nil
}
if info, ok := ProviderInfoFor(m.ProviderName()); ok {
if provider, ok := ProviderFor(info.RequestFormat); ok {
return provider.Defaults
if engine, ok := EngineFor(m.EndpointRequestFormat()); ok {
return engine.Defaults
}
if info, ok := EngineInfoFor(m.EngineName()); ok {
if engine, ok := EngineFor(info.RequestFormat); ok {
return engine.Defaults
}
}
return nil
}
// SchemaInstructions returns a helper string that can be appended to prompts.
// Nil receivers return an empty string.
func (m *Model) SchemaInstructions() string {
if m == nil {
return ""
}
if schema := m.SchemaTemplate(); schema != "" {
return fmt.Sprintf("Return JSON that matches this schema:\n%s", schema)
}
@@ -420,8 +520,13 @@ func (m *Model) SchemaInstructions() string {
return ""
}
// ClassifyModel returns the matching classify model instance, if any.
// ClassifyModel returns the matching classify model instance, if any. Nil
// receivers return nil.
func (m *Model) ClassifyModel() *classify.Model {
if m == nil {
return nil
}
// Use mutex to prevent models from being loaded and
// initialized twice by different indexing workers.
modelMutex.Lock()
@@ -481,8 +586,13 @@ func (m *Model) ClassifyModel() *classify.Model {
return m.classifyModel
}
// FaceModel returns the matching face model instance, if any.
// FaceModel returns the matching face recognition model instance, if any. Nil
// receivers return nil.
func (m *Model) FaceModel() *face.Model {
if m == nil {
return nil
}
// Use mutex to prevent models from being loaded and
// initialized twice by different indexing workers.
modelMutex.Lock()
@@ -536,8 +646,13 @@ func (m *Model) FaceModel() *face.Model {
return m.faceModel
}
// NsfwModel returns the matching nsfw model instance, if any.
// NsfwModel returns the matching nsfw model instance, if any. Nil receivers
// return nil.
func (m *Model) NsfwModel() *nsfw.Model {
if m == nil {
return nil
}
// Use mutex to prevent models from being loaded and
// initialized twice by different indexing workers.
modelMutex.Lock()
@@ -597,8 +712,12 @@ func (m *Model) NsfwModel() *nsfw.Model {
return m.nsfwModel
}
// Clone returns a clone of this model.
// Clone returns a shallow copy of the model. Nil receivers return nil.
func (m *Model) Clone() *Model {
if m == nil {
return nil
}
c := *m
return &c
}

View File

@@ -0,0 +1,29 @@
package vision
import (
"strings"
)
// FilterModels takes a list of model type names and a scheduling context, and
// returns only the types that are allowed to run according to the supplied
// predicate. Empty or unknown names are ignored.
func FilterModels(models []string, when RunType, allow func(ModelType, RunType) bool) []string {
if len(models) == 0 {
return models
}
filtered := make([]string, 0, len(models))
for _, name := range models {
modelType := ModelType(strings.TrimSpace(name))
if modelType == "" {
continue
}
if allow == nil || allow(modelType, when) {
filtered = append(filtered, string(modelType))
}
}
return filtered
}

View File

@@ -0,0 +1,56 @@
package vision
import (
"testing"
)
func TestFilterModels(t *testing.T) {
cases := []struct {
name string
models []string
when RunType
allow func(ModelType, RunType) bool
expected []string
}{
{
name: "NilPredicate",
models: []string{"caption", "labels"},
when: RunManual,
allow: nil,
expected: []string{"caption", "labels"},
},
{
name: "SkipUnknown",
models: []string{"caption", "", "unknown", "labels"},
when: RunManual,
allow: func(mt ModelType, when RunType) bool {
return mt == ModelTypeLabels
},
expected: []string{"labels"},
},
{
name: "ContextAware",
models: []string{"caption", "labels"},
when: RunOnSchedule,
allow: func(mt ModelType, when RunType) bool {
return mt == ModelTypeCaption
},
expected: []string{"caption"},
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
got := FilterModels(tc.models, tc.when, tc.allow)
if len(got) != len(tc.expected) {
t.Fatalf("expected %d models, got %d", len(tc.expected), len(got))
}
for i := range got {
if got[i] != tc.expected[i] {
t.Fatalf("expected %v, got %v", tc.expected, got)
}
}
})
}
}

View File

@@ -38,7 +38,8 @@ var RunTypes = map[string]RunType{
"index": RunOnIndex,
}
// ParseRunType parses a run type string.
// ParseRunType parses a run type string into the canonical RunType constant.
// Unknown or empty values default to RunAuto.
func ParseRunType(s string) RunType {
if t, ok := RunTypes[clean.TypeLowerDash(s)]; ok {
return t
@@ -47,13 +48,23 @@ func ParseRunType(s string) RunType {
return RunAuto
}
// RunType returns a normalized type that specifies when a vision model should run.
// RunType returns the normalized run type configured for the model. Nil
// receivers default to RunAuto.
func (m *Model) RunType() RunType {
if m == nil {
return RunAuto
}
return ParseRunType(m.Run)
}
// ShouldRun checks when the model should run based on the specified type.
// ShouldRun reports whether the model should execute in the specified
// scheduling context. Nil receivers always return false.
func (m *Model) ShouldRun(when RunType) bool {
if m == nil {
return false
}
when = ParseRunType(when)
switch m.RunType() {

View File

@@ -35,6 +35,11 @@ func TestModel_RunType(t *testing.T) {
model *Model
want RunType
}{
{
name: "Nil",
model: nil,
want: RunAuto,
},
{
name: "Manual",
model: &Model{Run: "manual"},
@@ -135,6 +140,13 @@ func TestModel_ShouldRun_RunNever(t *testing.T) {
assertShouldRun(t, model, RunOnDemand, false)
}
func TestModel_ShouldRun_NilModel(t *testing.T) {
var model *Model
if model.ShouldRun(RunManual) {
t.Fatalf("expected nil model to never run")
}
}
func TestModel_ShouldRun_RunOnIndex(t *testing.T) {
model := &Model{Run: string(RunOnIndex)}

View File

@@ -5,15 +5,16 @@ import (
"github.com/photoprism/photoprism/internal/ai/tensorflow"
"github.com/photoprism/photoprism/internal/ai/vision/ollama"
"github.com/photoprism/photoprism/internal/entity"
)
func TestModelGetOptionsDefaultsOllamaLabels(t *testing.T) {
model := &Model{
Type: ModelTypeLabels,
Provider: ollama.ProviderName,
Type: ModelTypeLabels,
Engine: ollama.EngineName,
}
model.ApplyProviderDefaults()
model.ApplyEngineDefaults()
opts := model.GetOptions()
if opts == nil {
@@ -39,8 +40,8 @@ func TestModelGetOptionsDefaultsOllamaLabels(t *testing.T) {
func TestModelGetOptionsRespectsCustomValues(t *testing.T) {
model := &Model{
Type: ModelTypeLabels,
Provider: ollama.ProviderName,
Type: ModelTypeLabels,
Engine: ollama.EngineName,
Options: &ApiRequestOptions{
Temperature: 5,
TopP: 0.95,
@@ -48,7 +49,7 @@ func TestModelGetOptionsRespectsCustomValues(t *testing.T) {
},
}
model.ApplyProviderDefaults()
model.ApplyEngineDefaults()
opts := model.GetOptions()
if opts.Temperature != MaxTemperature {
@@ -64,12 +65,12 @@ func TestModelGetOptionsRespectsCustomValues(t *testing.T) {
func TestModelGetOptionsFillsMissingFields(t *testing.T) {
model := &Model{
Type: ModelTypeLabels,
Provider: ollama.ProviderName,
Options: &ApiRequestOptions{},
Type: ModelTypeLabels,
Engine: ollama.EngineName,
Options: &ApiRequestOptions{},
}
model.ApplyProviderDefaults()
model.ApplyEngineDefaults()
opts := model.GetOptions()
if opts.TopP != 0.9 {
@@ -80,6 +81,52 @@ func TestModelGetOptionsFillsMissingFields(t *testing.T) {
}
}
func TestModelApplyEngineDefaultsSetsResolution(t *testing.T) {
model := &Model{Type: ModelTypeLabels, Engine: ollama.EngineName}
model.ApplyEngineDefaults()
if model.Resolution != ollama.Resolution {
t.Fatalf("expected resolution %d, got %d", ollama.Resolution, model.Resolution)
}
model.Resolution = 1024
model.ApplyEngineDefaults()
if model.Resolution != 1024 {
t.Fatalf("expected custom resolution to be preserved, got %d", model.Resolution)
}
}
func TestModelGetSource(t *testing.T) {
t.Run("NilModel", func(t *testing.T) {
var model *Model
if src := model.GetSource(); src != entity.SrcAuto {
t.Fatalf("expected SrcAuto for nil model, got %s", src)
}
})
t.Run("EngineAlias", func(t *testing.T) {
model := &Model{Engine: ollama.EngineName}
if src := model.GetSource(); src != entity.SrcOllama {
t.Fatalf("expected SrcOllama, got %s", src)
}
})
t.Run("RequestFormat", func(t *testing.T) {
model := &Model{Service: Service{RequestFormat: ApiFormatOpenAI}}
if src := model.GetSource(); src != entity.SrcOpenAI {
t.Fatalf("expected SrcOpenAI, got %s", src)
}
})
t.Run("DefaultImage", func(t *testing.T) {
model := &Model{}
if src := model.GetSource(); src != entity.SrcImage {
t.Fatalf("expected SrcImage fallback, got %s", src)
}
})
}
func TestModel_IsDefault(t *testing.T) {
nasnetCopy := *NasnetModel
nasnetCopy.Default = false
@@ -111,9 +158,9 @@ func TestModel_IsDefault(t *testing.T) {
{
name: "RemoteService",
model: &Model{
Type: ModelTypeCaption,
Name: "custom-caption",
Provider: ollama.ProviderName,
Type: ModelTypeCaption,
Name: "custom-caption",
Engine: ollama.EngineName,
},
want: false,
},

View File

@@ -91,7 +91,7 @@ var (
Type: ModelTypeCaption,
Name: ollama.CaptionModel,
Version: VersionLatest,
Provider: ollama.ProviderName,
Engine: ollama.EngineName,
Resolution: 720, // Original aspect ratio, with a max size of 720 x 720 pixels.
Service: Service{
Uri: "http://ollama:11434/api/generate",

View File

@@ -1,8 +1,8 @@
package ollama
const (
// ProviderName is the canonical identifier for Ollama-based vision services.
ProviderName = "ollama"
// EngineName is the canonical identifier for Ollama-based vision services.
EngineName = "ollama"
// ApiFormat identifies Ollama-compatible request and response payloads.
ApiFormat = "ollama"
)

View File

@@ -7,6 +7,7 @@ const (
CaptionModel = "gemma3"
LabelSystem = "You are a PhotoPrism vision model. Output concise JSON that matches the schema."
LabelPrompt = "Analyze the image and return label objects with name, confidence (0-1), and topicality (0-1)."
Resolution = 720
)
func LabelSchema() string {

View File

@@ -1,6 +1,6 @@
/*
Package ollama integrates PhotoPrism's vision pipeline with Ollama-compatible
multi-modal models so adapters can share logging and provider helpers.
multi-modal models so adapters can share logging and engine helpers.
Copyright (c) 2018 - 2025 PhotoPrism UG. All rights reserved.

View File

@@ -1,8 +1,8 @@
package openai
const (
// ProviderName is the canonical identifier for OpenAI-based vision services.
ProviderName = "openai"
// EngineName is the canonical identifier for OpenAI-based vision services.
EngineName = "openai"
// ApiFormat identifies OpenAI-compatible request and response payloads.
ApiFormat = "openai"
)

View File

@@ -0,0 +1,12 @@
package openai
import "github.com/photoprism/photoprism/internal/ai/vision/schema"
const (
DefaultModel = "gpt-5-mini"
Resolution = 720
)
func LabelSchema() string {
return schema.LabelDefaultV1
}

View File

@@ -1,97 +0,0 @@
package vision
import (
"context"
"strings"
"sync"
)
// ModelProvider represents the canonical identifier for a computer vision service provider.
type ModelProvider = string
const (
// ProviderVision represents the default PhotoPrism vision service endpoints.
ProviderVision ModelProvider = "vision"
// ProviderTensorFlow represents on-device TensorFlow models.
ProviderTensorFlow ModelProvider = "tensorflow"
// ProviderLocal is used when no explicit provider can be determined.
ProviderLocal ModelProvider = "local"
)
// RequestBuilder builds an API request for a provider based on the model configuration and input files.
type RequestBuilder interface {
Build(ctx context.Context, model *Model, files Files) (*ApiRequest, error)
}
// ResponseParser parses a raw provider response into the generic ApiResponse structure.
type ResponseParser interface {
Parse(ctx context.Context, req *ApiRequest, raw []byte, status int) (*ApiResponse, error)
}
// DefaultsProvider supplies provider-specific prompt and schema defaults when they are not configured explicitly.
type DefaultsProvider interface {
SystemPrompt(model *Model) string
UserPrompt(model *Model) string
SchemaTemplate(model *Model) string
Options(model *Model) *ApiRequestOptions
}
// Provider groups the callbacks required to integrate a third-party vision service.
type Provider struct {
Builder RequestBuilder
Parser ResponseParser
Defaults DefaultsProvider
}
var (
providerRegistry = make(map[ApiFormat]Provider)
providerAliasIndex = make(map[string]ProviderInfo)
providerMu sync.RWMutex
)
// RegisterProvider adds/overrides a provider implementation for a specific API format.
func RegisterProvider(format ApiFormat, provider Provider) {
providerMu.Lock()
defer providerMu.Unlock()
providerRegistry[format] = provider
}
// ProviderInfo describes metadata that can be associated with a provider alias.
type ProviderInfo struct {
RequestFormat ApiFormat
ResponseFormat ApiFormat
FileScheme string
}
// RegisterProviderAlias maps a logical provider name (e.g. "ollama") to a request/response format pair.
func RegisterProviderAlias(name string, info ProviderInfo) {
name = strings.TrimSpace(strings.ToLower(name))
if name == "" || info.RequestFormat == "" {
return
}
if info.ResponseFormat == "" {
info.ResponseFormat = info.RequestFormat
}
providerMu.Lock()
providerAliasIndex[name] = info
providerMu.Unlock()
}
// ProviderInfoFor returns the metadata associated with a logical provider name.
func ProviderInfoFor(name string) (ProviderInfo, bool) {
name = strings.TrimSpace(strings.ToLower(name))
providerMu.RLock()
info, ok := providerAliasIndex[name]
providerMu.RUnlock()
return info, ok
}
// ProviderFor returns the registered provider implementation for the given API format, if any.
func ProviderFor(format ApiFormat) (Provider, bool) {
providerMu.RLock()
defer providerMu.RUnlock()
provider, ok := providerRegistry[format]
return provider, ok
}

View File

@@ -1,74 +0,0 @@
package vision
import "testing"
func TestRegisterProviderAlias(t *testing.T) {
const alias = "unit-test"
providerMu.Lock()
prev, had := providerAliasIndex[alias]
if had {
delete(providerAliasIndex, alias)
}
providerMu.Unlock()
t.Cleanup(func() {
providerMu.Lock()
if had {
providerAliasIndex[alias] = prev
} else {
delete(providerAliasIndex, alias)
}
providerMu.Unlock()
})
RegisterProviderAlias(" Unit-Test ", ProviderInfo{RequestFormat: ApiFormat("custom"), ResponseFormat: "", FileScheme: "data"})
info, ok := ProviderInfoFor(alias)
if !ok {
t.Fatalf("expected provider alias %q to be registered", alias)
}
if info.RequestFormat != ApiFormat("custom") {
t.Errorf("unexpected request format: %s", info.RequestFormat)
}
if info.ResponseFormat != ApiFormat("custom") {
t.Errorf("expected response format default to request, got %s", info.ResponseFormat)
}
if info.FileScheme != "data" {
t.Errorf("unexpected file scheme: %s", info.FileScheme)
}
}
func TestRegisterProvider(t *testing.T) {
format := ApiFormat("unit-format")
provider := Provider{}
providerMu.Lock()
prev, had := providerRegistry[format]
if had {
delete(providerRegistry, format)
}
providerMu.Unlock()
t.Cleanup(func() {
providerMu.Lock()
if had {
providerRegistry[format] = prev
} else {
delete(providerRegistry, format)
}
providerMu.Unlock()
})
RegisterProvider(format, provider)
got, ok := ProviderFor(format)
if !ok {
t.Fatalf("expected provider for %s", format)
}
if got != provider {
t.Errorf("unexpected provider value: %#v", got)
}
}

View File

@@ -1,4 +1,4 @@
package schema
// LabelDefaultV1 provides the minimal JSON schema for label responses used across providers.
// LabelDefaultV1 provides the minimal JSON schema for label responses used across engines.
const LabelDefaultV1 = "{\n \"labels\": [{\n \"name\": \"\",\n \"confidence\": 0,\n \"topicality\": 0\n }]\n}"

View File

@@ -67,7 +67,7 @@ Models:
- Type: caption
Name: gemma3
Version: latest
Provider: ollama
Engine: ollama
Resolution: 720
Service:
Uri: http://ollama:11434/api/generate

View File

@@ -28,16 +28,14 @@ func visionListAction(ctx *cli.Context) error {
cols := []string{
"Type",
"Name",
"Version",
"Model",
"Engine",
"Endpoint",
"Format",
"Resolution",
"Provider",
"Service Endpoint",
"Request Format",
"Response Format",
"Options",
"Tags",
"Disabled",
"Schedule",
"Status",
}
// Show log message.
@@ -54,47 +52,55 @@ func visionListAction(ctx *cli.Context) error {
modelUri, modelMethod := model.Endpoint()
tags := ""
_, name, version := model.Model()
name, _, _ := model.Model()
if model.TensorFlow != nil && model.TensorFlow.Tags != nil {
tags = strings.Join(model.TensorFlow.Tags, ", ")
}
if model.Default {
version = "default"
}
var options []byte
if o := model.GetOptions(); o != nil {
options, _ = json.Marshal(*o)
}
var responseFormat, requestFormat string
var format string
if modelUri != "" && modelMethod != "" {
if f := strings.TrimSpace(string(model.EndpointRequestFormat())); f != "" {
requestFormat = f
}
if f := strings.TrimSpace(string(model.EndpointResponseFormat())); f != "" {
responseFormat = f
if f := model.EndpointRequestFormat(); f != "" {
format = f
}
}
provider := model.ProviderName()
if responseFormat := model.GetFormat(); responseFormat != "" {
if format != "" {
format = fmt.Sprintf("%s:%s", format, responseFormat)
} else {
format = responseFormat
}
}
if format == "" && model.Default {
format = "default"
}
var run string
if run = model.RunType(); run == "" {
run = "auto"
}
engine := model.EngineName()
rows[i] = []string{
model.Type,
name,
version,
fmt.Sprintf("%d", model.Resolution),
provider,
engine,
fmt.Sprintf("%s %s", modelMethod, modelUri),
requestFormat,
responseFormat,
string(options),
tags,
report.Bool(model.Disabled, report.Yes, report.No),
format,
fmt.Sprintf("%d", model.Resolution),
report.Bool(len(options) == 0, "tags: "+tags, string(options)),
run,
report.Bool(model.Disabled, report.Disabled, report.Enabled),
}
}

View File

@@ -51,6 +51,7 @@ func visionRunAction(ctx *cli.Context) error {
vision.ParseModelTypes(ctx.String("models")),
string(source),
ctx.Bool("force"),
vision.RunManual,
)
})
}

View File

@@ -3,6 +3,7 @@ package config
import (
"os"
"path/filepath"
"strings"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/pkg/clean"
@@ -22,6 +23,16 @@ func (c *Config) VisionYaml() string {
}
}
// VisionSchedule returns the cron schedule configured for the vision worker, or "" if disabled.
func (c *Config) VisionSchedule() string {
return Schedule(c.options.VisionSchedule)
}
// VisionFilter returns the search filter to use for scheduled vision runs.
func (c *Config) VisionFilter() string {
return strings.TrimSpace(c.options.VisionFilter)
}
// VisionModelShouldRun checks when the specified model type should run.
func (c *Config) VisionModelShouldRun(t vision.ModelType, when vision.RunType) bool {
if t == vision.ModelTypeLabels && c.DisableClassification() {

View File

@@ -113,3 +113,25 @@ func TestConfig_VisionModelShouldRun(t *testing.T) {
}
})
}
func TestConfig_VisionSchedule(t *testing.T) {
c := NewConfig(CliTestContext())
c.options.VisionSchedule = ""
assert.Equal(t, "", c.VisionSchedule())
c.options.VisionSchedule = "0 6 * * *"
assert.Equal(t, "0 6 * * *", c.VisionSchedule())
c.options.VisionSchedule = "invalid"
assert.Equal(t, "", c.VisionSchedule())
}
func TestConfig_VisionFilter(t *testing.T) {
c := NewConfig(CliTestContext())
c.options.VisionFilter = " private:false "
assert.Equal(t, "private:false", c.VisionFilter())
c.options.VisionFilter = ""
assert.Equal(t, "", c.VisionFilter())
}

View File

@@ -4,5 +4,5 @@ package feat
var (
VisionModelGenerate = false // controls exposure of the generate endpoint and CLI commands
VisionModelMarkers = false // gates marker generation/return until downstream UI and reconciliation paths are ready
VisionServiceOpenAI = false // controls whether users are able to configure OpenAI as vision service provider
VisionServiceOpenAI = false // controls whether users are able to configure OpenAI as a vision service engine
)

View File

@@ -1151,6 +1151,17 @@ var Flags = CliFlags{
Value: "",
EnvVars: EnvVars("VISION_KEY"),
}}, {
Flag: &cli.StringFlag{
Name: "vision-schedule",
Usage: "vision worker `SCHEDULE` for background processing (e.g. \"0 12 * * *\" for daily at noon) or at a random time (daily, weekly)",
EnvVars: EnvVars("VISION_SCHEDULE"),
}}, {
Flag: &cli.StringFlag{
Name: "vision-filter",
Usage: "vision worker search `FILTER` applied to scheduled runs (same syntax as photoprism vision run)",
Value: "public:true",
EnvVars: EnvVars("VISION_FILTER"),
}}, {
Flag: &cli.BoolFlag{
Name: "detect-nsfw",
Usage: "flags newly added pictures as private if they might be offensive (requires TensorFlow)",

View File

@@ -226,6 +226,8 @@ type Options struct {
VisionApi bool `yaml:"VisionApi" json:"-" flag:"vision-api"`
VisionUri string `yaml:"VisionUri" json:"-" flag:"vision-uri"`
VisionKey string `yaml:"VisionKey" json:"-" flag:"vision-key"`
VisionSchedule string `yaml:"VisionSchedule" json:"VisionSchedule" flag:"vision-schedule"`
VisionFilter string `yaml:"VisionFilter" json:"VisionFilter" flag:"vision-filter"`
DetectNSFW bool `yaml:"DetectNSFW" json:"DetectNSFW" flag:"detect-nsfw"`
FaceSize int `yaml:"-" json:"-" flag:"face-size"`
FaceScore float64 `yaml:"-" json:"-" flag:"face-score"`

View File

@@ -275,6 +275,8 @@ func (c *Config) Report() (rows [][]string, cols []string) {
{"vision-api", fmt.Sprintf("%t", c.VisionApi())},
{"vision-uri", c.VisionUri()},
{"vision-key", strings.Repeat("*", utf8.RuneCountInString(c.VisionKey()))},
{"vision-schedule", c.VisionSchedule()},
{"vision-filter", c.VisionFilter()},
{"nasnet-model-path", c.NasnetModelPath()},
{"facenet-model-path", c.FacenetModelPath()},
{"nsfw-model-path", c.NsfwModelPath()},

View File

@@ -16,6 +16,7 @@ import (
"github.com/photoprism/photoprism/internal/event"
"github.com/photoprism/photoprism/internal/form"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/list"
"github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/react"
"github.com/photoprism/photoprism/pkg/rnd"
@@ -752,6 +753,25 @@ func (m *Photo) SaveDetails() error {
}
}
// ShouldGenerateLabels checks if labels should be generated for this model.
func (m *Photo) ShouldGenerateLabels(force bool) bool {
// Return true if force is set or there are no labels yet.
if len(m.Labels) == 0 || force {
return true
}
// Check if any of the existing labels were generated using a vision model.
for _, l := range m.Labels {
if list.Contains(VisionSrcList, l.LabelSrc) {
return false
} else if l.LabelSrc == SrcCaption && list.Contains(VisionSrcList, m.CaptionSrc) {
return false
}
}
return true
}
// AddLabels updates the entity with additional or updated label information.
func (m *Photo) AddLabels(labels classify.Labels) {
for _, classifyLabel := range labels {

View File

@@ -19,6 +19,11 @@ func (m *Photo) NoCaption() bool {
return strings.TrimSpace(m.GetCaption()) == ""
}
// ShouldGenerateCaption checks if a caption should be generated for this model.
func (m *Photo) ShouldGenerateCaption(src Src, force bool) bool {
return SrcPriority[src] >= SrcPriority[m.CaptionSrc] && (m.NoCaption() || force)
}
// GetCaption returns the photo caption, if any.
func (m *Photo) GetCaption() string {
return m.PhotoCaption

View File

@@ -241,6 +241,94 @@ func TestPhoto_SaveLabels(t *testing.T) {
})
}
func TestPhoto_ShouldGenerateLabels(t *testing.T) {
t.Run("NoLabels", func(t *testing.T) {
p := Photo{}
assert.True(t, p.ShouldGenerateLabels(false))
})
t.Run("Force", func(t *testing.T) {
p := Photo{Labels: []PhotoLabel{{LabelSrc: string(SrcManual)}}}
assert.True(t, p.ShouldGenerateLabels(true))
})
t.Run("ExistingVisionLabel", func(t *testing.T) {
p := Photo{Labels: []PhotoLabel{{LabelSrc: string(SrcOllama)}}}
assert.False(t, p.ShouldGenerateLabels(false))
})
t.Run("CaptionGeneratedLabels", func(t *testing.T) {
p := Photo{
Labels: []PhotoLabel{{LabelSrc: string(SrcCaption)}},
CaptionSrc: SrcOllama,
}
assert.False(t, p.ShouldGenerateLabels(false))
})
t.Run("ManualLabels", func(t *testing.T) {
p := Photo{Labels: []PhotoLabel{{LabelSrc: string(SrcManual)}}}
assert.True(t, p.ShouldGenerateLabels(false))
})
t.Run("CaptionManualWithoutVision", func(t *testing.T) {
p := Photo{
Labels: []PhotoLabel{{LabelSrc: string(SrcCaption)}},
CaptionSrc: SrcManual,
}
assert.True(t, p.ShouldGenerateLabels(false))
})
}
func TestPhoto_ShouldGenerateCaption(t *testing.T) {
ctx := []struct {
name string
photo Photo
source Src
force bool
expect bool
}{
{
name: "NoCaptionAutoSource",
photo: Photo{CaptionSrc: SrcAuto},
source: SrcOllama,
expect: true,
},
{
name: "LowerPriority",
photo: Photo{CaptionSrc: SrcOllama},
source: SrcImage,
expect: false,
},
{
name: "HigherPriority",
photo: Photo{CaptionSrc: SrcImage},
source: SrcOllama,
expect: true,
},
{
name: "ForceOverrides",
photo: Photo{CaptionSrc: SrcImage, PhotoCaption: "existing"},
source: SrcImage,
force: true,
expect: true,
},
{
name: "SamePriorityNoForce",
photo: Photo{CaptionSrc: SrcOllama, PhotoCaption: "existing"},
source: SrcOllama,
expect: false,
},
}
for _, tc := range ctx {
tc := tc
t.Run(tc.name, func(t *testing.T) {
result := tc.photo.ShouldGenerateCaption(tc.source, tc.force)
assert.Equal(t, tc.expect, result)
})
}
}
func TestPhoto_ClassifyLabels(t *testing.T) {
t.Run("NewPhoto", func(t *testing.T) {
m := PhotoFixtures.Get("Photo19")

View File

@@ -95,8 +95,8 @@ var VisionSrcNames = SrcMap{
SrcVision: SrcVision,
}
// VisionSrc contains all the sources commonly used by computer vision models and services.
var VisionSrc = []Src{
// VisionSrcList contains all the sources commonly used by computer vision models and services.
var VisionSrcList = []Src{
SrcMarker,
SrcImage,
SrcOllama,

View File

@@ -1,17 +1,32 @@
package photoprism
import (
"errors"
"time"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media"
)
// Caption returns generated caption for the specified media file.
func (ind *Index) Caption(file *MediaFile) (caption *vision.CaptionResult, err error) {
// Caption generates a caption for the provided media file using the active
// vision model. When captionSrc is SrcAuto the model's declared source is used;
// otherwise the explicit source is recorded on the returned caption.
func (ind *Index) Caption(file *MediaFile, captionSrc entity.Src) (caption *vision.CaptionResult, err error) {
start := time.Now()
model := vision.Config.Model(vision.ModelTypeCaption)
// No caption generation model configured or usable.
if model == nil {
return caption, errors.New("no caption model configured")
}
if captionSrc == entity.SrcAuto {
captionSrc = model.GetSource()
}
size := vision.Thumb(vision.ModelTypeCaption)
// Get thumbnail filenames for the selected sizes.
@@ -22,9 +37,14 @@ func (ind *Index) Caption(file *MediaFile) (caption *vision.CaptionResult, err e
}
// Get matching labels from computer vision model.
// Generate a caption using the configured vision model.
if caption, _, err = vision.Caption(vision.Files{fileName}, media.SrcLocal); err != nil {
// Failed.
} else if caption.Text != "" {
if captionSrc != entity.SrcAuto {
caption.Source = captionSrc
}
log.Infof("vision: generated caption for %s [%s]", clean.Log(file.BaseName()), time.Since(start))
}

View File

@@ -13,7 +13,9 @@ import (
"github.com/photoprism/photoprism/pkg/media"
)
// Labels classifies a JPEG image and returns matching labels.
// Labels classifies the media file and returns matching labels. When labelSrc
// is SrcAuto the model's declared source is used; otherwise the provided source
// is applied to every returned label.
func (ind *Index) Labels(file *MediaFile, labelSrc entity.Src) (labels classify.Labels) {
start := time.Now()
@@ -21,8 +23,24 @@ func (ind *Index) Labels(file *MediaFile, labelSrc entity.Src) (labels classify.
var sizes []thumb.Name
var thumbnails []string
model := vision.Config.Model(vision.ModelTypeLabels)
// No label generation model configured or usable.
if model == nil {
return labels
}
if labelSrc == entity.SrcAuto {
labelSrc = model.GetSource()
}
size := vision.Thumb(vision.ModelTypeLabels)
// The thumbnail size may need to be adjusted to use other models.
if file.Square() {
if size.Name != "" && size.Name != thumb.Tile224 {
sizes = []thumb.Name{size.Name}
thumbnails = make([]string, 0, 1)
} else if file.Square() {
// Only one thumbnail is required for square images.
sizes = []thumb.Name{thumb.Tile224}
thumbnails = make([]string, 0, 1)
@@ -33,8 +51,8 @@ func (ind *Index) Labels(file *MediaFile, labelSrc entity.Src) (labels classify.
}
// Get thumbnail filenames for the selected sizes.
for _, size := range sizes {
if thumbnail, fileErr := file.Thumbnail(Config().ThumbCachePath(), size); fileErr != nil {
for _, s := range sizes {
if thumbnail, fileErr := file.Thumbnail(Config().ThumbCachePath(), s); fileErr != nil {
log.Debugf("index: %s in %s", err, clean.Log(file.BaseName()))
continue
} else {
@@ -42,7 +60,7 @@ func (ind *Index) Labels(file *MediaFile, labelSrc entity.Src) (labels classify.
}
}
// Get matching labels from computer vision model.
// Run the configured vision model to obtain labels for the generated thumbnails.
if labels, err = vision.Labels(thumbnails, media.SrcLocal, labelSrc); err != nil {
log.Debugf("labels: %s in %s", err, clean.Log(file.BaseName()))
return labels

View File

@@ -815,7 +815,7 @@ func (ind *Index) UserMediaFile(m *MediaFile, o IndexOptions, originalName, phot
// Classify images with TensorFlow?
if ind.findLabels {
labels = ind.Labels(m, entity.SrcImage)
labels = ind.Labels(m, entity.SrcAuto)
// Append labels from other sources such as face detection.
if len(extraLabels) > 0 {

View File

@@ -0,0 +1,111 @@
package photoprism
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/photoprism/photoprism/internal/ai/classify"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/pkg/media"
)
func TestIndexCaptionSource(t *testing.T) {
if testing.Short() {
t.Skip("skipping vision-dependent test in short mode")
}
cfg := config.TestConfig()
require.NoError(t, cfg.InitializeTestData())
ind := NewIndex(cfg, NewConvert(cfg), NewFiles(), NewPhotos())
mediaFile, err := NewMediaFile("testdata/flash.jpg")
require.NoError(t, err)
originalConfig := vision.Config
t.Cleanup(func() {
vision.Config = originalConfig
vision.SetCaptionFunc(nil)
})
captionModel := &vision.Model{Type: vision.ModelTypeCaption, Engine: vision.ApiFormatOpenAI}
captionModel.ApplyEngineDefaults()
vision.Config = &vision.ConfigValues{Models: vision.Models{captionModel}}
t.Run("AutoUsesModelSource", func(t *testing.T) {
vision.SetCaptionFunc(func(files vision.Files, mediaSrc media.Src) (*vision.CaptionResult, *vision.Model, error) {
return &vision.CaptionResult{Text: "stub", Source: captionModel.GetSource()}, captionModel, nil
})
t.Cleanup(func() { vision.SetCaptionFunc(nil) })
caption, err := ind.Caption(mediaFile, entity.SrcAuto)
require.NoError(t, err)
require.NotNil(t, caption)
assert.Equal(t, captionModel.GetSource(), caption.Source)
})
t.Run("CustomSource", func(t *testing.T) {
originalSource := captionModel.GetSource()
vision.SetCaptionFunc(func(files vision.Files, mediaSrc media.Src) (*vision.CaptionResult, *vision.Model, error) {
return &vision.CaptionResult{Text: "stub", Source: originalSource}, captionModel, nil
})
t.Cleanup(func() { vision.SetCaptionFunc(nil) })
caption, err := ind.Caption(mediaFile, entity.SrcManual)
require.NoError(t, err)
require.NotNil(t, caption)
assert.Equal(t, entity.SrcManual, caption.Source)
})
}
func TestIndexLabelsSource(t *testing.T) {
if testing.Short() {
t.Skip("skipping vision-dependent test in short mode")
}
cfg := config.TestConfig()
require.NoError(t, cfg.InitializeTestData())
ind := NewIndex(cfg, NewConvert(cfg), NewFiles(), NewPhotos())
mediaFile, err := NewMediaFile("testdata/flash.jpg")
require.NoError(t, err)
originalConfig := vision.Config
t.Cleanup(func() {
vision.Config = originalConfig
vision.SetLabelsFunc(nil)
})
labelModel := &vision.Model{Type: vision.ModelTypeLabels, Engine: vision.ApiFormatOllama}
labelModel.ApplyEngineDefaults()
vision.Config = &vision.ConfigValues{Models: vision.Models{labelModel}}
t.Run("AutoUsesModelSource", func(t *testing.T) {
var captured string
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src string) (classify.Labels, error) {
captured = src
return classify.Labels{{Name: "stub", Source: src, Uncertainty: 0}}, nil
})
t.Cleanup(func() { vision.SetLabelsFunc(nil) })
labels := ind.Labels(mediaFile, entity.SrcAuto)
assert.NotEmpty(t, labels)
assert.Equal(t, labelModel.GetSource(), captured)
})
t.Run("CustomSource", func(t *testing.T) {
var captured string
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src string) (classify.Labels, error) {
captured = src
return classify.Labels{{Name: "stub", Source: src, Uncertainty: 0}}, nil
})
t.Cleanup(func() { vision.SetLabelsFunc(nil) })
labels := ind.Labels(mediaFile, entity.SrcManual)
assert.NotEmpty(t, labels)
assert.Equal(t, entity.SrcManual, captured)
})
}

View File

@@ -57,8 +57,8 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
log.Debugf("index: running face recognition")
if faces := photoprism.NewFaces(w.conf); faces.Disabled() {
log.Debugf("index: skipping face recognition")
} else if err := faces.Start(photoprism.FacesOptions{}); err != nil {
log.Warn(err)
} else if facesErr := faces.Start(photoprism.FacesOptions{}); facesErr != nil {
log.Warn(facesErr)
}
}
@@ -74,8 +74,8 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
ind := get.Index()
generateLabels := w.conf.VisionModelShouldRun(vision.ModelTypeLabels, vision.RunNewlyIndexed)
generateCaptions := w.conf.VisionModelShouldRun(vision.ModelTypeCaption, vision.RunNewlyIndexed)
labelsModelShouldRun := w.conf.VisionModelShouldRun(vision.ModelTypeLabels, vision.RunNewlyIndexed)
captionModelShouldRun := w.conf.VisionModelShouldRun(vision.ModelTypeCaption, vision.RunNewlyIndexed)
for {
photos, queryErr := query.PhotosMetadataUpdate(limit, offset, delay, interval)
@@ -99,8 +99,11 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
done[photo.PhotoUID] = true
generateLabels := labelsModelShouldRun && photo.ShouldGenerateLabels(false)
generateCaption := captionModelShouldRun && photo.ShouldGenerateCaption(entity.SrcAuto, false)
// If configured, generate metadata for newly indexed photos using external vision services.
if photo.IsNewlyIndexed() && (generateLabels || generateCaptions) {
if photo.IsNewlyIndexed() && (generateLabels || generateCaption) {
primaryFile, fileErr := photo.PrimaryFile()
if fileErr != nil {
@@ -120,16 +123,13 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
}
}
if generateCaptions && photo.PhotoCaption == "" && photo.CaptionSrc == entity.SrcAuto {
if caption, captionErr := ind.Caption(mediaFile); captionErr != nil {
if generateCaption {
if caption, captionErr := ind.Caption(mediaFile, entity.SrcAuto); captionErr != nil {
log.Debugf("index: %s (generate caption for %s)", clean.Error(captionErr), photo.PhotoUID)
} else if caption != nil {
text := strings.TrimSpace(caption.Text)
if text != "" && caption.Source != "" {
photo.SetCaption(text, clean.ShortTypeLower(caption.Source))
if updateErr := photo.UpdateCaptionLabels(); updateErr != nil {
log.Warnf("index: %s (update caption labels for %s)", clean.Error(updateErr), photo.PhotoUID)
}
} else if text := strings.TrimSpace(caption.Text); text != "" {
photo.SetCaption(text, caption.Source)
if updateErr := photo.UpdateCaptionLabels(); updateErr != nil {
log.Warnf("index: %s (update caption labels for %s)", clean.Error(updateErr), photo.PhotoUID)
}
}
}

View File

@@ -25,45 +25,67 @@ import (
"github.com/photoprism/photoprism/pkg/txt"
)
// Vision represents a computer vision worker.
// Vision orchestrates background computer-vision tasks (labels, captions,
// NSFW detection). It wraps configuration lookups and scheduling helpers.
type Vision struct {
conf *config.Config
}
// NewVision returns a new Vision worker.
// NewVision constructs a Vision worker bound to the provided configuration.
func NewVision(conf *config.Config) *Vision {
return &Vision{conf: conf}
}
func captionSourceFromModel(model *vision.Model) string {
if model == nil {
return entity.SrcImage
// StartScheduled executes the worker in scheduled mode, selecting models that
// are allowed to run in the RunOnSchedule context.
func (w *Vision) StartScheduled() {
models := w.scheduledModels()
if len(models) == 0 {
return
}
switch model.EndpointRequestFormat() {
case vision.ApiFormatOllama:
return entity.SrcOllama
case vision.ApiFormatOpenAI:
return entity.SrcOpenAI
if err := w.Start(
w.conf.VisionFilter(),
0,
models,
entity.SrcAuto,
false,
vision.RunOnSchedule,
); err != nil {
log.Errorf("scheduler: %s (vision)", err)
}
switch model.ProviderName() {
case "ollama":
return entity.SrcOllama
case "openai":
return entity.SrcOpenAI
}
return entity.SrcImage
}
// originalsPath returns the original media files path as string.
// scheduledModels returns the model types that should run for scheduled jobs.
func (w *Vision) scheduledModels() []string {
models := make([]string, 0, 3)
if w.conf.VisionModelShouldRun(vision.ModelTypeLabels, vision.RunOnSchedule) {
models = append(models, vision.ModelTypeLabels)
}
if w.conf.VisionModelShouldRun(vision.ModelTypeNsfw, vision.RunOnSchedule) {
models = append(models, vision.ModelTypeNsfw)
}
if w.conf.VisionModelShouldRun(vision.ModelTypeCaption, vision.RunOnSchedule) {
models = append(models, vision.ModelTypeCaption)
}
return models
}
// originalsPath returns the path that holds original media files.
func (w *Vision) originalsPath() string {
return w.conf.OriginalsPath()
}
// Start runs the specified model types for photos matching the search query filter string.
func (w *Vision) Start(filter string, count int, models []string, customSrc string, force bool) (err error) {
// Start runs the requested vision models against photos matching the search
// filter. `customSrc` allows the caller to override the metadata source string,
// `force` regenerates metadata regardless of existing values, and `runType`
// describes the scheduling context (manual, scheduled, etc.).
func (w *Vision) Start(filter string, count int, models []string, customSrc string, force bool, runType vision.RunType) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("vision: %s (worker panic)\nstack: %s", r, debug.Stack())
@@ -77,6 +99,10 @@ func (w *Vision) Start(filter string, count int, models []string, customSrc stri
defer mutex.VisionWorker.Stop()
models = vision.FilterModels(models, runType, func(mt vision.ModelType, when vision.RunType) bool {
return w.conf.VisionModelShouldRun(mt, when)
})
updateLabels := slices.Contains(models, vision.ModelTypeLabels)
updateNsfw := slices.Contains(models, vision.ModelTypeNsfw)
updateCaptions := slices.Contains(models, vision.ModelTypeCaption)
@@ -90,21 +116,6 @@ func (w *Vision) Start(filter string, count int, models []string, customSrc stri
}
customSrc = clean.ShortTypeLower(customSrc)
useAutoSource := customSrc == entity.SrcAuto
labelSource := customSrc
if useAutoSource {
labelSource = entity.SrcAuto
}
if labelSource == entity.SrcImage {
labelSource = entity.SrcAuto
}
captionSource := customSrc
if useAutoSource {
captionSource = captionSourceFromModel(vision.Config.Model(vision.ModelTypeCaption))
}
// Check time when worker was last executed.
updateIndex := false
@@ -161,13 +172,6 @@ func (w *Vision) Start(filter string, count int, models []string, customSrc stri
done[photo.PhotoUID] = true
photoName := path.Join(photo.PhotoPath, photo.PhotoName)
fileName := photoprism.FileName(photo.FileRoot, photo.FileName)
file, fileErr := photoprism.NewMediaFile(fileName)
if fileErr != nil {
log.Errorf("vision: failed to open %s (%s)", photoName, fileErr)
continue
}
m, loadErr := query.PhotoByUID(photo.PhotoUID)
@@ -176,33 +180,48 @@ func (w *Vision) Start(filter string, count int, models []string, customSrc stri
continue
}
generateLabels := updateLabels && m.ShouldGenerateLabels(force)
generateCaptions := updateCaptions && m.ShouldGenerateCaption(customSrc, force)
generateNsfw := updateNsfw && (!photo.PhotoPrivate || force)
if !(generateLabels || generateCaptions || generateNsfw) {
continue
}
fileName := photoprism.FileName(photo.FileRoot, photo.FileName)
file, fileErr := photoprism.NewMediaFile(fileName)
if fileErr != nil {
log.Errorf("vision: failed to open %s (%s)", photoName, fileErr)
continue
}
changed := false
// Generate labels.
if updateLabels && (len(m.Labels) == 0 || force) {
labelSrc := labelSource
if labels := ind.Labels(file, labelSrc); len(labels) > 0 {
if generateLabels {
if labels := ind.Labels(file, customSrc); len(labels) > 0 {
m.AddLabels(labels)
changed = true
}
}
// Detect NSFW content.
if updateNsfw && (!photo.PhotoPrivate || force) {
if isNsfw := ind.IsNsfw(file); photo.PhotoPrivate != isNsfw {
photo.PhotoPrivate = isNsfw
if generateNsfw {
if isNsfw := ind.IsNsfw(file); m.PhotoPrivate != isNsfw {
m.PhotoPrivate = isNsfw
changed = true
log.Infof("vision: changed private flag of %s to %t", photoName, photo.PhotoPrivate)
log.Infof("vision: changed private flag of %s to %t", photoName, m.PhotoPrivate)
}
}
// Generate a caption if none exists or the force flag is used,
// and only if no caption was set or removed by a higher-priority source.
if updateCaptions && entity.SrcPriority[captionSource] >= entity.SrcPriority[m.CaptionSrc] && (m.NoCaption() || force) {
if caption, captionErr := ind.Caption(file); captionErr != nil {
if generateCaptions {
if caption, captionErr := ind.Caption(file, customSrc); captionErr != nil {
log.Warnf("vision: %s in %s (generate caption)", clean.Error(captionErr), photoName)
} else if caption.Text = strings.TrimSpace(caption.Text); caption.Text != "" {
m.SetCaption(caption.Text, captionSource)
} else if text := strings.TrimSpace(caption.Text); text != "" {
m.SetCaption(text, caption.Source)
if updateErr := m.UpdateCaptionLabels(); updateErr != nil {
log.Warnf("vision: %s in %s (update caption labels)", clean.Error(updateErr), photoName)
}

View File

@@ -1,32 +0,0 @@
package workers
import (
"testing"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/entity"
)
func TestCaptionSourceFromModel(t *testing.T) {
if got := captionSourceFromModel(nil); got != entity.SrcImage {
t.Fatalf("expected SrcImage for nil model, got %s", got)
}
openAIModel := &vision.Model{
Service: vision.Service{RequestFormat: vision.ApiFormatOpenAI},
}
if got := captionSourceFromModel(openAIModel); got != entity.SrcOpenAI {
t.Fatalf("expected SrcOpenAI for openai model, got %s", got)
}
providerModel := &vision.Model{Provider: "ollama"}
if got := captionSourceFromModel(providerModel); got != entity.SrcOllama {
t.Fatalf("expected SrcOllama from provider, got %s", got)
}
fallbackModel := &vision.Model{}
if got := captionSourceFromModel(fallbackModel); got != entity.SrcImage {
t.Fatalf("expected SrcImage fallback, got %s", got)
}
}

View File

@@ -38,7 +38,9 @@ import (
var log = event.Log
var stop = make(chan bool, 1)
// Start starts the execution of background workers and scheduled tasks based on the current configuration.
// Start launches background workers and scheduled tasks based on the current
// configuration. It sets up the cron scheduler and the periodic metadata/share
// workers.
func Start(conf *config.Config) {
if scheduler, err := gocron.NewScheduler(gocron.WithLocation(conf.DefaultTimezone())); err != nil {
log.Errorf("scheduler: %s (start)", err)
@@ -56,6 +58,11 @@ func Start(conf *config.Config) {
log.Errorf("scheduler: %s (index)", err)
}
// Schedule vision job.
if err = NewJob("vision", conf.VisionSchedule(), NewVision(conf).StartScheduled); err != nil {
log.Errorf("scheduler: %s (vision)", err)
}
// Start the scheduler.
Scheduler.Start()
}
@@ -89,7 +96,7 @@ func Start(conf *config.Config) {
}()
}
// Shutdown stops the background workers and scheduled tasks.
// Shutdown stops the background workers and shuts down the scheduler.
func Shutdown() {
log.Info("shutting down workers")