mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path"
|
||||
"runtime/debug"
|
||||
"sort"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/tensorflow"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
// Model represents a TensorFlow classification model.
|
||||
@@ -48,23 +50,38 @@ func (m *Model) Init() (err error) {
|
||||
return m.loadModel()
|
||||
}
|
||||
|
||||
// File returns matching labels for a jpeg media file.
|
||||
func (m *Model) File(imageUri string, confidenceThreshold int) (result Labels, err error) {
|
||||
// File returns matching labels for a local jpeg file.
|
||||
func (m *Model) File(fileName string, confidenceThreshold int) (result Labels, err error) {
|
||||
if m.disabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
|
||||
if data, err = media.ReadUrl(imageUri); err != nil {
|
||||
if data, err = os.ReadFile(fileName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.Labels(data, confidenceThreshold)
|
||||
return m.Run(data, confidenceThreshold)
|
||||
}
|
||||
|
||||
// Labels returns matching labels for a jpeg media string.
|
||||
func (m *Model) Labels(img []byte, confidenceThreshold int) (result Labels, err error) {
|
||||
// Url returns matching labels for a remote jpeg file.
|
||||
func (m *Model) Url(imgUrl string, confidenceThreshold int) (result Labels, err error) {
|
||||
if m.disabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
|
||||
if data, err = media.ReadUrl(imgUrl, scheme.HttpsData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.Run(data, confidenceThreshold)
|
||||
}
|
||||
|
||||
// Run returns matching labels for the specified JPEG image.
|
||||
func (m *Model) Run(img []byte, confidenceThreshold int) (result Labels, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("classify: %s (inference panic)\nstack: %s", r, debug.Stack())
|
||||
|
||||
@@ -116,7 +116,7 @@ func TestModel_LabelsFromFile(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestModel_Labels(t *testing.T) {
|
||||
func TestModel_Run(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
}
|
||||
@@ -127,7 +127,7 @@ func TestModel_Labels(t *testing.T) {
|
||||
if imageBuffer, err := os.ReadFile(examplesPath + "/chameleon_lime.jpg"); err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
result, err := tensorFlow.Labels(imageBuffer, 10)
|
||||
result, err := tensorFlow.Run(imageBuffer, 10)
|
||||
|
||||
t.Log(result)
|
||||
|
||||
@@ -151,7 +151,7 @@ func TestModel_Labels(t *testing.T) {
|
||||
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
result, err := tensorFlow.Labels(imageBuffer, 10)
|
||||
result, err := tensorFlow.Run(imageBuffer, 10)
|
||||
|
||||
t.Log(result)
|
||||
|
||||
@@ -175,7 +175,7 @@ func TestModel_Labels(t *testing.T) {
|
||||
if imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx"); err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
result, err := tensorFlow.Labels(imageBuffer, 10)
|
||||
result, err := tensorFlow.Run(imageBuffer, 10)
|
||||
assert.Empty(t, result)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
@@ -186,7 +186,7 @@ func TestModel_Labels(t *testing.T) {
|
||||
if imageBuffer, err := os.ReadFile(examplesPath + "/6720px_white.jpg"); err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
result, err := tensorFlow.Labels(imageBuffer, 10)
|
||||
result, err := tensorFlow.Run(imageBuffer, 10)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -201,7 +201,7 @@ func TestModel_Labels(t *testing.T) {
|
||||
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
result, err := tensorFlow.Labels(imageBuffer, 10)
|
||||
result, err := tensorFlow.Run(imageBuffer, 10)
|
||||
|
||||
t.Log(result)
|
||||
|
||||
|
||||
@@ -27,8 +27,14 @@ type Model struct {
|
||||
}
|
||||
|
||||
// NewModel returns a new TensorFlow Facenet instance.
|
||||
func NewModel(modelPath, cachePath string, disabled bool) *Model {
|
||||
return &Model{modelPath: modelPath, cachePath: cachePath, resolution: CropSize.Width, modelTags: []string{"serve"}, disabled: disabled}
|
||||
func NewModel(modelPath, cachePath string, resolution int, tags []string, disabled bool) *Model {
|
||||
if resolution == 0 {
|
||||
resolution = CropSize.Width
|
||||
}
|
||||
if len(tags) == 0 {
|
||||
tags = []string{"serve"}
|
||||
}
|
||||
return &Model{modelPath: modelPath, cachePath: cachePath, resolution: resolution, modelTags: tags, disabled: disabled}
|
||||
}
|
||||
|
||||
// Detect runs the detection and facenet algorithms over the provided source image.
|
||||
@@ -57,9 +63,9 @@ func (m *Model) Detect(fileName string, minSize int, cacheCrop bool, expected in
|
||||
continue
|
||||
}
|
||||
|
||||
if img, imgErr := crop.ImageFromThumb(fileName, f.CropArea(), CropSize, cacheCrop); imgErr != nil {
|
||||
if img, _, imgErr := crop.ImageFromThumb(fileName, f.CropArea(), CropSize, cacheCrop); imgErr != nil {
|
||||
log.Errorf("faces: failed to decode image: %s", imgErr)
|
||||
} else if embeddings := m.getEmbeddings(img); !embeddings.Empty() {
|
||||
} else if embeddings := m.Run(img); !embeddings.Empty() {
|
||||
faces[i].Embeddings = embeddings
|
||||
}
|
||||
}
|
||||
@@ -67,6 +73,15 @@ func (m *Model) Detect(fileName string, minSize int, cacheCrop bool, expected in
|
||||
return faces, nil
|
||||
}
|
||||
|
||||
// Init initialises tensorflow models if not disabled
|
||||
func (m *Model) Init() (err error) {
|
||||
if m.disabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.loadModel()
|
||||
}
|
||||
|
||||
// ModelLoaded tests if the TensorFlow model is loaded.
|
||||
func (m *Model) ModelLoaded() bool {
|
||||
return m.model != nil
|
||||
@@ -99,8 +114,8 @@ func (m *Model) loadModel() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getEmbeddings returns the face embeddings for an image.
|
||||
func (m *Model) getEmbeddings(img image.Image) Embeddings {
|
||||
// Run returns the face embeddings for an image.
|
||||
func (m *Model) Run(img image.Image) Embeddings {
|
||||
tensor, err := imageToTensor(img, m.resolution)
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -54,7 +54,7 @@ func TestNet(t *testing.T) {
|
||||
|
||||
var embeddings = make(Embeddings, 11)
|
||||
|
||||
faceNet := NewModel(modelPath, "testdata/cache", false)
|
||||
faceNet := NewModel(modelPath, "testdata/cache", 160, []string{"serve"}, false)
|
||||
|
||||
if err := fastwalk.Walk("testdata", func(fileName string, info os.FileMode) error {
|
||||
if info.IsDir() || filepath.Base(filepath.Dir(fileName)) != "testdata" {
|
||||
|
||||
@@ -11,7 +11,9 @@ import (
|
||||
"github.com/photoprism/photoprism/internal/ai/tensorflow"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/header"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
// Model uses TensorFlow to label drawing, hentai, neutral, porn and sexy images.
|
||||
@@ -21,31 +23,53 @@ type Model struct {
|
||||
resolution int
|
||||
modelTags []string
|
||||
labels []string
|
||||
disabled bool
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewModel returns a new detector instance.
|
||||
func NewModel(modelPath string) *Model {
|
||||
return &Model{modelPath: modelPath, resolution: 224, modelTags: []string{"serve"}}
|
||||
func NewModel(modelPath string, resolution int, tags []string, disabled bool) *Model {
|
||||
if resolution <= 0 {
|
||||
resolution = 224
|
||||
}
|
||||
if len(tags) == 0 {
|
||||
tags = []string{"serve"}
|
||||
}
|
||||
return &Model{modelPath: modelPath, resolution: resolution, modelTags: tags, disabled: disabled}
|
||||
}
|
||||
|
||||
// File returns matching labels for a jpeg media file.
|
||||
func (m *Model) File(filename string) (result Labels, err error) {
|
||||
if fs.MimeType(filename) != header.ContentTypeJpeg {
|
||||
return result, fmt.Errorf("nsfw: %s is not a jpeg file", clean.Log(filepath.Base(filename)))
|
||||
// File checks the specified JPEG file for inappropriate content.
|
||||
func (m *Model) File(fileName string) (result Result, err error) {
|
||||
if fs.MimeType(fileName) != header.ContentTypeJpeg {
|
||||
return result, fmt.Errorf("nsfw: %s is not a jpeg file", clean.Log(filepath.Base(fileName)))
|
||||
}
|
||||
|
||||
imageBuffer, err := os.ReadFile(filename)
|
||||
var img []byte
|
||||
|
||||
if err != nil {
|
||||
if img, err = os.ReadFile(fileName); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
return m.Labels(imageBuffer)
|
||||
return m.Run(img)
|
||||
}
|
||||
|
||||
// Labels returns matching labels for a jpeg media string.
|
||||
func (m *Model) Labels(img []byte) (result Labels, err error) {
|
||||
// Url checks the JPEG file from the specified https or data URL for inappropriate content.
|
||||
func (m *Model) Url(imgUrl string) (result Result, err error) {
|
||||
if m.disabled {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
var img []byte
|
||||
|
||||
if img, err = media.ReadUrl(imgUrl, scheme.HttpsData); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
return m.Run(img)
|
||||
}
|
||||
|
||||
// Run returns matching labels for a jpeg media string.
|
||||
func (m *Model) Run(img []byte) (result Result, err error) {
|
||||
if loadErr := m.loadModel(); loadErr != nil {
|
||||
return result, loadErr
|
||||
}
|
||||
@@ -83,11 +107,15 @@ func (m *Model) Labels(img []byte) (result Labels, err error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) loadLabels(modelPath string) (err error) {
|
||||
m.labels, err = tensorflow.LoadLabels(modelPath)
|
||||
// Init initialises tensorflow models if not disabled
|
||||
func (m *Model) Init() (err error) {
|
||||
if m.disabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.loadModel()
|
||||
}
|
||||
|
||||
func (m *Model) loadModel() error {
|
||||
// Use mutex to prevent the model from being loaded and
|
||||
// initialized twice by different indexing workers.
|
||||
@@ -113,8 +141,13 @@ func (m *Model) loadModel() error {
|
||||
return m.loadLabels(m.modelPath)
|
||||
}
|
||||
|
||||
func (m *Model) getLabels(p []float32) Labels {
|
||||
return Labels{
|
||||
func (m *Model) loadLabels(modelPath string) (err error) {
|
||||
m.labels, err = tensorflow.LoadLabels(modelPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) getLabels(p []float32) Result {
|
||||
return Result{
|
||||
Drawing: p[0],
|
||||
Hentai: p[1],
|
||||
Neutral: p[2],
|
||||
|
||||
@@ -36,7 +36,7 @@ const (
|
||||
|
||||
var log = event.Log
|
||||
|
||||
type Labels struct {
|
||||
type Result struct {
|
||||
Drawing float32
|
||||
Hentai float32
|
||||
Neutral float32
|
||||
@@ -45,12 +45,12 @@ type Labels struct {
|
||||
}
|
||||
|
||||
// IsSafe returns true if the image is probably safe for work.
|
||||
func (l *Labels) IsSafe() bool {
|
||||
func (l *Result) IsSafe() bool {
|
||||
return !l.NSFW(ThresholdSafe)
|
||||
}
|
||||
|
||||
// NSFW returns true if the image is may not be safe for work.
|
||||
func (l *Labels) NSFW(threshold float32) bool {
|
||||
func (l *Result) NSFW(threshold float32) bool {
|
||||
if l.Neutral > 0.25 {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -13,10 +13,10 @@ import (
|
||||
|
||||
var modelPath, _ = filepath.Abs("../../../assets/nsfw")
|
||||
|
||||
var detector = NewModel(modelPath)
|
||||
var detector = NewModel(modelPath, 224, nil, false)
|
||||
|
||||
func TestIsSafe(t *testing.T) {
|
||||
detect := func(filename string) Labels {
|
||||
detect := func(filename string) Result {
|
||||
result, err := detector.File(filename)
|
||||
|
||||
if err != nil {
|
||||
@@ -24,12 +24,12 @@ func TestIsSafe(t *testing.T) {
|
||||
}
|
||||
|
||||
assert.NotNil(t, result)
|
||||
assert.IsType(t, Labels{}, result)
|
||||
assert.IsType(t, Result{}, result)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
expected := map[string]Labels{
|
||||
expected := map[string]Result{
|
||||
"beach_sand.jpg": {0, 0, 0.9, 0, 0},
|
||||
"beach_wood.jpg": {0, 0, 0.36, 0.59, 0},
|
||||
"cat_brown.jpg": {0, 0, 0.93, 0, 0},
|
||||
@@ -98,11 +98,11 @@ func TestIsSafe(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNSFW(t *testing.T) {
|
||||
porn := Labels{0, 0, 0.11, 0.88, 0}
|
||||
sexy := Labels{0, 0, 0.2, 0.59, 0.98}
|
||||
maxi := Labels{0, 0.999, 0.1, 0.999, 0.999}
|
||||
drawing := Labels{0.999, 0, 0, 0, 0}
|
||||
hentai := Labels{0, 0.80, 0.2, 0, 0}
|
||||
porn := Result{0, 0, 0.11, 0.88, 0}
|
||||
sexy := Result{0, 0, 0.2, 0.59, 0.98}
|
||||
maxi := Result{0, 0.999, 0.1, 0.999, 0.999}
|
||||
drawing := Result{0.999, 0, 0, 0, 0}
|
||||
hentai := Result{0, 0.80, 0.2, 0, 0}
|
||||
|
||||
assert.Equal(t, true, porn.NSFW(ThresholdSafe))
|
||||
assert.Equal(t, true, sexy.NSFW(ThresholdSafe))
|
||||
|
||||
@@ -2,39 +2,59 @@ package vision
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/api/download"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
"github.com/photoprism/photoprism/pkg/rnd"
|
||||
)
|
||||
|
||||
const (
|
||||
LabelsEndpoint = "labels"
|
||||
)
|
||||
type Files = []string
|
||||
|
||||
// ApiRequest represents a Vision API service request.
|
||||
type ApiRequest struct {
|
||||
Id string `form:"id" yaml:"Id,omitempty" json:"id,omitempty"`
|
||||
Model string `form:"model" yaml:"Model,omitempty" json:"model,omitempty"`
|
||||
Images []string `form:"images" yaml:"Images,omitempty" json:"images,omitempty"`
|
||||
Videos []string `form:"videos" yaml:"Videos,omitempty" json:"videos,omitempty"`
|
||||
Images Files `form:"images" yaml:"Images,omitempty" json:"images,omitempty"`
|
||||
}
|
||||
|
||||
func NewClientRequest(model string, images []string) *ApiRequest {
|
||||
imageUrls := make([]string, 0, len(images))
|
||||
// NewClientRequest returns a new Vision API request with the specified file payload and scheme.
|
||||
func NewClientRequest(images Files, fileScheme string) (*ApiRequest, error) {
|
||||
imageUrls := make(Files, len(images))
|
||||
|
||||
if fileScheme == scheme.Https && !strings.HasPrefix(DownloadUrl, "https://") {
|
||||
log.Tracef("vision: file request scheme changed from https to data because https is not configured")
|
||||
fileScheme = scheme.Data
|
||||
}
|
||||
|
||||
for i := range images {
|
||||
switch fileScheme {
|
||||
case scheme.Https:
|
||||
if id, err := download.Register(images[i]); err != nil {
|
||||
log.Errorf("vision: %s (register download)", err)
|
||||
return nil, fmt.Errorf("%s (register download)", err)
|
||||
} else {
|
||||
imageUrls = append(imageUrls, path.Join(DownloadUrl, id))
|
||||
imageUrls[i] = fmt.Sprintf("%s/%s", DownloadUrl, id)
|
||||
}
|
||||
case scheme.Data:
|
||||
if file, err := os.Open(images[i]); err != nil {
|
||||
return nil, fmt.Errorf("%s (create data url)", err)
|
||||
} else {
|
||||
imageUrls[i] = media.DataUrl(file)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid file scheme %s", clean.Log(fileScheme))
|
||||
}
|
||||
}
|
||||
|
||||
return &ApiRequest{
|
||||
Id: rnd.UUID(),
|
||||
Model: "",
|
||||
Images: imageUrls,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetId returns the request ID string and generates a random ID if none was set.
|
||||
@@ -48,5 +68,5 @@ func (r *ApiRequest) GetId() string {
|
||||
|
||||
// MarshalJSON returns request as JSON.
|
||||
func (r *ApiRequest) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(r)
|
||||
return json.Marshal(*r)
|
||||
}
|
||||
|
||||
46
internal/ai/vision/api_request_test.go
Normal file
46
internal/ai/vision/api_request_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
func TestNewClientRequest(t *testing.T) {
|
||||
var assetsPath = fs.Abs("../../../assets")
|
||||
var examplesPath = assetsPath + "/examples"
|
||||
|
||||
t.Run("Data", func(t *testing.T) {
|
||||
thumbnails := Files{examplesPath + "/chameleon_lime.jpg"}
|
||||
result, err := NewClientRequest(thumbnails, scheme.Data)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
// t.Logf("request: %#v", result)
|
||||
|
||||
if result != nil {
|
||||
json, jsonErr := result.MarshalJSON()
|
||||
assert.NoError(t, jsonErr)
|
||||
assert.NotEmpty(t, json)
|
||||
// t.Logf("json: %s", json)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Https", func(t *testing.T) {
|
||||
thumbnails := Files{examplesPath + "/chameleon_lime.jpg"}
|
||||
result, err := NewClientRequest(thumbnails, scheme.Https)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
// t.Logf("request: %#v", result)
|
||||
if result != nil {
|
||||
json, jsonErr := result.MarshalJSON()
|
||||
assert.NoError(t, jsonErr)
|
||||
assert.NotEmpty(t, json)
|
||||
t.Logf("json: %s", json)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/classify"
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
)
|
||||
|
||||
@@ -20,9 +22,10 @@ type ApiResponse struct {
|
||||
// ApiResult represents the model response(s) to a Vision API service
|
||||
// request and can optionally include data from multiple models.
|
||||
type ApiResult struct {
|
||||
Caption *CaptionResult `yaml:"Caption,omitempty" json:"caption,omitempty"`
|
||||
Faces *[]string `yaml:"Faces,omitempty" json:"faces,omitempty"`
|
||||
Labels []LabelResult `yaml:"Labels,omitempty" json:"labels,omitempty"`
|
||||
Nsfw []nsfw.Result `yaml:"Nsfw,omitempty" json:"nsfw,omitempty"`
|
||||
Embeddings []face.Embeddings `yaml:"Embeddings,omitempty" json:"embeddings,omitempty"`
|
||||
Caption *CaptionResult `yaml:"Caption,omitempty" json:"caption,omitempty"`
|
||||
}
|
||||
|
||||
// CaptionResult represents the result generated by a caption generation model.
|
||||
@@ -41,6 +44,7 @@ type LabelResult struct {
|
||||
Categories []string `yaml:"Categories,omitempty" json:"categories,omitempty"`
|
||||
}
|
||||
|
||||
// ToClassify returns the label results as classify.Label.
|
||||
func (r LabelResult) ToClassify() classify.Label {
|
||||
uncertainty := math.RoundToEven(float64(100 - r.Confidence*100))
|
||||
return classify.Label{
|
||||
|
||||
16
internal/ai/vision/config.go
Normal file
16
internal/ai/vision/config.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
||||
var (
|
||||
AssetsPath = fs.Abs("../../../assets")
|
||||
FaceNetModelPath = fs.Abs("../../../assets/facenet")
|
||||
NsfwModelPath = fs.Abs("../../../assets/nsfw")
|
||||
CachePath = fs.Abs("../../../storage/cache")
|
||||
ServiceUri = ""
|
||||
ServiceKey = ""
|
||||
DownloadUrl = ""
|
||||
DefaultResolution = 224
|
||||
)
|
||||
10
internal/ai/vision/const.go
Normal file
10
internal/ai/vision/const.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package vision
|
||||
|
||||
type ModelType = string
|
||||
|
||||
const (
|
||||
ModelTypeLabels ModelType = "labels"
|
||||
ModelTypeNsfw ModelType = "nsfw"
|
||||
ModelTypeFaceEmbeddings ModelType = "face/embeddings"
|
||||
ModelTypeCaption ModelType = "caption"
|
||||
)
|
||||
37
internal/ai/vision/face_embeddings.go
Normal file
37
internal/ai/vision/face_embeddings.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image/jpeg"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
)
|
||||
|
||||
// FaceEmbeddings returns the embeddings for the specified face jpeg.
|
||||
func FaceEmbeddings(imgData []byte) (embeddings face.Embeddings, err error) {
|
||||
if len(imgData) == 0 {
|
||||
return embeddings, errors.New("missing image")
|
||||
}
|
||||
|
||||
if Config == nil {
|
||||
return embeddings, errors.New("missing configuration")
|
||||
} else if model := Config.Model(ModelTypeFaceEmbeddings); model != nil {
|
||||
img, imgErr := jpeg.Decode(bytes.NewReader(imgData))
|
||||
|
||||
if imgErr != nil {
|
||||
return embeddings, imgErr
|
||||
}
|
||||
|
||||
if tf := model.FaceModel(); tf == nil {
|
||||
return embeddings, fmt.Errorf("invalid face model configuration")
|
||||
} else if embeddings = tf.Run(img); !embeddings.Empty() {
|
||||
return embeddings, nil
|
||||
} else {
|
||||
return face.Embeddings{}, nil
|
||||
}
|
||||
} else {
|
||||
return embeddings, fmt.Errorf("no face model configured")
|
||||
}
|
||||
}
|
||||
119
internal/ai/vision/faces.go
Normal file
119
internal/ai/vision/faces.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/thumb/crop"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/header"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
// Faces runs face detection and facenet algorithms over the provided source image.
|
||||
func Faces(fileName string, minSize int, cacheCrop bool, expected int) (faces face.Faces, err error) {
|
||||
if fileName == "" {
|
||||
return faces, errors.New("missing image filename")
|
||||
}
|
||||
|
||||
// Return if there is no configuration or no image classification models are configured.
|
||||
if Config == nil {
|
||||
return faces, errors.New("missing configuration")
|
||||
} else if model := Config.Model(ModelTypeFaceEmbeddings); model != nil {
|
||||
faces, err = face.Detect(fileName, false, minSize)
|
||||
|
||||
if err != nil {
|
||||
return faces, err
|
||||
}
|
||||
|
||||
// Skip embeddings?
|
||||
if c := len(faces); c == 0 || expected > 0 && c == expected {
|
||||
return faces, nil
|
||||
}
|
||||
|
||||
if uri, method := model.Endpoint(); uri != "" && method != "" {
|
||||
faceCrops := make([]string, len(faces))
|
||||
|
||||
for i, f := range faces {
|
||||
if f.Area.Col == 0 && f.Area.Row == 0 {
|
||||
faceCrops[i] = ""
|
||||
continue
|
||||
}
|
||||
|
||||
if _, faceCrop, imgErr := crop.ImageFromThumb(fileName, f.CropArea(), face.CropSize, cacheCrop); imgErr != nil {
|
||||
log.Errorf("faces: failed to decode image: %s", imgErr)
|
||||
faceCrops[i] = ""
|
||||
} else if faceCrop != "" {
|
||||
faceCrops[i] = faceCrop
|
||||
}
|
||||
}
|
||||
|
||||
apiRequest, apiRequestErr := NewClientRequest(faceCrops, scheme.Data)
|
||||
|
||||
if apiRequestErr != nil {
|
||||
return faces, apiRequestErr
|
||||
}
|
||||
|
||||
if model.Name != "" {
|
||||
apiRequest.Model = model.Name
|
||||
}
|
||||
|
||||
data, jsonErr := apiRequest.MarshalJSON()
|
||||
|
||||
if jsonErr != nil {
|
||||
return faces, jsonErr
|
||||
}
|
||||
|
||||
// Create HTTP client and authenticated service API request.
|
||||
client := http.Client{}
|
||||
req, reqErr := http.NewRequest(method, uri, bytes.NewReader(data))
|
||||
header.SetAuthorization(req, model.EndpointKey())
|
||||
|
||||
if reqErr != nil {
|
||||
return faces, reqErr
|
||||
}
|
||||
|
||||
// Perform API request.
|
||||
clientResp, clientErr := client.Do(req)
|
||||
|
||||
if clientErr != nil {
|
||||
return faces, clientErr
|
||||
}
|
||||
|
||||
apiResponse := &ApiResponse{}
|
||||
|
||||
if apiJson, apiErr := io.ReadAll(clientResp.Body); apiErr != nil {
|
||||
return faces, apiErr
|
||||
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
|
||||
return faces, apiErr
|
||||
}
|
||||
|
||||
for i := range faces {
|
||||
if len(apiResponse.Result.Embeddings) > i {
|
||||
faces[i].Embeddings = apiResponse.Result.Embeddings[i]
|
||||
}
|
||||
}
|
||||
} else if tf := model.FaceModel(); tf != nil {
|
||||
for i, f := range faces {
|
||||
if f.Area.Col == 0 && f.Area.Row == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if img, _, imgErr := crop.ImageFromThumb(fileName, f.CropArea(), face.CropSize, cacheCrop); imgErr != nil {
|
||||
log.Errorf("faces: failed to decode image: %s", imgErr)
|
||||
} else if embeddings := tf.Run(img); !embeddings.Empty() {
|
||||
faces[i].Embeddings = embeddings
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return faces, errors.New("invalid face model configuration")
|
||||
}
|
||||
} else {
|
||||
return faces, errors.New("missing face model")
|
||||
}
|
||||
|
||||
return faces, nil
|
||||
}
|
||||
@@ -4,33 +4,41 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/classify"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/header"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
// Labels returns suitable labels for the specified image thumbnail.
|
||||
func Labels(thumbnails []string) (result classify.Labels, err error) {
|
||||
func Labels(images Files, src media.Src) (result classify.Labels, err error) {
|
||||
// Return if no thumbnail filenames were given.
|
||||
if len(thumbnails) == 0 {
|
||||
return result, errors.New("missing thumbnail filenames")
|
||||
if len(images) == 0 {
|
||||
return result, errors.New("missing image filenames")
|
||||
}
|
||||
|
||||
// Return if there is no configuration or no image classification models are configured.
|
||||
if Config == nil {
|
||||
return result, errors.New("missing configuration")
|
||||
} else if len(Config.Labels) == 0 {
|
||||
return result, errors.New("missing labels model configuration")
|
||||
} else if model := Config.Model(ModelTypeLabels); model != nil {
|
||||
// Use remote service API if a server endpoint has been configured.
|
||||
if uri, method := model.Endpoint(); uri != "" && method != "" {
|
||||
apiRequest, apiRequestErr := NewClientRequest(images, scheme.Data)
|
||||
|
||||
if apiRequestErr != nil {
|
||||
return result, apiRequestErr
|
||||
}
|
||||
|
||||
if model.Name != "" {
|
||||
apiRequest.Model = model.Name
|
||||
}
|
||||
|
||||
// Use computer vision models configured for image classification.
|
||||
for _, model := range Config.Labels {
|
||||
// Use remote service API if a server endpoint has been configured.
|
||||
if uri, method := model.Endpoint(LabelsEndpoint); uri != "" && method != "" {
|
||||
apiRequest := NewClientRequest(model.Name, thumbnails)
|
||||
data, jsonErr := apiRequest.MarshalJSON()
|
||||
|
||||
if jsonErr != nil {
|
||||
@@ -67,18 +75,29 @@ func Labels(thumbnails []string) (result classify.Labels, err error) {
|
||||
}
|
||||
} else if tf := model.ClassifyModel(); tf != nil {
|
||||
// Predict labels with local TensorFlow model.
|
||||
for i := range thumbnails {
|
||||
labels, modelErr := tf.File(thumbnails[i], Config.Thresholds.Confidence)
|
||||
for i := range images {
|
||||
var labels classify.Labels
|
||||
|
||||
if modelErr != nil {
|
||||
return result, modelErr
|
||||
switch src {
|
||||
case media.SrcLocal:
|
||||
labels, err = tf.File(images[i], Config.Thresholds.Confidence)
|
||||
case media.SrcRemote:
|
||||
labels, err = tf.Url(images[i], Config.Thresholds.Confidence)
|
||||
default:
|
||||
return result, fmt.Errorf("invalid image source %s", clean.Log(src))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
result = mergeLabels(result, labels)
|
||||
}
|
||||
} else {
|
||||
return result, errors.New("missing labels model")
|
||||
return result, errors.New("invalid labels model configuration")
|
||||
}
|
||||
} else {
|
||||
return result, errors.New("missing labels model")
|
||||
}
|
||||
|
||||
sort.Sort(result)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/classify"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
)
|
||||
|
||||
func TestLabels(t *testing.T) {
|
||||
@@ -14,7 +15,7 @@ func TestLabels(t *testing.T) {
|
||||
var examplesPath = assetsPath + "/examples"
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
result, err := Labels([]string{examplesPath + "/chameleon_lime.jpg"})
|
||||
result, err := Labels(Files{examplesPath + "/chameleon_lime.jpg"}, media.SrcLocal)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.IsType(t, classify.Labels{}, result)
|
||||
@@ -26,7 +27,7 @@ func TestLabels(t *testing.T) {
|
||||
assert.Equal(t, 7, result[0].Uncertainty)
|
||||
})
|
||||
t.Run("Cats", func(t *testing.T) {
|
||||
result, err := Labels([]string{examplesPath + "/cat_720.jpeg"})
|
||||
result, err := Labels(Files{examplesPath + "/cat_720.jpeg"}, media.SrcLocal)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.IsType(t, classify.Labels{}, result)
|
||||
@@ -39,7 +40,7 @@ func TestLabels(t *testing.T) {
|
||||
assert.InDelta(t, float32(0.4), result[0].Confidence(), 0.01)
|
||||
})
|
||||
t.Run("InvalidFile", func(t *testing.T) {
|
||||
_, err := Labels([]string{examplesPath + "/notexisting.jpg"})
|
||||
_, err := Labels(Files{examplesPath + "/notexisting.jpg"}, media.SrcLocal)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/classify"
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
)
|
||||
|
||||
@@ -13,6 +16,7 @@ var modelMutex = sync.Mutex{}
|
||||
|
||||
// Model represents a computer vision model configuration.
|
||||
type Model struct {
|
||||
Type ModelType `yaml:"Type,omitempty" json:"type,omitempty"`
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
|
||||
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
|
||||
@@ -20,18 +24,19 @@ type Model struct {
|
||||
Key string `yaml:"Key,omitempty" json:"-"`
|
||||
Method string `yaml:"Method,omitempty" json:"-"`
|
||||
Path string `yaml:"Path,omitempty" json:"-"`
|
||||
Format string `yaml:"Format,omitempty" json:"-"`
|
||||
Tags []string `yaml:"Tags,omitempty" json:"-"`
|
||||
Disabled bool `yaml:"Disabled,omitempty" json:"-"`
|
||||
classifyModel *classify.Model
|
||||
faceModel *face.Model
|
||||
nsfwModel *nsfw.Model
|
||||
}
|
||||
|
||||
// Models represents a set of computer vision models.
|
||||
type Models []*Model
|
||||
|
||||
// Endpoint returns the remote service request method and endpoint URL, if any.
|
||||
func (m *Model) Endpoint(name string) (method, uri string) {
|
||||
if m.Uri == "" && ServiceUri == "" {
|
||||
func (m *Model) Endpoint() (uri, method string) {
|
||||
if m.Uri == "" && ServiceUri == "" || m.Type == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
@@ -44,7 +49,7 @@ func (m *Model) Endpoint(name string) (method, uri string) {
|
||||
if m.Uri != "" {
|
||||
return m.Uri, method
|
||||
} else {
|
||||
return path.Join(ServiceUri, name), method
|
||||
return fmt.Sprintf("%s/%s", ServiceUri, clean.TypeLowerUnderscore(m.Type)), method
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,3 +119,115 @@ func (m *Model) ClassifyModel() *classify.Model {
|
||||
|
||||
return m.classifyModel
|
||||
}
|
||||
|
||||
// FaceModel returns the matching face model instance, if any.
|
||||
func (m *Model) FaceModel() *face.Model {
|
||||
// Use mutex to prevent models from being loaded and
|
||||
// initialized twice by different indexing workers.
|
||||
modelMutex.Lock()
|
||||
defer modelMutex.Unlock()
|
||||
|
||||
// Return the existing model instance if it has already been created.
|
||||
if m.faceModel != nil {
|
||||
return m.faceModel
|
||||
}
|
||||
|
||||
switch m.Name {
|
||||
case "":
|
||||
log.Warnf("vision: missing name, model instance cannot be created")
|
||||
return nil
|
||||
case FacenetModel.Name, "facenet":
|
||||
// Load and initialize the Nasnet image classification model.
|
||||
if model := face.NewModel(FaceNetModelPath, CachePath, m.Resolution, m.Tags, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
return nil
|
||||
} else {
|
||||
m.faceModel = model
|
||||
}
|
||||
default:
|
||||
// Set model path from model name if no path is configured.
|
||||
if m.Path == "" {
|
||||
m.Path = clean.TypeLowerUnderscore(m.Name)
|
||||
}
|
||||
|
||||
// Set default thumbnail resolution if no tags are configured.
|
||||
if m.Resolution <= 0 {
|
||||
m.Resolution = DefaultResolution
|
||||
}
|
||||
|
||||
// Set default tag if no tags are configured.
|
||||
if len(m.Tags) == 0 {
|
||||
m.Tags = []string{"serve"}
|
||||
}
|
||||
|
||||
// Try to load custom model based on the configuration values.
|
||||
if model := face.NewModel(filepath.Join(AssetsPath, m.Path), CachePath, m.Resolution, m.Tags, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
return nil
|
||||
} else {
|
||||
m.faceModel = model
|
||||
}
|
||||
}
|
||||
|
||||
return m.faceModel
|
||||
}
|
||||
|
||||
// NsfwModel returns the matching nsfw model instance, if any.
|
||||
func (m *Model) NsfwModel() *nsfw.Model {
|
||||
// Use mutex to prevent models from being loaded and
|
||||
// initialized twice by different indexing workers.
|
||||
modelMutex.Lock()
|
||||
defer modelMutex.Unlock()
|
||||
|
||||
// Return the existing model instance if it has already been created.
|
||||
if m.nsfwModel != nil {
|
||||
return m.nsfwModel
|
||||
}
|
||||
|
||||
switch m.Name {
|
||||
case "":
|
||||
log.Warnf("vision: missing name, model instance cannot be created")
|
||||
return nil
|
||||
case NsfwModel.Name, "nsfw":
|
||||
// Load and initialize the Nasnet image classification model.
|
||||
if model := nsfw.NewModel(NsfwModelPath, m.Resolution, m.Tags, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
return nil
|
||||
} else {
|
||||
m.nsfwModel = model
|
||||
}
|
||||
default:
|
||||
// Set model path from model name if no path is configured.
|
||||
if m.Path == "" {
|
||||
m.Path = clean.TypeLowerUnderscore(m.Name)
|
||||
}
|
||||
|
||||
// Set default thumbnail resolution if no tags are configured.
|
||||
if m.Resolution <= 0 {
|
||||
m.Resolution = DefaultResolution
|
||||
}
|
||||
|
||||
// Set default tag if no tags are configured.
|
||||
if len(m.Tags) == 0 {
|
||||
m.Tags = []string{"serve"}
|
||||
}
|
||||
|
||||
// Try to load custom model based on the configuration values.
|
||||
if model := nsfw.NewModel(filepath.Join(AssetsPath, m.Path), m.Resolution, m.Tags, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
return nil
|
||||
} else {
|
||||
m.nsfwModel = model
|
||||
}
|
||||
}
|
||||
|
||||
return m.nsfwModel
|
||||
}
|
||||
|
||||
21
internal/ai/vision/model_test.go
Normal file
21
internal/ai/vision/model_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestModel(t *testing.T) {
|
||||
t.Run("Nasnet", func(t *testing.T) {
|
||||
ServiceUri = "https://app.localssl.dev/api/v1/vision"
|
||||
uri, method := NasnetModel.Endpoint()
|
||||
ServiceUri = ""
|
||||
assert.Equal(t, "https://app.localssl.dev/api/v1/vision/labels", uri)
|
||||
assert.Equal(t, http.MethodPost, method)
|
||||
uri, method = NasnetModel.Endpoint()
|
||||
assert.Equal(t, "", uri)
|
||||
assert.Equal(t, "", method)
|
||||
})
|
||||
}
|
||||
@@ -1,20 +1,39 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ModelType = string
|
||||
|
||||
// Default computer vision model configuration.
|
||||
var (
|
||||
AssetsPath = fs.Abs("../../../assets")
|
||||
ServiceUri = ""
|
||||
ServiceKey = ""
|
||||
DownloadUrl = ""
|
||||
DefaultResolution = 224
|
||||
)
|
||||
|
||||
// NasnetModel is a standard TensorFlow model used for label generation.
|
||||
var (
|
||||
NasnetModel = &Model{Name: "Nasnet", Version: "Mobile", Resolution: 224, Tags: []string{"photoprism"}}
|
||||
NasnetModel = &Model{
|
||||
Type: ModelTypeLabels,
|
||||
Name: "NASNet",
|
||||
Version: "Mobile",
|
||||
Resolution: 224,
|
||||
Tags: []string{"photoprism"},
|
||||
}
|
||||
NsfwModel = &Model{
|
||||
Type: ModelTypeNsfw,
|
||||
Name: "Nsfw",
|
||||
Version: "",
|
||||
Resolution: 224,
|
||||
Tags: []string{"serve"},
|
||||
}
|
||||
FacenetModel = &Model{
|
||||
Type: ModelTypeFaceEmbeddings,
|
||||
Name: "FaceNet",
|
||||
Version: "",
|
||||
Resolution: 160,
|
||||
Tags: []string{"serve"},
|
||||
}
|
||||
CaptionModel = &Model{
|
||||
Type: ModelTypeCaption,
|
||||
Name: "Caption",
|
||||
Uri: "http://photoprism-vision/api/v1/vision/describe",
|
||||
Method: http.MethodPost,
|
||||
Resolution: 720,
|
||||
}
|
||||
DefaultModels = Models{NasnetModel, NsfwModel, FacenetModel, CaptionModel}
|
||||
DefaultThresholds = Thresholds{Confidence: 10}
|
||||
)
|
||||
|
||||
103
internal/ai/vision/nsfw.go
Normal file
103
internal/ai/vision/nsfw.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/header"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
// Nsfw checks the specified images for inappropriate content.
|
||||
func Nsfw(images Files, src media.Src) (result []nsfw.Result, err error) {
|
||||
// Return if no thumbnail filenames were given.
|
||||
if len(images) == 0 {
|
||||
return result, errors.New("missing image filenames")
|
||||
}
|
||||
|
||||
result = make([]nsfw.Result, len(images))
|
||||
|
||||
// Return if there is no configuration or no image classification models are configured.
|
||||
if Config == nil {
|
||||
return result, errors.New("missing configuration")
|
||||
} else if model := Config.Model(ModelTypeNsfw); model != nil {
|
||||
// Use remote service API if a server endpoint has been configured.
|
||||
if uri, method := model.Endpoint(); uri != "" && method != "" {
|
||||
apiRequest, apiRequestErr := NewClientRequest(images, scheme.Data)
|
||||
|
||||
if apiRequestErr != nil {
|
||||
return result, apiRequestErr
|
||||
}
|
||||
|
||||
if model.Name != "" {
|
||||
apiRequest.Model = model.Name
|
||||
}
|
||||
|
||||
data, jsonErr := apiRequest.MarshalJSON()
|
||||
|
||||
if jsonErr != nil {
|
||||
return result, jsonErr
|
||||
}
|
||||
|
||||
// Create HTTP client and authenticated service API request.
|
||||
client := http.Client{}
|
||||
req, reqErr := http.NewRequest(method, uri, bytes.NewReader(data))
|
||||
header.SetAuthorization(req, model.EndpointKey())
|
||||
|
||||
if reqErr != nil {
|
||||
return result, reqErr
|
||||
}
|
||||
|
||||
// Perform API request.
|
||||
clientResp, clientErr := client.Do(req)
|
||||
|
||||
if clientErr != nil {
|
||||
return result, clientErr
|
||||
}
|
||||
|
||||
apiResponse := &ApiResponse{}
|
||||
|
||||
// Unmarshal response and add labels, if returned.
|
||||
if apiJson, apiErr := io.ReadAll(clientResp.Body); apiErr != nil {
|
||||
return result, apiErr
|
||||
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
|
||||
return result, apiErr
|
||||
}
|
||||
|
||||
result = apiResponse.Result.Nsfw
|
||||
} else if tf := model.NsfwModel(); tf != nil {
|
||||
// Predict labels with local TensorFlow model.
|
||||
for i := range images {
|
||||
var labels nsfw.Result
|
||||
|
||||
switch src {
|
||||
case media.SrcLocal:
|
||||
labels, err = tf.File(images[i])
|
||||
case media.SrcRemote:
|
||||
labels, err = tf.Url(images[i])
|
||||
default:
|
||||
return result, fmt.Errorf("invalid image source %s", clean.Log(src))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("nsfw: %s", err)
|
||||
}
|
||||
|
||||
result[i] = labels
|
||||
}
|
||||
} else {
|
||||
return result, errors.New("invalid nsfw model configuration")
|
||||
}
|
||||
} else {
|
||||
return result, errors.New("missing nsfw model")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -15,21 +15,15 @@ var Config = NewOptions()
|
||||
|
||||
// Options represents a computer vision configuration for the supported Model types.
|
||||
type Options struct {
|
||||
Caption Models `yaml:"Caption,omitempty" json:"caption,omitempty"`
|
||||
Faces Models `yaml:"Faces,omitempty" json:"faces,omitempty"`
|
||||
Labels Models `yaml:"Labels,omitempty" json:"labels,omitempty"`
|
||||
Nsfw Models `yaml:"Nsfw,omitempty" json:"nsfw,omitempty"`
|
||||
Models Models `yaml:"Models,omitempty" json:"models,omitempty"`
|
||||
Thresholds Thresholds `yaml:"Thresholds" json:"thresholds"`
|
||||
}
|
||||
|
||||
// NewOptions returns a new computer vision config with defaults.
|
||||
func NewOptions() *Options {
|
||||
return &Options{
|
||||
Caption: Models{},
|
||||
Faces: Models{},
|
||||
Labels: Models{NasnetModel},
|
||||
Nsfw: Models{},
|
||||
Thresholds: Thresholds{Confidence: 10},
|
||||
Models: DefaultModels,
|
||||
Thresholds: DefaultThresholds,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,3 +66,14 @@ func (c *Options) Save(fileName string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Model returns the first enabled model with the matching type from the configuration.
|
||||
func (c *Options) Model(t ModelType) *Model {
|
||||
for _, m := range c.Models {
|
||||
if m.Type == t && !m.Disabled {
|
||||
return m
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
26
internal/ai/vision/vision_test.go
Normal file
26
internal/ai/vision/vision_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/event"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Init test logger.
|
||||
log = logrus.StandardLogger()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
event.AuditLog = log
|
||||
|
||||
// Set test config values.
|
||||
DownloadUrl = "https://app.localssl.dev/api/v1/dl"
|
||||
ServiceUri = ""
|
||||
|
||||
// Run unit tests.
|
||||
code := m.Run()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/dustin/go-humanize/english"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/internal/auth/acl"
|
||||
"github.com/photoprism/photoprism/internal/entity/query"
|
||||
"github.com/photoprism/photoprism/internal/event"
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/i18n"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
)
|
||||
|
||||
// UploadUserFiles adds files to the user upload folder, from where they can be moved and indexed.
|
||||
@@ -184,19 +186,18 @@ func UploadUserFiles(router *gin.RouterGroup) {
|
||||
|
||||
// Check if the uploaded file may contain inappropriate content.
|
||||
if len(uploads) > 0 && !conf.UploadNSFW() {
|
||||
nd := get.NsfwDetector()
|
||||
|
||||
containsNSFW := false
|
||||
|
||||
for _, filename := range uploads {
|
||||
labels, nsfwErr := nd.File(filename)
|
||||
labels, nsfwErr := vision.Nsfw([]string{filename}, media.SrcLocal)
|
||||
|
||||
if nsfwErr != nil {
|
||||
log.Debug(nsfwErr)
|
||||
continue
|
||||
}
|
||||
|
||||
if labels.IsSafe() {
|
||||
} else if len(labels) < 1 {
|
||||
log.Errorf("nsfw: model returned no result")
|
||||
continue
|
||||
} else if labels[0].IsSafe() {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -5,24 +5,27 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/internal/auth/acl"
|
||||
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/header"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
// PostVisionFaces returns the positions and embeddings of detected faces.
|
||||
// PostVisionFaceEmbeddings returns the embeddings of detected faces.
|
||||
//
|
||||
// @Summary returns the positions and embeddings of detected faces
|
||||
// @Id PostVisionFaces
|
||||
// @Id PostVisionFaceEmbeddings
|
||||
// @Tags Vision
|
||||
// @Produce json
|
||||
// @Success 200 {object} vision.ApiResponse
|
||||
// @Failure 401,403,429,501 {object} i18n.Response
|
||||
// @Param images body vision.ApiRequest true "list of image file urls"
|
||||
// @Router /api/v1/vision/faces [post]
|
||||
func PostVisionFaces(router *gin.RouterGroup) {
|
||||
router.POST("/vision/faces", func(c *gin.Context) {
|
||||
// @Router /api/v1/vision/face/embeddings [post]
|
||||
func PostVisionFaceEmbeddings(router *gin.RouterGroup) {
|
||||
router.POST("/vision/face/embeddings", func(c *gin.Context) {
|
||||
s := Auth(c, acl.ResourceVision, acl.AccessAll)
|
||||
|
||||
// Abort if permission is not granted.
|
||||
@@ -51,13 +54,27 @@ func PostVisionFaces(router *gin.RouterGroup) {
|
||||
return
|
||||
}
|
||||
|
||||
// Run inference to find matching labels.
|
||||
results := make([]face.Embeddings, len(request.Images))
|
||||
|
||||
for i := range request.Images {
|
||||
if data, err := media.ReadUrl(request.Images[i], scheme.HttpsData); err != nil {
|
||||
results[i] = face.Embeddings{}
|
||||
log.Errorf("vision: %s (read face embedding from url)", err)
|
||||
} else if result, faceErr := vision.FaceEmbeddings(data); faceErr != nil {
|
||||
results[i] = face.Embeddings{}
|
||||
log.Errorf("vision: %s (generate face embedding)", faceErr)
|
||||
} else {
|
||||
results[i] = result
|
||||
}
|
||||
}
|
||||
|
||||
// Generate Vision API service response.
|
||||
response := vision.ApiResponse{
|
||||
Id: request.GetId(),
|
||||
Code: http.StatusNotImplemented,
|
||||
Error: http.StatusText(http.StatusNotImplemented),
|
||||
Model: &vision.Model{Name: "Faces"},
|
||||
Result: &vision.ApiResult{},
|
||||
Code: http.StatusOK,
|
||||
Model: &vision.Model{Name: vision.FacenetModel.Name},
|
||||
Result: &vision.ApiResult{Embeddings: results},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNotImplemented, response)
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/internal/auth/acl"
|
||||
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/header"
|
||||
)
|
||||
|
||||
@@ -53,7 +54,7 @@ func PostVisionLabels(router *gin.RouterGroup) {
|
||||
}
|
||||
|
||||
// Run inference to find matching labels.
|
||||
labels, err := vision.Labels(request.Images)
|
||||
labels, err := vision.Labels(request.Images, media.SrcLocal)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, vision.NewApiError(request.GetId(), http.StatusBadRequest))
|
||||
|
||||
74
internal/api/vision_nsfw.go
Normal file
74
internal/api/vision_nsfw.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/internal/auth/acl"
|
||||
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/header"
|
||||
)
|
||||
|
||||
// PostVisionNsfw checks the specified images for inappropriate content.
|
||||
//
|
||||
// @Summary checks the specified images for inappropriate content
|
||||
// @Id PostVisionNsfw
|
||||
// @Tags Vision
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} vision.ApiResponse
|
||||
// @Failure 401,403,429 {object} i18n.Response
|
||||
// @Param images body vision.ApiRequest true "list of image file urls"
|
||||
// @Router /api/v1/vision/nsfw [post]
|
||||
func PostVisionNsfw(router *gin.RouterGroup) {
|
||||
router.POST("/vision/nsfw", func(c *gin.Context) {
|
||||
s := Auth(c, acl.ResourceVision, acl.AccessAll)
|
||||
|
||||
// Abort if permission is not granted.
|
||||
if s.Abort(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var request vision.ApiRequest
|
||||
|
||||
// File uploads are not currently supported for this API endpoint.
|
||||
if header.HasContentType(&c.Request.Header, header.ContentTypeMultipart) {
|
||||
c.JSON(http.StatusBadRequest, vision.NewApiError(request.GetId(), http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
// Assign and validate request form values.
|
||||
if err := c.BindJSON(&request); err != nil {
|
||||
c.JSON(http.StatusBadRequest, vision.NewApiError(request.GetId(), http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the Computer Vision API is enabled, otherwise abort with an error.
|
||||
if !get.Config().VisionApi() {
|
||||
AbortFeatureDisabled(c)
|
||||
c.JSON(http.StatusForbidden, vision.NewApiError(request.GetId(), http.StatusForbidden))
|
||||
return
|
||||
}
|
||||
|
||||
// Run inference to check the specified images for inappropriate content.
|
||||
results, err := vision.Nsfw(request.Images, media.SrcRemote)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, vision.NewApiError(request.GetId(), http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
// Generate Vision API service response.
|
||||
response := vision.ApiResponse{
|
||||
Id: request.GetId(),
|
||||
Code: http.StatusOK,
|
||||
Model: &vision.Model{Name: vision.NsfwModel.Name},
|
||||
Result: &vision.ApiResult{Nsfw: results},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
})
|
||||
}
|
||||
@@ -282,6 +282,9 @@ func (c *Config) Propagate() {
|
||||
|
||||
// Configure computer vision package.
|
||||
vision.AssetsPath = c.AssetsPath()
|
||||
vision.FaceNetModelPath = c.FaceNetModelPath()
|
||||
vision.NsfwModelPath = c.NSFWModelPath()
|
||||
vision.CachePath = c.CachePath()
|
||||
vision.ServiceUri = c.VisionUri()
|
||||
vision.ServiceKey = c.VisionKey()
|
||||
vision.DownloadUrl = c.DownloadUrl()
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
package get
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/classify"
|
||||
)
|
||||
|
||||
var onceClassify sync.Once
|
||||
|
||||
func initClassify() {
|
||||
services.Classify = classify.NewNasnet(Config().AssetsPath(), Config().DisableClassification())
|
||||
}
|
||||
|
||||
func Classify() *classify.Model {
|
||||
onceClassify.Do(initClassify)
|
||||
|
||||
return services.Classify
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package get
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
)
|
||||
|
||||
var onceFaceNet sync.Once
|
||||
|
||||
func initFaceNet() {
|
||||
services.FaceNet = face.NewModel(conf.FaceNetModelPath(), "", conf.DisableFaces())
|
||||
}
|
||||
|
||||
func FaceNet() *face.Model {
|
||||
onceFaceNet.Do(initFaceNet)
|
||||
|
||||
return services.FaceNet
|
||||
}
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
var onceIndex sync.Once
|
||||
|
||||
func initIndex() {
|
||||
services.Index = photoprism.NewIndex(Config(), NsfwDetector(), FaceNet(), Convert(), Files(), Photos())
|
||||
services.Index = photoprism.NewIndex(Config(), Convert(), Files(), Photos())
|
||||
}
|
||||
|
||||
func Index() *photoprism.Index {
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
package get
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
)
|
||||
|
||||
var onceNsfwDetector sync.Once
|
||||
|
||||
func initNsfwDetector() {
|
||||
services.Nsfw = nsfw.NewModel(conf.NSFWModelPath())
|
||||
}
|
||||
|
||||
func NsfwDetector() *nsfw.Model {
|
||||
onceNsfwDetector.Do(initNsfwDetector)
|
||||
|
||||
return services.Nsfw
|
||||
}
|
||||
@@ -27,9 +27,6 @@ package get
|
||||
import (
|
||||
gc "github.com/patrickmn/go-cache"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/classify"
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/auth/oidc"
|
||||
"github.com/photoprism/photoprism/internal/auth/session"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
@@ -43,7 +40,6 @@ var services struct {
|
||||
FolderCache *gc.Cache
|
||||
CoverCache *gc.Cache
|
||||
ThumbCache *gc.Cache
|
||||
Classify *classify.Model
|
||||
Convert *photoprism.Convert
|
||||
Files *photoprism.Files
|
||||
Photos *photoprism.Photos
|
||||
@@ -54,8 +50,6 @@ var services struct {
|
||||
Places *photoprism.Places
|
||||
Purge *photoprism.Purge
|
||||
CleanUp *photoprism.CleanUp
|
||||
Nsfw *nsfw.Model
|
||||
FaceNet *face.Model
|
||||
Query *query.Query
|
||||
Thumbs *photoprism.Thumbs
|
||||
Session *session.Session
|
||||
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
gc "github.com/patrickmn/go-cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/classify"
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/auth/oidc"
|
||||
"github.com/photoprism/photoprism/internal/auth/session"
|
||||
"github.com/photoprism/photoprism/internal/entity/query"
|
||||
@@ -30,10 +28,6 @@ func TestThumbCache(t *testing.T) {
|
||||
assert.IsType(t, &gc.Cache{}, ThumbCache())
|
||||
}
|
||||
|
||||
func TestClassify(t *testing.T) {
|
||||
assert.IsType(t, &classify.Model{}, Classify())
|
||||
}
|
||||
|
||||
func TestConvert(t *testing.T) {
|
||||
assert.IsType(t, &photoprism.Convert{}, Convert())
|
||||
}
|
||||
@@ -58,10 +52,6 @@ func TestCleanUp(t *testing.T) {
|
||||
assert.IsType(t, &photoprism.CleanUp{}, CleanUp())
|
||||
}
|
||||
|
||||
func TestNsfwDetector(t *testing.T) {
|
||||
assert.IsType(t, &nsfw.Model{}, NsfwDetector())
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
assert.IsType(t, &query.Query{}, Query())
|
||||
}
|
||||
|
||||
@@ -5,40 +5,34 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
)
|
||||
|
||||
func TestNewImport(t *testing.T) {
|
||||
conf := config.TestConfig()
|
||||
cfg := config.TestConfig()
|
||||
|
||||
nd := nsfw.NewModel(conf.NSFWModelPath())
|
||||
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
|
||||
convert := NewConvert(conf)
|
||||
convert := NewConvert(cfg)
|
||||
|
||||
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
|
||||
imp := NewImport(conf, ind, convert)
|
||||
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
|
||||
imp := NewImport(cfg, ind, convert)
|
||||
|
||||
assert.IsType(t, &Import{}, imp)
|
||||
}
|
||||
|
||||
func TestImport_DestinationFilename(t *testing.T) {
|
||||
conf := config.TestConfig()
|
||||
cfg := config.TestConfig()
|
||||
|
||||
if err := conf.InitializeTestData(); err != nil {
|
||||
if err := cfg.InitializeTestData(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nd := nsfw.NewModel(conf.NSFWModelPath())
|
||||
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
|
||||
convert := NewConvert(conf)
|
||||
convert := NewConvert(cfg)
|
||||
|
||||
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
|
||||
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
|
||||
|
||||
imp := NewImport(conf, ind, convert)
|
||||
imp := NewImport(cfg, ind, convert)
|
||||
|
||||
rawFile, err := NewMediaFile(conf.ImportPath() + "/raw/IMG_2567.CR2")
|
||||
rawFile, err := NewMediaFile(cfg.ImportPath() + "/raw/IMG_2567.CR2")
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -51,7 +45,7 @@ func TestImport_DestinationFilename(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assert.Equal(t, conf.OriginalsPath()+"/2019/07/20190705_153230_C167C6FD.cr2", fileName)
|
||||
assert.Equal(t, cfg.OriginalsPath()+"/2019/07/20190705_153230_C167C6FD.cr2", fileName)
|
||||
})
|
||||
|
||||
t.Run("WithBasePath", func(t *testing.T) {
|
||||
@@ -61,7 +55,7 @@ func TestImport_DestinationFilename(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assert.Equal(t, conf.OriginalsPath()+"/users/guest/2019/07/20190705_153230_C167C6FD.cr2", fileName)
|
||||
assert.Equal(t, cfg.OriginalsPath()+"/users/guest/2019/07/20190705_153230_C167C6FD.cr2", fileName)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -70,19 +64,17 @@ func TestImport_Start(t *testing.T) {
|
||||
t.Skip("skipping test in short mode.")
|
||||
}
|
||||
|
||||
conf := config.TestConfig()
|
||||
cfg := config.TestConfig()
|
||||
|
||||
conf.InitializeTestData()
|
||||
cfg.InitializeTestData()
|
||||
|
||||
nd := nsfw.NewModel(conf.NSFWModelPath())
|
||||
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
|
||||
convert := NewConvert(conf)
|
||||
convert := NewConvert(cfg)
|
||||
|
||||
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
|
||||
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
|
||||
|
||||
imp := NewImport(conf, ind, convert)
|
||||
imp := NewImport(cfg, ind, convert)
|
||||
|
||||
opt := ImportOptionsMove(conf.ImportPath(), "")
|
||||
opt := ImportOptionsMove(cfg.ImportPath(), "")
|
||||
|
||||
imp.Start(opt)
|
||||
}
|
||||
|
||||
@@ -5,36 +5,32 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
)
|
||||
|
||||
func TestImportWorker_OriginalFileNames(t *testing.T) {
|
||||
c := config.TestConfig()
|
||||
cfg := config.TestConfig()
|
||||
|
||||
if err := c.InitializeTestData(); err != nil {
|
||||
if err := cfg.InitializeTestData(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nd := nsfw.NewModel(c.NSFWModelPath())
|
||||
fn := face.NewModel(c.FaceNetModelPath(), "", c.DisableTensorFlow())
|
||||
convert := NewConvert(c)
|
||||
ind := NewIndex(c, nd, fn, convert, NewFiles(), NewPhotos())
|
||||
imp := &Import{c, ind, convert, c.ImportAllow()}
|
||||
convert := NewConvert(cfg)
|
||||
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
|
||||
imp := &Import{cfg, ind, convert, cfg.ImportAllow()}
|
||||
|
||||
mediaFileName := c.ExamplesPath() + "/beach_sand.jpg"
|
||||
mediaFileName := cfg.ExamplesPath() + "/beach_sand.jpg"
|
||||
mediaFile, err := NewMediaFile(mediaFileName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mediaFileName2 := c.ExamplesPath() + "/beach_wood.jpg"
|
||||
mediaFileName2 := cfg.ExamplesPath() + "/beach_wood.jpg"
|
||||
mediaFile2, err2 := NewMediaFile(mediaFileName2)
|
||||
if err2 != nil {
|
||||
t.Fatal(err2)
|
||||
}
|
||||
mediaFileName3 := c.ExamplesPath() + "/beach_colorfilter.jpg"
|
||||
mediaFileName3 := cfg.ExamplesPath() + "/beach_colorfilter.jpg"
|
||||
mediaFile3, err3 := NewMediaFile(mediaFileName3)
|
||||
if err3 != nil {
|
||||
t.Fatal(err3)
|
||||
@@ -56,7 +52,7 @@ func TestImportWorker_OriginalFileNames(t *testing.T) {
|
||||
FileName: mediaFile.FileName(),
|
||||
Related: relatedFiles,
|
||||
IndexOpt: IndexOptionsAll(),
|
||||
ImportOpt: ImportOptionsCopy(c.ImportPath(), c.ImportDest()),
|
||||
ImportOpt: ImportOptionsCopy(cfg.ImportPath(), cfg.ImportDest()),
|
||||
Imp: imp,
|
||||
}
|
||||
|
||||
|
||||
@@ -11,8 +11,6 @@ import (
|
||||
|
||||
"github.com/karrick/godirwalk"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/internal/event"
|
||||
@@ -26,8 +24,6 @@ import (
|
||||
// Index represents an indexer that indexes files in the originals directory.
|
||||
type Index struct {
|
||||
conf *config.Config
|
||||
nsfwDetector *nsfw.Model
|
||||
faceNet *face.Model
|
||||
convert *Convert
|
||||
files *Files
|
||||
photos *Photos
|
||||
@@ -38,7 +34,7 @@ type Index struct {
|
||||
}
|
||||
|
||||
// NewIndex returns a new indexer and expects its dependencies as arguments.
|
||||
func NewIndex(conf *config.Config, nsfwDetector *nsfw.Model, faceNet *face.Model, convert *Convert, files *Files, photos *Photos) *Index {
|
||||
func NewIndex(conf *config.Config, convert *Convert, files *Files, photos *Photos) *Index {
|
||||
if conf == nil {
|
||||
log.Errorf("index: config is not set")
|
||||
return nil
|
||||
@@ -46,8 +42,6 @@ func NewIndex(conf *config.Config, nsfwDetector *nsfw.Model, faceNet *face.Model
|
||||
|
||||
i := &Index{
|
||||
conf: conf,
|
||||
nsfwDetector: nsfwDetector,
|
||||
faceNet: faceNet,
|
||||
convert: convert,
|
||||
files: files,
|
||||
photos: photos,
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/dustin/go-humanize/english"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/internal/thumb"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
)
|
||||
@@ -39,7 +40,7 @@ func (ind *Index) Faces(jpeg *MediaFile, expected int) face.Faces {
|
||||
|
||||
start := time.Now()
|
||||
|
||||
faces, err := ind.faceNet.Detect(thumbName, Config().FaceSize(), true, expected)
|
||||
faces, err := vision.Faces(thumbName, Config().FaceSize(), true, expected)
|
||||
|
||||
if err != nil {
|
||||
log.Debugf("%s in %s", err, clean.Log(jpeg.BaseName()))
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/internal/thumb"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
)
|
||||
|
||||
// Labels classifies a JPEG image and returns matching labels.
|
||||
@@ -41,7 +42,7 @@ func (ind *Index) Labels(file *MediaFile) (labels classify.Labels) {
|
||||
}
|
||||
|
||||
// Get matching labels from computer vision model.
|
||||
if labels, err = vision.Labels(thumbnails); err != nil {
|
||||
if labels, err = vision.Labels(thumbnails, media.SrcLocal); err != nil {
|
||||
log.Debugf("labels: %s in %s", err, clean.Log(file.BaseName()))
|
||||
return labels
|
||||
}
|
||||
|
||||
@@ -805,7 +805,7 @@ func (ind *Index) UserMediaFile(m *MediaFile, o IndexOptions, originalName, phot
|
||||
}
|
||||
|
||||
if !photoExists && Config().Settings().Features.Private && Config().DetectNSFW() {
|
||||
photo.PhotoPrivate = ind.NSFW(m)
|
||||
photo.PhotoPrivate = ind.IsNsfw(m)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
)
|
||||
@@ -21,11 +19,9 @@ func TestIndex_MediaFile(t *testing.T) {
|
||||
|
||||
cfg.InitializeTestData()
|
||||
|
||||
nd := nsfw.NewModel(cfg.NSFWModelPath())
|
||||
fn := face.NewModel(cfg.FaceNetModelPath(), "", cfg.DisableTensorFlow())
|
||||
convert := NewConvert(cfg)
|
||||
|
||||
ind := NewIndex(cfg, nd, fn, convert, NewFiles(), NewPhotos())
|
||||
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
|
||||
indexOpt := IndexOptionsAll()
|
||||
mediaFile, err := NewMediaFile("testdata/flash.jpg")
|
||||
|
||||
@@ -57,11 +53,9 @@ func TestIndex_MediaFile(t *testing.T) {
|
||||
|
||||
cfg.InitializeTestData()
|
||||
|
||||
nd := nsfw.NewModel(cfg.NSFWModelPath())
|
||||
fn := face.NewModel(cfg.FaceNetModelPath(), "", cfg.DisableTensorFlow())
|
||||
convert := NewConvert(cfg)
|
||||
|
||||
ind := NewIndex(cfg, nd, fn, convert, NewFiles(), NewPhotos())
|
||||
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
|
||||
indexOpt := IndexOptionsAll()
|
||||
mediaFile, err := NewMediaFile(cfg.ExamplesPath() + "/blue-go-video.mp4")
|
||||
if err != nil {
|
||||
@@ -79,11 +73,9 @@ func TestIndex_MediaFile(t *testing.T) {
|
||||
|
||||
cfg.InitializeTestData()
|
||||
|
||||
nd := nsfw.NewModel(cfg.NSFWModelPath())
|
||||
fn := face.NewModel(cfg.FaceNetModelPath(), "", cfg.DisableTensorFlow())
|
||||
convert := NewConvert(cfg)
|
||||
|
||||
ind := NewIndex(cfg, nd, fn, convert, NewFiles(), NewPhotos())
|
||||
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
|
||||
indexOpt := IndexOptionsAll()
|
||||
|
||||
result := ind.MediaFile(nil, indexOpt, "blue-go-video.mp4", "")
|
||||
|
||||
@@ -2,12 +2,14 @@ package photoprism
|
||||
|
||||
import (
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/internal/thumb"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
)
|
||||
|
||||
// NSFW returns true if media file might be offensive and detection is enabled.
|
||||
func (ind *Index) NSFW(m *MediaFile) bool {
|
||||
// IsNsfw returns true if media file might be offensive and detection is enabled.
|
||||
func (ind *Index) IsNsfw(m *MediaFile) bool {
|
||||
filename, err := m.Thumbnail(Config().ThumbCachePath(), thumb.Fit720)
|
||||
|
||||
if err != nil {
|
||||
@@ -15,15 +17,16 @@ func (ind *Index) NSFW(m *MediaFile) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if nsfwLabels, err := ind.nsfwDetector.File(filename); err != nil {
|
||||
log.Errorf("index: %s in %s (detect nsfw)", err, m.RootRelName())
|
||||
if results, modelErr := vision.Nsfw([]string{filename}, media.SrcLocal); modelErr != nil {
|
||||
log.Errorf("index: %s in %s (detect nsfw)", modelErr, m.RootRelName())
|
||||
return false
|
||||
} else {
|
||||
if nsfwLabels.NSFW(nsfw.ThresholdHigh) {
|
||||
} else if len(results) < 1 {
|
||||
log.Errorf("index: nsfw model returned no result for %s", m.RootRelName())
|
||||
return false
|
||||
} else if results[0].NSFW(nsfw.ThresholdHigh) {
|
||||
log.Warnf("index: %s might contain offensive content", clean.Log(m.RelName(Config().OriginalsPath())))
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/entity/query"
|
||||
"github.com/photoprism/photoprism/pkg/rnd"
|
||||
@@ -15,7 +13,7 @@ import (
|
||||
|
||||
func TestIndexRelated(t *testing.T) {
|
||||
t.Run("2018-04-12 19_24_49.gif", func(t *testing.T) {
|
||||
conf := config.TestConfig()
|
||||
cfg := config.TestConfig()
|
||||
|
||||
testFile, err := NewMediaFile("testdata/2018-04-12 19_24_49.gif")
|
||||
|
||||
@@ -30,7 +28,7 @@ func TestIndexRelated(t *testing.T) {
|
||||
}
|
||||
|
||||
testToken := rnd.Base36(8)
|
||||
testPath := filepath.Join(conf.OriginalsPath(), testToken)
|
||||
testPath := filepath.Join(cfg.OriginalsPath(), testToken)
|
||||
|
||||
for _, f := range testRelated.Files {
|
||||
dest := filepath.Join(testPath, f.BaseName())
|
||||
@@ -52,11 +50,8 @@ func TestIndexRelated(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nd := nsfw.NewModel(conf.NSFWModelPath())
|
||||
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
|
||||
convert := NewConvert(conf)
|
||||
|
||||
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
|
||||
convert := NewConvert(cfg)
|
||||
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
|
||||
opt := IndexOptionsAll()
|
||||
|
||||
result := IndexRelated(related, ind, opt)
|
||||
@@ -75,7 +70,7 @@ func TestIndexRelated(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("apple-test-2.jpg", func(t *testing.T) {
|
||||
conf := config.TestConfig()
|
||||
cfg := config.TestConfig()
|
||||
|
||||
testFile, err := NewMediaFile("testdata/apple-test-2.jpg")
|
||||
|
||||
@@ -90,7 +85,7 @@ func TestIndexRelated(t *testing.T) {
|
||||
}
|
||||
|
||||
testToken := rnd.Base36(8)
|
||||
testPath := filepath.Join(conf.OriginalsPath(), testToken)
|
||||
testPath := filepath.Join(cfg.OriginalsPath(), testToken)
|
||||
|
||||
for _, f := range testRelated.Files {
|
||||
dest := filepath.Join(testPath, f.BaseName())
|
||||
@@ -112,11 +107,8 @@ func TestIndexRelated(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nd := nsfw.NewModel(conf.NSFWModelPath())
|
||||
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
|
||||
convert := NewConvert(conf)
|
||||
|
||||
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
|
||||
convert := NewConvert(cfg)
|
||||
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
|
||||
opt := IndexOptionsAll()
|
||||
|
||||
result := IndexRelated(related, ind, opt)
|
||||
|
||||
@@ -7,8 +7,6 @@ import (
|
||||
"github.com/dustin/go-humanize/english"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
)
|
||||
|
||||
@@ -17,17 +15,13 @@ func TestIndex_Start(t *testing.T) {
|
||||
t.Skip("skipping test in short mode.")
|
||||
}
|
||||
|
||||
conf := config.TestConfig()
|
||||
cfg := config.TestConfig()
|
||||
cfg.InitializeTestData()
|
||||
|
||||
conf.InitializeTestData()
|
||||
|
||||
nd := nsfw.NewModel(conf.NSFWModelPath())
|
||||
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
|
||||
convert := NewConvert(conf)
|
||||
|
||||
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
|
||||
imp := NewImport(conf, ind, convert)
|
||||
opt := ImportOptionsMove(conf.ImportPath(), "")
|
||||
convert := NewConvert(cfg)
|
||||
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
|
||||
imp := NewImport(cfg, ind, convert)
|
||||
opt := ImportOptionsMove(cfg.ImportPath(), "")
|
||||
|
||||
imp.Start(opt)
|
||||
|
||||
@@ -65,15 +59,11 @@ func TestIndex_File(t *testing.T) {
|
||||
t.Skip("skipping test in short mode.")
|
||||
}
|
||||
|
||||
conf := config.TestConfig()
|
||||
cfg := config.TestConfig()
|
||||
cfg.InitializeTestData()
|
||||
|
||||
conf.InitializeTestData()
|
||||
|
||||
nd := nsfw.NewModel(conf.NSFWModelPath())
|
||||
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
|
||||
convert := NewConvert(conf)
|
||||
|
||||
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
|
||||
convert := NewConvert(cfg)
|
||||
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
|
||||
|
||||
err := ind.FileName("xxx", IndexOptionsAll())
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
|
||||
"github.com/disintegration/imaging"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/internal/thumb"
|
||||
@@ -21,26 +19,23 @@ func TestResample_Start(t *testing.T) {
|
||||
t.Skip("skipping test in short mode.")
|
||||
}
|
||||
|
||||
conf := config.TestConfig()
|
||||
cfg := config.TestConfig()
|
||||
|
||||
if err := conf.CreateDirectories(); err != nil {
|
||||
if err := cfg.CreateDirectories(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conf.InitializeTestData()
|
||||
cfg.InitializeTestData()
|
||||
|
||||
nd := nsfw.NewModel(conf.NSFWModelPath())
|
||||
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
|
||||
convert := NewConvert(conf)
|
||||
convert := NewConvert(cfg)
|
||||
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
|
||||
|
||||
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
|
||||
|
||||
imp := NewImport(conf, ind, convert)
|
||||
opt := ImportOptionsMove(conf.ImportPath(), "")
|
||||
imp := NewImport(cfg, ind, convert)
|
||||
opt := ImportOptionsMove(cfg.ImportPath(), "")
|
||||
|
||||
imp.Start(opt)
|
||||
|
||||
rs := NewThumbs(conf)
|
||||
rs := NewThumbs(cfg)
|
||||
|
||||
err := rs.Start("", true, false)
|
||||
|
||||
|
||||
@@ -164,9 +164,10 @@ func registerRoutes(router *gin.Engine, conf *config.Config) {
|
||||
api.FolderCover(APIv1)
|
||||
|
||||
// Computer Vision.
|
||||
api.PostVisionCaption(APIv1)
|
||||
api.PostVisionFaces(APIv1)
|
||||
api.PostVisionLabels(APIv1)
|
||||
api.PostVisionNsfw(APIv1)
|
||||
api.PostVisionFaceEmbeddings(APIv1)
|
||||
api.PostVisionCaption(APIv1)
|
||||
|
||||
// People.
|
||||
api.SearchSubjects(APIv1)
|
||||
|
||||
@@ -39,7 +39,7 @@ var thumbFileSizes = []thumb.Size{
|
||||
}
|
||||
|
||||
// ImageFromThumb returns a cropped area from an existing thumbnail image.
|
||||
func ImageFromThumb(thumbName string, area Area, size Size, cache bool) (img image.Image, err error) {
|
||||
func ImageFromThumb(thumbName string, area Area, size Size, cache bool) (img image.Image, cropName string, err error) {
|
||||
// Use same folder for caching if "cache" is true.
|
||||
filePath := filepath.Dir(thumbName)
|
||||
|
||||
@@ -48,12 +48,12 @@ func ImageFromThumb(thumbName string, area Area, size Size, cache bool) (img ima
|
||||
|
||||
// Resolve symlinks.
|
||||
if thumbName, err = fs.Resolve(thumbName); err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Compose cached crop image file name.
|
||||
cropBase := fmt.Sprintf("%s_%dx%d_crop_%s%s", hash, size.Width, size.Height, area.String(), fs.ExtJpeg)
|
||||
cropName := filepath.Join(filePath, cropBase)
|
||||
cropName = filepath.Join(filePath, cropBase)
|
||||
|
||||
// Cached?
|
||||
if !fs.FileExists(cropName) {
|
||||
@@ -61,14 +61,14 @@ func ImageFromThumb(thumbName string, area Area, size Size, cache bool) (img ima
|
||||
} else if cropImg, cropErr := imaging.Open(cropName); cropErr != nil {
|
||||
log.Errorf("crop: failed loading %s", filepath.Base(cropName))
|
||||
} else {
|
||||
return cropImg, nil
|
||||
return cropImg, cropName, nil
|
||||
}
|
||||
|
||||
// Open thumb image file.
|
||||
img, err = openIdealThumbFile(thumbName, hash, area, size)
|
||||
|
||||
if err != nil {
|
||||
return img, err
|
||||
return img, "", err
|
||||
}
|
||||
|
||||
// Get absolute crop coordinates and dimension.
|
||||
@@ -93,7 +93,7 @@ func ImageFromThumb(thumbName string, area Area, size Size, cache bool) (img ima
|
||||
}
|
||||
}
|
||||
|
||||
return img, nil
|
||||
return img, cropName, nil
|
||||
}
|
||||
|
||||
// ThumbFileName returns the ideal thumb file name.
|
||||
|
||||
@@ -1,69 +1,98 @@
|
||||
package media
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/header"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
// DataUrl returns a data URL representing the binary buffer data.
|
||||
func DataUrl(buf *bytes.Buffer) string {
|
||||
encoded := EncodeBase64(buf.Bytes())
|
||||
// DataUrl generates a data URL of the binary data from the specified io.Reader.
|
||||
func DataUrl(r io.Reader) string {
|
||||
// Read binary data.
|
||||
data, err := io.ReadAll(r)
|
||||
|
||||
if encoded == "" {
|
||||
if err != nil || len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Return as string if it already appears to be a data URL.
|
||||
if string(data[0:4]) == "data:" {
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// Detect mime type.
|
||||
var mime *mimetype.MIME
|
||||
var mimeType string
|
||||
|
||||
mime, err := mimetype.DetectReader(buf)
|
||||
|
||||
if err != nil {
|
||||
mimeType = "application/octet-stream"
|
||||
if mime = mimetype.Detect(data); mime == nil {
|
||||
mimeType = header.ContentTypeBinary
|
||||
} else {
|
||||
mimeType = mime.String()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded)
|
||||
// Generate data URL.
|
||||
return fmt.Sprintf("data:%s;base64,%s", mimeType, EncodeBase64(data))
|
||||
}
|
||||
|
||||
// ReadUrl reads binary data from a regular file path,
|
||||
// fetches its data from a remote http or https URL,
|
||||
// or decodes a base64 data URL as created by DataUrl.
|
||||
func ReadUrl(file string) (data []byte, err error) {
|
||||
u, err := url.Parse(file)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
func ReadUrl(fileUrl string, schemes []string) (data []byte, err error) {
|
||||
if fileUrl == "" {
|
||||
return data, errors.New("missing url")
|
||||
}
|
||||
|
||||
// Also supports http, https, and data URLs instead of a file name for remote processing.
|
||||
if u.Scheme == "http" || u.Scheme == "https" {
|
||||
resp, httpErr := http.Get(file)
|
||||
// Parse file URL.
|
||||
var u *url.URL
|
||||
|
||||
if u, err = url.Parse(fileUrl); err != nil {
|
||||
return data, fmt.Errorf("invalid url (%s)", err)
|
||||
}
|
||||
|
||||
// Reject it if it is not absolute, i.e. it does not contain a scheme.
|
||||
if !u.IsAbs() {
|
||||
return data, fmt.Errorf("url %s requires a scheme", clean.Log(fileUrl))
|
||||
} else if !slices.Contains(schemes, u.Scheme) {
|
||||
return data, fmt.Errorf("invalid url scheme %s", clean.Log(u.Scheme))
|
||||
}
|
||||
|
||||
// Fetch the file data from the specified URL, depending on its scheme.
|
||||
switch u.Scheme {
|
||||
case scheme.Https, scheme.Http, scheme.Unix, scheme.HttpUnix:
|
||||
resp, httpErr := http.Get(fileUrl)
|
||||
|
||||
if httpErr != nil {
|
||||
return nil, httpErr
|
||||
return data, fmt.Errorf("invalid %s url (%s)", u.Scheme, httpErr)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
if data, err = io.ReadAll(resp.Body); err != nil {
|
||||
return nil, err
|
||||
return data, err
|
||||
}
|
||||
} else if u.Scheme == "data" {
|
||||
case scheme.Data:
|
||||
if _, binaryData, found := strings.Cut(u.Opaque, ";base64,"); !found || len(binaryData) == 0 {
|
||||
return nil, fmt.Errorf("invalid data URL")
|
||||
return data, fmt.Errorf("invalid %s url", u.Scheme)
|
||||
} else {
|
||||
return DecodeBase64(binaryData)
|
||||
}
|
||||
} else if data, err = os.ReadFile(file); err != nil {
|
||||
return nil, err
|
||||
case scheme.File:
|
||||
if data, err = os.ReadFile(fileUrl); err != nil {
|
||||
return data, fmt.Errorf("invalid %s url (%s)", u.Scheme, err)
|
||||
}
|
||||
default:
|
||||
return data, fmt.Errorf("unsupported url scheme %s", clean.Log(u.Scheme))
|
||||
}
|
||||
|
||||
return data, err
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package media
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -16,17 +15,14 @@ func gopherPng() io.Reader { return ReadBase64(strings.NewReader(gopher)) }
|
||||
|
||||
func TestDataUrl(t *testing.T) {
|
||||
t.Run("Gopher", func(t *testing.T) {
|
||||
buf := new(bytes.Buffer)
|
||||
_, bufErr := buf.ReadFrom(gopherPng())
|
||||
assert.NoError(t, bufErr)
|
||||
assert.Equal(t, "data:image/png;base64,"+gopher, DataUrl(buf))
|
||||
assert.Equal(t, "data:image/png;base64,"+gopher, DataUrl(gopherPng()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadUrl(t *testing.T) {
|
||||
t.Run("Gopher", func(t *testing.T) {
|
||||
dataUrl := "data:image/png;base64," + gopher
|
||||
if data, err := ReadUrl(dataUrl); err != nil {
|
||||
if data, err := ReadUrl(dataUrl, []string{"https", "data"}); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
expected, _ := DecodeBase64(gopher)
|
||||
|
||||
17
pkg/media/http/scheme/const.go
Normal file
17
pkg/media/http/scheme/const.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package scheme
|
||||
|
||||
const (
|
||||
File = "file"
|
||||
Data = "data"
|
||||
Http = "http"
|
||||
Https = "https"
|
||||
HttpUnix = Http + "+" + Unix
|
||||
Websocket = "wss"
|
||||
Unix = "unix"
|
||||
Unixgram = "unixgram"
|
||||
Unixpacket = "unixpacket"
|
||||
)
|
||||
|
||||
var (
|
||||
HttpsData = []string{Https, Data}
|
||||
)
|
||||
@@ -1,8 +0,0 @@
|
||||
package scheme
|
||||
|
||||
const (
|
||||
Http = "http"
|
||||
Https = "https"
|
||||
HttpUnix = Http + "+" + Unix
|
||||
Websocket = "wss"
|
||||
)
|
||||
@@ -1,7 +0,0 @@
|
||||
package scheme
|
||||
|
||||
const (
|
||||
Unix = "unix"
|
||||
Unixgram = "unixgram"
|
||||
Unixpacket = "unixpacket"
|
||||
)
|
||||
9
pkg/media/source.go
Normal file
9
pkg/media/source.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package media
|
||||
|
||||
type Src = string
|
||||
|
||||
// Data source types.
|
||||
const (
|
||||
SrcLocal Src = "local"
|
||||
SrcRemote Src = "remote"
|
||||
)
|
||||
Reference in New Issue
Block a user