diff --git a/internal/ai/nsfw/model.go b/internal/ai/nsfw/model.go index 54a910252..19d1af5ce 100644 --- a/internal/ai/nsfw/model.go +++ b/internal/ai/nsfw/model.go @@ -47,7 +47,7 @@ func (m *Model) File(fileName string) (result Result, err error) { var img []byte - if img, err = os.ReadFile(fileName); err != nil { + if img, err = os.ReadFile(fileName); err != nil { //nolint:gosec // fileName is provided by trusted callers; reading local test fixtures is intentional return result, err } @@ -109,7 +109,7 @@ func (m *Model) Run(img []byte) (result Result, err error) { return result, nil } -// Init initialises tensorflow models if not disabled +// Init initializes tensorflow models if not disabled. func (m *Model) Init() (err error) { if m.disabled { return nil @@ -133,13 +133,17 @@ func (m *Model) loadModel() error { if len(m.meta.Tags) == 0 { infos, err := tensorflow.GetModelTagsInfo(m.modelPath) - if err != nil { + + switch { + case err != nil: log.Errorf("nsfw: could not get the model info at %s: %v", clean.Log(m.modelPath)) - } else if len(infos) == 1 { + case len(infos) == 1: log.Debugf("nsfw: model info: %+v", infos[0]) m.meta.Merge(&infos[0]) - } else { + case len(infos) > 1: log.Warnf("nsfw: found %d metagraphs... that's too many", len(infos)) + default: + log.Warnf("nsfw: no metagraphs found in %s", clean.Log(m.modelPath)) } } @@ -179,7 +183,7 @@ func (m *Model) loadModel() error { func (m *Model) loadLabels(modelPath string) (err error) { m.labels, err = tensorflow.LoadLabels(modelPath, int(m.meta.Output.NumOutputs)) - return nil + return err } func (m *Model) getLabels(p []float32) Result { diff --git a/internal/ai/nsfw/nsfw.go b/internal/ai/nsfw/nsfw.go index 9853b5635..7b7d9f3b1 100644 --- a/internal/ai/nsfw/nsfw.go +++ b/internal/ai/nsfw/nsfw.go @@ -28,6 +28,7 @@ import ( "github.com/photoprism/photoprism/internal/event" ) +// Thresholds for classifying NSFW scores. const ( ThresholdSafe = 0.75 ThresholdMedium = 0.85 @@ -36,6 +37,7 @@ const ( var log = event.Log +// Result represents the classification scores returned by the NSFW model. type Result struct { Drawing float32 Hentai float32 diff --git a/internal/ai/nsfw/nsfw_test.go b/internal/ai/nsfw/nsfw_test.go index 1edb6bb45..9f39fc336 100644 --- a/internal/ai/nsfw/nsfw_test.go +++ b/internal/ai/nsfw/nsfw_test.go @@ -84,7 +84,7 @@ func TestIsSafe(t *testing.T) { assert.GreaterOrEqual(t, l.Sexy, e.Sexy) } - isSafe := !(strings.Contains(basename, "porn") || strings.Contains(basename, "hentai")) + isSafe := !strings.Contains(basename, "porn") && !strings.Contains(basename, "hentai") if isSafe { assert.True(t, l.IsSafe())