mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
141 lines
4.2 KiB
Go
141 lines
4.2 KiB
Go
package photoprism
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/photoprism/photoprism/internal/ai/classify"
|
|
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
|
"github.com/photoprism/photoprism/internal/ai/vision"
|
|
"github.com/photoprism/photoprism/internal/config"
|
|
"github.com/photoprism/photoprism/internal/entity"
|
|
"github.com/photoprism/photoprism/pkg/media"
|
|
)
|
|
|
|
func setupVisionMediaFile(t *testing.T) *MediaFile {
|
|
t.Helper()
|
|
|
|
cfg := config.TestConfig()
|
|
require.NoError(t, cfg.InitializeTestData())
|
|
|
|
mediaFile, err := NewMediaFile("testdata/flash.jpg")
|
|
require.NoError(t, err)
|
|
|
|
return mediaFile
|
|
}
|
|
|
|
func TestMediaFile_GenerateCaption(t *testing.T) {
|
|
mediaFile := setupVisionMediaFile(t)
|
|
|
|
originalConfig := vision.Config
|
|
t.Cleanup(func() {
|
|
vision.Config = originalConfig
|
|
vision.SetCaptionFunc(nil)
|
|
})
|
|
|
|
captionModel := &vision.Model{Type: vision.ModelTypeCaption, Engine: vision.ApiFormatOpenAI}
|
|
captionModel.ApplyEngineDefaults()
|
|
vision.Config = &vision.ConfigValues{Models: vision.Models{captionModel}}
|
|
|
|
t.Run("AutoUsesModelSource", func(t *testing.T) {
|
|
vision.SetCaptionFunc(func(files vision.Files, mediaSrc media.Src) (*vision.CaptionResult, *vision.Model, error) {
|
|
return &vision.CaptionResult{Text: "stub", Source: captionModel.GetSource()}, captionModel, nil
|
|
})
|
|
|
|
caption, err := mediaFile.GenerateCaption(entity.SrcAuto)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, caption)
|
|
assert.Equal(t, captionModel.GetSource(), caption.Source)
|
|
})
|
|
|
|
t.Run("CustomSourceOverrides", func(t *testing.T) {
|
|
vision.SetCaptionFunc(func(files vision.Files, mediaSrc media.Src) (*vision.CaptionResult, *vision.Model, error) {
|
|
return &vision.CaptionResult{Text: "stub", Source: captionModel.GetSource()}, captionModel, nil
|
|
})
|
|
|
|
caption, err := mediaFile.GenerateCaption(entity.SrcManual)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, caption)
|
|
assert.Equal(t, entity.SrcManual, caption.Source)
|
|
})
|
|
|
|
t.Run("MissingModelReturnsError", func(t *testing.T) {
|
|
vision.Config = &vision.ConfigValues{}
|
|
vision.SetCaptionFunc(nil)
|
|
|
|
caption, err := mediaFile.GenerateCaption(entity.SrcAuto)
|
|
assert.Error(t, err)
|
|
assert.Nil(t, caption)
|
|
})
|
|
}
|
|
|
|
func TestMediaFile_GenerateLabels(t *testing.T) {
|
|
mediaFile := setupVisionMediaFile(t)
|
|
|
|
originalConfig := vision.Config
|
|
t.Cleanup(func() {
|
|
vision.Config = originalConfig
|
|
vision.SetLabelsFunc(nil)
|
|
})
|
|
|
|
labelModel := &vision.Model{Type: vision.ModelTypeLabels, Engine: vision.ApiFormatOllama}
|
|
labelModel.ApplyEngineDefaults()
|
|
vision.Config = &vision.ConfigValues{Models: vision.Models{labelModel}}
|
|
|
|
t.Run("AutoUsesModelSource", func(t *testing.T) {
|
|
var captured string
|
|
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src entity.Src) (classify.Labels, error) {
|
|
captured = src
|
|
return classify.Labels{{Name: "stub", Source: src}}, nil
|
|
})
|
|
|
|
labels := mediaFile.GenerateLabels(entity.SrcAuto)
|
|
assert.NotEmpty(t, labels)
|
|
assert.Equal(t, labelModel.GetSource(), captured)
|
|
})
|
|
|
|
t.Run("CustomSourceOverrides", func(t *testing.T) {
|
|
var captured string
|
|
vision.SetLabelsFunc(func(files vision.Files, mediaSrc media.Src, src entity.Src) (classify.Labels, error) {
|
|
captured = src
|
|
return classify.Labels{{Name: "stub", Source: src}}, nil
|
|
})
|
|
|
|
labels := mediaFile.GenerateLabels(entity.SrcManual)
|
|
assert.NotEmpty(t, labels)
|
|
assert.Equal(t, entity.SrcManual, captured)
|
|
})
|
|
|
|
t.Run("MissingModel", func(t *testing.T) {
|
|
vision.Config = &vision.ConfigValues{}
|
|
vision.SetLabelsFunc(nil)
|
|
|
|
labels := mediaFile.GenerateLabels(entity.SrcAuto)
|
|
assert.Empty(t, labels)
|
|
})
|
|
}
|
|
|
|
func TestMediaFile_DetectNSFW(t *testing.T) {
|
|
mediaFile := setupVisionMediaFile(t)
|
|
|
|
t.Run("FlagsHighConfidence", func(t *testing.T) {
|
|
vision.SetNSFWFunc(func(files vision.Files, mediaSrc media.Src) ([]nsfw.Result, error) {
|
|
return []nsfw.Result{{Porn: nsfw.ThresholdHigh + 0.01}}, nil
|
|
})
|
|
t.Cleanup(func() { vision.SetNSFWFunc(nil) })
|
|
|
|
assert.True(t, mediaFile.DetectNSFW())
|
|
})
|
|
|
|
t.Run("SafeContent", func(t *testing.T) {
|
|
vision.SetNSFWFunc(func(files vision.Files, mediaSrc media.Src) ([]nsfw.Result, error) {
|
|
return []nsfw.Result{{Neutral: 0.9}}, nil
|
|
})
|
|
t.Cleanup(func() { vision.SetNSFWFunc(nil) })
|
|
|
|
assert.False(t, mediaFile.DetectNSFW())
|
|
})
|
|
}
|