mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
198 lines
4.7 KiB
Go
198 lines
4.7 KiB
Go
package nsfw
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
|
|
tf "github.com/wamuir/graft/tensorflow"
|
|
|
|
"github.com/photoprism/photoprism/internal/ai/tensorflow"
|
|
"github.com/photoprism/photoprism/pkg/clean"
|
|
"github.com/photoprism/photoprism/pkg/fs"
|
|
"github.com/photoprism/photoprism/pkg/http/header"
|
|
"github.com/photoprism/photoprism/pkg/http/scheme"
|
|
"github.com/photoprism/photoprism/pkg/media"
|
|
)
|
|
|
|
// Model uses TensorFlow to label drawing, hentai, neutral, porn and sexy images.
|
|
type Model struct {
|
|
model *tf.SavedModel
|
|
modelPath string
|
|
labels []string
|
|
meta *tensorflow.ModelInfo
|
|
disabled bool
|
|
mutex sync.Mutex
|
|
}
|
|
|
|
// NewModel returns a new detector instance.
|
|
func NewModel(modelPath string, meta *tensorflow.ModelInfo, disabled bool) *Model {
|
|
if meta == nil {
|
|
meta = new(tensorflow.ModelInfo)
|
|
}
|
|
|
|
return &Model{
|
|
modelPath: modelPath,
|
|
meta: meta,
|
|
disabled: disabled,
|
|
}
|
|
}
|
|
|
|
// File checks the specified JPEG file for inappropriate content.
|
|
func (m *Model) File(fileName string) (result Result, err error) {
|
|
if fs.MimeType(fileName) != header.ContentTypeJpeg {
|
|
return result, fmt.Errorf("%s is not a jpeg file", clean.Log(filepath.Base(fileName)))
|
|
}
|
|
|
|
var img []byte
|
|
|
|
if img, err = os.ReadFile(fileName); err != nil { //nolint:gosec // fileName is provided by trusted callers; reading local test fixtures is intentional
|
|
return result, err
|
|
}
|
|
|
|
return m.Run(img)
|
|
}
|
|
|
|
// Url checks the JPEG file from the specified https or data URL for inappropriate content.
|
|
func (m *Model) Url(imgUrl string) (result Result, err error) {
|
|
if m.disabled {
|
|
return result, nil
|
|
}
|
|
|
|
var img []byte
|
|
|
|
if img, err = media.ReadUrl(imgUrl, scheme.HttpsData); err != nil {
|
|
return result, err
|
|
}
|
|
|
|
return m.Run(img)
|
|
}
|
|
|
|
// Run returns matching labels for a jpeg media string.
|
|
func (m *Model) Run(img []byte) (result Result, err error) {
|
|
if loadErr := m.loadModel(); loadErr != nil {
|
|
return result, loadErr
|
|
}
|
|
|
|
// Create input tensor from image.
|
|
input, err := tensorflow.ImageTransform(
|
|
img, fs.ImageJpeg, m.meta.Input.Resolution())
|
|
|
|
if err != nil {
|
|
return result, fmt.Errorf("%s", err)
|
|
}
|
|
|
|
// Run inference.
|
|
output, err := m.model.Session.Run(
|
|
map[tf.Output]*tf.Tensor{
|
|
m.model.Graph.Operation(m.meta.Input.Name).Output(m.meta.Input.OutputIndex): input,
|
|
},
|
|
[]tf.Output{
|
|
m.model.Graph.Operation(m.meta.Output.Name).Output(m.meta.Output.OutputIndex),
|
|
},
|
|
nil)
|
|
|
|
if err != nil {
|
|
return result, fmt.Errorf("%s (run inference)", err.Error())
|
|
}
|
|
|
|
if len(output) < 1 {
|
|
return result, fmt.Errorf("inference failed, no output")
|
|
}
|
|
|
|
// Return best labels.
|
|
result = m.getLabels(output[0].Value().([][]float32)[0])
|
|
|
|
log.Tracef("nsfw: image classified as %+v", result)
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// Init initializes tensorflow models if not disabled.
|
|
func (m *Model) Init() (err error) {
|
|
if m.disabled {
|
|
return nil
|
|
}
|
|
|
|
return m.loadModel()
|
|
}
|
|
|
|
func (m *Model) loadModel() error {
|
|
// Use mutex to prevent the model from being loaded and
|
|
// initialized twice by different indexing workers.
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
if m.model != nil {
|
|
// Already loaded
|
|
return nil
|
|
}
|
|
|
|
log.Infof("nsfw: loading %s", clean.Log(filepath.Base(m.modelPath)))
|
|
|
|
if len(m.meta.Tags) == 0 {
|
|
infos, err := tensorflow.GetModelTagsInfo(m.modelPath)
|
|
|
|
switch {
|
|
case err != nil:
|
|
log.Errorf("nsfw: could not get the model info at %s: %v", clean.Log(m.modelPath))
|
|
case len(infos) == 1:
|
|
log.Debugf("nsfw: model info: %+v", infos[0])
|
|
m.meta.Merge(&infos[0])
|
|
case len(infos) > 1:
|
|
log.Warnf("nsfw: found %d metagraphs... that's too many", len(infos))
|
|
default:
|
|
log.Warnf("nsfw: no metagraphs found in %s", clean.Log(m.modelPath))
|
|
}
|
|
}
|
|
|
|
// Load saved TensorFlow model from the specified path.
|
|
model, err := tensorflow.SavedModel(m.modelPath, m.meta.Tags)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !m.meta.IsComplete() {
|
|
input, output, err := tensorflow.GetInputAndOutputFromSavedModel(model)
|
|
if err != nil {
|
|
log.Errorf("nsfw: could not get info from signatures: %v", err)
|
|
input, output, err = tensorflow.GuessInputAndOutput(model)
|
|
if err != nil {
|
|
return fmt.Errorf("nsfw: %w", err)
|
|
}
|
|
}
|
|
|
|
m.meta.Merge(&tensorflow.ModelInfo{
|
|
Input: input,
|
|
Output: output,
|
|
})
|
|
}
|
|
|
|
m.model = model
|
|
|
|
if m.meta.Output.OutputsLogits {
|
|
_, err = tensorflow.AddSoftmax(m.model.Graph, m.meta)
|
|
if err != nil {
|
|
return fmt.Errorf("nsfw: could not add softmax (%s)", clean.Error(err))
|
|
}
|
|
}
|
|
|
|
return m.loadLabels(m.modelPath)
|
|
}
|
|
|
|
func (m *Model) loadLabels(modelPath string) (err error) {
|
|
m.labels, err = tensorflow.LoadLabels(modelPath, int(m.meta.Output.NumOutputs))
|
|
return err
|
|
}
|
|
|
|
func (m *Model) getLabels(p []float32) Result {
|
|
return Result{
|
|
Drawing: p[0],
|
|
Hentai: p[1],
|
|
Neutral: p[2],
|
|
Porn: p[3],
|
|
Sexy: p[4],
|
|
}
|
|
}
|