Files
photoprism/internal/photoprism/mediafile_vision_test.go
2025-10-30 12:40:03 +01:00

141 lines
4.2 KiB
Go

package photoprism
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/photoprism/photoprism/internal/ai/classify"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/pkg/media"
)
func setupVisionMediaFile(t *testing.T) *MediaFile {
t.Helper()
cfg := config.TestConfig()
require.NoError(t, cfg.InitializeTestData())
mediaFile, err := NewMediaFile("testdata/flash.jpg")
require.NoError(t, err)
return mediaFile
}
func TestMediaFile_GenerateCaption(t *testing.T) {
mediaFile := setupVisionMediaFile(t)
originalConfig := vision.Config
t.Cleanup(func() {
vision.Config = originalConfig
vision.SetCaptionFunc(nil)
})
captionModel := &vision.Model{Type: vision.ModelTypeCaption, Engine: vision.ApiFormatOpenAI}
captionModel.ApplyEngineDefaults()
vision.Config = &vision.ConfigValues{Models: vision.Models{captionModel}}
t.Run("AutoUsesModelSource", func(t *testing.T) {
vision.SetCaptionFunc(func(files vision.Files, mediaSrc media.Src) (*vision.CaptionResult, *vision.Model, error) {
return &vision.CaptionResult{Text: "stub", Source: captionModel.GetSource()}, captionModel, nil
})
caption, err := mediaFile.GenerateCaption(entity.SrcAuto)
require.NoError(t, err)
require.NotNil(t, caption)
assert.Equal(t, captionModel.GetSource(), caption.Source)
})
t.Run("CustomSourceOverrides", func(t *testing.T) {
vision.SetCaptionFunc(func(files vision.Files, mediaSrc media.Src) (*vision.CaptionResult, *vision.Model, error) {
return &vision.CaptionResult{Text: "stub", Source: captionModel.GetSource()}, captionModel, nil
})
caption, err := mediaFile.GenerateCaption(entity.SrcManual)
require.NoError(t, err)
require.NotNil(t, caption)
assert.Equal(t, entity.SrcManual, caption.Source)
})
t.Run("MissingModelReturnsError", func(t *testing.T) {
vision.Config = &vision.ConfigValues{}
vision.SetCaptionFunc(nil)
caption, err := mediaFile.GenerateCaption(entity.SrcAuto)
assert.Error(t, err)
assert.Nil(t, caption)
})
}
func TestMediaFile_GenerateLabels(t *testing.T) {
mediaFile := setupVisionMediaFile(t)
originalConfig := vision.Config
t.Cleanup(func() {
vision.Config = originalConfig
vision.SetLabelsFunc(nil)
})
labelModel := &vision.Model{Type: vision.ModelTypeLabels, Engine: vision.ApiFormatOllama}
labelModel.ApplyEngineDefaults()
vision.Config = &vision.ConfigValues{Models: vision.Models{labelModel}}
t.Run("AutoUsesModelSource", func(t *testing.T) {
var captured string
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src entity.Src) (classify.Labels, error) {
captured = src
return classify.Labels{{Name: "stub", Source: src}}, nil
})
labels := mediaFile.GenerateLabels(entity.SrcAuto)
assert.NotEmpty(t, labels)
assert.Equal(t, labelModel.GetSource(), captured)
})
t.Run("CustomSourceOverrides", func(t *testing.T) {
var captured string
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src entity.Src) (classify.Labels, error) {
captured = src
return classify.Labels{{Name: "stub", Source: src}}, nil
})
labels := mediaFile.GenerateLabels(entity.SrcManual)
assert.NotEmpty(t, labels)
assert.Equal(t, entity.SrcManual, captured)
})
t.Run("MissingModel", func(t *testing.T) {
vision.Config = &vision.ConfigValues{}
vision.SetLabelsFunc(nil)
labels := mediaFile.GenerateLabels(entity.SrcAuto)
assert.Empty(t, labels)
})
}
func TestMediaFile_DetectNSFW(t *testing.T) {
mediaFile := setupVisionMediaFile(t)
t.Run("FlagsHighConfidence", func(t *testing.T) {
vision.SetNSFWFunc(func(files vision.Files, mediaSrc media.Src) ([]nsfw.Result, error) {
return []nsfw.Result{{Porn: nsfw.ThresholdHigh + 0.01}}, nil
})
t.Cleanup(func() { vision.SetNSFWFunc(nil) })
assert.True(t, mediaFile.DetectNSFW())
})
t.Run("SafeContent", func(t *testing.T) {
vision.SetNSFWFunc(func(files vision.Files, mediaSrc media.Src) ([]nsfw.Result, error) {
return []nsfw.Result{{Neutral: 0.9}}, nil
})
t.Cleanup(func() { vision.SetNSFWFunc(nil) })
assert.False(t, mediaFile.DetectNSFW())
})
}