mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# PhotoPrism® Repository Guidelines
|
||||
|
||||
**Last Updated:** September 29, 2025
|
||||
**Last Updated:** September 30, 2025
|
||||
|
||||
## Purpose
|
||||
|
||||
@@ -273,6 +273,7 @@ If anything in this file conflicts with the `Makefile` or the Developer Guide, t
|
||||
|
||||
- Respect precedence: `options.yml` overrides CLI/env values, which override defaults. When adding a new option, update `internal/config/options.go` (yaml/flag tags), register it in `internal/config/flags.go`, expose a getter, surface it in `*config.Report()`, and write generated values back to `options.yml` by setting `c.options.OptionsYaml` before persisting. Use `CliTestContext` in `internal/config/test.go` to exercise new flags.
|
||||
- When touching configuration in Go code, use the public accessors on `*config.Config` (e.g. `Config.JWKSUrl()`, `Config.SetJWKSUrl()`, `Config.ClusterUUID()`) instead of mutating `Config.Options()` directly; reserve raw option tweaks for test fixtures only.
|
||||
- Vision worker scheduling is controlled via `VisionSchedule` / `VisionFilter` and the `Run` property set in `vision.yml`. Utilities like `vision.FilterModels` and `entity.Photo.ShouldGenerateLabels/Caption` help decide when work is required before loading media files.
|
||||
- Logging: use the shared logger (`event.Log`) via the package-level `log` variable (see `internal/auth/jwt/logger.go`) instead of direct `fmt.Print*` or ad-hoc loggers.
|
||||
- Cluster registry tests (`internal/service/cluster/registry`) currently rely on a full test config because they persist `entity.Client` rows. They run migrations and seed the SQLite DB, so they are intentionally slow. If you refactor them, consider sharing a single `config.TestConfig()` across subtests or building a lightweight schema harness; do not swap to the minimal config helper unless the tests stop touching the database.
|
||||
- Favor explicit CLI flags: check `c.cliCtx.IsSet("<flag>")` before overriding user-supplied values, and follow the `ClusterUUID` pattern (`options.yml` → CLI/env → generated UUIDv4 persisted).
|
||||
|
||||
@@ -60,12 +60,12 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
|
||||
|
||||
format := apiRequest.GetResponseFormat()
|
||||
|
||||
if provider, ok := ProviderFor(format); ok && provider.Parser != nil {
|
||||
if engine, ok := EngineFor(format); ok && engine.Parser != nil {
|
||||
if clientResp.StatusCode >= 300 {
|
||||
log.Debugf("vision: %s (status code %d)", body, clientResp.StatusCode)
|
||||
}
|
||||
|
||||
parsed, parseErr := provider.Parser.Parse(context.Background(), apiRequest, body, clientResp.StatusCode)
|
||||
parsed, parseErr := engine.Parser.Parse(context.Background(), apiRequest, body, clientResp.StatusCode)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
|
||||
@@ -4,14 +4,27 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/vision/ollama"
|
||||
"github.com/photoprism/photoprism/internal/ai/vision/openai"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
)
|
||||
|
||||
var captionFunc = captionInternal
|
||||
|
||||
// SetCaptionFunc overrides the caption generator. Intended for tests.
|
||||
func SetCaptionFunc(fn func(Files, media.Src) (*CaptionResult, *Model, error)) {
|
||||
if fn == nil {
|
||||
captionFunc = captionInternal
|
||||
return
|
||||
}
|
||||
|
||||
captionFunc = fn
|
||||
}
|
||||
|
||||
// Caption returns generated captions for the specified images.
|
||||
func Caption(images Files, mediaSrc media.Src) (result *CaptionResult, model *Model, err error) {
|
||||
func Caption(images Files, mediaSrc media.Src) (*CaptionResult, *Model, error) {
|
||||
return captionFunc(images, mediaSrc)
|
||||
}
|
||||
|
||||
func captionInternal(images Files, mediaSrc media.Src) (result *CaptionResult, model *Model, err error) {
|
||||
// Return if there is no configuration or no image classification models are configured.
|
||||
if Config == nil {
|
||||
return result, model, errors.New("vision service is not configured")
|
||||
@@ -21,8 +34,8 @@ func Caption(images Files, mediaSrc media.Src) (result *CaptionResult, model *Mo
|
||||
var apiRequest *ApiRequest
|
||||
var apiResponse *ApiResponse
|
||||
|
||||
if provider, ok := ProviderFor(model.EndpointRequestFormat()); ok && provider.Builder != nil {
|
||||
if apiRequest, err = provider.Builder.Build(context.Background(), model, images); err != nil {
|
||||
if engine, ok := EngineFor(model.EndpointRequestFormat()); ok && engine.Builder != nil {
|
||||
if apiRequest, err = engine.Builder.Build(context.Background(), model, images); err != nil {
|
||||
return result, model, err
|
||||
}
|
||||
} else if apiRequest, err = NewApiRequest(model.EndpointRequestFormat(), images, model.EndpointFileScheme()); err != nil {
|
||||
@@ -50,15 +63,8 @@ func Caption(images Files, mediaSrc media.Src) (result *CaptionResult, model *Mo
|
||||
}
|
||||
|
||||
// Set image as the default caption source.
|
||||
if apiResponse.Result.Caption.Text != "" && apiResponse.Result.Caption.Source == "" {
|
||||
switch model.Provider {
|
||||
case ollama.ProviderName:
|
||||
apiResponse.Result.Caption.Source = entity.SrcOllama
|
||||
case openai.ProviderName:
|
||||
apiResponse.Result.Caption.Source = entity.SrcOpenAI
|
||||
default:
|
||||
apiResponse.Result.Caption.Source = entity.SrcImage
|
||||
}
|
||||
if apiResponse.Result.Caption.Source == "" {
|
||||
apiResponse.Result.Caption.Source = model.GetSource()
|
||||
}
|
||||
|
||||
result = apiResponse.Result.Caption
|
||||
|
||||
@@ -47,7 +47,7 @@ func NewConfig() *ConfigValues {
|
||||
}
|
||||
|
||||
for _, model := range cfg.Models {
|
||||
model.ApplyProviderDefaults()
|
||||
model.ApplyEngineDefaults()
|
||||
}
|
||||
|
||||
return cfg
|
||||
@@ -95,7 +95,7 @@ func (c *ConfigValues) Load(fileName string) error {
|
||||
}
|
||||
|
||||
for _, model := range c.Models {
|
||||
model.ApplyProviderDefaults()
|
||||
model.ApplyEngineDefaults()
|
||||
}
|
||||
|
||||
if c.Thresholds.Confidence <= 0 || c.Thresholds.Confidence > 100 {
|
||||
@@ -124,7 +124,9 @@ func (c *ConfigValues) Save(fileName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Model returns the first enabled model with the matching type from the configuration.
|
||||
// Model returns the first enabled model with the matching type.
|
||||
// It returns nil if no matching model is available or every model of that
|
||||
// type is disabled, allowing callers to chain nil-safe Model methods.
|
||||
func (c *ConfigValues) Model(t ModelType) *Model {
|
||||
for i := len(c.Models) - 1; i >= 0; i-- {
|
||||
m := c.Models[i]
|
||||
@@ -136,7 +138,9 @@ func (c *ConfigValues) Model(t ModelType) *Model {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ShouldRun checks when the specified model type should run.
|
||||
// ShouldRun reports whether the configured model for the given type is
|
||||
// allowed to run in the specified context. It returns false when no
|
||||
// suitable model exists or when execution is explicitly disabled.
|
||||
func (c *ConfigValues) ShouldRun(t ModelType, when RunType) bool {
|
||||
m := c.Model(t)
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ func TestConfigModelPrefersLastEnabled(t *testing.T) {
|
||||
customModel := &Model{
|
||||
Type: ModelTypeLabels,
|
||||
Name: "ollama-labels",
|
||||
Provider: "ollama",
|
||||
Engine: "ollama",
|
||||
Disabled: false,
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func TestConfigValues_IsDefaultAndIsCustom(t *testing.T) {
|
||||
}
|
||||
})
|
||||
t.Run("CustomOverridesDefault", func(t *testing.T) {
|
||||
custom := &Model{Type: ModelTypeLabels, Name: "custom", Provider: "ollama"}
|
||||
custom := &Model{Type: ModelTypeLabels, Name: "custom", Engine: "ollama"}
|
||||
cfg := &ConfigValues{Models: Models{defaultModel, custom}}
|
||||
if cfg.IsDefault(ModelTypeLabels) {
|
||||
t.Fatalf("expected custom model to disable default detection")
|
||||
@@ -84,7 +84,7 @@ func TestConfigValues_IsDefaultAndIsCustom(t *testing.T) {
|
||||
}
|
||||
})
|
||||
t.Run("DisabledCustomFallsBackToDefault", func(t *testing.T) {
|
||||
custom := &Model{Type: ModelTypeLabels, Name: "custom", Provider: "ollama", Disabled: true}
|
||||
custom := &Model{Type: ModelTypeLabels, Name: "custom", Engine: "ollama", Disabled: true}
|
||||
cfg := &ConfigValues{Models: Models{defaultModel, custom}}
|
||||
if !cfg.IsDefault(ModelTypeLabels) {
|
||||
t.Fatalf("expected disabled custom model to fall back to default")
|
||||
|
||||
121
internal/ai/vision/engine.go
Normal file
121
internal/ai/vision/engine.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/vision/openai"
|
||||
"github.com/photoprism/photoprism/pkg/service/http/scheme"
|
||||
)
|
||||
|
||||
// ModelEngine represents the canonical identifier for a computer vision service engine.
|
||||
type ModelEngine = string
|
||||
|
||||
const (
|
||||
// EngineVision represents the default PhotoPrism vision service endpoints.
|
||||
EngineVision ModelEngine = "vision"
|
||||
// EngineTensorFlow represents on-device TensorFlow models.
|
||||
EngineTensorFlow ModelEngine = "tensorflow"
|
||||
// EngineLocal is used when no explicit engine can be determined.
|
||||
EngineLocal ModelEngine = "local"
|
||||
)
|
||||
|
||||
// RequestBuilder builds an API request for an engine based on the model configuration and input files.
|
||||
type RequestBuilder interface {
|
||||
Build(ctx context.Context, model *Model, files Files) (*ApiRequest, error)
|
||||
}
|
||||
|
||||
// ResponseParser parses a raw engine response into the generic ApiResponse structure.
|
||||
type ResponseParser interface {
|
||||
Parse(ctx context.Context, req *ApiRequest, raw []byte, status int) (*ApiResponse, error)
|
||||
}
|
||||
|
||||
// EngineDefaults supplies engine-specific prompt and schema defaults when they are not configured explicitly.
|
||||
type EngineDefaults interface {
|
||||
SystemPrompt(model *Model) string
|
||||
UserPrompt(model *Model) string
|
||||
SchemaTemplate(model *Model) string
|
||||
Options(model *Model) *ApiRequestOptions
|
||||
}
|
||||
|
||||
// Engine groups the callbacks required to integrate a third-party vision service.
|
||||
type Engine struct {
|
||||
Builder RequestBuilder
|
||||
Parser ResponseParser
|
||||
Defaults EngineDefaults
|
||||
}
|
||||
|
||||
var (
|
||||
engineRegistry = make(map[ApiFormat]Engine)
|
||||
engineAliasIndex = make(map[string]EngineInfo)
|
||||
engineMu sync.RWMutex
|
||||
)
|
||||
|
||||
// init wires up the built-in aliases so configuration files can reference the
|
||||
// human-friendly engine names without duplicating adapter metadata.
|
||||
func init() {
|
||||
RegisterEngineAlias(EngineVision, EngineInfo{
|
||||
RequestFormat: ApiFormatVision,
|
||||
ResponseFormat: ApiFormatVision,
|
||||
FileScheme: string(scheme.Data),
|
||||
Resolution: DefaultResolution,
|
||||
})
|
||||
|
||||
RegisterEngineAlias(openai.EngineName, EngineInfo{
|
||||
RequestFormat: ApiFormatOpenAI,
|
||||
ResponseFormat: ApiFormatOpenAI,
|
||||
FileScheme: string(scheme.Data),
|
||||
Resolution: openai.Resolution,
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterEngine adds/overrides an engine implementation for a specific API format.
|
||||
func RegisterEngine(format ApiFormat, engine Engine) {
|
||||
engineMu.Lock()
|
||||
defer engineMu.Unlock()
|
||||
engineRegistry[format] = engine
|
||||
}
|
||||
|
||||
// EngineInfo describes metadata that can be associated with an engine alias.
|
||||
type EngineInfo struct {
|
||||
RequestFormat ApiFormat
|
||||
ResponseFormat ApiFormat
|
||||
FileScheme string
|
||||
Resolution int
|
||||
}
|
||||
|
||||
// RegisterEngineAlias maps a logical engine name (e.g., "ollama") to a
|
||||
// request/response format pair.
|
||||
func RegisterEngineAlias(name string, info EngineInfo) {
|
||||
name = strings.TrimSpace(strings.ToLower(name))
|
||||
if name == "" || info.RequestFormat == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if info.ResponseFormat == "" {
|
||||
info.ResponseFormat = info.RequestFormat
|
||||
}
|
||||
|
||||
engineMu.Lock()
|
||||
engineAliasIndex[name] = info
|
||||
engineMu.Unlock()
|
||||
}
|
||||
|
||||
// EngineInfoFor returns the metadata associated with a logical engine name.
|
||||
func EngineInfoFor(name string) (EngineInfo, bool) {
|
||||
name = strings.TrimSpace(strings.ToLower(name))
|
||||
engineMu.RLock()
|
||||
info, ok := engineAliasIndex[name]
|
||||
engineMu.RUnlock()
|
||||
return info, ok
|
||||
}
|
||||
|
||||
// EngineFor returns the registered engine implementation for the given API
|
||||
// format, if any.
|
||||
func EngineFor(format ApiFormat) (Engine, bool) {
|
||||
engineMu.RLock()
|
||||
defer engineMu.RUnlock()
|
||||
engine, ok := engineRegistry[format]
|
||||
return engine, ok
|
||||
}
|
||||
@@ -18,20 +18,23 @@ type ollamaBuilder struct{}
|
||||
type ollamaParser struct{}
|
||||
|
||||
func init() {
|
||||
RegisterProvider(ApiFormatOllama, Provider{
|
||||
RegisterEngine(ApiFormatOllama, Engine{
|
||||
Builder: ollamaBuilder{},
|
||||
Parser: ollamaParser{},
|
||||
Defaults: ollamaDefaults{},
|
||||
})
|
||||
|
||||
RegisterProviderAlias(ollama.ProviderName, ProviderInfo{
|
||||
// Register the human-friendly engine name so configuration can simply use
|
||||
// `Engine: "ollama"` and inherit adapter defaults.
|
||||
RegisterEngineAlias(ollama.EngineName, EngineInfo{
|
||||
RequestFormat: ApiFormatOllama,
|
||||
ResponseFormat: ApiFormatOllama,
|
||||
FileScheme: string(scheme.Base64),
|
||||
Resolution: ollama.Resolution,
|
||||
})
|
||||
|
||||
CaptionModel.Provider = ollama.ProviderName
|
||||
CaptionModel.ApplyProviderDefaults()
|
||||
CaptionModel.Engine = ollama.EngineName
|
||||
CaptionModel.ApplyEngineDefaults()
|
||||
}
|
||||
|
||||
func (ollamaDefaults) SystemPrompt(model *Model) string {
|
||||
18
internal/ai/vision/engine_openai.go
Normal file
18
internal/ai/vision/engine_openai.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"github.com/photoprism/photoprism/internal/ai/vision/openai"
|
||||
"github.com/photoprism/photoprism/pkg/service/http/scheme"
|
||||
)
|
||||
|
||||
// init registers the OpenAI engine alias so models can set Engine: "openai"
|
||||
// and inherit sensible defaults (request/response formats, file scheme, and
|
||||
// preferred thumbnail resolution).
|
||||
func init() {
|
||||
RegisterEngineAlias(openai.EngineName, EngineInfo{
|
||||
RequestFormat: ApiFormatOpenAI,
|
||||
ResponseFormat: ApiFormatOpenAI,
|
||||
FileScheme: string(scheme.Base64),
|
||||
Resolution: openai.Resolution,
|
||||
})
|
||||
}
|
||||
78
internal/ai/vision/engine_test.go
Normal file
78
internal/ai/vision/engine_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package vision
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRegisterEngineAlias(t *testing.T) {
|
||||
const alias = "unit-test"
|
||||
engineMu.Lock()
|
||||
prev, had := engineAliasIndex[alias]
|
||||
if had {
|
||||
delete(engineAliasIndex, alias)
|
||||
}
|
||||
engineMu.Unlock()
|
||||
|
||||
t.Cleanup(func() {
|
||||
engineMu.Lock()
|
||||
if had {
|
||||
engineAliasIndex[alias] = prev
|
||||
} else {
|
||||
delete(engineAliasIndex, alias)
|
||||
}
|
||||
engineMu.Unlock()
|
||||
})
|
||||
|
||||
RegisterEngineAlias(" Unit-Test ", EngineInfo{RequestFormat: ApiFormat("custom"), ResponseFormat: "", FileScheme: "data", Resolution: 512})
|
||||
|
||||
info, ok := EngineInfoFor(alias)
|
||||
if !ok {
|
||||
t.Fatalf("expected engine alias %q to be registered", alias)
|
||||
}
|
||||
|
||||
if info.RequestFormat != ApiFormat("custom") {
|
||||
t.Errorf("unexpected request format: %s", info.RequestFormat)
|
||||
}
|
||||
|
||||
if info.ResponseFormat != ApiFormat("custom") {
|
||||
t.Errorf("expected response format default to request, got %s", info.ResponseFormat)
|
||||
}
|
||||
|
||||
if info.FileScheme != "data" {
|
||||
t.Errorf("unexpected file scheme: %s", info.FileScheme)
|
||||
}
|
||||
|
||||
if info.Resolution != 512 {
|
||||
t.Errorf("unexpected resolution: %d", info.Resolution)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterEngine(t *testing.T) {
|
||||
format := ApiFormat("unit-format")
|
||||
engine := Engine{}
|
||||
|
||||
engineMu.Lock()
|
||||
prev, had := engineRegistry[format]
|
||||
if had {
|
||||
delete(engineRegistry, format)
|
||||
}
|
||||
engineMu.Unlock()
|
||||
|
||||
t.Cleanup(func() {
|
||||
engineMu.Lock()
|
||||
if had {
|
||||
engineRegistry[format] = prev
|
||||
} else {
|
||||
delete(engineRegistry, format)
|
||||
}
|
||||
engineMu.Unlock()
|
||||
})
|
||||
|
||||
RegisterEngine(format, engine)
|
||||
got, ok := EngineFor(format)
|
||||
if !ok {
|
||||
t.Fatalf("expected engine for %s", format)
|
||||
}
|
||||
|
||||
if got != engine {
|
||||
t.Errorf("unexpected engine value: %#v", got)
|
||||
}
|
||||
}
|
||||
@@ -13,10 +13,26 @@ import (
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
)
|
||||
|
||||
var labelsFunc = labelsInternal
|
||||
|
||||
// SetLabelsFunc overrides the labels generator. Intended for tests.
|
||||
func SetLabelsFunc(fn func(Files, media.Src, string) (classify.Labels, error)) {
|
||||
if fn == nil {
|
||||
labelsFunc = labelsInternal
|
||||
return
|
||||
}
|
||||
|
||||
labelsFunc = fn
|
||||
}
|
||||
|
||||
// Labels finds matching labels for the specified image.
|
||||
// Caller must pass the appropriate metadata source string (e.g., entity.SrcOllama, entity.SrcOpenAI)
|
||||
// so that downstream indexing can record where the labels originated.
|
||||
func Labels(images Files, mediaSrc media.Src, labelSrc string) (result classify.Labels, err error) {
|
||||
func Labels(images Files, mediaSrc media.Src, labelSrc string) (classify.Labels, error) {
|
||||
return labelsFunc(images, mediaSrc, labelSrc)
|
||||
}
|
||||
|
||||
func labelsInternal(images Files, mediaSrc media.Src, labelSrc string) (result classify.Labels, err error) {
|
||||
// Return if no thumbnail filenames were given.
|
||||
if len(images) == 0 {
|
||||
return result, errors.New("at least one image required")
|
||||
@@ -42,8 +58,8 @@ func Labels(images Files, mediaSrc media.Src, labelSrc string) (result classify.
|
||||
var apiRequest *ApiRequest
|
||||
var apiResponse *ApiResponse
|
||||
|
||||
if provider, ok := ProviderFor(model.EndpointRequestFormat()); ok && provider.Builder != nil {
|
||||
if apiRequest, err = provider.Builder.Build(context.Background(), model, images); err != nil {
|
||||
if engine, ok := EngineFor(model.EndpointRequestFormat()); ok && engine.Builder != nil {
|
||||
if apiRequest, err = engine.Builder.Build(context.Background(), model, images); err != nil {
|
||||
return result, err
|
||||
}
|
||||
} else if apiRequest, err = NewApiRequest(model.EndpointRequestFormat(), images, model.EndpointFileScheme()); err != nil {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/photoprism/photoprism/internal/ai/vision/ollama"
|
||||
"github.com/photoprism/photoprism/internal/ai/vision/openai"
|
||||
visionschema "github.com/photoprism/photoprism/internal/ai/vision/schema"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/service/http/scheme"
|
||||
@@ -35,7 +36,7 @@ type Model struct {
|
||||
Default bool `yaml:"Default,omitempty" json:"default,omitempty"`
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
|
||||
Provider ModelProvider `yaml:"Provider,omitempty" json:"provider,omitempty"`
|
||||
Engine ModelEngine `yaml:"Engine,omitempty" json:"engine,omitempty"`
|
||||
Run RunType `yaml:"Run,omitempty" json:"Run,omitempty"` // "auto", "never", "manual", "always", "newly-indexed", "on-schedule"
|
||||
System string `yaml:"System,omitempty" json:"system,omitempty"`
|
||||
Prompt string `yaml:"Prompt,omitempty" json:"prompt,omitempty"`
|
||||
@@ -58,8 +59,14 @@ type Model struct {
|
||||
// Models represents a set of computer vision models.
|
||||
type Models []*Model
|
||||
|
||||
// Model returns the parsed and normalized model identifier, name, and version strings.
|
||||
// Model returns the parsed and normalized identifier, name, and version
|
||||
// strings. Nil receivers return empty values so callers can destructure the
|
||||
// tuple without additional nil checks.
|
||||
func (m *Model) Model() (model, name, version string) {
|
||||
if m == nil {
|
||||
return "", "", ""
|
||||
}
|
||||
|
||||
// Return empty identifier string if no name was set.
|
||||
if m.Name == "" {
|
||||
return "", "", clean.TypeLowerDash(m.Version)
|
||||
@@ -91,8 +98,13 @@ func (m *Model) Model() (model, name, version string) {
|
||||
return model, name, version
|
||||
}
|
||||
|
||||
// IsDefault checks if this is a built-in default model.
|
||||
// IsDefault reports whether the model refers to one of the built-in defaults.
|
||||
// Nil receivers return false.
|
||||
func (m *Model) IsDefault() bool {
|
||||
if m == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if m.Default {
|
||||
return true
|
||||
}
|
||||
@@ -115,8 +127,13 @@ func (m *Model) IsDefault() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Endpoint returns the remote service request method and endpoint URL, if any.
|
||||
// Endpoint returns the remote service request method and endpoint URL. Nil
|
||||
// receivers return empty strings.
|
||||
func (m *Model) Endpoint() (uri, method string) {
|
||||
if m == nil {
|
||||
return uri, method
|
||||
}
|
||||
|
||||
if uri, method = m.Service.Endpoint(); uri != "" && method != "" {
|
||||
return uri, method
|
||||
} else if ServiceUri == "" {
|
||||
@@ -128,8 +145,13 @@ func (m *Model) Endpoint() (uri, method string) {
|
||||
}
|
||||
}
|
||||
|
||||
// EndpointKey returns the access token belonging to the remote service endpoint, if any.
|
||||
// EndpointKey returns the access token belonging to the remote service
|
||||
// endpoint, or an empty string for nil receivers.
|
||||
func (m *Model) EndpointKey() (key string) {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if key = m.Service.EndpointKey(); key != "" {
|
||||
return key
|
||||
} else {
|
||||
@@ -137,8 +159,13 @@ func (m *Model) EndpointKey() (key string) {
|
||||
}
|
||||
}
|
||||
|
||||
// EndpointFileScheme returns the endpoint API request file scheme type.
|
||||
// EndpointFileScheme returns the endpoint API request file scheme type. Nil
|
||||
// receivers fall back to the global default scheme.
|
||||
func (m *Model) EndpointFileScheme() (fileScheme scheme.Type) {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if fileScheme = m.Service.EndpointFileScheme(); fileScheme != "" {
|
||||
return fileScheme
|
||||
}
|
||||
@@ -146,8 +173,13 @@ func (m *Model) EndpointFileScheme() (fileScheme scheme.Type) {
|
||||
return ServiceFileScheme
|
||||
}
|
||||
|
||||
// EndpointRequestFormat returns the endpoint API request format.
|
||||
// EndpointRequestFormat returns the endpoint API request format. Nil receivers
|
||||
// fall back to the global default format.
|
||||
func (m *Model) EndpointRequestFormat() (format ApiFormat) {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if format = m.Service.EndpointRequestFormat(); format != "" {
|
||||
return format
|
||||
}
|
||||
@@ -155,8 +187,13 @@ func (m *Model) EndpointRequestFormat() (format ApiFormat) {
|
||||
return ServiceRequestFormat
|
||||
}
|
||||
|
||||
// EndpointResponseFormat returns the endpoint API response format.
|
||||
// EndpointResponseFormat returns the endpoint API response format. Nil
|
||||
// receivers fall back to the global default format.
|
||||
func (m *Model) EndpointResponseFormat() (format ApiFormat) {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if format = m.Service.EndpointResponseFormat(); format != "" {
|
||||
return format
|
||||
}
|
||||
@@ -164,13 +201,18 @@ func (m *Model) EndpointResponseFormat() (format ApiFormat) {
|
||||
return ServiceResponseFormat
|
||||
}
|
||||
|
||||
// GetPrompt returns the configured model prompt, or the default prompt if none is specified.
|
||||
// GetPrompt returns the configured model prompt, using engine defaults when
|
||||
// none is specified. Nil receivers return an empty string.
|
||||
func (m *Model) GetPrompt() string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if m.Prompt != "" {
|
||||
return m.Prompt
|
||||
}
|
||||
|
||||
if defaults := m.defaultsProvider(); defaults != nil {
|
||||
if defaults := m.engineDefaults(); defaults != nil {
|
||||
if prompt := defaults.UserPrompt(m); prompt != "" {
|
||||
return prompt
|
||||
}
|
||||
@@ -186,13 +228,19 @@ func (m *Model) GetPrompt() string {
|
||||
}
|
||||
}
|
||||
|
||||
// GetSystemPrompt returns the configured system model prompt, or the default system prompt if none is specified.
|
||||
// GetSystemPrompt returns the configured system prompt, falling back to
|
||||
// engine defaults when none is specified. Nil receivers return an empty
|
||||
// string.
|
||||
func (m *Model) GetSystemPrompt() string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if m.System != "" {
|
||||
return m.System
|
||||
}
|
||||
|
||||
if defaults := m.defaultsProvider(); defaults != nil {
|
||||
if defaults := m.engineDefaults(); defaults != nil {
|
||||
if system := defaults.SystemPrompt(m); system != "" {
|
||||
return system
|
||||
}
|
||||
@@ -206,8 +254,13 @@ func (m *Model) GetSystemPrompt() string {
|
||||
}
|
||||
}
|
||||
|
||||
// GetFormat returns the configured response format or a sensible default.
|
||||
// GetFormat returns the configured response format or a sensible default. Nil
|
||||
// receivers return an empty string.
|
||||
func (m *Model) GetFormat() string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if f := strings.TrimSpace(strings.ToLower(m.Format)); f != "" {
|
||||
return f
|
||||
}
|
||||
@@ -219,28 +272,56 @@ func (m *Model) GetFormat() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetOptions returns the API request options.
|
||||
// GetSource returns the default entity src based on the model configuration.
|
||||
func (m *Model) GetSource() string {
|
||||
if m == nil {
|
||||
return entity.SrcAuto
|
||||
}
|
||||
|
||||
switch m.EngineName() {
|
||||
case ollama.EngineName:
|
||||
return entity.SrcOllama
|
||||
case openai.EngineName:
|
||||
return entity.SrcOpenAI
|
||||
}
|
||||
|
||||
switch m.EndpointRequestFormat() {
|
||||
case ApiFormatOllama:
|
||||
return entity.SrcOllama
|
||||
case ApiFormatOpenAI:
|
||||
return entity.SrcOpenAI
|
||||
}
|
||||
|
||||
return entity.SrcImage
|
||||
}
|
||||
|
||||
// GetOptions returns the API request options, applying engine defaults on
|
||||
// demand. Nil receivers return nil.
|
||||
func (m *Model) GetOptions() *ApiRequestOptions {
|
||||
var providerDefaults *ApiRequestOptions
|
||||
if defaults := m.defaultsProvider(); defaults != nil {
|
||||
providerDefaults = cloneOptions(defaults.Options(m))
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var engineDefaults *ApiRequestOptions
|
||||
if defaults := m.engineDefaults(); defaults != nil {
|
||||
engineDefaults = cloneOptions(defaults.Options(m))
|
||||
}
|
||||
|
||||
if m.Options == nil {
|
||||
switch m.Type {
|
||||
case ModelTypeLabels, ModelTypeCaption, ModelTypeGenerate:
|
||||
if providerDefaults == nil {
|
||||
providerDefaults = &ApiRequestOptions{}
|
||||
if engineDefaults == nil {
|
||||
engineDefaults = &ApiRequestOptions{}
|
||||
}
|
||||
normalizeOptions(providerDefaults)
|
||||
m.Options = providerDefaults
|
||||
normalizeOptions(engineDefaults)
|
||||
m.Options = engineDefaults
|
||||
return m.Options
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
mergeOptionDefaults(m.Options, providerDefaults)
|
||||
mergeOptionDefaults(m.Options, engineDefaults)
|
||||
normalizeOptions(m.Options)
|
||||
|
||||
return m.Options
|
||||
@@ -286,14 +367,15 @@ func cloneOptions(opts *ApiRequestOptions) *ApiRequestOptions {
|
||||
return &clone
|
||||
}
|
||||
|
||||
// ProviderName returns the configured provider or infers a sensible default based on the model settings.
|
||||
func (m *Model) ProviderName() string {
|
||||
// EngineName returns the normalized engine identifier or infers one from the
|
||||
// request configuration. Nil receivers return an empty string.
|
||||
func (m *Model) EngineName() string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if provider := strings.TrimSpace(strings.ToLower(m.Provider)); provider != "" {
|
||||
return provider
|
||||
if engine := strings.TrimSpace(strings.ToLower(m.Engine)); engine != "" {
|
||||
return engine
|
||||
}
|
||||
|
||||
uri, method := m.Endpoint()
|
||||
@@ -301,36 +383,36 @@ func (m *Model) ProviderName() string {
|
||||
format := m.EndpointRequestFormat()
|
||||
switch format {
|
||||
case ApiFormatOllama:
|
||||
return ollama.ProviderName
|
||||
return ollama.EngineName
|
||||
case ApiFormatOpenAI:
|
||||
return openai.ProviderName
|
||||
return openai.EngineName
|
||||
case ApiFormatVision, "":
|
||||
return ProviderVision
|
||||
return EngineVision
|
||||
default:
|
||||
return strings.ToLower(string(format))
|
||||
}
|
||||
}
|
||||
|
||||
if m.TensorFlow != nil {
|
||||
return ProviderTensorFlow
|
||||
return EngineTensorFlow
|
||||
}
|
||||
|
||||
return ProviderLocal
|
||||
return EngineLocal
|
||||
}
|
||||
|
||||
// ApplyProviderDefaults normalizes the provider name and applies registered provider defaults
|
||||
// for request/response formats and file schemes when these are not explicitly configured.
|
||||
func (m *Model) ApplyProviderDefaults() {
|
||||
// ApplyEngineDefaults normalizes the engine name and applies registered engine
|
||||
// defaults (formats, schemes, resolution) when these are not explicitly configured.
|
||||
func (m *Model) ApplyEngineDefaults() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
provider := strings.TrimSpace(strings.ToLower(m.Provider))
|
||||
if provider == "" {
|
||||
engine := strings.TrimSpace(strings.ToLower(m.Engine))
|
||||
if engine == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if info, ok := ProviderInfoFor(provider); ok {
|
||||
if info, ok := EngineInfoFor(engine); ok {
|
||||
if m.Service.RequestFormat == "" {
|
||||
m.Service.RequestFormat = info.RequestFormat
|
||||
}
|
||||
@@ -342,13 +424,22 @@ func (m *Model) ApplyProviderDefaults() {
|
||||
if info.FileScheme != "" && m.Service.FileScheme == "" {
|
||||
m.Service.FileScheme = info.FileScheme
|
||||
}
|
||||
|
||||
if info.Resolution > 0 && m.Resolution <= 0 {
|
||||
m.Resolution = info.Resolution
|
||||
}
|
||||
}
|
||||
|
||||
m.Provider = provider
|
||||
m.Engine = engine
|
||||
}
|
||||
|
||||
// SchemaTemplate returns the model-specific JSON schema template, if any.
|
||||
// SchemaTemplate returns the model-specific JSON schema template, if any. Nil
|
||||
// receivers return an empty string.
|
||||
func (m *Model) SchemaTemplate() string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
m.schemaOnce.Do(func() {
|
||||
var schemaText string
|
||||
|
||||
@@ -385,7 +476,7 @@ func (m *Model) SchemaTemplate() string {
|
||||
m.schema = strings.TrimSpace(schemaText)
|
||||
|
||||
if m.schema == "" && m.Type == ModelTypeLabels {
|
||||
if defaults := m.defaultsProvider(); defaults != nil {
|
||||
if defaults := m.engineDefaults(); defaults != nil {
|
||||
m.schema = strings.TrimSpace(defaults.SchemaTemplate(m))
|
||||
}
|
||||
}
|
||||
@@ -398,21 +489,30 @@ func (m *Model) SchemaTemplate() string {
|
||||
return m.schema
|
||||
}
|
||||
|
||||
func (m *Model) defaultsProvider() DefaultsProvider {
|
||||
if provider, ok := ProviderFor(m.EndpointRequestFormat()); ok {
|
||||
return provider.Defaults
|
||||
func (m *Model) engineDefaults() EngineDefaults {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if info, ok := ProviderInfoFor(m.ProviderName()); ok {
|
||||
if provider, ok := ProviderFor(info.RequestFormat); ok {
|
||||
return provider.Defaults
|
||||
if engine, ok := EngineFor(m.EndpointRequestFormat()); ok {
|
||||
return engine.Defaults
|
||||
}
|
||||
|
||||
if info, ok := EngineInfoFor(m.EngineName()); ok {
|
||||
if engine, ok := EngineFor(info.RequestFormat); ok {
|
||||
return engine.Defaults
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SchemaInstructions returns a helper string that can be appended to prompts.
|
||||
// Nil receivers return an empty string.
|
||||
func (m *Model) SchemaInstructions() string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if schema := m.SchemaTemplate(); schema != "" {
|
||||
return fmt.Sprintf("Return JSON that matches this schema:\n%s", schema)
|
||||
}
|
||||
@@ -420,8 +520,13 @@ func (m *Model) SchemaInstructions() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// ClassifyModel returns the matching classify model instance, if any.
|
||||
// ClassifyModel returns the matching classify model instance, if any. Nil
|
||||
// receivers return nil.
|
||||
func (m *Model) ClassifyModel() *classify.Model {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use mutex to prevent models from being loaded and
|
||||
// initialized twice by different indexing workers.
|
||||
modelMutex.Lock()
|
||||
@@ -481,8 +586,13 @@ func (m *Model) ClassifyModel() *classify.Model {
|
||||
return m.classifyModel
|
||||
}
|
||||
|
||||
// FaceModel returns the matching face model instance, if any.
|
||||
// FaceModel returns the matching face recognition model instance, if any. Nil
|
||||
// receivers return nil.
|
||||
func (m *Model) FaceModel() *face.Model {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use mutex to prevent models from being loaded and
|
||||
// initialized twice by different indexing workers.
|
||||
modelMutex.Lock()
|
||||
@@ -536,8 +646,13 @@ func (m *Model) FaceModel() *face.Model {
|
||||
return m.faceModel
|
||||
}
|
||||
|
||||
// NsfwModel returns the matching nsfw model instance, if any.
|
||||
// NsfwModel returns the matching nsfw model instance, if any. Nil receivers
|
||||
// return nil.
|
||||
func (m *Model) NsfwModel() *nsfw.Model {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use mutex to prevent models from being loaded and
|
||||
// initialized twice by different indexing workers.
|
||||
modelMutex.Lock()
|
||||
@@ -597,8 +712,12 @@ func (m *Model) NsfwModel() *nsfw.Model {
|
||||
return m.nsfwModel
|
||||
}
|
||||
|
||||
// Clone returns a clone of this model.
|
||||
// Clone returns a shallow copy of the model. Nil receivers return nil.
|
||||
func (m *Model) Clone() *Model {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := *m
|
||||
return &c
|
||||
}
|
||||
|
||||
29
internal/ai/vision/model_filters.go
Normal file
29
internal/ai/vision/model_filters.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FilterModels takes a list of model type names and a scheduling context, and
|
||||
// returns only the types that are allowed to run according to the supplied
|
||||
// predicate. Empty or unknown names are ignored.
|
||||
func FilterModels(models []string, when RunType, allow func(ModelType, RunType) bool) []string {
|
||||
if len(models) == 0 {
|
||||
return models
|
||||
}
|
||||
|
||||
filtered := make([]string, 0, len(models))
|
||||
|
||||
for _, name := range models {
|
||||
modelType := ModelType(strings.TrimSpace(name))
|
||||
if modelType == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if allow == nil || allow(modelType, when) {
|
||||
filtered = append(filtered, string(modelType))
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
56
internal/ai/vision/model_filters_test.go
Normal file
56
internal/ai/vision/model_filters_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFilterModels(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
models []string
|
||||
when RunType
|
||||
allow func(ModelType, RunType) bool
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "NilPredicate",
|
||||
models: []string{"caption", "labels"},
|
||||
when: RunManual,
|
||||
allow: nil,
|
||||
expected: []string{"caption", "labels"},
|
||||
},
|
||||
{
|
||||
name: "SkipUnknown",
|
||||
models: []string{"caption", "", "unknown", "labels"},
|
||||
when: RunManual,
|
||||
allow: func(mt ModelType, when RunType) bool {
|
||||
return mt == ModelTypeLabels
|
||||
},
|
||||
expected: []string{"labels"},
|
||||
},
|
||||
{
|
||||
name: "ContextAware",
|
||||
models: []string{"caption", "labels"},
|
||||
when: RunOnSchedule,
|
||||
allow: func(mt ModelType, when RunType) bool {
|
||||
return mt == ModelTypeCaption
|
||||
},
|
||||
expected: []string{"caption"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := FilterModels(tc.models, tc.when, tc.allow)
|
||||
if len(got) != len(tc.expected) {
|
||||
t.Fatalf("expected %d models, got %d", len(tc.expected), len(got))
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tc.expected[i] {
|
||||
t.Fatalf("expected %v, got %v", tc.expected, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -38,7 +38,8 @@ var RunTypes = map[string]RunType{
|
||||
"index": RunOnIndex,
|
||||
}
|
||||
|
||||
// ParseRunType parses a run type string.
|
||||
// ParseRunType parses a run type string into the canonical RunType constant.
|
||||
// Unknown or empty values default to RunAuto.
|
||||
func ParseRunType(s string) RunType {
|
||||
if t, ok := RunTypes[clean.TypeLowerDash(s)]; ok {
|
||||
return t
|
||||
@@ -47,13 +48,23 @@ func ParseRunType(s string) RunType {
|
||||
return RunAuto
|
||||
}
|
||||
|
||||
// RunType returns a normalized type that specifies when a vision model should run.
|
||||
// RunType returns the normalized run type configured for the model. Nil
|
||||
// receivers default to RunAuto.
|
||||
func (m *Model) RunType() RunType {
|
||||
if m == nil {
|
||||
return RunAuto
|
||||
}
|
||||
|
||||
return ParseRunType(m.Run)
|
||||
}
|
||||
|
||||
// ShouldRun checks when the model should run based on the specified type.
|
||||
// ShouldRun reports whether the model should execute in the specified
|
||||
// scheduling context. Nil receivers always return false.
|
||||
func (m *Model) ShouldRun(when RunType) bool {
|
||||
if m == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
when = ParseRunType(when)
|
||||
|
||||
switch m.RunType() {
|
||||
|
||||
@@ -35,6 +35,11 @@ func TestModel_RunType(t *testing.T) {
|
||||
model *Model
|
||||
want RunType
|
||||
}{
|
||||
{
|
||||
name: "Nil",
|
||||
model: nil,
|
||||
want: RunAuto,
|
||||
},
|
||||
{
|
||||
name: "Manual",
|
||||
model: &Model{Run: "manual"},
|
||||
@@ -135,6 +140,13 @@ func TestModel_ShouldRun_RunNever(t *testing.T) {
|
||||
assertShouldRun(t, model, RunOnDemand, false)
|
||||
}
|
||||
|
||||
func TestModel_ShouldRun_NilModel(t *testing.T) {
|
||||
var model *Model
|
||||
if model.ShouldRun(RunManual) {
|
||||
t.Fatalf("expected nil model to never run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_ShouldRun_RunOnIndex(t *testing.T) {
|
||||
model := &Model{Run: string(RunOnIndex)}
|
||||
|
||||
|
||||
@@ -5,15 +5,16 @@ import (
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/tensorflow"
|
||||
"github.com/photoprism/photoprism/internal/ai/vision/ollama"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
)
|
||||
|
||||
func TestModelGetOptionsDefaultsOllamaLabels(t *testing.T) {
|
||||
model := &Model{
|
||||
Type: ModelTypeLabels,
|
||||
Provider: ollama.ProviderName,
|
||||
Type: ModelTypeLabels,
|
||||
Engine: ollama.EngineName,
|
||||
}
|
||||
|
||||
model.ApplyProviderDefaults()
|
||||
model.ApplyEngineDefaults()
|
||||
|
||||
opts := model.GetOptions()
|
||||
if opts == nil {
|
||||
@@ -39,8 +40,8 @@ func TestModelGetOptionsDefaultsOllamaLabels(t *testing.T) {
|
||||
|
||||
func TestModelGetOptionsRespectsCustomValues(t *testing.T) {
|
||||
model := &Model{
|
||||
Type: ModelTypeLabels,
|
||||
Provider: ollama.ProviderName,
|
||||
Type: ModelTypeLabels,
|
||||
Engine: ollama.EngineName,
|
||||
Options: &ApiRequestOptions{
|
||||
Temperature: 5,
|
||||
TopP: 0.95,
|
||||
@@ -48,7 +49,7 @@ func TestModelGetOptionsRespectsCustomValues(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
model.ApplyProviderDefaults()
|
||||
model.ApplyEngineDefaults()
|
||||
|
||||
opts := model.GetOptions()
|
||||
if opts.Temperature != MaxTemperature {
|
||||
@@ -64,12 +65,12 @@ func TestModelGetOptionsRespectsCustomValues(t *testing.T) {
|
||||
|
||||
func TestModelGetOptionsFillsMissingFields(t *testing.T) {
|
||||
model := &Model{
|
||||
Type: ModelTypeLabels,
|
||||
Provider: ollama.ProviderName,
|
||||
Options: &ApiRequestOptions{},
|
||||
Type: ModelTypeLabels,
|
||||
Engine: ollama.EngineName,
|
||||
Options: &ApiRequestOptions{},
|
||||
}
|
||||
|
||||
model.ApplyProviderDefaults()
|
||||
model.ApplyEngineDefaults()
|
||||
|
||||
opts := model.GetOptions()
|
||||
if opts.TopP != 0.9 {
|
||||
@@ -80,6 +81,52 @@ func TestModelGetOptionsFillsMissingFields(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelApplyEngineDefaultsSetsResolution(t *testing.T) {
|
||||
model := &Model{Type: ModelTypeLabels, Engine: ollama.EngineName}
|
||||
|
||||
model.ApplyEngineDefaults()
|
||||
|
||||
if model.Resolution != ollama.Resolution {
|
||||
t.Fatalf("expected resolution %d, got %d", ollama.Resolution, model.Resolution)
|
||||
}
|
||||
|
||||
model.Resolution = 1024
|
||||
model.ApplyEngineDefaults()
|
||||
if model.Resolution != 1024 {
|
||||
t.Fatalf("expected custom resolution to be preserved, got %d", model.Resolution)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelGetSource(t *testing.T) {
|
||||
t.Run("NilModel", func(t *testing.T) {
|
||||
var model *Model
|
||||
if src := model.GetSource(); src != entity.SrcAuto {
|
||||
t.Fatalf("expected SrcAuto for nil model, got %s", src)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EngineAlias", func(t *testing.T) {
|
||||
model := &Model{Engine: ollama.EngineName}
|
||||
if src := model.GetSource(); src != entity.SrcOllama {
|
||||
t.Fatalf("expected SrcOllama, got %s", src)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RequestFormat", func(t *testing.T) {
|
||||
model := &Model{Service: Service{RequestFormat: ApiFormatOpenAI}}
|
||||
if src := model.GetSource(); src != entity.SrcOpenAI {
|
||||
t.Fatalf("expected SrcOpenAI, got %s", src)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DefaultImage", func(t *testing.T) {
|
||||
model := &Model{}
|
||||
if src := model.GetSource(); src != entity.SrcImage {
|
||||
t.Fatalf("expected SrcImage fallback, got %s", src)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModel_IsDefault(t *testing.T) {
|
||||
nasnetCopy := *NasnetModel
|
||||
nasnetCopy.Default = false
|
||||
@@ -111,9 +158,9 @@ func TestModel_IsDefault(t *testing.T) {
|
||||
{
|
||||
name: "RemoteService",
|
||||
model: &Model{
|
||||
Type: ModelTypeCaption,
|
||||
Name: "custom-caption",
|
||||
Provider: ollama.ProviderName,
|
||||
Type: ModelTypeCaption,
|
||||
Name: "custom-caption",
|
||||
Engine: ollama.EngineName,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
|
||||
@@ -91,7 +91,7 @@ var (
|
||||
Type: ModelTypeCaption,
|
||||
Name: ollama.CaptionModel,
|
||||
Version: VersionLatest,
|
||||
Provider: ollama.ProviderName,
|
||||
Engine: ollama.EngineName,
|
||||
Resolution: 720, // Original aspect ratio, with a max size of 720 x 720 pixels.
|
||||
Service: Service{
|
||||
Uri: "http://ollama:11434/api/generate",
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package ollama
|
||||
|
||||
const (
|
||||
// ProviderName is the canonical identifier for Ollama-based vision services.
|
||||
ProviderName = "ollama"
|
||||
// EngineName is the canonical identifier for Ollama-based vision services.
|
||||
EngineName = "ollama"
|
||||
// ApiFormat identifies Ollama-compatible request and response payloads.
|
||||
ApiFormat = "ollama"
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ const (
|
||||
CaptionModel = "gemma3"
|
||||
LabelSystem = "You are a PhotoPrism vision model. Output concise JSON that matches the schema."
|
||||
LabelPrompt = "Analyze the image and return label objects with name, confidence (0-1), and topicality (0-1)."
|
||||
Resolution = 720
|
||||
)
|
||||
|
||||
func LabelSchema() string {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/*
|
||||
Package ollama integrates PhotoPrism's vision pipeline with Ollama-compatible
|
||||
multi-modal models so adapters can share logging and provider helpers.
|
||||
multi-modal models so adapters can share logging and engine helpers.
|
||||
|
||||
Copyright (c) 2018 - 2025 PhotoPrism UG. All rights reserved.
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package openai
|
||||
|
||||
const (
|
||||
// ProviderName is the canonical identifier for OpenAI-based vision services.
|
||||
ProviderName = "openai"
|
||||
// EngineName is the canonical identifier for OpenAI-based vision services.
|
||||
EngineName = "openai"
|
||||
// ApiFormat identifies OpenAI-compatible request and response payloads.
|
||||
ApiFormat = "openai"
|
||||
)
|
||||
|
||||
12
internal/ai/vision/openai/defaults.go
Normal file
12
internal/ai/vision/openai/defaults.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package openai
|
||||
|
||||
import "github.com/photoprism/photoprism/internal/ai/vision/schema"
|
||||
|
||||
const (
|
||||
DefaultModel = "gpt-5-mini"
|
||||
Resolution = 720
|
||||
)
|
||||
|
||||
func LabelSchema() string {
|
||||
return schema.LabelDefaultV1
|
||||
}
|
||||
@@ -1,97 +0,0 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ModelProvider represents the canonical identifier for a computer vision service provider.
|
||||
type ModelProvider = string
|
||||
|
||||
const (
|
||||
// ProviderVision represents the default PhotoPrism vision service endpoints.
|
||||
ProviderVision ModelProvider = "vision"
|
||||
// ProviderTensorFlow represents on-device TensorFlow models.
|
||||
ProviderTensorFlow ModelProvider = "tensorflow"
|
||||
// ProviderLocal is used when no explicit provider can be determined.
|
||||
ProviderLocal ModelProvider = "local"
|
||||
)
|
||||
|
||||
// RequestBuilder builds an API request for a provider based on the model configuration and input files.
|
||||
type RequestBuilder interface {
|
||||
Build(ctx context.Context, model *Model, files Files) (*ApiRequest, error)
|
||||
}
|
||||
|
||||
// ResponseParser parses a raw provider response into the generic ApiResponse structure.
|
||||
type ResponseParser interface {
|
||||
Parse(ctx context.Context, req *ApiRequest, raw []byte, status int) (*ApiResponse, error)
|
||||
}
|
||||
|
||||
// DefaultsProvider supplies provider-specific prompt and schema defaults when they are not configured explicitly.
|
||||
type DefaultsProvider interface {
|
||||
SystemPrompt(model *Model) string
|
||||
UserPrompt(model *Model) string
|
||||
SchemaTemplate(model *Model) string
|
||||
Options(model *Model) *ApiRequestOptions
|
||||
}
|
||||
|
||||
// Provider groups the callbacks required to integrate a third-party vision service.
|
||||
type Provider struct {
|
||||
Builder RequestBuilder
|
||||
Parser ResponseParser
|
||||
Defaults DefaultsProvider
|
||||
}
|
||||
|
||||
var (
|
||||
providerRegistry = make(map[ApiFormat]Provider)
|
||||
providerAliasIndex = make(map[string]ProviderInfo)
|
||||
providerMu sync.RWMutex
|
||||
)
|
||||
|
||||
// RegisterProvider adds/overrides a provider implementation for a specific API format.
|
||||
func RegisterProvider(format ApiFormat, provider Provider) {
|
||||
providerMu.Lock()
|
||||
defer providerMu.Unlock()
|
||||
providerRegistry[format] = provider
|
||||
}
|
||||
|
||||
// ProviderInfo describes metadata that can be associated with a provider alias.
|
||||
type ProviderInfo struct {
|
||||
RequestFormat ApiFormat
|
||||
ResponseFormat ApiFormat
|
||||
FileScheme string
|
||||
}
|
||||
|
||||
// RegisterProviderAlias maps a logical provider name (e.g. "ollama") to a request/response format pair.
|
||||
func RegisterProviderAlias(name string, info ProviderInfo) {
|
||||
name = strings.TrimSpace(strings.ToLower(name))
|
||||
if name == "" || info.RequestFormat == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if info.ResponseFormat == "" {
|
||||
info.ResponseFormat = info.RequestFormat
|
||||
}
|
||||
|
||||
providerMu.Lock()
|
||||
providerAliasIndex[name] = info
|
||||
providerMu.Unlock()
|
||||
}
|
||||
|
||||
// ProviderInfoFor returns the metadata associated with a logical provider name.
|
||||
func ProviderInfoFor(name string) (ProviderInfo, bool) {
|
||||
name = strings.TrimSpace(strings.ToLower(name))
|
||||
providerMu.RLock()
|
||||
info, ok := providerAliasIndex[name]
|
||||
providerMu.RUnlock()
|
||||
return info, ok
|
||||
}
|
||||
|
||||
// ProviderFor returns the registered provider implementation for the given API format, if any.
|
||||
func ProviderFor(format ApiFormat) (Provider, bool) {
|
||||
providerMu.RLock()
|
||||
defer providerMu.RUnlock()
|
||||
provider, ok := providerRegistry[format]
|
||||
return provider, ok
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
package vision
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRegisterProviderAlias(t *testing.T) {
|
||||
const alias = "unit-test"
|
||||
providerMu.Lock()
|
||||
prev, had := providerAliasIndex[alias]
|
||||
if had {
|
||||
delete(providerAliasIndex, alias)
|
||||
}
|
||||
providerMu.Unlock()
|
||||
|
||||
t.Cleanup(func() {
|
||||
providerMu.Lock()
|
||||
if had {
|
||||
providerAliasIndex[alias] = prev
|
||||
} else {
|
||||
delete(providerAliasIndex, alias)
|
||||
}
|
||||
providerMu.Unlock()
|
||||
})
|
||||
|
||||
RegisterProviderAlias(" Unit-Test ", ProviderInfo{RequestFormat: ApiFormat("custom"), ResponseFormat: "", FileScheme: "data"})
|
||||
|
||||
info, ok := ProviderInfoFor(alias)
|
||||
if !ok {
|
||||
t.Fatalf("expected provider alias %q to be registered", alias)
|
||||
}
|
||||
|
||||
if info.RequestFormat != ApiFormat("custom") {
|
||||
t.Errorf("unexpected request format: %s", info.RequestFormat)
|
||||
}
|
||||
|
||||
if info.ResponseFormat != ApiFormat("custom") {
|
||||
t.Errorf("expected response format default to request, got %s", info.ResponseFormat)
|
||||
}
|
||||
|
||||
if info.FileScheme != "data" {
|
||||
t.Errorf("unexpected file scheme: %s", info.FileScheme)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterProvider(t *testing.T) {
|
||||
format := ApiFormat("unit-format")
|
||||
provider := Provider{}
|
||||
|
||||
providerMu.Lock()
|
||||
prev, had := providerRegistry[format]
|
||||
if had {
|
||||
delete(providerRegistry, format)
|
||||
}
|
||||
providerMu.Unlock()
|
||||
|
||||
t.Cleanup(func() {
|
||||
providerMu.Lock()
|
||||
if had {
|
||||
providerRegistry[format] = prev
|
||||
} else {
|
||||
delete(providerRegistry, format)
|
||||
}
|
||||
providerMu.Unlock()
|
||||
})
|
||||
|
||||
RegisterProvider(format, provider)
|
||||
got, ok := ProviderFor(format)
|
||||
if !ok {
|
||||
t.Fatalf("expected provider for %s", format)
|
||||
}
|
||||
|
||||
if got != provider {
|
||||
t.Errorf("unexpected provider value: %#v", got)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package schema
|
||||
|
||||
// LabelDefaultV1 provides the minimal JSON schema for label responses used across providers.
|
||||
// LabelDefaultV1 provides the minimal JSON schema for label responses used across engines.
|
||||
const LabelDefaultV1 = "{\n \"labels\": [{\n \"name\": \"\",\n \"confidence\": 0,\n \"topicality\": 0\n }]\n}"
|
||||
|
||||
2
internal/ai/vision/testdata/vision.yml
vendored
2
internal/ai/vision/testdata/vision.yml
vendored
@@ -67,7 +67,7 @@ Models:
|
||||
- Type: caption
|
||||
Name: gemma3
|
||||
Version: latest
|
||||
Provider: ollama
|
||||
Engine: ollama
|
||||
Resolution: 720
|
||||
Service:
|
||||
Uri: http://ollama:11434/api/generate
|
||||
|
||||
@@ -28,16 +28,14 @@ func visionListAction(ctx *cli.Context) error {
|
||||
|
||||
cols := []string{
|
||||
"Type",
|
||||
"Name",
|
||||
"Version",
|
||||
"Model",
|
||||
"Engine",
|
||||
"Endpoint",
|
||||
"Format",
|
||||
"Resolution",
|
||||
"Provider",
|
||||
"Service Endpoint",
|
||||
"Request Format",
|
||||
"Response Format",
|
||||
"Options",
|
||||
"Tags",
|
||||
"Disabled",
|
||||
"Schedule",
|
||||
"Status",
|
||||
}
|
||||
|
||||
// Show log message.
|
||||
@@ -54,47 +52,55 @@ func visionListAction(ctx *cli.Context) error {
|
||||
modelUri, modelMethod := model.Endpoint()
|
||||
tags := ""
|
||||
|
||||
_, name, version := model.Model()
|
||||
name, _, _ := model.Model()
|
||||
|
||||
if model.TensorFlow != nil && model.TensorFlow.Tags != nil {
|
||||
tags = strings.Join(model.TensorFlow.Tags, ", ")
|
||||
}
|
||||
|
||||
if model.Default {
|
||||
version = "default"
|
||||
}
|
||||
|
||||
var options []byte
|
||||
if o := model.GetOptions(); o != nil {
|
||||
options, _ = json.Marshal(*o)
|
||||
}
|
||||
|
||||
var responseFormat, requestFormat string
|
||||
var format string
|
||||
|
||||
if modelUri != "" && modelMethod != "" {
|
||||
if f := strings.TrimSpace(string(model.EndpointRequestFormat())); f != "" {
|
||||
requestFormat = f
|
||||
}
|
||||
|
||||
if f := strings.TrimSpace(string(model.EndpointResponseFormat())); f != "" {
|
||||
responseFormat = f
|
||||
if f := model.EndpointRequestFormat(); f != "" {
|
||||
format = f
|
||||
}
|
||||
}
|
||||
|
||||
provider := model.ProviderName()
|
||||
if responseFormat := model.GetFormat(); responseFormat != "" {
|
||||
if format != "" {
|
||||
format = fmt.Sprintf("%s:%s", format, responseFormat)
|
||||
} else {
|
||||
format = responseFormat
|
||||
}
|
||||
}
|
||||
|
||||
if format == "" && model.Default {
|
||||
format = "default"
|
||||
}
|
||||
|
||||
var run string
|
||||
|
||||
if run = model.RunType(); run == "" {
|
||||
run = "auto"
|
||||
}
|
||||
|
||||
engine := model.EngineName()
|
||||
|
||||
rows[i] = []string{
|
||||
model.Type,
|
||||
name,
|
||||
version,
|
||||
fmt.Sprintf("%d", model.Resolution),
|
||||
provider,
|
||||
engine,
|
||||
fmt.Sprintf("%s %s", modelMethod, modelUri),
|
||||
requestFormat,
|
||||
responseFormat,
|
||||
string(options),
|
||||
tags,
|
||||
report.Bool(model.Disabled, report.Yes, report.No),
|
||||
format,
|
||||
fmt.Sprintf("%d", model.Resolution),
|
||||
report.Bool(len(options) == 0, "tags: "+tags, string(options)),
|
||||
run,
|
||||
report.Bool(model.Disabled, report.Disabled, report.Enabled),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ func visionRunAction(ctx *cli.Context) error {
|
||||
vision.ParseModelTypes(ctx.String("models")),
|
||||
string(source),
|
||||
ctx.Bool("force"),
|
||||
vision.RunManual,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package config
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
@@ -22,6 +23,16 @@ func (c *Config) VisionYaml() string {
|
||||
}
|
||||
}
|
||||
|
||||
// VisionSchedule returns the cron schedule configured for the vision worker, or "" if disabled.
|
||||
func (c *Config) VisionSchedule() string {
|
||||
return Schedule(c.options.VisionSchedule)
|
||||
}
|
||||
|
||||
// VisionFilter returns the search filter to use for scheduled vision runs.
|
||||
func (c *Config) VisionFilter() string {
|
||||
return strings.TrimSpace(c.options.VisionFilter)
|
||||
}
|
||||
|
||||
// VisionModelShouldRun checks when the specified model type should run.
|
||||
func (c *Config) VisionModelShouldRun(t vision.ModelType, when vision.RunType) bool {
|
||||
if t == vision.ModelTypeLabels && c.DisableClassification() {
|
||||
|
||||
@@ -113,3 +113,25 @@ func TestConfig_VisionModelShouldRun(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_VisionSchedule(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
|
||||
c.options.VisionSchedule = ""
|
||||
assert.Equal(t, "", c.VisionSchedule())
|
||||
|
||||
c.options.VisionSchedule = "0 6 * * *"
|
||||
assert.Equal(t, "0 6 * * *", c.VisionSchedule())
|
||||
|
||||
c.options.VisionSchedule = "invalid"
|
||||
assert.Equal(t, "", c.VisionSchedule())
|
||||
}
|
||||
|
||||
func TestConfig_VisionFilter(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
c.options.VisionFilter = " private:false "
|
||||
assert.Equal(t, "private:false", c.VisionFilter())
|
||||
|
||||
c.options.VisionFilter = ""
|
||||
assert.Equal(t, "", c.VisionFilter())
|
||||
}
|
||||
|
||||
@@ -4,5 +4,5 @@ package feat
|
||||
var (
|
||||
VisionModelGenerate = false // controls exposure of the generate endpoint and CLI commands
|
||||
VisionModelMarkers = false // gates marker generation/return until downstream UI and reconciliation paths are ready
|
||||
VisionServiceOpenAI = false // controls whether users are able to configure OpenAI as vision service provider
|
||||
VisionServiceOpenAI = false // controls whether users are able to configure OpenAI as a vision service engine
|
||||
)
|
||||
|
||||
@@ -1151,6 +1151,17 @@ var Flags = CliFlags{
|
||||
Value: "",
|
||||
EnvVars: EnvVars("VISION_KEY"),
|
||||
}}, {
|
||||
Flag: &cli.StringFlag{
|
||||
Name: "vision-schedule",
|
||||
Usage: "vision worker `SCHEDULE` for background processing (e.g. \"0 12 * * *\" for daily at noon) or at a random time (daily, weekly)",
|
||||
EnvVars: EnvVars("VISION_SCHEDULE"),
|
||||
}}, {
|
||||
Flag: &cli.StringFlag{
|
||||
Name: "vision-filter",
|
||||
Usage: "vision worker search `FILTER` applied to scheduled runs (same syntax as photoprism vision run)",
|
||||
Value: "public:true",
|
||||
EnvVars: EnvVars("VISION_FILTER"),
|
||||
}}, {
|
||||
Flag: &cli.BoolFlag{
|
||||
Name: "detect-nsfw",
|
||||
Usage: "flags newly added pictures as private if they might be offensive (requires TensorFlow)",
|
||||
|
||||
@@ -226,6 +226,8 @@ type Options struct {
|
||||
VisionApi bool `yaml:"VisionApi" json:"-" flag:"vision-api"`
|
||||
VisionUri string `yaml:"VisionUri" json:"-" flag:"vision-uri"`
|
||||
VisionKey string `yaml:"VisionKey" json:"-" flag:"vision-key"`
|
||||
VisionSchedule string `yaml:"VisionSchedule" json:"VisionSchedule" flag:"vision-schedule"`
|
||||
VisionFilter string `yaml:"VisionFilter" json:"VisionFilter" flag:"vision-filter"`
|
||||
DetectNSFW bool `yaml:"DetectNSFW" json:"DetectNSFW" flag:"detect-nsfw"`
|
||||
FaceSize int `yaml:"-" json:"-" flag:"face-size"`
|
||||
FaceScore float64 `yaml:"-" json:"-" flag:"face-score"`
|
||||
|
||||
@@ -275,6 +275,8 @@ func (c *Config) Report() (rows [][]string, cols []string) {
|
||||
{"vision-api", fmt.Sprintf("%t", c.VisionApi())},
|
||||
{"vision-uri", c.VisionUri()},
|
||||
{"vision-key", strings.Repeat("*", utf8.RuneCountInString(c.VisionKey()))},
|
||||
{"vision-schedule", c.VisionSchedule()},
|
||||
{"vision-filter", c.VisionFilter()},
|
||||
{"nasnet-model-path", c.NasnetModelPath()},
|
||||
{"facenet-model-path", c.FacenetModelPath()},
|
||||
{"nsfw-model-path", c.NsfwModelPath()},
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/photoprism/photoprism/internal/event"
|
||||
"github.com/photoprism/photoprism/internal/form"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/list"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
"github.com/photoprism/photoprism/pkg/react"
|
||||
"github.com/photoprism/photoprism/pkg/rnd"
|
||||
@@ -752,6 +753,25 @@ func (m *Photo) SaveDetails() error {
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldGenerateLabels checks if labels should be generated for this model.
|
||||
func (m *Photo) ShouldGenerateLabels(force bool) bool {
|
||||
// Return true if force is set or there are no labels yet.
|
||||
if len(m.Labels) == 0 || force {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if any of the existing labels were generated using a vision model.
|
||||
for _, l := range m.Labels {
|
||||
if list.Contains(VisionSrcList, l.LabelSrc) {
|
||||
return false
|
||||
} else if l.LabelSrc == SrcCaption && list.Contains(VisionSrcList, m.CaptionSrc) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AddLabels updates the entity with additional or updated label information.
|
||||
func (m *Photo) AddLabels(labels classify.Labels) {
|
||||
for _, classifyLabel := range labels {
|
||||
|
||||
@@ -19,6 +19,11 @@ func (m *Photo) NoCaption() bool {
|
||||
return strings.TrimSpace(m.GetCaption()) == ""
|
||||
}
|
||||
|
||||
// ShouldGenerateCaption checks if a caption should be generated for this model.
|
||||
func (m *Photo) ShouldGenerateCaption(src Src, force bool) bool {
|
||||
return SrcPriority[src] >= SrcPriority[m.CaptionSrc] && (m.NoCaption() || force)
|
||||
}
|
||||
|
||||
// GetCaption returns the photo caption, if any.
|
||||
func (m *Photo) GetCaption() string {
|
||||
return m.PhotoCaption
|
||||
|
||||
@@ -241,6 +241,94 @@ func TestPhoto_SaveLabels(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPhoto_ShouldGenerateLabels(t *testing.T) {
|
||||
t.Run("NoLabels", func(t *testing.T) {
|
||||
p := Photo{}
|
||||
assert.True(t, p.ShouldGenerateLabels(false))
|
||||
})
|
||||
|
||||
t.Run("Force", func(t *testing.T) {
|
||||
p := Photo{Labels: []PhotoLabel{{LabelSrc: string(SrcManual)}}}
|
||||
assert.True(t, p.ShouldGenerateLabels(true))
|
||||
})
|
||||
|
||||
t.Run("ExistingVisionLabel", func(t *testing.T) {
|
||||
p := Photo{Labels: []PhotoLabel{{LabelSrc: string(SrcOllama)}}}
|
||||
assert.False(t, p.ShouldGenerateLabels(false))
|
||||
})
|
||||
|
||||
t.Run("CaptionGeneratedLabels", func(t *testing.T) {
|
||||
p := Photo{
|
||||
Labels: []PhotoLabel{{LabelSrc: string(SrcCaption)}},
|
||||
CaptionSrc: SrcOllama,
|
||||
}
|
||||
assert.False(t, p.ShouldGenerateLabels(false))
|
||||
})
|
||||
|
||||
t.Run("ManualLabels", func(t *testing.T) {
|
||||
p := Photo{Labels: []PhotoLabel{{LabelSrc: string(SrcManual)}}}
|
||||
assert.True(t, p.ShouldGenerateLabels(false))
|
||||
})
|
||||
|
||||
t.Run("CaptionManualWithoutVision", func(t *testing.T) {
|
||||
p := Photo{
|
||||
Labels: []PhotoLabel{{LabelSrc: string(SrcCaption)}},
|
||||
CaptionSrc: SrcManual,
|
||||
}
|
||||
assert.True(t, p.ShouldGenerateLabels(false))
|
||||
})
|
||||
}
|
||||
|
||||
func TestPhoto_ShouldGenerateCaption(t *testing.T) {
|
||||
ctx := []struct {
|
||||
name string
|
||||
photo Photo
|
||||
source Src
|
||||
force bool
|
||||
expect bool
|
||||
}{
|
||||
{
|
||||
name: "NoCaptionAutoSource",
|
||||
photo: Photo{CaptionSrc: SrcAuto},
|
||||
source: SrcOllama,
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "LowerPriority",
|
||||
photo: Photo{CaptionSrc: SrcOllama},
|
||||
source: SrcImage,
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "HigherPriority",
|
||||
photo: Photo{CaptionSrc: SrcImage},
|
||||
source: SrcOllama,
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "ForceOverrides",
|
||||
photo: Photo{CaptionSrc: SrcImage, PhotoCaption: "existing"},
|
||||
source: SrcImage,
|
||||
force: true,
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "SamePriorityNoForce",
|
||||
photo: Photo{CaptionSrc: SrcOllama, PhotoCaption: "existing"},
|
||||
source: SrcOllama,
|
||||
expect: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range ctx {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := tc.photo.ShouldGenerateCaption(tc.source, tc.force)
|
||||
assert.Equal(t, tc.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPhoto_ClassifyLabels(t *testing.T) {
|
||||
t.Run("NewPhoto", func(t *testing.T) {
|
||||
m := PhotoFixtures.Get("Photo19")
|
||||
|
||||
@@ -95,8 +95,8 @@ var VisionSrcNames = SrcMap{
|
||||
SrcVision: SrcVision,
|
||||
}
|
||||
|
||||
// VisionSrc contains all the sources commonly used by computer vision models and services.
|
||||
var VisionSrc = []Src{
|
||||
// VisionSrcList contains all the sources commonly used by computer vision models and services.
|
||||
var VisionSrcList = []Src{
|
||||
SrcMarker,
|
||||
SrcImage,
|
||||
SrcOllama,
|
||||
|
||||
@@ -1,17 +1,32 @@
|
||||
package photoprism
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
)
|
||||
|
||||
// Caption returns generated caption for the specified media file.
|
||||
func (ind *Index) Caption(file *MediaFile) (caption *vision.CaptionResult, err error) {
|
||||
// Caption generates a caption for the provided media file using the active
|
||||
// vision model. When captionSrc is SrcAuto the model's declared source is used;
|
||||
// otherwise the explicit source is recorded on the returned caption.
|
||||
func (ind *Index) Caption(file *MediaFile, captionSrc entity.Src) (caption *vision.CaptionResult, err error) {
|
||||
start := time.Now()
|
||||
|
||||
model := vision.Config.Model(vision.ModelTypeCaption)
|
||||
|
||||
// No caption generation model configured or usable.
|
||||
if model == nil {
|
||||
return caption, errors.New("no caption model configured")
|
||||
}
|
||||
|
||||
if captionSrc == entity.SrcAuto {
|
||||
captionSrc = model.GetSource()
|
||||
}
|
||||
|
||||
size := vision.Thumb(vision.ModelTypeCaption)
|
||||
|
||||
// Get thumbnail filenames for the selected sizes.
|
||||
@@ -22,9 +37,14 @@ func (ind *Index) Caption(file *MediaFile) (caption *vision.CaptionResult, err e
|
||||
}
|
||||
|
||||
// Get matching labels from computer vision model.
|
||||
// Generate a caption using the configured vision model.
|
||||
if caption, _, err = vision.Caption(vision.Files{fileName}, media.SrcLocal); err != nil {
|
||||
// Failed.
|
||||
} else if caption.Text != "" {
|
||||
if captionSrc != entity.SrcAuto {
|
||||
caption.Source = captionSrc
|
||||
}
|
||||
|
||||
log.Infof("vision: generated caption for %s [%s]", clean.Log(file.BaseName()), time.Since(start))
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,9 @@ import (
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
)
|
||||
|
||||
// Labels classifies a JPEG image and returns matching labels.
|
||||
// Labels classifies the media file and returns matching labels. When labelSrc
|
||||
// is SrcAuto the model's declared source is used; otherwise the provided source
|
||||
// is applied to every returned label.
|
||||
func (ind *Index) Labels(file *MediaFile, labelSrc entity.Src) (labels classify.Labels) {
|
||||
start := time.Now()
|
||||
|
||||
@@ -21,8 +23,24 @@ func (ind *Index) Labels(file *MediaFile, labelSrc entity.Src) (labels classify.
|
||||
var sizes []thumb.Name
|
||||
var thumbnails []string
|
||||
|
||||
model := vision.Config.Model(vision.ModelTypeLabels)
|
||||
|
||||
// No label generation model configured or usable.
|
||||
if model == nil {
|
||||
return labels
|
||||
}
|
||||
|
||||
if labelSrc == entity.SrcAuto {
|
||||
labelSrc = model.GetSource()
|
||||
}
|
||||
|
||||
size := vision.Thumb(vision.ModelTypeLabels)
|
||||
|
||||
// The thumbnail size may need to be adjusted to use other models.
|
||||
if file.Square() {
|
||||
if size.Name != "" && size.Name != thumb.Tile224 {
|
||||
sizes = []thumb.Name{size.Name}
|
||||
thumbnails = make([]string, 0, 1)
|
||||
} else if file.Square() {
|
||||
// Only one thumbnail is required for square images.
|
||||
sizes = []thumb.Name{thumb.Tile224}
|
||||
thumbnails = make([]string, 0, 1)
|
||||
@@ -33,8 +51,8 @@ func (ind *Index) Labels(file *MediaFile, labelSrc entity.Src) (labels classify.
|
||||
}
|
||||
|
||||
// Get thumbnail filenames for the selected sizes.
|
||||
for _, size := range sizes {
|
||||
if thumbnail, fileErr := file.Thumbnail(Config().ThumbCachePath(), size); fileErr != nil {
|
||||
for _, s := range sizes {
|
||||
if thumbnail, fileErr := file.Thumbnail(Config().ThumbCachePath(), s); fileErr != nil {
|
||||
log.Debugf("index: %s in %s", err, clean.Log(file.BaseName()))
|
||||
continue
|
||||
} else {
|
||||
@@ -42,7 +60,7 @@ func (ind *Index) Labels(file *MediaFile, labelSrc entity.Src) (labels classify.
|
||||
}
|
||||
}
|
||||
|
||||
// Get matching labels from computer vision model.
|
||||
// Run the configured vision model to obtain labels for the generated thumbnails.
|
||||
if labels, err = vision.Labels(thumbnails, media.SrcLocal, labelSrc); err != nil {
|
||||
log.Debugf("labels: %s in %s", err, clean.Log(file.BaseName()))
|
||||
return labels
|
||||
|
||||
@@ -815,7 +815,7 @@ func (ind *Index) UserMediaFile(m *MediaFile, o IndexOptions, originalName, phot
|
||||
|
||||
// Classify images with TensorFlow?
|
||||
if ind.findLabels {
|
||||
labels = ind.Labels(m, entity.SrcImage)
|
||||
labels = ind.Labels(m, entity.SrcAuto)
|
||||
|
||||
// Append labels from other sources such as face detection.
|
||||
if len(extraLabels) > 0 {
|
||||
|
||||
111
internal/photoprism/index_vision_test.go
Normal file
111
internal/photoprism/index_vision_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package photoprism
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/classify"
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
)
|
||||
|
||||
func TestIndexCaptionSource(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping vision-dependent test in short mode")
|
||||
}
|
||||
|
||||
cfg := config.TestConfig()
|
||||
require.NoError(t, cfg.InitializeTestData())
|
||||
|
||||
ind := NewIndex(cfg, NewConvert(cfg), NewFiles(), NewPhotos())
|
||||
mediaFile, err := NewMediaFile("testdata/flash.jpg")
|
||||
require.NoError(t, err)
|
||||
|
||||
originalConfig := vision.Config
|
||||
t.Cleanup(func() {
|
||||
vision.Config = originalConfig
|
||||
vision.SetCaptionFunc(nil)
|
||||
})
|
||||
|
||||
captionModel := &vision.Model{Type: vision.ModelTypeCaption, Engine: vision.ApiFormatOpenAI}
|
||||
captionModel.ApplyEngineDefaults()
|
||||
vision.Config = &vision.ConfigValues{Models: vision.Models{captionModel}}
|
||||
|
||||
t.Run("AutoUsesModelSource", func(t *testing.T) {
|
||||
vision.SetCaptionFunc(func(files vision.Files, mediaSrc media.Src) (*vision.CaptionResult, *vision.Model, error) {
|
||||
return &vision.CaptionResult{Text: "stub", Source: captionModel.GetSource()}, captionModel, nil
|
||||
})
|
||||
t.Cleanup(func() { vision.SetCaptionFunc(nil) })
|
||||
|
||||
caption, err := ind.Caption(mediaFile, entity.SrcAuto)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, caption)
|
||||
assert.Equal(t, captionModel.GetSource(), caption.Source)
|
||||
})
|
||||
|
||||
t.Run("CustomSource", func(t *testing.T) {
|
||||
originalSource := captionModel.GetSource()
|
||||
vision.SetCaptionFunc(func(files vision.Files, mediaSrc media.Src) (*vision.CaptionResult, *vision.Model, error) {
|
||||
return &vision.CaptionResult{Text: "stub", Source: originalSource}, captionModel, nil
|
||||
})
|
||||
t.Cleanup(func() { vision.SetCaptionFunc(nil) })
|
||||
|
||||
caption, err := ind.Caption(mediaFile, entity.SrcManual)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, caption)
|
||||
assert.Equal(t, entity.SrcManual, caption.Source)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIndexLabelsSource(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping vision-dependent test in short mode")
|
||||
}
|
||||
|
||||
cfg := config.TestConfig()
|
||||
require.NoError(t, cfg.InitializeTestData())
|
||||
|
||||
ind := NewIndex(cfg, NewConvert(cfg), NewFiles(), NewPhotos())
|
||||
mediaFile, err := NewMediaFile("testdata/flash.jpg")
|
||||
require.NoError(t, err)
|
||||
|
||||
originalConfig := vision.Config
|
||||
t.Cleanup(func() {
|
||||
vision.Config = originalConfig
|
||||
vision.SetLabelsFunc(nil)
|
||||
})
|
||||
|
||||
labelModel := &vision.Model{Type: vision.ModelTypeLabels, Engine: vision.ApiFormatOllama}
|
||||
labelModel.ApplyEngineDefaults()
|
||||
vision.Config = &vision.ConfigValues{Models: vision.Models{labelModel}}
|
||||
|
||||
t.Run("AutoUsesModelSource", func(t *testing.T) {
|
||||
var captured string
|
||||
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src string) (classify.Labels, error) {
|
||||
captured = src
|
||||
return classify.Labels{{Name: "stub", Source: src, Uncertainty: 0}}, nil
|
||||
})
|
||||
t.Cleanup(func() { vision.SetLabelsFunc(nil) })
|
||||
|
||||
labels := ind.Labels(mediaFile, entity.SrcAuto)
|
||||
assert.NotEmpty(t, labels)
|
||||
assert.Equal(t, labelModel.GetSource(), captured)
|
||||
})
|
||||
|
||||
t.Run("CustomSource", func(t *testing.T) {
|
||||
var captured string
|
||||
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src string) (classify.Labels, error) {
|
||||
captured = src
|
||||
return classify.Labels{{Name: "stub", Source: src, Uncertainty: 0}}, nil
|
||||
})
|
||||
t.Cleanup(func() { vision.SetLabelsFunc(nil) })
|
||||
|
||||
labels := ind.Labels(mediaFile, entity.SrcManual)
|
||||
assert.NotEmpty(t, labels)
|
||||
assert.Equal(t, entity.SrcManual, captured)
|
||||
})
|
||||
}
|
||||
@@ -57,8 +57,8 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
|
||||
log.Debugf("index: running face recognition")
|
||||
if faces := photoprism.NewFaces(w.conf); faces.Disabled() {
|
||||
log.Debugf("index: skipping face recognition")
|
||||
} else if err := faces.Start(photoprism.FacesOptions{}); err != nil {
|
||||
log.Warn(err)
|
||||
} else if facesErr := faces.Start(photoprism.FacesOptions{}); facesErr != nil {
|
||||
log.Warn(facesErr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,8 +74,8 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
|
||||
|
||||
ind := get.Index()
|
||||
|
||||
generateLabels := w.conf.VisionModelShouldRun(vision.ModelTypeLabels, vision.RunNewlyIndexed)
|
||||
generateCaptions := w.conf.VisionModelShouldRun(vision.ModelTypeCaption, vision.RunNewlyIndexed)
|
||||
labelsModelShouldRun := w.conf.VisionModelShouldRun(vision.ModelTypeLabels, vision.RunNewlyIndexed)
|
||||
captionModelShouldRun := w.conf.VisionModelShouldRun(vision.ModelTypeCaption, vision.RunNewlyIndexed)
|
||||
|
||||
for {
|
||||
photos, queryErr := query.PhotosMetadataUpdate(limit, offset, delay, interval)
|
||||
@@ -99,8 +99,11 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
|
||||
|
||||
done[photo.PhotoUID] = true
|
||||
|
||||
generateLabels := labelsModelShouldRun && photo.ShouldGenerateLabels(false)
|
||||
generateCaption := captionModelShouldRun && photo.ShouldGenerateCaption(entity.SrcAuto, false)
|
||||
|
||||
// If configured, generate metadata for newly indexed photos using external vision services.
|
||||
if photo.IsNewlyIndexed() && (generateLabels || generateCaptions) {
|
||||
if photo.IsNewlyIndexed() && (generateLabels || generateCaption) {
|
||||
primaryFile, fileErr := photo.PrimaryFile()
|
||||
|
||||
if fileErr != nil {
|
||||
@@ -120,16 +123,13 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
if generateCaptions && photo.PhotoCaption == "" && photo.CaptionSrc == entity.SrcAuto {
|
||||
if caption, captionErr := ind.Caption(mediaFile); captionErr != nil {
|
||||
if generateCaption {
|
||||
if caption, captionErr := ind.Caption(mediaFile, entity.SrcAuto); captionErr != nil {
|
||||
log.Debugf("index: %s (generate caption for %s)", clean.Error(captionErr), photo.PhotoUID)
|
||||
} else if caption != nil {
|
||||
text := strings.TrimSpace(caption.Text)
|
||||
if text != "" && caption.Source != "" {
|
||||
photo.SetCaption(text, clean.ShortTypeLower(caption.Source))
|
||||
if updateErr := photo.UpdateCaptionLabels(); updateErr != nil {
|
||||
log.Warnf("index: %s (update caption labels for %s)", clean.Error(updateErr), photo.PhotoUID)
|
||||
}
|
||||
} else if text := strings.TrimSpace(caption.Text); text != "" {
|
||||
photo.SetCaption(text, caption.Source)
|
||||
if updateErr := photo.UpdateCaptionLabels(); updateErr != nil {
|
||||
log.Warnf("index: %s (update caption labels for %s)", clean.Error(updateErr), photo.PhotoUID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,45 +25,67 @@ import (
|
||||
"github.com/photoprism/photoprism/pkg/txt"
|
||||
)
|
||||
|
||||
// Vision represents a computer vision worker.
|
||||
// Vision orchestrates background computer-vision tasks (labels, captions,
|
||||
// NSFW detection). It wraps configuration lookups and scheduling helpers.
|
||||
type Vision struct {
|
||||
conf *config.Config
|
||||
}
|
||||
|
||||
// NewVision returns a new Vision worker.
|
||||
// NewVision constructs a Vision worker bound to the provided configuration.
|
||||
func NewVision(conf *config.Config) *Vision {
|
||||
return &Vision{conf: conf}
|
||||
}
|
||||
|
||||
func captionSourceFromModel(model *vision.Model) string {
|
||||
if model == nil {
|
||||
return entity.SrcImage
|
||||
// StartScheduled executes the worker in scheduled mode, selecting models that
|
||||
// are allowed to run in the RunOnSchedule context.
|
||||
func (w *Vision) StartScheduled() {
|
||||
models := w.scheduledModels()
|
||||
|
||||
if len(models) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
switch model.EndpointRequestFormat() {
|
||||
case vision.ApiFormatOllama:
|
||||
return entity.SrcOllama
|
||||
case vision.ApiFormatOpenAI:
|
||||
return entity.SrcOpenAI
|
||||
if err := w.Start(
|
||||
w.conf.VisionFilter(),
|
||||
0,
|
||||
models,
|
||||
entity.SrcAuto,
|
||||
false,
|
||||
vision.RunOnSchedule,
|
||||
); err != nil {
|
||||
log.Errorf("scheduler: %s (vision)", err)
|
||||
}
|
||||
|
||||
switch model.ProviderName() {
|
||||
case "ollama":
|
||||
return entity.SrcOllama
|
||||
case "openai":
|
||||
return entity.SrcOpenAI
|
||||
}
|
||||
|
||||
return entity.SrcImage
|
||||
}
|
||||
|
||||
// originalsPath returns the original media files path as string.
|
||||
// scheduledModels returns the model types that should run for scheduled jobs.
|
||||
func (w *Vision) scheduledModels() []string {
|
||||
models := make([]string, 0, 3)
|
||||
|
||||
if w.conf.VisionModelShouldRun(vision.ModelTypeLabels, vision.RunOnSchedule) {
|
||||
models = append(models, vision.ModelTypeLabels)
|
||||
}
|
||||
|
||||
if w.conf.VisionModelShouldRun(vision.ModelTypeNsfw, vision.RunOnSchedule) {
|
||||
models = append(models, vision.ModelTypeNsfw)
|
||||
}
|
||||
|
||||
if w.conf.VisionModelShouldRun(vision.ModelTypeCaption, vision.RunOnSchedule) {
|
||||
models = append(models, vision.ModelTypeCaption)
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
// originalsPath returns the path that holds original media files.
|
||||
func (w *Vision) originalsPath() string {
|
||||
return w.conf.OriginalsPath()
|
||||
}
|
||||
|
||||
// Start runs the specified model types for photos matching the search query filter string.
|
||||
func (w *Vision) Start(filter string, count int, models []string, customSrc string, force bool) (err error) {
|
||||
// Start runs the requested vision models against photos matching the search
|
||||
// filter. `customSrc` allows the caller to override the metadata source string,
|
||||
// `force` regenerates metadata regardless of existing values, and `runType`
|
||||
// describes the scheduling context (manual, scheduled, etc.).
|
||||
func (w *Vision) Start(filter string, count int, models []string, customSrc string, force bool, runType vision.RunType) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("vision: %s (worker panic)\nstack: %s", r, debug.Stack())
|
||||
@@ -77,6 +99,10 @@ func (w *Vision) Start(filter string, count int, models []string, customSrc stri
|
||||
|
||||
defer mutex.VisionWorker.Stop()
|
||||
|
||||
models = vision.FilterModels(models, runType, func(mt vision.ModelType, when vision.RunType) bool {
|
||||
return w.conf.VisionModelShouldRun(mt, when)
|
||||
})
|
||||
|
||||
updateLabels := slices.Contains(models, vision.ModelTypeLabels)
|
||||
updateNsfw := slices.Contains(models, vision.ModelTypeNsfw)
|
||||
updateCaptions := slices.Contains(models, vision.ModelTypeCaption)
|
||||
@@ -90,21 +116,6 @@ func (w *Vision) Start(filter string, count int, models []string, customSrc stri
|
||||
}
|
||||
|
||||
customSrc = clean.ShortTypeLower(customSrc)
|
||||
useAutoSource := customSrc == entity.SrcAuto
|
||||
|
||||
labelSource := customSrc
|
||||
if useAutoSource {
|
||||
labelSource = entity.SrcAuto
|
||||
}
|
||||
|
||||
if labelSource == entity.SrcImage {
|
||||
labelSource = entity.SrcAuto
|
||||
}
|
||||
|
||||
captionSource := customSrc
|
||||
if useAutoSource {
|
||||
captionSource = captionSourceFromModel(vision.Config.Model(vision.ModelTypeCaption))
|
||||
}
|
||||
|
||||
// Check time when worker was last executed.
|
||||
updateIndex := false
|
||||
@@ -161,13 +172,6 @@ func (w *Vision) Start(filter string, count int, models []string, customSrc stri
|
||||
done[photo.PhotoUID] = true
|
||||
|
||||
photoName := path.Join(photo.PhotoPath, photo.PhotoName)
|
||||
fileName := photoprism.FileName(photo.FileRoot, photo.FileName)
|
||||
file, fileErr := photoprism.NewMediaFile(fileName)
|
||||
|
||||
if fileErr != nil {
|
||||
log.Errorf("vision: failed to open %s (%s)", photoName, fileErr)
|
||||
continue
|
||||
}
|
||||
|
||||
m, loadErr := query.PhotoByUID(photo.PhotoUID)
|
||||
|
||||
@@ -176,33 +180,48 @@ func (w *Vision) Start(filter string, count int, models []string, customSrc stri
|
||||
continue
|
||||
}
|
||||
|
||||
generateLabels := updateLabels && m.ShouldGenerateLabels(force)
|
||||
generateCaptions := updateCaptions && m.ShouldGenerateCaption(customSrc, force)
|
||||
generateNsfw := updateNsfw && (!photo.PhotoPrivate || force)
|
||||
|
||||
if !(generateLabels || generateCaptions || generateNsfw) {
|
||||
continue
|
||||
}
|
||||
|
||||
fileName := photoprism.FileName(photo.FileRoot, photo.FileName)
|
||||
file, fileErr := photoprism.NewMediaFile(fileName)
|
||||
|
||||
if fileErr != nil {
|
||||
log.Errorf("vision: failed to open %s (%s)", photoName, fileErr)
|
||||
continue
|
||||
}
|
||||
|
||||
changed := false
|
||||
|
||||
// Generate labels.
|
||||
if updateLabels && (len(m.Labels) == 0 || force) {
|
||||
labelSrc := labelSource
|
||||
if labels := ind.Labels(file, labelSrc); len(labels) > 0 {
|
||||
if generateLabels {
|
||||
if labels := ind.Labels(file, customSrc); len(labels) > 0 {
|
||||
m.AddLabels(labels)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
// Detect NSFW content.
|
||||
if updateNsfw && (!photo.PhotoPrivate || force) {
|
||||
if isNsfw := ind.IsNsfw(file); photo.PhotoPrivate != isNsfw {
|
||||
photo.PhotoPrivate = isNsfw
|
||||
if generateNsfw {
|
||||
if isNsfw := ind.IsNsfw(file); m.PhotoPrivate != isNsfw {
|
||||
m.PhotoPrivate = isNsfw
|
||||
changed = true
|
||||
log.Infof("vision: changed private flag of %s to %t", photoName, photo.PhotoPrivate)
|
||||
log.Infof("vision: changed private flag of %s to %t", photoName, m.PhotoPrivate)
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a caption if none exists or the force flag is used,
|
||||
// and only if no caption was set or removed by a higher-priority source.
|
||||
if updateCaptions && entity.SrcPriority[captionSource] >= entity.SrcPriority[m.CaptionSrc] && (m.NoCaption() || force) {
|
||||
if caption, captionErr := ind.Caption(file); captionErr != nil {
|
||||
if generateCaptions {
|
||||
if caption, captionErr := ind.Caption(file, customSrc); captionErr != nil {
|
||||
log.Warnf("vision: %s in %s (generate caption)", clean.Error(captionErr), photoName)
|
||||
} else if caption.Text = strings.TrimSpace(caption.Text); caption.Text != "" {
|
||||
m.SetCaption(caption.Text, captionSource)
|
||||
} else if text := strings.TrimSpace(caption.Text); text != "" {
|
||||
m.SetCaption(text, caption.Source)
|
||||
if updateErr := m.UpdateCaptionLabels(); updateErr != nil {
|
||||
log.Warnf("vision: %s in %s (update caption labels)", clean.Error(updateErr), photoName)
|
||||
}
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
package workers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
)
|
||||
|
||||
func TestCaptionSourceFromModel(t *testing.T) {
|
||||
if got := captionSourceFromModel(nil); got != entity.SrcImage {
|
||||
t.Fatalf("expected SrcImage for nil model, got %s", got)
|
||||
}
|
||||
|
||||
openAIModel := &vision.Model{
|
||||
Service: vision.Service{RequestFormat: vision.ApiFormatOpenAI},
|
||||
}
|
||||
|
||||
if got := captionSourceFromModel(openAIModel); got != entity.SrcOpenAI {
|
||||
t.Fatalf("expected SrcOpenAI for openai model, got %s", got)
|
||||
}
|
||||
|
||||
providerModel := &vision.Model{Provider: "ollama"}
|
||||
if got := captionSourceFromModel(providerModel); got != entity.SrcOllama {
|
||||
t.Fatalf("expected SrcOllama from provider, got %s", got)
|
||||
}
|
||||
|
||||
fallbackModel := &vision.Model{}
|
||||
if got := captionSourceFromModel(fallbackModel); got != entity.SrcImage {
|
||||
t.Fatalf("expected SrcImage fallback, got %s", got)
|
||||
}
|
||||
}
|
||||
@@ -38,7 +38,9 @@ import (
|
||||
var log = event.Log
|
||||
var stop = make(chan bool, 1)
|
||||
|
||||
// Start starts the execution of background workers and scheduled tasks based on the current configuration.
|
||||
// Start launches background workers and scheduled tasks based on the current
|
||||
// configuration. It sets up the cron scheduler and the periodic metadata/share
|
||||
// workers.
|
||||
func Start(conf *config.Config) {
|
||||
if scheduler, err := gocron.NewScheduler(gocron.WithLocation(conf.DefaultTimezone())); err != nil {
|
||||
log.Errorf("scheduler: %s (start)", err)
|
||||
@@ -56,6 +58,11 @@ func Start(conf *config.Config) {
|
||||
log.Errorf("scheduler: %s (index)", err)
|
||||
}
|
||||
|
||||
// Schedule vision job.
|
||||
if err = NewJob("vision", conf.VisionSchedule(), NewVision(conf).StartScheduled); err != nil {
|
||||
log.Errorf("scheduler: %s (vision)", err)
|
||||
}
|
||||
|
||||
// Start the scheduler.
|
||||
Scheduler.Start()
|
||||
}
|
||||
@@ -89,7 +96,7 @@ func Start(conf *config.Config) {
|
||||
}()
|
||||
}
|
||||
|
||||
// Shutdown stops the background workers and scheduled tasks.
|
||||
// Shutdown stops the background workers and shuts down the scheduler.
|
||||
func Shutdown() {
|
||||
log.Info("shutting down workers")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user