mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
AI: Auto-add model defaults when loading "vision.yml" #5234
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
@@ -91,40 +91,13 @@ func (c *ConfigValues) Load(fileName string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. Ensure that there is at least one configuration for each model type,
|
// Replace default placeholders with canonical defaults while respecting
|
||||||
// so that adding a copy of the default configuration to the vision.yml file
|
// explicit Run / Disabled overrides.
|
||||||
// is not required. We could alternatively require a model to included in
|
c.applyDefaultModels()
|
||||||
// the "vision.yml" file, but set the defaults if the "Default" flag is set
|
|
||||||
// while preserving explicit Run / Disabled overrides.
|
|
||||||
// 2. Use the default "Thresholds" if no custom thresholds are configured.
|
|
||||||
|
|
||||||
for i, model := range c.Models {
|
// Add missing default models so users are not required to list them in
|
||||||
if !model.Default {
|
// vision.yml. Custom models continue to override defaults when present.
|
||||||
continue
|
c.ensureDefaultModels()
|
||||||
}
|
|
||||||
|
|
||||||
runType := model.Run
|
|
||||||
disabled := model.Disabled
|
|
||||||
|
|
||||||
switch model.Type {
|
|
||||||
case ModelTypeLabels:
|
|
||||||
c.Models[i] = NasnetModel.Clone()
|
|
||||||
case ModelTypeNsfw:
|
|
||||||
c.Models[i] = NsfwModel.Clone()
|
|
||||||
case ModelTypeFace:
|
|
||||||
c.Models[i] = FacenetModel.Clone()
|
|
||||||
case ModelTypeCaption:
|
|
||||||
c.Models[i] = CaptionModel.Clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
if runType != RunAuto {
|
|
||||||
c.Models[i].Run = runType
|
|
||||||
}
|
|
||||||
|
|
||||||
if disabled {
|
|
||||||
c.Models[i].Disabled = disabled
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, model := range c.Models {
|
for _, model := range c.Models {
|
||||||
model.ApplyEngineDefaults()
|
model.ApplyEngineDefaults()
|
||||||
@@ -145,6 +118,74 @@ func (c *ConfigValues) Load(fileName string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyDefaultModels swaps entries marked as Default with the built-in
|
||||||
|
// models while keeping user-specified Run / Disabled overrides intact.
|
||||||
|
func (c *ConfigValues) applyDefaultModels() {
|
||||||
|
for i, model := range c.Models {
|
||||||
|
if !model.Default {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
runType := model.Run
|
||||||
|
disabled := model.Disabled
|
||||||
|
|
||||||
|
switch model.Type {
|
||||||
|
case ModelTypeLabels:
|
||||||
|
c.Models[i] = NasnetModel.Clone()
|
||||||
|
case ModelTypeNsfw:
|
||||||
|
c.Models[i] = NsfwModel.Clone()
|
||||||
|
case ModelTypeFace:
|
||||||
|
c.Models[i] = FacenetModel.Clone()
|
||||||
|
case ModelTypeCaption:
|
||||||
|
c.Models[i] = CaptionModel.Clone()
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if runType != RunAuto {
|
||||||
|
c.Models[i].Run = runType
|
||||||
|
}
|
||||||
|
|
||||||
|
if disabled {
|
||||||
|
c.Models[i].Disabled = disabled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureDefaultModels appends built-in default models for any types
|
||||||
|
// that are completely missing from the configuration. Custom models (enabled
|
||||||
|
// or disabled) block the addition for their respective types so user intent is
|
||||||
|
// preserved.
|
||||||
|
func (c *ConfigValues) ensureDefaultModels() {
|
||||||
|
for _, defaultModel := range DefaultModels {
|
||||||
|
if defaultModel == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.hasModelType(defaultModel.Type) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Models = append(c.Models, defaultModel.Clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasModelType reports whether any configured model (enabled or disabled)
|
||||||
|
// matches the provided type.
|
||||||
|
func (c *ConfigValues) hasModelType(t ModelType) bool {
|
||||||
|
for _, model := range c.Models {
|
||||||
|
if model == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if model.Type == t {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Save user settings to a file.
|
// Save user settings to a file.
|
||||||
func (c *ConfigValues) Save(fileName string) error {
|
func (c *ConfigValues) Save(fileName string) error {
|
||||||
if fileName == "" {
|
if fileName == "" {
|
||||||
|
|||||||
@@ -29,44 +29,175 @@ func TestOptions(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfigValues_LoadDefaultModelWithCustomRun(t *testing.T) {
|
func TestConfigValues_Load(t *testing.T) {
|
||||||
originalRun := NasnetModel.Run
|
t.Run("DefaultModelWithCustomRun", func(t *testing.T) {
|
||||||
t.Cleanup(func() {
|
originalRun := NasnetModel.Run
|
||||||
NasnetModel.Run = originalRun
|
t.Cleanup(func() {
|
||||||
|
NasnetModel.Run = originalRun
|
||||||
|
})
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
configFile := filepath.Join(tempDir, "vision.yml")
|
||||||
|
|
||||||
|
err := os.WriteFile(configFile, []byte("Models:\n- Type: labels\n Default: true\n Run: on-demand\n"), fs.ModeConfigFile)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
cfg := NewConfig()
|
||||||
|
err = cfg.Load(configFile)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, RunOnDemand, cfg.RunType(ModelTypeLabels))
|
||||||
|
assert.True(t, cfg.ShouldRun(ModelTypeLabels, RunOnSchedule))
|
||||||
|
assert.False(t, cfg.ShouldRun(ModelTypeLabels, RunOnIndex))
|
||||||
})
|
})
|
||||||
|
t.Run("AddsMissingDefaults", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
configFile := filepath.Join(tempDir, "vision.yml")
|
||||||
|
|
||||||
tempDir := t.TempDir()
|
configYml := "Models:\n- Type: caption\n Name: custom-caption\n"
|
||||||
configFile := filepath.Join(tempDir, "vision.yml")
|
|
||||||
|
|
||||||
err := os.WriteFile(configFile, []byte("Models:\n- Type: labels\n Default: true\n Run: on-demand\n"), fs.ModeConfigFile)
|
err := os.WriteFile(configFile, []byte(configYml), fs.ModeConfigFile)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
cfg := NewConfig()
|
cfg := NewConfig()
|
||||||
err = cfg.Load(configFile)
|
err = cfg.Load(configFile)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, RunOnDemand, cfg.RunType(ModelTypeLabels))
|
assert.Len(t, cfg.Models, len(DefaultModels))
|
||||||
assert.True(t, cfg.ShouldRun(ModelTypeLabels, RunOnSchedule))
|
|
||||||
assert.False(t, cfg.ShouldRun(ModelTypeLabels, RunOnIndex))
|
if labels := cfg.Model(ModelTypeLabels); assert.NotNil(t, labels) {
|
||||||
|
assert.Equal(t, NasnetModel.Name, labels.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if caption := cfg.Model(ModelTypeCaption); assert.NotNil(t, caption) {
|
||||||
|
assert.Equal(t, "custom-caption", caption.Name)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("AddsDefaultsWhenModelsMissing", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
configFile := filepath.Join(tempDir, "vision.yml")
|
||||||
|
|
||||||
|
// Empty config should be populated with all default models.
|
||||||
|
err := os.WriteFile(configFile, []byte(""), fs.ModeConfigFile)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
cfg := NewConfig()
|
||||||
|
err = cfg.Load(configFile)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, cfg.Models, len(DefaultModels))
|
||||||
|
assert.True(t, cfg.IsDefault(ModelTypeLabels))
|
||||||
|
assert.True(t, cfg.IsDefault(ModelTypeNsfw))
|
||||||
|
assert.True(t, cfg.IsDefault(ModelTypeFace))
|
||||||
|
})
|
||||||
|
t.Run("DefaultModelDisabled", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
configFile := filepath.Join(tempDir, "vision.yml")
|
||||||
|
|
||||||
|
err := os.WriteFile(configFile, []byte("Models:\n- Type: labels\n Default: true\n Disabled: true\n"), fs.ModeConfigFile)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
cfg := NewConfig()
|
||||||
|
err = cfg.Load(configFile)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
if m := cfg.Model(ModelTypeLabels); m != nil {
|
||||||
|
t.Fatalf("expected disabled default model to be ignored, got %v", m)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, RunNever, cfg.RunType(ModelTypeLabels))
|
||||||
|
assert.False(t, cfg.ShouldRun(ModelTypeLabels, RunManual))
|
||||||
|
})
|
||||||
|
t.Run("MissingThresholdsUsesDefaults", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
configFile := filepath.Join(tempDir, "vision.yml")
|
||||||
|
|
||||||
|
err := os.WriteFile(configFile, []byte("Models:\n- Type: labels\n"), fs.ModeConfigFile)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
cfg := NewConfig()
|
||||||
|
err = cfg.Load(configFile)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, DefaultThresholds, cfg.Thresholds)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfigValues_LoadDefaultModelDisabled(t *testing.T) {
|
func TestConfigValues_applyDefaultModels(t *testing.T) {
|
||||||
tempDir := t.TempDir()
|
t.Run("ReplacesPlaceholderAndKeepsOverrides", func(t *testing.T) {
|
||||||
configFile := filepath.Join(tempDir, "vision.yml")
|
cfg := &ConfigValues{
|
||||||
|
Models: Models{
|
||||||
|
{
|
||||||
|
Type: ModelTypeLabels,
|
||||||
|
Default: true,
|
||||||
|
Run: RunOnDemand,
|
||||||
|
Disabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
err := os.WriteFile(configFile, []byte("Models:\n- Type: labels\n Default: true\n Disabled: true\n"), fs.ModeConfigFile)
|
cfg.applyDefaultModels()
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
cfg := NewConfig()
|
if got := cfg.Models[0]; got.Name != NasnetModel.Name {
|
||||||
err = cfg.Load(configFile)
|
t.Fatalf("expected placeholder to become nasnet, got %s", got.Name)
|
||||||
assert.NoError(t, err)
|
} else if got.Run != RunOnDemand {
|
||||||
|
t.Fatalf("expected Run to be preserved, got %s", got.Run)
|
||||||
|
} else if !got.Disabled {
|
||||||
|
t.Fatalf("expected Disabled to be preserved")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("IgnoresNonDefaultEntries", func(t *testing.T) {
|
||||||
|
original := &Model{Type: ModelTypeLabels, Name: "custom", Default: false}
|
||||||
|
cfg := &ConfigValues{Models: Models{original}}
|
||||||
|
|
||||||
if m := cfg.Model(ModelTypeLabels); m != nil {
|
cfg.applyDefaultModels()
|
||||||
t.Fatalf("expected disabled default model to be ignored, got %v", m)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, RunNever, cfg.RunType(ModelTypeLabels))
|
if cfg.Models[0] != original {
|
||||||
assert.False(t, cfg.ShouldRun(ModelTypeLabels, RunManual))
|
t.Fatalf("expected non-default model to remain unchanged")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigValues_ensureDefaultModels(t *testing.T) {
|
||||||
|
t.Run("AppendsMissingDefaults", func(t *testing.T) {
|
||||||
|
cfg := &ConfigValues{Models: Models{}}
|
||||||
|
|
||||||
|
cfg.ensureDefaultModels()
|
||||||
|
|
||||||
|
if len(cfg.Models) != len(DefaultModels) {
|
||||||
|
t.Fatalf("expected %d models, got %d", len(DefaultModels), len(cfg.Models))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("SkipsTypesAlreadyPresent", func(t *testing.T) {
|
||||||
|
custom := &Model{Type: ModelTypeLabels, Name: "custom"}
|
||||||
|
cfg := &ConfigValues{Models: Models{custom}}
|
||||||
|
|
||||||
|
cfg.ensureDefaultModels()
|
||||||
|
|
||||||
|
if len(cfg.Models) != len(DefaultModels) {
|
||||||
|
t.Fatalf("expected defaults minus duplicate type, got %d", len(cfg.Models))
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Models[0] != custom && cfg.Models[len(cfg.Models)-1] != custom {
|
||||||
|
t.Fatalf("expected existing custom model to remain")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("TreatsDisabledCustomAsPresent", func(t *testing.T) {
|
||||||
|
custom := &Model{Type: ModelTypeNsfw, Name: "custom", Disabled: true}
|
||||||
|
cfg := &ConfigValues{Models: Models{custom}}
|
||||||
|
|
||||||
|
cfg.ensureDefaultModels()
|
||||||
|
|
||||||
|
countType := 0
|
||||||
|
for _, m := range cfg.Models {
|
||||||
|
if m.Type == ModelTypeNsfw {
|
||||||
|
countType++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if countType != 1 {
|
||||||
|
t.Fatalf("expected no additional nsfw default when custom exists, got %d entries", countType)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfigModelPrefersLastEnabled(t *testing.T) {
|
func TestConfigModelPrefersLastEnabled(t *testing.T) {
|
||||||
@@ -151,7 +282,6 @@ func TestConfigValues_ShouldRun(t *testing.T) {
|
|||||||
t.Fatalf("expected false when no model configured")
|
t.Fatalf("expected false when no model configured")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("DefaultAutoModel", func(t *testing.T) {
|
t.Run("DefaultAutoModel", func(t *testing.T) {
|
||||||
cfg := &ConfigValues{Models: Models{NasnetModel.Clone()}}
|
cfg := &ConfigValues{Models: Models{NasnetModel.Clone()}}
|
||||||
assertConfigShouldRun(t, cfg, RunManual, true)
|
assertConfigShouldRun(t, cfg, RunManual, true)
|
||||||
@@ -161,7 +291,6 @@ func TestConfigValues_ShouldRun(t *testing.T) {
|
|||||||
assertConfigShouldRun(t, cfg, RunNewlyIndexed, false)
|
assertConfigShouldRun(t, cfg, RunNewlyIndexed, false)
|
||||||
assertConfigShouldRun(t, cfg, RunNever, false)
|
assertConfigShouldRun(t, cfg, RunNever, false)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("CustomOverridesDefault", func(t *testing.T) {
|
t.Run("CustomOverridesDefault", func(t *testing.T) {
|
||||||
defaultModel := NasnetModel.Clone()
|
defaultModel := NasnetModel.Clone()
|
||||||
custom := &Model{Type: ModelTypeLabels, Name: "custom"}
|
custom := &Model{Type: ModelTypeLabels, Name: "custom"}
|
||||||
@@ -171,7 +300,6 @@ func TestConfigValues_ShouldRun(t *testing.T) {
|
|||||||
assertConfigShouldRun(t, cfg, RunOnIndex, false)
|
assertConfigShouldRun(t, cfg, RunOnIndex, false)
|
||||||
assertConfigShouldRun(t, cfg, RunNewlyIndexed, true)
|
assertConfigShouldRun(t, cfg, RunNewlyIndexed, true)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("DisabledCustomFallsBack", func(t *testing.T) {
|
t.Run("DisabledCustomFallsBack", func(t *testing.T) {
|
||||||
defaultModel := NasnetModel.Clone()
|
defaultModel := NasnetModel.Clone()
|
||||||
custom := &Model{Type: ModelTypeLabels, Name: "custom", Disabled: true}
|
custom := &Model{Type: ModelTypeLabels, Name: "custom", Disabled: true}
|
||||||
@@ -181,7 +309,6 @@ func TestConfigValues_ShouldRun(t *testing.T) {
|
|||||||
assertConfigShouldRun(t, cfg, RunOnIndex, true)
|
assertConfigShouldRun(t, cfg, RunOnIndex, true)
|
||||||
assertConfigShouldRun(t, cfg, RunNewlyIndexed, false)
|
assertConfigShouldRun(t, cfg, RunNewlyIndexed, false)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("ManualOnly", func(t *testing.T) {
|
t.Run("ManualOnly", func(t *testing.T) {
|
||||||
model := &Model{Type: ModelTypeLabels, Run: RunManual}
|
model := &Model{Type: ModelTypeLabels, Run: RunManual}
|
||||||
cfg := &ConfigValues{Models: Models{model}}
|
cfg := &ConfigValues{Models: Models{model}}
|
||||||
|
|||||||
Reference in New Issue
Block a user