mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
CI: Apply Go more linter recommendations to "ai/nsfw" package #5330
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
@@ -47,7 +47,7 @@ func (m *Model) File(fileName string) (result Result, err error) {
|
|||||||
|
|
||||||
var img []byte
|
var img []byte
|
||||||
|
|
||||||
if img, err = os.ReadFile(fileName); err != nil {
|
if img, err = os.ReadFile(fileName); err != nil { //nolint:gosec // fileName is provided by trusted callers; reading local test fixtures is intentional
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,7 +109,7 @@ func (m *Model) Run(img []byte) (result Result, err error) {
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initialises tensorflow models if not disabled
|
// Init initializes tensorflow models if not disabled.
|
||||||
func (m *Model) Init() (err error) {
|
func (m *Model) Init() (err error) {
|
||||||
if m.disabled {
|
if m.disabled {
|
||||||
return nil
|
return nil
|
||||||
@@ -133,13 +133,17 @@ func (m *Model) loadModel() error {
|
|||||||
|
|
||||||
if len(m.meta.Tags) == 0 {
|
if len(m.meta.Tags) == 0 {
|
||||||
infos, err := tensorflow.GetModelTagsInfo(m.modelPath)
|
infos, err := tensorflow.GetModelTagsInfo(m.modelPath)
|
||||||
if err != nil {
|
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
log.Errorf("nsfw: could not get the model info at %s: %v", clean.Log(m.modelPath))
|
log.Errorf("nsfw: could not get the model info at %s: %v", clean.Log(m.modelPath))
|
||||||
} else if len(infos) == 1 {
|
case len(infos) == 1:
|
||||||
log.Debugf("nsfw: model info: %+v", infos[0])
|
log.Debugf("nsfw: model info: %+v", infos[0])
|
||||||
m.meta.Merge(&infos[0])
|
m.meta.Merge(&infos[0])
|
||||||
} else {
|
case len(infos) > 1:
|
||||||
log.Warnf("nsfw: found %d metagraphs... that's too many", len(infos))
|
log.Warnf("nsfw: found %d metagraphs... that's too many", len(infos))
|
||||||
|
default:
|
||||||
|
log.Warnf("nsfw: no metagraphs found in %s", clean.Log(m.modelPath))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -179,7 +183,7 @@ func (m *Model) loadModel() error {
|
|||||||
|
|
||||||
func (m *Model) loadLabels(modelPath string) (err error) {
|
func (m *Model) loadLabels(modelPath string) (err error) {
|
||||||
m.labels, err = tensorflow.LoadLabels(modelPath, int(m.meta.Output.NumOutputs))
|
m.labels, err = tensorflow.LoadLabels(modelPath, int(m.meta.Output.NumOutputs))
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) getLabels(p []float32) Result {
|
func (m *Model) getLabels(p []float32) Result {
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import (
|
|||||||
"github.com/photoprism/photoprism/internal/event"
|
"github.com/photoprism/photoprism/internal/event"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Thresholds for classifying NSFW scores.
|
||||||
const (
|
const (
|
||||||
ThresholdSafe = 0.75
|
ThresholdSafe = 0.75
|
||||||
ThresholdMedium = 0.85
|
ThresholdMedium = 0.85
|
||||||
@@ -36,6 +37,7 @@ const (
|
|||||||
|
|
||||||
var log = event.Log
|
var log = event.Log
|
||||||
|
|
||||||
|
// Result represents the classification scores returned by the NSFW model.
|
||||||
type Result struct {
|
type Result struct {
|
||||||
Drawing float32
|
Drawing float32
|
||||||
Hentai float32
|
Hentai float32
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ func TestIsSafe(t *testing.T) {
|
|||||||
assert.GreaterOrEqual(t, l.Sexy, e.Sexy)
|
assert.GreaterOrEqual(t, l.Sexy, e.Sexy)
|
||||||
}
|
}
|
||||||
|
|
||||||
isSafe := !(strings.Contains(basename, "porn") || strings.Contains(basename, "hentai"))
|
isSafe := !strings.Contains(basename, "porn") && !strings.Contains(basename, "hentai")
|
||||||
|
|
||||||
if isSafe {
|
if isSafe {
|
||||||
assert.True(t, l.IsSafe())
|
assert.True(t, l.IsSafe())
|
||||||
|
|||||||
Reference in New Issue
Block a user