Backend: Set NSFW flag while indexing

Signed-off-by: Michael Mayer <michael@liquidbytes.net>
This commit is contained in:
Michael Mayer
2019-12-14 20:35:14 +01:00
parent 78eae2f14e
commit 8cce9f7c8c
16 changed files with 75 additions and 27 deletions

View File

@@ -10,6 +10,7 @@ import (
"github.com/photoprism/photoprism/internal/config" "github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/event" "github.com/photoprism/photoprism/internal/event"
"github.com/photoprism/photoprism/internal/form" "github.com/photoprism/photoprism/internal/form"
"github.com/photoprism/photoprism/internal/nsfw"
"github.com/photoprism/photoprism/internal/photoprism" "github.com/photoprism/photoprism/internal/photoprism"
"github.com/photoprism/photoprism/internal/util" "github.com/photoprism/photoprism/internal/util"
) )
@@ -22,8 +23,9 @@ func initIndexer(conf *config.Config) {
} }
tensorFlow := photoprism.NewTensorFlow(conf) tensorFlow := photoprism.NewTensorFlow(conf)
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
indexer = photoprism.NewIndexer(conf, tensorFlow) indexer = photoprism.NewIndexer(conf, tensorFlow, nsfwDetector)
} }
// POST /api/v1/index // POST /api/v1/index

View File

@@ -5,6 +5,7 @@ import (
"time" "time"
"github.com/photoprism/photoprism/internal/config" "github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/nsfw"
"github.com/photoprism/photoprism/internal/photoprism" "github.com/photoprism/photoprism/internal/photoprism"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
@@ -40,8 +41,9 @@ func importAction(ctx *cli.Context) error {
log.Infof("importing photos from %s", conf.ImportPath()) log.Infof("importing photos from %s", conf.ImportPath())
tensorFlow := photoprism.NewTensorFlow(conf) tensorFlow := photoprism.NewTensorFlow(conf)
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
indexer := photoprism.NewIndexer(conf, tensorFlow) indexer := photoprism.NewIndexer(conf, tensorFlow, nsfwDetector)
converter := photoprism.NewConverter(conf) converter := photoprism.NewConverter(conf)

View File

@@ -5,6 +5,7 @@ import (
"time" "time"
"github.com/photoprism/photoprism/internal/config" "github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/nsfw"
"github.com/photoprism/photoprism/internal/photoprism" "github.com/photoprism/photoprism/internal/photoprism"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
@@ -39,8 +40,9 @@ func indexAction(ctx *cli.Context) error {
} }
tensorFlow := photoprism.NewTensorFlow(conf) tensorFlow := photoprism.NewTensorFlow(conf)
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
indexer := photoprism.NewIndexer(conf, tensorFlow) indexer := photoprism.NewIndexer(conf, tensorFlow, nsfwDetector)
options := photoprism.IndexerOptionsAll() options := photoprism.IndexerOptionsAll()
files := indexer.IndexOriginals(options) files := indexer.IndexOriginals(options)

View File

@@ -473,6 +473,11 @@ func (c *Config) TensorFlowModelPath() string {
return c.ResourcesPath() + "/nasnet" return c.ResourcesPath() + "/nasnet"
} }
// NSFWModelPath returns the NSFW tensorflow model path.
func (c *Config) NSFWModelPath() string {
return c.ResourcesPath() + "/nsfw"
}
// HttpTemplatesPath returns the server templates path. // HttpTemplatesPath returns the server templates path.
func (c *Config) HttpTemplatesPath() string { func (c *Config) HttpTemplatesPath() string {
return c.ResourcesPath() + "/templates" return c.ResourcesPath() + "/templates"

1
internal/nsfw/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
testdata/porn*

View File

@@ -27,21 +27,19 @@ func (l *Labels) IsSafe() bool {
} }
func (l *Labels) NSFW() bool { func (l *Labels) NSFW() bool {
if l.Neutral > 0.25 && l.Porn < 0.75 { if l.Neutral > 0.25 {
return false return false
} }
if l.Porn > 0.4 {
if l.Porn > 0.75 {
return true return true
} }
if l.Sexy > 0.5 { if l.Sexy > 0.75 {
return true return true
} }
if l.Hentai > 0.75 { if l.Hentai > 0.75 {
return true return true
} }
if l.Drawing > 0.9 {
return true
}
return false return false
} }

View File

@@ -86,10 +86,11 @@ func TestNSFW(t *testing.T) {
assert.GreaterOrEqual(t, l.Sexy, e.Sexy) assert.GreaterOrEqual(t, l.Sexy, e.Sexy)
} }
isNSFW := strings.Contains(basename, "porn") || strings.Contains(basename, "hentai") isSafe := !(strings.Contains(basename, "porn") || strings.Contains(basename, "hentai"))
assert.Equal(t, isNSFW, l.NSFW()) if isSafe {
assert.Equal(t, !isNSFW, l.IsSafe()) assert.True(t, l.IsSafe())
}
}) })
return nil return nil

BIN
internal/nsfw/testdata/architecture.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 171 KiB

BIN
internal/nsfw/testdata/art.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 118 KiB

BIN
internal/nsfw/testdata/museum.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 135 KiB

BIN
internal/nsfw/testdata/san-francisco.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 161 KiB

View File

@@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/photoprism/photoprism/internal/config" "github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/nsfw"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -11,8 +12,9 @@ func TestNewImporter(t *testing.T) {
conf := config.TestConfig() conf := config.TestConfig()
tensorFlow := NewTensorFlow(conf) tensorFlow := NewTensorFlow(conf)
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
indexer := NewIndexer(conf, tensorFlow) indexer := NewIndexer(conf, tensorFlow, nsfwDetector)
converter := NewConverter(conf) converter := NewConverter(conf)
@@ -27,8 +29,9 @@ func TestImporter_DestinationFilename(t *testing.T) {
conf.InitializeTestData(t) conf.InitializeTestData(t)
tensorFlow := NewTensorFlow(conf) tensorFlow := NewTensorFlow(conf)
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
indexer := NewIndexer(conf, tensorFlow) indexer := NewIndexer(conf, tensorFlow, nsfwDetector)
converter := NewConverter(conf) converter := NewConverter(conf)
@@ -55,8 +58,9 @@ func TestImporter_ImportPhotosFromDirectory(t *testing.T) {
conf.InitializeTestData(t) conf.InitializeTestData(t)
tensorFlow := NewTensorFlow(conf) tensorFlow := NewTensorFlow(conf)
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
indexer := NewIndexer(conf, tensorFlow) indexer := NewIndexer(conf, tensorFlow, nsfwDetector)
converter := NewConverter(conf) converter := NewConverter(conf)

View File

@@ -7,21 +7,24 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/photoprism/photoprism/internal/config" "github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/nsfw"
) )
// Indexer defines an indexer with originals path tensorflow and a db. // Indexer defines an indexer with originals path tensorflow and a db.
type Indexer struct { type Indexer struct {
conf *config.Config conf *config.Config
tensorFlow *TensorFlow tensorFlow *TensorFlow
nsfwDetector *nsfw.Detector
db *gorm.DB db *gorm.DB
} }
// NewIndexer returns a new indexer. // NewIndexer returns a new indexer.
// TODO: Is it really necessary to return a pointer? // TODO: Is it really necessary to return a pointer?
func NewIndexer(conf *config.Config, tensorFlow *TensorFlow) *Indexer { func NewIndexer(conf *config.Config, tensorFlow *TensorFlow, nsfwDetector *nsfw.Detector) *Indexer {
i := &Indexer{ i := &Indexer{
conf: conf, conf: conf,
tensorFlow: tensorFlow, tensorFlow: tensorFlow,
nsfwDetector: nsfwDetector,
db: conf.Db(), db: conf.Db(),
} }

View File

@@ -2,6 +2,7 @@ package photoprism
import ( import (
"fmt" "fmt"
"math"
"path/filepath" "path/filepath"
"sort" "sort"
"strings" "strings"
@@ -29,6 +30,7 @@ func (i *Indexer) indexMediaFile(m *MediaFile, o IndexerOptions) IndexResult {
var exifData *Exif var exifData *Exif
var photoQuery, fileQuery *gorm.DB var photoQuery, fileQuery *gorm.DB
var keywords []string var keywords []string
var isNSFW bool
labels := Labels{} labels := Labels{}
fileBase := m.Basename() fileBase := m.Basename()
@@ -86,7 +88,8 @@ func (i *Indexer) indexMediaFile(m *MediaFile, o IndexerOptions) IndexResult {
if file.FilePrimary { if file.FilePrimary {
if fileChanged || o.UpdateKeywords || o.UpdateLabels || o.UpdateTitle { if fileChanged || o.UpdateKeywords || o.UpdateLabels || o.UpdateTitle {
// Image classification labels // Image classification labels
labels = i.classifyImage(m) labels, isNSFW = i.classifyImage(m)
photo.PhotoNSFW = isNSFW
} }
if fileChanged || o.UpdateExif { if fileChanged || o.UpdateExif {
@@ -225,7 +228,7 @@ func (i *Indexer) indexMediaFile(m *MediaFile, o IndexerOptions) IndexResult {
} }
// classifyImage returns all matching labels for a media file. // classifyImage returns all matching labels for a media file.
func (i *Indexer) classifyImage(jpeg *MediaFile) (results Labels) { func (i *Indexer) classifyImage(jpeg *MediaFile) (results Labels, isNSFW bool) {
start := time.Now() start := time.Now()
var thumbs []string var thumbs []string
@@ -256,6 +259,25 @@ func (i *Indexer) classifyImage(jpeg *MediaFile) (results Labels) {
labels = append(labels, imageLabels...) labels = append(labels, imageLabels...)
} }
if filename, err := jpeg.Thumbnail(i.thumbnailsPath(), "fit_720"); err != nil {
log.Error(err)
} else {
if nsfwLabels, err := i.nsfwDetector.LabelsFromFile(filename); err != nil {
log.Error(err)
} else {
log.Infof("nsfw: %+v", nsfwLabels)
if nsfwLabels.NSFW() {
isNSFW = true
}
if nsfwLabels.Sexy > 0.2 {
uncertainty := 100 - int(math.Round(float64(nsfwLabels.Sexy*100)))
labels = append(labels, Label{Name: "sexy", Source: "nsfw", Uncertainty: uncertainty, Priority: -1})
}
}
}
// Sort by priority and uncertainty // Sort by priority and uncertainty
sort.Sort(labels) sort.Sort(labels)
@@ -271,11 +293,15 @@ func (i *Indexer) classifyImage(jpeg *MediaFile) (results Labels) {
} }
} }
if isNSFW {
log.Info("index: image might contain sexually explicit content")
}
elapsed := time.Since(start) elapsed := time.Since(start)
log.Debugf("index: image classification took %s", elapsed) log.Debugf("index: image classification took %s", elapsed)
return results return results, isNSFW
} }
func (i *Indexer) addLabels(photoId uint, labels Labels) { func (i *Indexer) addLabels(photoId uint, labels Labels) {

View File

@@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/photoprism/photoprism/internal/config" "github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/nsfw"
) )
func TestIndexer_IndexAll(t *testing.T) { func TestIndexer_IndexAll(t *testing.T) {
@@ -16,8 +17,9 @@ func TestIndexer_IndexAll(t *testing.T) {
conf.InitializeTestData(t) conf.InitializeTestData(t)
tensorFlow := NewTensorFlow(conf) tensorFlow := NewTensorFlow(conf)
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
indexer := NewIndexer(conf, tensorFlow) indexer := NewIndexer(conf, tensorFlow, nsfwDetector)
converter := NewConverter(conf) converter := NewConverter(conf)

View File

@@ -6,6 +6,7 @@ import (
"github.com/disintegration/imaging" "github.com/disintegration/imaging"
"github.com/photoprism/photoprism/internal/entity" "github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/internal/nsfw"
"github.com/photoprism/photoprism/internal/config" "github.com/photoprism/photoprism/internal/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -66,8 +67,9 @@ func TestThumbnails_CreateThumbnailsFromOriginals(t *testing.T) {
conf.InitializeTestData(t) conf.InitializeTestData(t)
tensorFlow := NewTensorFlow(conf) tensorFlow := NewTensorFlow(conf)
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
indexer := NewIndexer(conf, tensorFlow) indexer := NewIndexer(conf, tensorFlow, nsfwDetector)
converter := NewConverter(conf) converter := NewConverter(conf)