mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 08:44:04 +01:00
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
@@ -214,7 +214,10 @@ func (m *Model) createTensor(image []byte) (*tf.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Resize the image only if its resolution does not match the model.
|
||||
if img.Bounds().Dx() != m.resolution || img.Bounds().Dy() != m.resolution {
|
||||
img = imaging.Fill(img, m.resolution, m.resolution, imaging.Center, imaging.Lanczos)
|
||||
}
|
||||
|
||||
return tensorflow.Image(img, m.resolution)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
|
||||
"github.com/disintegration/imaging"
|
||||
tf "github.com/wamuir/graft/tensorflow"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/thumb/crop"
|
||||
@@ -116,6 +117,7 @@ func (m *Model) loadModel() error {
|
||||
|
||||
// Run returns the face embeddings for an image.
|
||||
func (m *Model) Run(img image.Image) Embeddings {
|
||||
// Create input tensor from image.
|
||||
tensor, err := imageToTensor(img, m.resolution)
|
||||
|
||||
if err != nil {
|
||||
@@ -160,6 +162,11 @@ func imageToTensor(img image.Image, resolution int) (tfTensor *tf.Tensor, err er
|
||||
return tfTensor, fmt.Errorf("faces: invalid model resolution")
|
||||
}
|
||||
|
||||
// Resize the image only if its resolution does not match the model.
|
||||
if img.Bounds().Dx() != resolution || img.Bounds().Dy() != resolution {
|
||||
img = imaging.Fill(img, resolution, resolution, imaging.Center, imaging.Lanczos)
|
||||
}
|
||||
|
||||
var tfImage [1][][][3]float32
|
||||
|
||||
for j := 0; j < resolution; j++ {
|
||||
|
||||
@@ -46,11 +46,11 @@ type Result struct {
|
||||
|
||||
// IsSafe returns true if the image is probably safe for work.
|
||||
func (l *Result) IsSafe() bool {
|
||||
return !l.NSFW(ThresholdSafe)
|
||||
return !l.IsNsfw(ThresholdSafe)
|
||||
}
|
||||
|
||||
// NSFW returns true if the image is may not be safe for work.
|
||||
func (l *Result) NSFW(threshold float32) bool {
|
||||
// IsNsfw returns true if the image is may not be safe for work.
|
||||
func (l *Result) IsNsfw(threshold float32) bool {
|
||||
if l.Neutral > 0.25 {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -97,28 +97,28 @@ func TestIsSafe(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNSFW(t *testing.T) {
|
||||
func TestIsNsfw(t *testing.T) {
|
||||
porn := Result{0, 0, 0.11, 0.88, 0}
|
||||
sexy := Result{0, 0, 0.2, 0.59, 0.98}
|
||||
maxi := Result{0, 0.999, 0.1, 0.999, 0.999}
|
||||
drawing := Result{0.999, 0, 0, 0, 0}
|
||||
hentai := Result{0, 0.80, 0.2, 0, 0}
|
||||
|
||||
assert.Equal(t, true, porn.NSFW(ThresholdSafe))
|
||||
assert.Equal(t, true, sexy.NSFW(ThresholdSafe))
|
||||
assert.Equal(t, true, hentai.NSFW(ThresholdSafe))
|
||||
assert.Equal(t, false, drawing.NSFW(ThresholdSafe))
|
||||
assert.Equal(t, true, maxi.NSFW(ThresholdSafe))
|
||||
assert.Equal(t, true, porn.IsNsfw(ThresholdSafe))
|
||||
assert.Equal(t, true, sexy.IsNsfw(ThresholdSafe))
|
||||
assert.Equal(t, true, hentai.IsNsfw(ThresholdSafe))
|
||||
assert.Equal(t, false, drawing.IsNsfw(ThresholdSafe))
|
||||
assert.Equal(t, true, maxi.IsNsfw(ThresholdSafe))
|
||||
|
||||
assert.Equal(t, true, porn.NSFW(ThresholdMedium))
|
||||
assert.Equal(t, true, sexy.NSFW(ThresholdMedium))
|
||||
assert.Equal(t, false, hentai.NSFW(ThresholdMedium))
|
||||
assert.Equal(t, false, drawing.NSFW(ThresholdMedium))
|
||||
assert.Equal(t, true, maxi.NSFW(ThresholdMedium))
|
||||
assert.Equal(t, true, porn.IsNsfw(ThresholdMedium))
|
||||
assert.Equal(t, true, sexy.IsNsfw(ThresholdMedium))
|
||||
assert.Equal(t, false, hentai.IsNsfw(ThresholdMedium))
|
||||
assert.Equal(t, false, drawing.IsNsfw(ThresholdMedium))
|
||||
assert.Equal(t, true, maxi.IsNsfw(ThresholdMedium))
|
||||
|
||||
assert.Equal(t, false, porn.NSFW(ThresholdHigh))
|
||||
assert.Equal(t, false, sexy.NSFW(ThresholdHigh))
|
||||
assert.Equal(t, false, hentai.NSFW(ThresholdHigh))
|
||||
assert.Equal(t, false, drawing.NSFW(ThresholdHigh))
|
||||
assert.Equal(t, true, maxi.NSFW(ThresholdHigh))
|
||||
assert.Equal(t, false, porn.IsNsfw(ThresholdHigh))
|
||||
assert.Equal(t, false, sexy.IsNsfw(ThresholdHigh))
|
||||
assert.Equal(t, false, hentai.IsNsfw(ThresholdHigh))
|
||||
assert.Equal(t, false, drawing.IsNsfw(ThresholdHigh))
|
||||
assert.Equal(t, true, maxi.IsNsfw(ThresholdHigh))
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
|
||||
@@ -16,7 +18,35 @@ type ApiResponse struct {
|
||||
Code int `yaml:"Code,omitempty" json:"code,omitempty"`
|
||||
Error string `yaml:"Error,omitempty" json:"error,omitempty"`
|
||||
Model *Model `yaml:"Model,omitempty" json:"model,omitempty"`
|
||||
Result *ApiResult `yaml:"Result,omitempty" json:"result,omitempty"`
|
||||
Result ApiResult `yaml:"Result,omitempty" json:"result,omitempty"`
|
||||
}
|
||||
|
||||
// Err returns an error if the request has failed.
|
||||
func (r *ApiResponse) Err() error {
|
||||
if r == nil {
|
||||
return errors.New("response is nil")
|
||||
}
|
||||
|
||||
if r.Code >= 400 {
|
||||
if r.Error != "" {
|
||||
return errors.New(r.Error)
|
||||
}
|
||||
|
||||
return fmt.Errorf("error %d", r.Code)
|
||||
} else if r.Result.IsEmpty() {
|
||||
return errors.New("no result")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasResult checks if there is at least one result in the response data.
|
||||
func (r *ApiResponse) HasResult() bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return !r.Result.IsEmpty()
|
||||
}
|
||||
|
||||
// ApiResult represents the model response(s) to a Vision API service
|
||||
@@ -28,6 +58,15 @@ type ApiResult struct {
|
||||
Caption *CaptionResult `yaml:"Caption,omitempty" json:"caption,omitempty"`
|
||||
}
|
||||
|
||||
// IsEmpty checks if there is no result in the response data.
|
||||
func (r *ApiResult) IsEmpty() bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return len(r.Labels) == 0 && len(r.Nsfw) == 0 && len(r.Embeddings) == 0 && r.Caption == nil
|
||||
}
|
||||
|
||||
// CaptionResult represents the result generated by a caption generation model.
|
||||
type CaptionResult struct {
|
||||
Text string `yaml:"Text,omitempty" json:"text,omitempty"`
|
||||
@@ -85,7 +124,7 @@ func NewLabelsResponse(id string, model *Model, results classify.Labels) ApiResp
|
||||
return ApiResponse{
|
||||
Id: clean.Type(id),
|
||||
Code: http.StatusOK,
|
||||
Model: model,
|
||||
Result: &ApiResult{Labels: labels},
|
||||
Model: &Model{Type: ModelTypeLabels, Name: model.Name, Version: model.Version, Resolution: model.Resolution},
|
||||
Result: ApiResult{Labels: labels},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ func FaceEmbeddings(imgData []byte) (embeddings face.Embeddings, err error) {
|
||||
}
|
||||
|
||||
if Config == nil {
|
||||
return embeddings, errors.New("missing configuration")
|
||||
return embeddings, errors.New("vision service is not configured")
|
||||
} else if model := Config.Model(ModelTypeFaceEmbeddings); model != nil {
|
||||
img, imgErr := jpeg.Decode(bytes.NewReader(imgData))
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ func Faces(fileName string, minSize int, cacheCrop bool, expected int) (faces fa
|
||||
|
||||
// Return if there is no configuration or no image classification models are configured.
|
||||
if Config == nil {
|
||||
return faces, errors.New("missing configuration")
|
||||
return faces, errors.New("vision service is not configured")
|
||||
} else if model := Config.Model(ModelTypeFaceEmbeddings); model != nil {
|
||||
faces, err = face.Detect(fileName, false, minSize)
|
||||
|
||||
|
||||
@@ -20,12 +20,12 @@ import (
|
||||
func Labels(images Files, src media.Src) (result classify.Labels, err error) {
|
||||
// Return if no thumbnail filenames were given.
|
||||
if len(images) == 0 {
|
||||
return result, errors.New("missing image filenames")
|
||||
return result, errors.New("at least one image required")
|
||||
}
|
||||
|
||||
// Return if there is no configuration or no image classification models are configured.
|
||||
if Config == nil {
|
||||
return result, errors.New("missing configuration")
|
||||
return result, errors.New("vision service is not configured")
|
||||
} else if model := Config.Model(ModelTypeLabels); model != nil {
|
||||
// Use remote service API if a server endpoint has been configured.
|
||||
if uri, method := model.Endpoint(); uri != "" && method != "" {
|
||||
|
||||
@@ -26,7 +26,20 @@ func TestLabels(t *testing.T) {
|
||||
assert.Equal(t, "chameleon", result[0].Name)
|
||||
assert.Equal(t, 7, result[0].Uncertainty)
|
||||
})
|
||||
t.Run("Cats", func(t *testing.T) {
|
||||
t.Run("Cat224", func(t *testing.T) {
|
||||
result, err := Labels(Files{examplesPath + "/cat_224.jpeg"}, media.SrcLocal)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.IsType(t, classify.Labels{}, result)
|
||||
assert.Equal(t, 1, len(result))
|
||||
|
||||
t.Log(result)
|
||||
|
||||
assert.Equal(t, "cat", result[0].Name)
|
||||
assert.InDelta(t, 59, result[0].Uncertainty, 10)
|
||||
assert.InDelta(t, float32(0.41), result[0].Confidence(), 0.1)
|
||||
})
|
||||
t.Run("Cat720", func(t *testing.T) {
|
||||
result, err := Labels(Files{examplesPath + "/cat_720.jpeg"}, media.SrcLocal)
|
||||
|
||||
assert.NoError(t, err)
|
||||
@@ -36,8 +49,8 @@ func TestLabels(t *testing.T) {
|
||||
t.Log(result)
|
||||
|
||||
assert.Equal(t, "cat", result[0].Name)
|
||||
assert.Equal(t, 60, result[0].Uncertainty)
|
||||
assert.InDelta(t, float32(0.4), result[0].Confidence(), 0.01)
|
||||
assert.InDelta(t, 60, result[0].Uncertainty, 10)
|
||||
assert.InDelta(t, float32(0.4), result[0].Confidence(), 0.1)
|
||||
})
|
||||
t.Run("InvalidFile", func(t *testing.T) {
|
||||
_, err := Labels(Files{examplesPath + "/notexisting.jpg"}, media.SrcLocal)
|
||||
|
||||
@@ -19,14 +19,14 @@ import (
|
||||
func Nsfw(images Files, src media.Src) (result []nsfw.Result, err error) {
|
||||
// Return if no thumbnail filenames were given.
|
||||
if len(images) == 0 {
|
||||
return result, errors.New("missing image filenames")
|
||||
return result, errors.New("at least one image required")
|
||||
}
|
||||
|
||||
result = make([]nsfw.Result, len(images))
|
||||
|
||||
// Return if there is no configuration or no image classification models are configured.
|
||||
if Config == nil {
|
||||
return result, errors.New("missing configuration")
|
||||
return result, errors.New("vision service is not configured")
|
||||
} else if model := Config.Model(ModelTypeNsfw); model != nil {
|
||||
// Use remote service API if a server endpoint has been configured.
|
||||
if uri, method := model.Endpoint(); uri != "" && method != "" {
|
||||
|
||||
BIN
internal/api/testdata/cat_224x224.jpg
vendored
Normal file
BIN
internal/api/testdata/cat_224x224.jpg
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
BIN
internal/api/testdata/face_160x160.jpg
vendored
Normal file
BIN
internal/api/testdata/face_160x160.jpg
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 9.8 KiB |
BIN
internal/api/testdata/face_320x320.jpg
vendored
Normal file
BIN
internal/api/testdata/face_320x320.jpg
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 30 KiB |
BIN
internal/api/testdata/green_224x224.jpg
vendored
Normal file
BIN
internal/api/testdata/green_224x224.jpg
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 22 KiB |
BIN
internal/api/testdata/london_160x160.jpg
vendored
Normal file
BIN
internal/api/testdata/london_160x160.jpg
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 9.2 KiB |
BIN
internal/api/testdata/nsfw_224x224.jpg
vendored
Normal file
BIN
internal/api/testdata/nsfw_224x224.jpg
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 9.9 KiB |
@@ -51,15 +51,18 @@ func PostVisionCaption(router *gin.RouterGroup) {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Return error code 501 until this service is implemented.
|
||||
code := http.StatusNotImplemented
|
||||
|
||||
// Generate Vision API service response.
|
||||
response := vision.ApiResponse{
|
||||
Id: request.GetId(),
|
||||
Code: http.StatusNotImplemented,
|
||||
Code: code,
|
||||
Error: http.StatusText(http.StatusNotImplemented),
|
||||
Model: &vision.Model{Name: "Caption"},
|
||||
Result: &vision.ApiResult{Caption: &vision.CaptionResult{Text: "This is a test.", Confidence: 0.14159265359}},
|
||||
Model: &vision.Model{Type: vision.ModelTypeCaption},
|
||||
Result: vision.ApiResult{Caption: &vision.CaptionResult{Text: "This is a test.", Confidence: 0.14159265359}},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNotImplemented, response)
|
||||
c.JSON(code, response)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -54,6 +54,13 @@ func PostVisionFaceEmbeddings(router *gin.RouterGroup) {
|
||||
return
|
||||
}
|
||||
|
||||
// Return if no thumbnail filenames were given.
|
||||
if len(request.Images) == 0 {
|
||||
log.Errorf("vision: at least one image required (run face embeddings)")
|
||||
c.JSON(http.StatusBadRequest, vision.NewApiError(request.GetId(), http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
// Run inference to find matching labels.
|
||||
results := make([]face.Embeddings, len(request.Images))
|
||||
|
||||
@@ -63,7 +70,7 @@ func PostVisionFaceEmbeddings(router *gin.RouterGroup) {
|
||||
log.Errorf("vision: %s (read face embedding from url)", err)
|
||||
} else if result, faceErr := vision.FaceEmbeddings(data); faceErr != nil {
|
||||
results[i] = face.Embeddings{}
|
||||
log.Errorf("vision: %s (generate face embedding)", faceErr)
|
||||
log.Errorf("vision: %s (run face embeddings)", faceErr)
|
||||
} else {
|
||||
results[i] = result
|
||||
}
|
||||
@@ -73,10 +80,10 @@ func PostVisionFaceEmbeddings(router *gin.RouterGroup) {
|
||||
response := vision.ApiResponse{
|
||||
Id: request.GetId(),
|
||||
Code: http.StatusOK,
|
||||
Model: &vision.Model{Name: vision.FacenetModel.Name},
|
||||
Result: &vision.ApiResult{Embeddings: results},
|
||||
Model: &vision.Model{Type: vision.ModelTypeFaceEmbeddings},
|
||||
Result: vision.ApiResult{Embeddings: results},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNotImplemented, response)
|
||||
c.JSON(http.StatusOK, response)
|
||||
})
|
||||
}
|
||||
|
||||
185
internal/api/vision_face_embeddings_test.go
Normal file
185
internal/api/vision_face_embeddings_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
func TestPostVisionFaceEmbeddings(t *testing.T) {
|
||||
t.Run("Face", func(t *testing.T) {
|
||||
app, router, _ := NewApiTest()
|
||||
PostVisionFaceEmbeddings(router)
|
||||
|
||||
files := vision.Files{
|
||||
fs.Abs("./testdata/face_160x160.jpg"),
|
||||
}
|
||||
|
||||
req, err := vision.NewClientRequest(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
jsonReq, jsonErr := req.MarshalJSON()
|
||||
|
||||
if jsonErr != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// t.Logf("request: %s", string(jsonReq))
|
||||
|
||||
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/face/embeddings", string(jsonReq))
|
||||
|
||||
apiResponse := &vision.ApiResponse{}
|
||||
|
||||
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
}
|
||||
|
||||
// t.Logf("response: %#v", apiResponse)
|
||||
|
||||
assert.Len(t, apiResponse.Result.Embeddings, 1)
|
||||
|
||||
if len(apiResponse.Result.Embeddings) != 1 {
|
||||
t.Fatal("one nsfw result expected")
|
||||
}
|
||||
|
||||
assert.Equal(t, vision.ModelTypeFaceEmbeddings, apiResponse.Model.Type)
|
||||
assert.Equal(t, http.StatusOK, r.Code)
|
||||
})
|
||||
t.Run("London", func(t *testing.T) {
|
||||
app, router, _ := NewApiTest()
|
||||
PostVisionFaceEmbeddings(router)
|
||||
|
||||
files := vision.Files{
|
||||
fs.Abs("./testdata/london_160x160.jpg"),
|
||||
}
|
||||
|
||||
req, err := vision.NewClientRequest(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
jsonReq, jsonErr := req.MarshalJSON()
|
||||
|
||||
if jsonErr != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// t.Logf("request: %s", string(jsonReq))
|
||||
|
||||
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/face/embeddings", string(jsonReq))
|
||||
|
||||
apiResponse := &vision.ApiResponse{}
|
||||
|
||||
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
}
|
||||
|
||||
assert.Len(t, apiResponse.Result.Embeddings, 1)
|
||||
assert.Equal(t, vision.ModelTypeFaceEmbeddings, apiResponse.Model.Type)
|
||||
assert.Equal(t, http.StatusOK, r.Code)
|
||||
})
|
||||
t.Run("WrongResolution", func(t *testing.T) {
|
||||
app, router, _ := NewApiTest()
|
||||
PostVisionFaceEmbeddings(router)
|
||||
|
||||
files := vision.Files{
|
||||
fs.Abs("./testdata/face_320x320.jpg"),
|
||||
}
|
||||
|
||||
req, err := vision.NewClientRequest(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
jsonReq, jsonErr := req.MarshalJSON()
|
||||
|
||||
if jsonErr != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// t.Logf("request: %s", string(jsonReq))
|
||||
|
||||
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/face/embeddings", string(jsonReq))
|
||||
|
||||
apiResponse := &vision.ApiResponse{}
|
||||
|
||||
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
}
|
||||
|
||||
// t.Logf("response: %#v", apiResponse)
|
||||
|
||||
assert.Len(t, apiResponse.Result.Embeddings, 1)
|
||||
|
||||
if len(apiResponse.Result.Embeddings) != 1 {
|
||||
t.Fatal("one nsfw result expected")
|
||||
}
|
||||
|
||||
assert.Equal(t, vision.ModelTypeFaceEmbeddings, apiResponse.Model.Type)
|
||||
assert.Equal(t, http.StatusOK, r.Code)
|
||||
})
|
||||
t.Run("NoImages", func(t *testing.T) {
|
||||
app, router, _ := NewApiTest()
|
||||
PostVisionFaceEmbeddings(router)
|
||||
|
||||
files := vision.Files{}
|
||||
|
||||
req, err := vision.NewClientRequest(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
jsonReq, jsonErr := req.MarshalJSON()
|
||||
|
||||
if jsonErr != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// t.Logf("request: %s", string(jsonReq))
|
||||
|
||||
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/face/embeddings", string(jsonReq))
|
||||
|
||||
apiResponse := &vision.ApiResponse{}
|
||||
|
||||
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
}
|
||||
|
||||
if apiResponse == nil {
|
||||
t.Fatal("api response expected")
|
||||
}
|
||||
|
||||
// t.Logf("error: %s", apiResponse.Err())
|
||||
|
||||
assert.Error(t, apiResponse.Err())
|
||||
assert.False(t, apiResponse.HasResult())
|
||||
assert.Equal(t, http.StatusBadRequest, r.Code)
|
||||
})
|
||||
t.Run("NoBody", func(t *testing.T) {
|
||||
app, router, _ := NewApiTest()
|
||||
PostVisionFaceEmbeddings(router)
|
||||
r := PerformRequest(app, http.MethodPost, "/api/v1/vision/face/embeddings")
|
||||
assert.Equal(t, http.StatusBadRequest, r.Code)
|
||||
})
|
||||
}
|
||||
@@ -54,9 +54,10 @@ func PostVisionLabels(router *gin.RouterGroup) {
|
||||
}
|
||||
|
||||
// Run inference to find matching labels.
|
||||
labels, err := vision.Labels(request.Images, media.SrcLocal)
|
||||
labels, err := vision.Labels(request.Images, media.SrcRemote)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("vision: %s (run labels)", err)
|
||||
c.JSON(http.StatusBadRequest, vision.NewApiError(request.GetId(), http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
136
internal/api/vision_labels_test.go
Normal file
136
internal/api/vision_labels_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
func TestPostVisionLabels(t *testing.T) {
|
||||
t.Run("OneImage", func(t *testing.T) {
|
||||
app, router, _ := NewApiTest()
|
||||
PostVisionLabels(router)
|
||||
|
||||
files := vision.Files{
|
||||
fs.Abs("./testdata/cat_224x224.jpg"),
|
||||
}
|
||||
|
||||
req, err := vision.NewClientRequest(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
jsonReq, jsonErr := req.MarshalJSON()
|
||||
|
||||
if jsonErr != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// t.Logf("request: %s", string(jsonReq))
|
||||
|
||||
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/labels", string(jsonReq))
|
||||
|
||||
apiResponse := &vision.ApiResponse{}
|
||||
|
||||
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
}
|
||||
|
||||
assert.Len(t, apiResponse.Result.Labels, 1)
|
||||
assert.Equal(t, vision.ModelTypeLabels, apiResponse.Model.Type)
|
||||
assert.Equal(t, http.StatusOK, r.Code)
|
||||
})
|
||||
t.Run("TwoImages", func(t *testing.T) {
|
||||
app, router, _ := NewApiTest()
|
||||
PostVisionLabels(router)
|
||||
|
||||
files := vision.Files{
|
||||
fs.Abs("./testdata/cat_224x224.jpg"),
|
||||
fs.Abs("./testdata/green_224x224.jpg"),
|
||||
}
|
||||
|
||||
req, err := vision.NewClientRequest(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
jsonReq, jsonErr := req.MarshalJSON()
|
||||
|
||||
if jsonErr != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// t.Logf("request: %s", string(jsonReq))
|
||||
|
||||
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/labels", string(jsonReq))
|
||||
|
||||
apiResponse := &vision.ApiResponse{}
|
||||
|
||||
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
}
|
||||
|
||||
assert.Len(t, apiResponse.Result.Labels, 2)
|
||||
assert.Equal(t, vision.ModelTypeLabels, apiResponse.Model.Type)
|
||||
assert.Equal(t, http.StatusOK, r.Code)
|
||||
})
|
||||
t.Run("NoImages", func(t *testing.T) {
|
||||
app, router, _ := NewApiTest()
|
||||
PostVisionLabels(router)
|
||||
|
||||
files := vision.Files{}
|
||||
|
||||
req, err := vision.NewClientRequest(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
jsonReq, jsonErr := req.MarshalJSON()
|
||||
|
||||
if jsonErr != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("request: %s", string(jsonReq))
|
||||
|
||||
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/labels", string(jsonReq))
|
||||
|
||||
apiResponse := &vision.ApiResponse{}
|
||||
|
||||
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
}
|
||||
|
||||
if apiResponse == nil {
|
||||
t.Fatal("api response expected")
|
||||
}
|
||||
|
||||
t.Logf("error: %s", apiResponse.Err())
|
||||
|
||||
assert.Error(t, apiResponse.Err())
|
||||
assert.False(t, apiResponse.HasResult())
|
||||
assert.Equal(t, http.StatusBadRequest, r.Code)
|
||||
})
|
||||
t.Run("NoBody", func(t *testing.T) {
|
||||
app, router, _ := NewApiTest()
|
||||
PostVisionLabels(router)
|
||||
r := PerformRequest(app, http.MethodPost, "/api/v1/vision/labels")
|
||||
assert.Equal(t, http.StatusBadRequest, r.Code)
|
||||
})
|
||||
}
|
||||
@@ -57,6 +57,7 @@ func PostVisionNsfw(router *gin.RouterGroup) {
|
||||
results, err := vision.Nsfw(request.Images, media.SrcRemote)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("vision: %s (run nsfw)", err)
|
||||
c.JSON(http.StatusBadRequest, vision.NewApiError(request.GetId(), http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
@@ -65,8 +66,8 @@ func PostVisionNsfw(router *gin.RouterGroup) {
|
||||
response := vision.ApiResponse{
|
||||
Id: request.GetId(),
|
||||
Code: http.StatusOK,
|
||||
Model: &vision.Model{Name: vision.NsfwModel.Name},
|
||||
Result: &vision.ApiResult{Nsfw: results},
|
||||
Model: &vision.Model{Type: vision.ModelTypeNsfw},
|
||||
Result: vision.ApiResult{Nsfw: results},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
|
||||
151
internal/api/vision_nsfw_test.go
Normal file
151
internal/api/vision_nsfw_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
func TestPostVisionNsfw(t *testing.T) {
|
||||
t.Run("OneImage", func(t *testing.T) {
|
||||
app, router, _ := NewApiTest()
|
||||
PostVisionNsfw(router)
|
||||
|
||||
files := vision.Files{
|
||||
fs.Abs("./testdata/nsfw_224x224.jpg"),
|
||||
}
|
||||
|
||||
req, err := vision.NewClientRequest(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
jsonReq, jsonErr := req.MarshalJSON()
|
||||
|
||||
if jsonErr != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// t.Logf("request: %s", string(jsonReq))
|
||||
|
||||
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/nsfw", string(jsonReq))
|
||||
|
||||
apiResponse := &vision.ApiResponse{}
|
||||
|
||||
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
}
|
||||
|
||||
// t.Logf("response: %#v", apiResponse)
|
||||
|
||||
assert.Len(t, apiResponse.Result.Nsfw, 1)
|
||||
|
||||
if len(apiResponse.Result.Nsfw) != 1 {
|
||||
t.Fatal("one nsfw result expected")
|
||||
} else if nsfw := apiResponse.Result.Nsfw[0]; !nsfw.IsNsfw(0.6) {
|
||||
t.Fatalf("image should not be safe for work: %#v", nsfw)
|
||||
} else {
|
||||
// Drawing:7.547473e-05, Hentai:0.19912475, Neutral:0.00097554235, Porn:0.67095983, Sexy:0.12886441
|
||||
assert.InDelta(t, nsfw.Drawing, 0.01, 0.2)
|
||||
assert.InDelta(t, nsfw.Hentai, 0.2, 0.2)
|
||||
assert.InDelta(t, nsfw.Porn, 0.7, 0.2)
|
||||
assert.InDelta(t, nsfw.Sexy, 0.1, 0.2)
|
||||
}
|
||||
|
||||
assert.Equal(t, vision.ModelTypeNsfw, apiResponse.Model.Type)
|
||||
assert.Equal(t, http.StatusOK, r.Code)
|
||||
})
|
||||
t.Run("TwoImages", func(t *testing.T) {
|
||||
app, router, _ := NewApiTest()
|
||||
PostVisionNsfw(router)
|
||||
|
||||
files := vision.Files{
|
||||
fs.Abs("./testdata/cat_224x224.jpg"),
|
||||
fs.Abs("./testdata/green_224x224.jpg"),
|
||||
}
|
||||
|
||||
req, err := vision.NewClientRequest(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
jsonReq, jsonErr := req.MarshalJSON()
|
||||
|
||||
if jsonErr != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// t.Logf("request: %s", string(jsonReq))
|
||||
|
||||
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/nsfw", string(jsonReq))
|
||||
|
||||
apiResponse := &vision.ApiResponse{}
|
||||
|
||||
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
}
|
||||
|
||||
assert.Len(t, apiResponse.Result.Nsfw, 2)
|
||||
assert.Equal(t, vision.ModelTypeNsfw, apiResponse.Model.Type)
|
||||
assert.Equal(t, http.StatusOK, r.Code)
|
||||
})
|
||||
t.Run("NoImages", func(t *testing.T) {
|
||||
app, router, _ := NewApiTest()
|
||||
PostVisionNsfw(router)
|
||||
|
||||
files := vision.Files{}
|
||||
|
||||
req, err := vision.NewClientRequest(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
jsonReq, jsonErr := req.MarshalJSON()
|
||||
|
||||
if jsonErr != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// t.Logf("request: %s", string(jsonReq))
|
||||
|
||||
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/vision/nsfw", string(jsonReq))
|
||||
|
||||
apiResponse := &vision.ApiResponse{}
|
||||
|
||||
if apiJson, apiErr := io.ReadAll(r.Body); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
|
||||
t.Fatal(apiErr)
|
||||
}
|
||||
|
||||
if apiResponse == nil {
|
||||
t.Fatal("api response expected")
|
||||
}
|
||||
|
||||
// t.Logf("error: %s", apiResponse.Err())
|
||||
|
||||
assert.Error(t, apiResponse.Err())
|
||||
assert.False(t, apiResponse.HasResult())
|
||||
assert.Equal(t, http.StatusBadRequest, r.Code)
|
||||
})
|
||||
t.Run("NoBody", func(t *testing.T) {
|
||||
app, router, _ := NewApiTest()
|
||||
PostVisionNsfw(router)
|
||||
r := PerformRequest(app, http.MethodPost, "/api/v1/vision/nsfw")
|
||||
assert.Equal(t, http.StatusBadRequest, r.Code)
|
||||
})
|
||||
}
|
||||
@@ -102,7 +102,6 @@ func NewTestOptions(pkg string) *Options {
|
||||
Trace: false,
|
||||
Experimental: true,
|
||||
ReadOnly: false,
|
||||
DetectNSFW: true,
|
||||
UploadNSFW: false,
|
||||
ExifBruteForce: false,
|
||||
AssetsPath: assetsPath,
|
||||
@@ -123,6 +122,8 @@ func NewTestOptions(pkg string) *Options {
|
||||
AdminPassword: "photoprism",
|
||||
OriginalsLimit: 66,
|
||||
ResolutionLimit: 33,
|
||||
VisionApi: true,
|
||||
DetectNSFW: true,
|
||||
}
|
||||
|
||||
return c
|
||||
|
||||
@@ -23,7 +23,7 @@ func (ind *Index) IsNsfw(m *MediaFile) bool {
|
||||
} else if len(results) < 1 {
|
||||
log.Errorf("index: nsfw model returned no result for %s", m.RootRelName())
|
||||
return false
|
||||
} else if results[0].NSFW(nsfw.ThresholdHigh) {
|
||||
} else if results[0].IsNsfw(nsfw.ThresholdHigh) {
|
||||
log.Warnf("index: %s might contain offensive content", clean.Log(m.RelName(Config().OriginalsPath())))
|
||||
return true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user