AI: Finalize facial embeddings, labels and nsfw API endpoints #127 #1090

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-04-10 20:28:26 +02:00
parent caf3ae1ab5
commit 190be2a1b5
25 changed files with 596 additions and 49 deletions

View File

@@ -214,7 +214,10 @@ func (m *Model) createTensor(image []byte) (*tf.Tensor, error) {
return nil, err
}
// Resize the image only if its resolution does not match the model.
if img.Bounds().Dx() != m.resolution || img.Bounds().Dy() != m.resolution {
img = imaging.Fill(img, m.resolution, m.resolution, imaging.Center, imaging.Lanczos)
}
return tensorflow.Image(img, m.resolution)
}

View File

@@ -8,6 +8,7 @@ import (
"runtime/debug"
"sync"
"github.com/disintegration/imaging"
tf "github.com/wamuir/graft/tensorflow"
"github.com/photoprism/photoprism/internal/thumb/crop"
@@ -116,6 +117,7 @@ func (m *Model) loadModel() error {
// Run returns the face embeddings for an image.
func (m *Model) Run(img image.Image) Embeddings {
// Create input tensor from image.
tensor, err := imageToTensor(img, m.resolution)
if err != nil {
@@ -160,6 +162,11 @@ func imageToTensor(img image.Image, resolution int) (tfTensor *tf.Tensor, err er
return tfTensor, fmt.Errorf("faces: invalid model resolution")
}
// Resize the image only if its resolution does not match the model.
if img.Bounds().Dx() != resolution || img.Bounds().Dy() != resolution {
img = imaging.Fill(img, resolution, resolution, imaging.Center, imaging.Lanczos)
}
var tfImage [1][][][3]float32
for j := 0; j < resolution; j++ {

View File

@@ -46,11 +46,11 @@ type Result struct {
// IsSafe returns true if the image is probably safe for work.
func (l *Result) IsSafe() bool {
return !l.NSFW(ThresholdSafe)
return !l.IsNsfw(ThresholdSafe)
}
// NSFW returns true if the image is may not be safe for work.
func (l *Result) NSFW(threshold float32) bool {
// IsNsfw returns true if the image is may not be safe for work.
func (l *Result) IsNsfw(threshold float32) bool {
if l.Neutral > 0.25 {
return false
}

View File

@@ -97,28 +97,28 @@ func TestIsSafe(t *testing.T) {
}
}
func TestNSFW(t *testing.T) {
func TestIsNsfw(t *testing.T) {
porn := Result{0, 0, 0.11, 0.88, 0}
sexy := Result{0, 0, 0.2, 0.59, 0.98}
maxi := Result{0, 0.999, 0.1, 0.999, 0.999}
drawing := Result{0.999, 0, 0, 0, 0}
hentai := Result{0, 0.80, 0.2, 0, 0}
assert.Equal(t, true, porn.NSFW(ThresholdSafe))
assert.Equal(t, true, sexy.NSFW(ThresholdSafe))
assert.Equal(t, true, hentai.NSFW(ThresholdSafe))
assert.Equal(t, false, drawing.NSFW(ThresholdSafe))
assert.Equal(t, true, maxi.NSFW(ThresholdSafe))
assert.Equal(t, true, porn.IsNsfw(ThresholdSafe))
assert.Equal(t, true, sexy.IsNsfw(ThresholdSafe))
assert.Equal(t, true, hentai.IsNsfw(ThresholdSafe))
assert.Equal(t, false, drawing.IsNsfw(ThresholdSafe))
assert.Equal(t, true, maxi.IsNsfw(ThresholdSafe))
assert.Equal(t, true, porn.NSFW(ThresholdMedium))
assert.Equal(t, true, sexy.NSFW(ThresholdMedium))
assert.Equal(t, false, hentai.NSFW(ThresholdMedium))
assert.Equal(t, false, drawing.NSFW(ThresholdMedium))
assert.Equal(t, true, maxi.NSFW(ThresholdMedium))
assert.Equal(t, true, porn.IsNsfw(ThresholdMedium))
assert.Equal(t, true, sexy.IsNsfw(ThresholdMedium))
assert.Equal(t, false, hentai.IsNsfw(ThresholdMedium))
assert.Equal(t, false, drawing.IsNsfw(ThresholdMedium))
assert.Equal(t, true, maxi.IsNsfw(ThresholdMedium))
assert.Equal(t, false, porn.NSFW(ThresholdHigh))
assert.Equal(t, false, sexy.NSFW(ThresholdHigh))
assert.Equal(t, false, hentai.NSFW(ThresholdHigh))
assert.Equal(t, false, drawing.NSFW(ThresholdHigh))
assert.Equal(t, true, maxi.NSFW(ThresholdHigh))
assert.Equal(t, false, porn.IsNsfw(ThresholdHigh))
assert.Equal(t, false, sexy.IsNsfw(ThresholdHigh))
assert.Equal(t, false, hentai.IsNsfw(ThresholdHigh))
assert.Equal(t, false, drawing.IsNsfw(ThresholdHigh))
assert.Equal(t, true, maxi.IsNsfw(ThresholdHigh))
}

View File

@@ -1,6 +1,8 @@
package vision
import (
"errors"
"fmt"
"math"
"net/http"
@@ -16,7 +18,35 @@ type ApiResponse struct {
Code int `yaml:"Code,omitempty" json:"code,omitempty"`
Error string `yaml:"Error,omitempty" json:"error,omitempty"`
Model *Model `yaml:"Model,omitempty" json:"model,omitempty"`
Result *ApiResult `yaml:"Result,omitempty" json:"result,omitempty"`
Result ApiResult `yaml:"Result,omitempty" json:"result,omitempty"`
}
// Err returns an error if the request has failed.
func (r *ApiResponse) Err() error {
if r == nil {
return errors.New("response is nil")
}
if r.Code >= 400 {
if r.Error != "" {
return errors.New(r.Error)
}
return fmt.Errorf("error %d", r.Code)
} else if r.Result.IsEmpty() {
return errors.New("no result")
}
return nil
}
// HasResult checks if there is at least one result in the response data.
func (r *ApiResponse) HasResult() bool {
if r == nil {
return false
}
return !r.Result.IsEmpty()
}
// ApiResult represents the model response(s) to a Vision API service
@@ -28,6 +58,15 @@ type ApiResult struct {
Caption *CaptionResult `yaml:"Caption,omitempty" json:"caption,omitempty"`
}
// IsEmpty checks if there is no result in the response data.
func (r *ApiResult) IsEmpty() bool {
if r == nil {
return false
}
return len(r.Labels) == 0 && len(r.Nsfw) == 0 && len(r.Embeddings) == 0 && r.Caption == nil
}
// CaptionResult represents the result generated by a caption generation model.
type CaptionResult struct {
Text string `yaml:"Text,omitempty" json:"text,omitempty"`
@@ -85,7 +124,7 @@ func NewLabelsResponse(id string, model *Model, results classify.Labels) ApiResp
return ApiResponse{
Id: clean.Type(id),
Code: http.StatusOK,
Model: model,
Result: &ApiResult{Labels: labels},
Model: &Model{Type: ModelTypeLabels, Name: model.Name, Version: model.Version, Resolution: model.Resolution},
Result: ApiResult{Labels: labels},
}
}

View File

@@ -16,7 +16,7 @@ func FaceEmbeddings(imgData []byte) (embeddings face.Embeddings, err error) {
}
if Config == nil {
return embeddings, errors.New("missing configuration")
return embeddings, errors.New("vision service is not configured")
} else if model := Config.Model(ModelTypeFaceEmbeddings); model != nil {
img, imgErr := jpeg.Decode(bytes.NewReader(imgData))

View File

@@ -21,7 +21,7 @@ func Faces(fileName string, minSize int, cacheCrop bool, expected int) (faces fa
// Return if there is no configuration or no image classification models are configured.
if Config == nil {
return faces, errors.New("missing configuration")
return faces, errors.New("vision service is not configured")
} else if model := Config.Model(ModelTypeFaceEmbeddings); model != nil {
faces, err = face.Detect(fileName, false, minSize)

View File

@@ -20,12 +20,12 @@ import (
func Labels(images Files, src media.Src) (result classify.Labels, err error) {
// Return if no thumbnail filenames were given.
if len(images) == 0 {
return result, errors.New("missing image filenames")
return result, errors.New("at least one image required")
}
// Return if there is no configuration or no image classification models are configured.
if Config == nil {
return result, errors.New("missing configuration")
return result, errors.New("vision service is not configured")
} else if model := Config.Model(ModelTypeLabels); model != nil {
// Use remote service API if a server endpoint has been configured.
if uri, method := model.Endpoint(); uri != "" && method != "" {

View File

@@ -26,7 +26,20 @@ func TestLabels(t *testing.T) {
assert.Equal(t, "chameleon", result[0].Name)
assert.Equal(t, 7, result[0].Uncertainty)
})
t.Run("Cats", func(t *testing.T) {
t.Run("Cat224", func(t *testing.T) {
result, err := Labels(Files{examplesPath + "/cat_224.jpeg"}, media.SrcLocal)
assert.NoError(t, err)
assert.IsType(t, classify.Labels{}, result)
assert.Equal(t, 1, len(result))
t.Log(result)
assert.Equal(t, "cat", result[0].Name)
assert.InDelta(t, 59, result[0].Uncertainty, 10)
assert.InDelta(t, float32(0.41), result[0].Confidence(), 0.1)
})
t.Run("Cat720", func(t *testing.T) {
result, err := Labels(Files{examplesPath + "/cat_720.jpeg"}, media.SrcLocal)
assert.NoError(t, err)
@@ -36,8 +49,8 @@ func TestLabels(t *testing.T) {
t.Log(result)
assert.Equal(t, "cat", result[0].Name)
assert.Equal(t, 60, result[0].Uncertainty)
assert.InDelta(t, float32(0.4), result[0].Confidence(), 0.01)
assert.InDelta(t, 60, result[0].Uncertainty, 10)
assert.InDelta(t, float32(0.4), result[0].Confidence(), 0.1)
})
t.Run("InvalidFile", func(t *testing.T) {
_, err := Labels(Files{examplesPath + "/notexisting.jpg"}, media.SrcLocal)

View File

@@ -19,14 +19,14 @@ import (
func Nsfw(images Files, src media.Src) (result []nsfw.Result, err error) {
// Return if no thumbnail filenames were given.
if len(images) == 0 {
return result, errors.New("missing image filenames")
return result, errors.New("at least one image required")
}
result = make([]nsfw.Result, len(images))
// Return if there is no configuration or no image classification models are configured.
if Config == nil {
return result, errors.New("missing configuration")
return result, errors.New("vision service is not configured")
} else if model := Config.Model(ModelTypeNsfw); model != nil {
// Use remote service API if a server endpoint has been configured.
if uri, method := model.Endpoint(); uri != "" && method != "" {

BIN
internal/api/testdata/cat_224x224.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

BIN
internal/api/testdata/face_160x160.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.8 KiB

BIN
internal/api/testdata/face_320x320.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

BIN
internal/api/testdata/green_224x224.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

BIN
internal/api/testdata/london_160x160.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.2 KiB

BIN
internal/api/testdata/nsfw_224x224.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.9 KiB

View File

@@ -51,15 +51,18 @@ func PostVisionCaption(router *gin.RouterGroup) {
return
}
// TODO: Return error code 501 until this service is implemented.
code := http.StatusNotImplemented
// Generate Vision API service response.
response := vision.ApiResponse{
Id: request.GetId(),
Code: http.StatusNotImplemented,
Code: code,
Error: http.StatusText(http.StatusNotImplemented),
Model: &vision.Model{Name: "Caption"},
Result: &vision.ApiResult{Caption: &vision.CaptionResult{Text: "This is a test.", Confidence: 0.14159265359}},
Model: &vision.Model{Type: vision.ModelTypeCaption},
Result: vision.ApiResult{Caption: &vision.CaptionResult{Text: "This is a test.", Confidence: 0.14159265359}},
}
c.JSON(http.StatusNotImplemented, response)
c.JSON(code, response)
})
}

View File

@@ -54,6 +54,13 @@ func PostVisionFaceEmbeddings(router *gin.RouterGroup) {
return
}
// Return if no thumbnail filenames were given.
if len(request.Images) == 0 {
log.Errorf("vision: at least one image required (run face embeddings)")
c.JSON(http.StatusBadRequest, vision.NewApiError(request.GetId(), http.StatusBadRequest))
return
}
// Run inference to find matching labels.
results := make([]face.Embeddings, len(request.Images))
@@ -63,7 +70,7 @@ func PostVisionFaceEmbeddings(router *gin.RouterGroup) {
log.Errorf("vision: %s (read face embedding from url)", err)
} else if result, faceErr := vision.FaceEmbeddings(data); faceErr != nil {
results[i] = face.Embeddings{}
log.Errorf("vision: %s (generate face embedding)", faceErr)
log.Errorf("vision: %s (run face embeddings)", faceErr)
} else {
results[i] = result
}
@@ -73,10 +80,10 @@ func PostVisionFaceEmbeddings(router *gin.RouterGroup) {
response := vision.ApiResponse{
Id: request.GetId(),
Code: http.StatusOK,
Model: &vision.Model{Name: vision.FacenetModel.Name},
Result: &vision.ApiResult{Embeddings: results},
Model: &vision.Model{Type: vision.ModelTypeFaceEmbeddings},
Result: vision.ApiResult{Embeddings: results},
}
c.JSON(http.StatusNotImplemented, response)
c.JSON(http.StatusOK, response)
})
}

View File

@@ -0,0 +1,185 @@
package api
import (
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
func TestPostVisionFaceEmbeddings(t *testing.T) {
t.Run("Face", func(t *testing.T) {
app, router, _ := NewApiTest()
PostVisionFaceEmbeddings(router)
files := vision.Files{
fs.Abs("./testdata/face_160x160.jpg"),
}
req, err := vision.NewClientRequest(files, scheme.Data)
if err != nil {
t.Fatal(err)
}
jsonReq, jsonErr := req.MarshalJSON()
if jsonErr != nil {
t.Fatal(err)
}
// t.Logf("request: %s", string(jsonReq))
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/face/embeddings", string(jsonReq))
apiResponse := &vision.ApiResponse{}
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
t.Fatal(apiErr)
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
t.Fatal(apiErr)
}
// t.Logf("response: %#v", apiResponse)
assert.Len(t, apiResponse.Result.Embeddings, 1)
if len(apiResponse.Result.Embeddings) != 1 {
t.Fatal("one nsfw result expected")
}
assert.Equal(t, vision.ModelTypeFaceEmbeddings, apiResponse.Model.Type)
assert.Equal(t, http.StatusOK, r.Code)
})
t.Run("London", func(t *testing.T) {
app, router, _ := NewApiTest()
PostVisionFaceEmbeddings(router)
files := vision.Files{
fs.Abs("./testdata/london_160x160.jpg"),
}
req, err := vision.NewClientRequest(files, scheme.Data)
if err != nil {
t.Fatal(err)
}
jsonReq, jsonErr := req.MarshalJSON()
if jsonErr != nil {
t.Fatal(err)
}
// t.Logf("request: %s", string(jsonReq))
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/face/embeddings", string(jsonReq))
apiResponse := &vision.ApiResponse{}
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
t.Fatal(apiErr)
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
t.Fatal(apiErr)
}
assert.Len(t, apiResponse.Result.Embeddings, 1)
assert.Equal(t, vision.ModelTypeFaceEmbeddings, apiResponse.Model.Type)
assert.Equal(t, http.StatusOK, r.Code)
})
t.Run("WrongResolution", func(t *testing.T) {
app, router, _ := NewApiTest()
PostVisionFaceEmbeddings(router)
files := vision.Files{
fs.Abs("./testdata/face_320x320.jpg"),
}
req, err := vision.NewClientRequest(files, scheme.Data)
if err != nil {
t.Fatal(err)
}
jsonReq, jsonErr := req.MarshalJSON()
if jsonErr != nil {
t.Fatal(err)
}
// t.Logf("request: %s", string(jsonReq))
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/face/embeddings", string(jsonReq))
apiResponse := &vision.ApiResponse{}
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
t.Fatal(apiErr)
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
t.Fatal(apiErr)
}
// t.Logf("response: %#v", apiResponse)
assert.Len(t, apiResponse.Result.Embeddings, 1)
if len(apiResponse.Result.Embeddings) != 1 {
t.Fatal("one nsfw result expected")
}
assert.Equal(t, vision.ModelTypeFaceEmbeddings, apiResponse.Model.Type)
assert.Equal(t, http.StatusOK, r.Code)
})
t.Run("NoImages", func(t *testing.T) {
app, router, _ := NewApiTest()
PostVisionFaceEmbeddings(router)
files := vision.Files{}
req, err := vision.NewClientRequest(files, scheme.Data)
if err != nil {
t.Fatal(err)
}
jsonReq, jsonErr := req.MarshalJSON()
if jsonErr != nil {
t.Fatal(err)
}
// t.Logf("request: %s", string(jsonReq))
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/face/embeddings", string(jsonReq))
apiResponse := &vision.ApiResponse{}
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
t.Fatal(apiErr)
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
t.Fatal(apiErr)
}
if apiResponse == nil {
t.Fatal("api response expected")
}
// t.Logf("error: %s", apiResponse.Err())
assert.Error(t, apiResponse.Err())
assert.False(t, apiResponse.HasResult())
assert.Equal(t, http.StatusBadRequest, r.Code)
})
t.Run("NoBody", func(t *testing.T) {
app, router, _ := NewApiTest()
PostVisionFaceEmbeddings(router)
r := PerformRequest(app, http.MethodPost, "/api/v1/vision/face/embeddings")
assert.Equal(t, http.StatusBadRequest, r.Code)
})
}

View File

@@ -54,9 +54,10 @@ func PostVisionLabels(router *gin.RouterGroup) {
}
// Run inference to find matching labels.
labels, err := vision.Labels(request.Images, media.SrcLocal)
labels, err := vision.Labels(request.Images, media.SrcRemote)
if err != nil {
log.Errorf("vision: %s (run labels)", err)
c.JSON(http.StatusBadRequest, vision.NewApiError(request.GetId(), http.StatusBadRequest))
return
}

View File

@@ -0,0 +1,136 @@
package api
import (
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
func TestPostVisionLabels(t *testing.T) {
t.Run("OneImage", func(t *testing.T) {
app, router, _ := NewApiTest()
PostVisionLabels(router)
files := vision.Files{
fs.Abs("./testdata/cat_224x224.jpg"),
}
req, err := vision.NewClientRequest(files, scheme.Data)
if err != nil {
t.Fatal(err)
}
jsonReq, jsonErr := req.MarshalJSON()
if jsonErr != nil {
t.Fatal(err)
}
// t.Logf("request: %s", string(jsonReq))
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/labels", string(jsonReq))
apiResponse := &vision.ApiResponse{}
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
t.Fatal(apiErr)
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
t.Fatal(apiErr)
}
assert.Len(t, apiResponse.Result.Labels, 1)
assert.Equal(t, vision.ModelTypeLabels, apiResponse.Model.Type)
assert.Equal(t, http.StatusOK, r.Code)
})
t.Run("TwoImages", func(t *testing.T) {
app, router, _ := NewApiTest()
PostVisionLabels(router)
files := vision.Files{
fs.Abs("./testdata/cat_224x224.jpg"),
fs.Abs("./testdata/green_224x224.jpg"),
}
req, err := vision.NewClientRequest(files, scheme.Data)
if err != nil {
t.Fatal(err)
}
jsonReq, jsonErr := req.MarshalJSON()
if jsonErr != nil {
t.Fatal(err)
}
// t.Logf("request: %s", string(jsonReq))
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/labels", string(jsonReq))
apiResponse := &vision.ApiResponse{}
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
t.Fatal(apiErr)
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
t.Fatal(apiErr)
}
assert.Len(t, apiResponse.Result.Labels, 2)
assert.Equal(t, vision.ModelTypeLabels, apiResponse.Model.Type)
assert.Equal(t, http.StatusOK, r.Code)
})
t.Run("NoImages", func(t *testing.T) {
app, router, _ := NewApiTest()
PostVisionLabels(router)
files := vision.Files{}
req, err := vision.NewClientRequest(files, scheme.Data)
if err != nil {
t.Fatal(err)
}
jsonReq, jsonErr := req.MarshalJSON()
if jsonErr != nil {
t.Fatal(err)
}
t.Logf("request: %s", string(jsonReq))
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/labels", string(jsonReq))
apiResponse := &vision.ApiResponse{}
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
t.Fatal(apiErr)
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
t.Fatal(apiErr)
}
if apiResponse == nil {
t.Fatal("api response expected")
}
t.Logf("error: %s", apiResponse.Err())
assert.Error(t, apiResponse.Err())
assert.False(t, apiResponse.HasResult())
assert.Equal(t, http.StatusBadRequest, r.Code)
})
t.Run("NoBody", func(t *testing.T) {
app, router, _ := NewApiTest()
PostVisionLabels(router)
r := PerformRequest(app, http.MethodPost, "/api/v1/vision/labels")
assert.Equal(t, http.StatusBadRequest, r.Code)
})
}

View File

@@ -57,6 +57,7 @@ func PostVisionNsfw(router *gin.RouterGroup) {
results, err := vision.Nsfw(request.Images, media.SrcRemote)
if err != nil {
log.Errorf("vision: %s (run nsfw)", err)
c.JSON(http.StatusBadRequest, vision.NewApiError(request.GetId(), http.StatusBadRequest))
return
}
@@ -65,8 +66,8 @@ func PostVisionNsfw(router *gin.RouterGroup) {
response := vision.ApiResponse{
Id: request.GetId(),
Code: http.StatusOK,
Model: &vision.Model{Name: vision.NsfwModel.Name},
Result: &vision.ApiResult{Nsfw: results},
Model: &vision.Model{Type: vision.ModelTypeNsfw},
Result: vision.ApiResult{Nsfw: results},
}
c.JSON(http.StatusOK, response)

View File

@@ -0,0 +1,151 @@
package api
import (
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
func TestPostVisionNsfw(t *testing.T) {
t.Run("OneImage", func(t *testing.T) {
app, router, _ := NewApiTest()
PostVisionNsfw(router)
files := vision.Files{
fs.Abs("./testdata/nsfw_224x224.jpg"),
}
req, err := vision.NewClientRequest(files, scheme.Data)
if err != nil {
t.Fatal(err)
}
jsonReq, jsonErr := req.MarshalJSON()
if jsonErr != nil {
t.Fatal(err)
}
// t.Logf("request: %s", string(jsonReq))
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/nsfw", string(jsonReq))
apiResponse := &vision.ApiResponse{}
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
t.Fatal(apiErr)
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
t.Fatal(apiErr)
}
// t.Logf("response: %#v", apiResponse)
assert.Len(t, apiResponse.Result.Nsfw, 1)
if len(apiResponse.Result.Nsfw) != 1 {
t.Fatal("one nsfw result expected")
} else if nsfw := apiResponse.Result.Nsfw[0]; !nsfw.IsNsfw(0.6) {
t.Fatalf("image should not be safe for work: %#v", nsfw)
} else {
// Drawing:7.547473e-05, Hentai:0.19912475, Neutral:0.00097554235, Porn:0.67095983, Sexy:0.12886441
assert.InDelta(t, nsfw.Drawing, 0.01, 0.2)
assert.InDelta(t, nsfw.Hentai, 0.2, 0.2)
assert.InDelta(t, nsfw.Porn, 0.7, 0.2)
assert.InDelta(t, nsfw.Sexy, 0.1, 0.2)
}
assert.Equal(t, vision.ModelTypeNsfw, apiResponse.Model.Type)
assert.Equal(t, http.StatusOK, r.Code)
})
t.Run("TwoImages", func(t *testing.T) {
app, router, _ := NewApiTest()
PostVisionNsfw(router)
files := vision.Files{
fs.Abs("./testdata/cat_224x224.jpg"),
fs.Abs("./testdata/green_224x224.jpg"),
}
req, err := vision.NewClientRequest(files, scheme.Data)
if err != nil {
t.Fatal(err)
}
jsonReq, jsonErr := req.MarshalJSON()
if jsonErr != nil {
t.Fatal(err)
}
// t.Logf("request: %s", string(jsonReq))
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/nsfw", string(jsonReq))
apiResponse := &vision.ApiResponse{}
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
t.Fatal(apiErr)
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
t.Fatal(apiErr)
}
assert.Len(t, apiResponse.Result.Nsfw, 2)
assert.Equal(t, vision.ModelTypeNsfw, apiResponse.Model.Type)
assert.Equal(t, http.StatusOK, r.Code)
})
t.Run("NoImages", func(t *testing.T) {
app, router, _ := NewApiTest()
PostVisionNsfw(router)
files := vision.Files{}
req, err := vision.NewClientRequest(files, scheme.Data)
if err != nil {
t.Fatal(err)
}
jsonReq, jsonErr := req.MarshalJSON()
if jsonErr != nil {
t.Fatal(err)
}
// t.Logf("request: %s", string(jsonReq))
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/nsfw", string(jsonReq))
apiResponse := &vision.ApiResponse{}
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
t.Fatal(apiErr)
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
t.Fatal(apiErr)
}
if apiResponse == nil {
t.Fatal("api response expected")
}
// t.Logf("error: %s", apiResponse.Err())
assert.Error(t, apiResponse.Err())
assert.False(t, apiResponse.HasResult())
assert.Equal(t, http.StatusBadRequest, r.Code)
})
t.Run("NoBody", func(t *testing.T) {
app, router, _ := NewApiTest()
PostVisionNsfw(router)
r := PerformRequest(app, http.MethodPost, "/api/v1/vision/nsfw")
assert.Equal(t, http.StatusBadRequest, r.Code)
})
}

View File

@@ -102,7 +102,6 @@ func NewTestOptions(pkg string) *Options {
Trace: false,
Experimental: true,
ReadOnly: false,
DetectNSFW: true,
UploadNSFW: false,
ExifBruteForce: false,
AssetsPath: assetsPath,
@@ -123,6 +122,8 @@ func NewTestOptions(pkg string) *Options {
AdminPassword: "photoprism",
OriginalsLimit: 66,
ResolutionLimit: 33,
VisionApi: true,
DetectNSFW: true,
}
return c

View File

@@ -23,7 +23,7 @@ func (ind *Index) IsNsfw(m *MediaFile) bool {
} else if len(results) < 1 {
log.Errorf("index: nsfw model returned no result for %s", m.RootRelName())
return false
} else if results[0].NSFW(nsfw.ThresholdHigh) {
} else if results[0].IsNsfw(nsfw.ThresholdHigh) {
log.Warnf("index: %s might contain offensive content", clean.Log(m.RelName(Config().OriginalsPath())))
return true
}