AI: Refactor use of face embeddings, labels, and nsfw models #127 #1090

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-04-10 16:49:34 +02:00
parent ecef34a8da
commit caf3ae1ab5
52 changed files with 1040 additions and 396 deletions

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"math"
"os"
"path"
"runtime/debug"
"sort"
@@ -15,6 +16,7 @@ import (
"github.com/photoprism/photoprism/internal/ai/tensorflow"
"github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
// Model represents a TensorFlow classification model.
@@ -48,23 +50,38 @@ func (m *Model) Init() (err error) {
return m.loadModel()
}
// File returns matching labels for a jpeg media file.
func (m *Model) File(imageUri string, confidenceThreshold int) (result Labels, err error) {
// File returns matching labels for a local jpeg file.
func (m *Model) File(fileName string, confidenceThreshold int) (result Labels, err error) {
if m.disabled {
return nil, nil
}
var data []byte
if data, err = media.ReadUrl(imageUri); err != nil {
if data, err = os.ReadFile(fileName); err != nil {
return nil, err
}
return m.Labels(data, confidenceThreshold)
return m.Run(data, confidenceThreshold)
}
// Labels returns matching labels for a jpeg media string.
func (m *Model) Labels(img []byte, confidenceThreshold int) (result Labels, err error) {
// Url returns matching labels for a remote jpeg file.
func (m *Model) Url(imgUrl string, confidenceThreshold int) (result Labels, err error) {
if m.disabled {
return nil, nil
}
var data []byte
if data, err = media.ReadUrl(imgUrl, scheme.HttpsData); err != nil {
return nil, err
}
return m.Run(data, confidenceThreshold)
}
// Run returns matching labels for the specified JPEG image.
func (m *Model) Run(img []byte, confidenceThreshold int) (result Labels, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("classify: %s (inference panic)\nstack: %s", r, debug.Stack())

View File

@@ -116,7 +116,7 @@ func TestModel_LabelsFromFile(t *testing.T) {
})
}
func TestModel_Labels(t *testing.T) {
func TestModel_Run(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
@@ -127,7 +127,7 @@ func TestModel_Labels(t *testing.T) {
if imageBuffer, err := os.ReadFile(examplesPath + "/chameleon_lime.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Labels(imageBuffer, 10)
result, err := tensorFlow.Run(imageBuffer, 10)
t.Log(result)
@@ -151,7 +151,7 @@ func TestModel_Labels(t *testing.T) {
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Labels(imageBuffer, 10)
result, err := tensorFlow.Run(imageBuffer, 10)
t.Log(result)
@@ -175,7 +175,7 @@ func TestModel_Labels(t *testing.T) {
if imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Labels(imageBuffer, 10)
result, err := tensorFlow.Run(imageBuffer, 10)
assert.Empty(t, result)
assert.Error(t, err)
}
@@ -186,7 +186,7 @@ func TestModel_Labels(t *testing.T) {
if imageBuffer, err := os.ReadFile(examplesPath + "/6720px_white.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Labels(imageBuffer, 10)
result, err := tensorFlow.Run(imageBuffer, 10)
if err != nil {
t.Fatal(err)
@@ -201,7 +201,7 @@ func TestModel_Labels(t *testing.T) {
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Labels(imageBuffer, 10)
result, err := tensorFlow.Run(imageBuffer, 10)
t.Log(result)

View File

@@ -27,8 +27,14 @@ type Model struct {
}
// NewModel returns a new TensorFlow Facenet instance.
func NewModel(modelPath, cachePath string, disabled bool) *Model {
return &Model{modelPath: modelPath, cachePath: cachePath, resolution: CropSize.Width, modelTags: []string{"serve"}, disabled: disabled}
func NewModel(modelPath, cachePath string, resolution int, tags []string, disabled bool) *Model {
if resolution == 0 {
resolution = CropSize.Width
}
if len(tags) == 0 {
tags = []string{"serve"}
}
return &Model{modelPath: modelPath, cachePath: cachePath, resolution: resolution, modelTags: tags, disabled: disabled}
}
// Detect runs the detection and facenet algorithms over the provided source image.
@@ -57,9 +63,9 @@ func (m *Model) Detect(fileName string, minSize int, cacheCrop bool, expected in
continue
}
if img, imgErr := crop.ImageFromThumb(fileName, f.CropArea(), CropSize, cacheCrop); imgErr != nil {
if img, _, imgErr := crop.ImageFromThumb(fileName, f.CropArea(), CropSize, cacheCrop); imgErr != nil {
log.Errorf("faces: failed to decode image: %s", imgErr)
} else if embeddings := m.getEmbeddings(img); !embeddings.Empty() {
} else if embeddings := m.Run(img); !embeddings.Empty() {
faces[i].Embeddings = embeddings
}
}
@@ -67,6 +73,15 @@ func (m *Model) Detect(fileName string, minSize int, cacheCrop bool, expected in
return faces, nil
}
// Init initialises tensorflow models if not disabled
func (m *Model) Init() (err error) {
if m.disabled {
return nil
}
return m.loadModel()
}
// ModelLoaded tests if the TensorFlow model is loaded.
func (m *Model) ModelLoaded() bool {
return m.model != nil
@@ -99,8 +114,8 @@ func (m *Model) loadModel() error {
return nil
}
// getEmbeddings returns the face embeddings for an image.
func (m *Model) getEmbeddings(img image.Image) Embeddings {
// Run returns the face embeddings for an image.
func (m *Model) Run(img image.Image) Embeddings {
tensor, err := imageToTensor(img, m.resolution)
if err != nil {

View File

@@ -54,7 +54,7 @@ func TestNet(t *testing.T) {
var embeddings = make(Embeddings, 11)
faceNet := NewModel(modelPath, "testdata/cache", false)
faceNet := NewModel(modelPath, "testdata/cache", 160, []string{"serve"}, false)
if err := fastwalk.Walk("testdata", func(fileName string, info os.FileMode) error {
if info.IsDir() || filepath.Base(filepath.Dir(fileName)) != "testdata" {

View File

@@ -11,7 +11,9 @@ import (
"github.com/photoprism/photoprism/internal/ai/tensorflow"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/media/http/header"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
// Model uses TensorFlow to label drawing, hentai, neutral, porn and sexy images.
@@ -21,31 +23,53 @@ type Model struct {
resolution int
modelTags []string
labels []string
disabled bool
mutex sync.Mutex
}
// NewModel returns a new detector instance.
func NewModel(modelPath string) *Model {
return &Model{modelPath: modelPath, resolution: 224, modelTags: []string{"serve"}}
func NewModel(modelPath string, resolution int, tags []string, disabled bool) *Model {
if resolution <= 0 {
resolution = 224
}
if len(tags) == 0 {
tags = []string{"serve"}
}
return &Model{modelPath: modelPath, resolution: resolution, modelTags: tags, disabled: disabled}
}
// File returns matching labels for a jpeg media file.
func (m *Model) File(filename string) (result Labels, err error) {
if fs.MimeType(filename) != header.ContentTypeJpeg {
return result, fmt.Errorf("nsfw: %s is not a jpeg file", clean.Log(filepath.Base(filename)))
// File checks the specified JPEG file for inappropriate content.
func (m *Model) File(fileName string) (result Result, err error) {
if fs.MimeType(fileName) != header.ContentTypeJpeg {
return result, fmt.Errorf("nsfw: %s is not a jpeg file", clean.Log(filepath.Base(fileName)))
}
imageBuffer, err := os.ReadFile(filename)
var img []byte
if err != nil {
if img, err = os.ReadFile(fileName); err != nil {
return result, err
}
return m.Labels(imageBuffer)
return m.Run(img)
}
// Labels returns matching labels for a jpeg media string.
func (m *Model) Labels(img []byte) (result Labels, err error) {
// Url checks the JPEG file from the specified https or data URL for inappropriate content.
func (m *Model) Url(imgUrl string) (result Result, err error) {
if m.disabled {
return result, nil
}
var img []byte
if img, err = media.ReadUrl(imgUrl, scheme.HttpsData); err != nil {
return result, err
}
return m.Run(img)
}
// Run returns matching labels for a jpeg media string.
func (m *Model) Run(img []byte) (result Result, err error) {
if loadErr := m.loadModel(); loadErr != nil {
return result, loadErr
}
@@ -83,11 +107,15 @@ func (m *Model) Labels(img []byte) (result Labels, err error) {
return result, nil
}
func (m *Model) loadLabels(modelPath string) (err error) {
m.labels, err = tensorflow.LoadLabels(modelPath)
// Init initialises tensorflow models if not disabled
func (m *Model) Init() (err error) {
if m.disabled {
return nil
}
return m.loadModel()
}
func (m *Model) loadModel() error {
// Use mutex to prevent the model from being loaded and
// initialized twice by different indexing workers.
@@ -113,8 +141,13 @@ func (m *Model) loadModel() error {
return m.loadLabels(m.modelPath)
}
func (m *Model) getLabels(p []float32) Labels {
return Labels{
func (m *Model) loadLabels(modelPath string) (err error) {
m.labels, err = tensorflow.LoadLabels(modelPath)
return nil
}
func (m *Model) getLabels(p []float32) Result {
return Result{
Drawing: p[0],
Hentai: p[1],
Neutral: p[2],

View File

@@ -36,7 +36,7 @@ const (
var log = event.Log
type Labels struct {
type Result struct {
Drawing float32
Hentai float32
Neutral float32
@@ -45,12 +45,12 @@ type Labels struct {
}
// IsSafe returns true if the image is probably safe for work.
func (l *Labels) IsSafe() bool {
func (l *Result) IsSafe() bool {
return !l.NSFW(ThresholdSafe)
}
// NSFW returns true if the image is may not be safe for work.
func (l *Labels) NSFW(threshold float32) bool {
func (l *Result) NSFW(threshold float32) bool {
if l.Neutral > 0.25 {
return false
}

View File

@@ -13,10 +13,10 @@ import (
var modelPath, _ = filepath.Abs("../../../assets/nsfw")
var detector = NewModel(modelPath)
var detector = NewModel(modelPath, 224, nil, false)
func TestIsSafe(t *testing.T) {
detect := func(filename string) Labels {
detect := func(filename string) Result {
result, err := detector.File(filename)
if err != nil {
@@ -24,12 +24,12 @@ func TestIsSafe(t *testing.T) {
}
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.IsType(t, Result{}, result)
return result
}
expected := map[string]Labels{
expected := map[string]Result{
"beach_sand.jpg": {0, 0, 0.9, 0, 0},
"beach_wood.jpg": {0, 0, 0.36, 0.59, 0},
"cat_brown.jpg": {0, 0, 0.93, 0, 0},
@@ -98,11 +98,11 @@ func TestIsSafe(t *testing.T) {
}
func TestNSFW(t *testing.T) {
porn := Labels{0, 0, 0.11, 0.88, 0}
sexy := Labels{0, 0, 0.2, 0.59, 0.98}
maxi := Labels{0, 0.999, 0.1, 0.999, 0.999}
drawing := Labels{0.999, 0, 0, 0, 0}
hentai := Labels{0, 0.80, 0.2, 0, 0}
porn := Result{0, 0, 0.11, 0.88, 0}
sexy := Result{0, 0, 0.2, 0.59, 0.98}
maxi := Result{0, 0.999, 0.1, 0.999, 0.999}
drawing := Result{0.999, 0, 0, 0, 0}
hentai := Result{0, 0.80, 0.2, 0, 0}
assert.Equal(t, true, porn.NSFW(ThresholdSafe))
assert.Equal(t, true, sexy.NSFW(ThresholdSafe))

View File

@@ -2,39 +2,59 @@ package vision
import (
"encoding/json"
"path"
"fmt"
"os"
"strings"
"github.com/photoprism/photoprism/internal/api/download"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
"github.com/photoprism/photoprism/pkg/rnd"
)
const (
LabelsEndpoint = "labels"
)
type Files = []string
// ApiRequest represents a Vision API service request.
type ApiRequest struct {
Id string `form:"id" yaml:"Id,omitempty" json:"id,omitempty"`
Model string `form:"model" yaml:"Model,omitempty" json:"model,omitempty"`
Images []string `form:"images" yaml:"Images,omitempty" json:"images,omitempty"`
Videos []string `form:"videos" yaml:"Videos,omitempty" json:"videos,omitempty"`
Images Files `form:"images" yaml:"Images,omitempty" json:"images,omitempty"`
}
func NewClientRequest(model string, images []string) *ApiRequest {
imageUrls := make([]string, 0, len(images))
// NewClientRequest returns a new Vision API request with the specified file payload and scheme.
func NewClientRequest(images Files, fileScheme string) (*ApiRequest, error) {
imageUrls := make(Files, len(images))
if fileScheme == scheme.Https && !strings.HasPrefix(DownloadUrl, "https://") {
log.Tracef("vision: file request scheme changed from https to data because https is not configured")
fileScheme = scheme.Data
}
for i := range images {
switch fileScheme {
case scheme.Https:
if id, err := download.Register(images[i]); err != nil {
log.Errorf("vision: %s (register download)", err)
return nil, fmt.Errorf("%s (register download)", err)
} else {
imageUrls = append(imageUrls, path.Join(DownloadUrl, id))
imageUrls[i] = fmt.Sprintf("%s/%s", DownloadUrl, id)
}
case scheme.Data:
if file, err := os.Open(images[i]); err != nil {
return nil, fmt.Errorf("%s (create data url)", err)
} else {
imageUrls[i] = media.DataUrl(file)
}
default:
return nil, fmt.Errorf("invalid file scheme %s", clean.Log(fileScheme))
}
}
return &ApiRequest{
Id: rnd.UUID(),
Model: "",
Images: imageUrls,
}
}, nil
}
// GetId returns the request ID string and generates a random ID if none was set.
@@ -48,5 +68,5 @@ func (r *ApiRequest) GetId() string {
// MarshalJSON returns request as JSON.
func (r *ApiRequest) MarshalJSON() ([]byte, error) {
return json.Marshal(r)
return json.Marshal(*r)
}

View File

@@ -0,0 +1,46 @@
package vision
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
func TestNewClientRequest(t *testing.T) {
var assetsPath = fs.Abs("../../../assets")
var examplesPath = assetsPath + "/examples"
t.Run("Data", func(t *testing.T) {
thumbnails := Files{examplesPath + "/chameleon_lime.jpg"}
result, err := NewClientRequest(thumbnails, scheme.Data)
assert.NoError(t, err)
assert.NotNil(t, result)
// t.Logf("request: %#v", result)
if result != nil {
json, jsonErr := result.MarshalJSON()
assert.NoError(t, jsonErr)
assert.NotEmpty(t, json)
// t.Logf("json: %s", json)
}
})
t.Run("Https", func(t *testing.T) {
thumbnails := Files{examplesPath + "/chameleon_lime.jpg"}
result, err := NewClientRequest(thumbnails, scheme.Https)
assert.NoError(t, err)
assert.NotNil(t, result)
// t.Logf("request: %#v", result)
if result != nil {
json, jsonErr := result.MarshalJSON()
assert.NoError(t, jsonErr)
assert.NotEmpty(t, json)
t.Logf("json: %s", json)
}
})
}

View File

@@ -5,6 +5,8 @@ import (
"net/http"
"github.com/photoprism/photoprism/internal/ai/classify"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/pkg/clean"
)
@@ -20,9 +22,10 @@ type ApiResponse struct {
// ApiResult represents the model response(s) to a Vision API service
// request and can optionally include data from multiple models.
type ApiResult struct {
Caption *CaptionResult `yaml:"Caption,omitempty" json:"caption,omitempty"`
Faces *[]string `yaml:"Faces,omitempty" json:"faces,omitempty"`
Labels []LabelResult `yaml:"Labels,omitempty" json:"labels,omitempty"`
Nsfw []nsfw.Result `yaml:"Nsfw,omitempty" json:"nsfw,omitempty"`
Embeddings []face.Embeddings `yaml:"Embeddings,omitempty" json:"embeddings,omitempty"`
Caption *CaptionResult `yaml:"Caption,omitempty" json:"caption,omitempty"`
}
// CaptionResult represents the result generated by a caption generation model.
@@ -41,6 +44,7 @@ type LabelResult struct {
Categories []string `yaml:"Categories,omitempty" json:"categories,omitempty"`
}
// ToClassify returns the label results as classify.Label.
func (r LabelResult) ToClassify() classify.Label {
uncertainty := math.RoundToEven(float64(100 - r.Confidence*100))
return classify.Label{

View File

@@ -0,0 +1,16 @@
package vision
import (
"github.com/photoprism/photoprism/pkg/fs"
)
var (
AssetsPath = fs.Abs("../../../assets")
FaceNetModelPath = fs.Abs("../../../assets/facenet")
NsfwModelPath = fs.Abs("../../../assets/nsfw")
CachePath = fs.Abs("../../../storage/cache")
ServiceUri = ""
ServiceKey = ""
DownloadUrl = ""
DefaultResolution = 224
)

View File

@@ -0,0 +1,10 @@
package vision
type ModelType = string
const (
ModelTypeLabels ModelType = "labels"
ModelTypeNsfw ModelType = "nsfw"
ModelTypeFaceEmbeddings ModelType = "face/embeddings"
ModelTypeCaption ModelType = "caption"
)

View File

@@ -0,0 +1,37 @@
package vision
import (
"bytes"
"errors"
"fmt"
"image/jpeg"
"github.com/photoprism/photoprism/internal/ai/face"
)
// FaceEmbeddings returns the embeddings for the specified face jpeg.
func FaceEmbeddings(imgData []byte) (embeddings face.Embeddings, err error) {
if len(imgData) == 0 {
return embeddings, errors.New("missing image")
}
if Config == nil {
return embeddings, errors.New("missing configuration")
} else if model := Config.Model(ModelTypeFaceEmbeddings); model != nil {
img, imgErr := jpeg.Decode(bytes.NewReader(imgData))
if imgErr != nil {
return embeddings, imgErr
}
if tf := model.FaceModel(); tf == nil {
return embeddings, fmt.Errorf("invalid face model configuration")
} else if embeddings = tf.Run(img); !embeddings.Empty() {
return embeddings, nil
} else {
return face.Embeddings{}, nil
}
} else {
return embeddings, fmt.Errorf("no face model configured")
}
}

119
internal/ai/vision/faces.go Normal file
View File

@@ -0,0 +1,119 @@
package vision
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/thumb/crop"
"github.com/photoprism/photoprism/pkg/media/http/header"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
// Faces runs face detection and facenet algorithms over the provided source image.
func Faces(fileName string, minSize int, cacheCrop bool, expected int) (faces face.Faces, err error) {
if fileName == "" {
return faces, errors.New("missing image filename")
}
// Return if there is no configuration or no image classification models are configured.
if Config == nil {
return faces, errors.New("missing configuration")
} else if model := Config.Model(ModelTypeFaceEmbeddings); model != nil {
faces, err = face.Detect(fileName, false, minSize)
if err != nil {
return faces, err
}
// Skip embeddings?
if c := len(faces); c == 0 || expected > 0 && c == expected {
return faces, nil
}
if uri, method := model.Endpoint(); uri != "" && method != "" {
faceCrops := make([]string, len(faces))
for i, f := range faces {
if f.Area.Col == 0 && f.Area.Row == 0 {
faceCrops[i] = ""
continue
}
if _, faceCrop, imgErr := crop.ImageFromThumb(fileName, f.CropArea(), face.CropSize, cacheCrop); imgErr != nil {
log.Errorf("faces: failed to decode image: %s", imgErr)
faceCrops[i] = ""
} else if faceCrop != "" {
faceCrops[i] = faceCrop
}
}
apiRequest, apiRequestErr := NewClientRequest(faceCrops, scheme.Data)
if apiRequestErr != nil {
return faces, apiRequestErr
}
if model.Name != "" {
apiRequest.Model = model.Name
}
data, jsonErr := apiRequest.MarshalJSON()
if jsonErr != nil {
return faces, jsonErr
}
// Create HTTP client and authenticated service API request.
client := http.Client{}
req, reqErr := http.NewRequest(method, uri, bytes.NewReader(data))
header.SetAuthorization(req, model.EndpointKey())
if reqErr != nil {
return faces, reqErr
}
// Perform API request.
clientResp, clientErr := client.Do(req)
if clientErr != nil {
return faces, clientErr
}
apiResponse := &ApiResponse{}
if apiJson, apiErr := io.ReadAll(clientResp.Body); apiErr != nil {
return faces, apiErr
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
return faces, apiErr
}
for i := range faces {
if len(apiResponse.Result.Embeddings) > i {
faces[i].Embeddings = apiResponse.Result.Embeddings[i]
}
}
} else if tf := model.FaceModel(); tf != nil {
for i, f := range faces {
if f.Area.Col == 0 && f.Area.Row == 0 {
continue
}
if img, _, imgErr := crop.ImageFromThumb(fileName, f.CropArea(), face.CropSize, cacheCrop); imgErr != nil {
log.Errorf("faces: failed to decode image: %s", imgErr)
} else if embeddings := tf.Run(img); !embeddings.Empty() {
faces[i].Embeddings = embeddings
}
}
} else {
return faces, errors.New("invalid face model configuration")
}
} else {
return faces, errors.New("missing face model")
}
return faces, nil
}

View File

@@ -4,33 +4,41 @@ import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"sort"
"github.com/photoprism/photoprism/internal/ai/classify"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/media/http/header"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
// Labels returns suitable labels for the specified image thumbnail.
func Labels(thumbnails []string) (result classify.Labels, err error) {
func Labels(images Files, src media.Src) (result classify.Labels, err error) {
// Return if no thumbnail filenames were given.
if len(thumbnails) == 0 {
return result, errors.New("missing thumbnail filenames")
if len(images) == 0 {
return result, errors.New("missing image filenames")
}
// Return if there is no configuration or no image classification models are configured.
if Config == nil {
return result, errors.New("missing configuration")
} else if len(Config.Labels) == 0 {
return result, errors.New("missing labels model configuration")
} else if model := Config.Model(ModelTypeLabels); model != nil {
// Use remote service API if a server endpoint has been configured.
if uri, method := model.Endpoint(); uri != "" && method != "" {
apiRequest, apiRequestErr := NewClientRequest(images, scheme.Data)
if apiRequestErr != nil {
return result, apiRequestErr
}
if model.Name != "" {
apiRequest.Model = model.Name
}
// Use computer vision models configured for image classification.
for _, model := range Config.Labels {
// Use remote service API if a server endpoint has been configured.
if uri, method := model.Endpoint(LabelsEndpoint); uri != "" && method != "" {
apiRequest := NewClientRequest(model.Name, thumbnails)
data, jsonErr := apiRequest.MarshalJSON()
if jsonErr != nil {
@@ -67,18 +75,29 @@ func Labels(thumbnails []string) (result classify.Labels, err error) {
}
} else if tf := model.ClassifyModel(); tf != nil {
// Predict labels with local TensorFlow model.
for i := range thumbnails {
labels, modelErr := tf.File(thumbnails[i], Config.Thresholds.Confidence)
for i := range images {
var labels classify.Labels
if modelErr != nil {
return result, modelErr
switch src {
case media.SrcLocal:
labels, err = tf.File(images[i], Config.Thresholds.Confidence)
case media.SrcRemote:
labels, err = tf.Url(images[i], Config.Thresholds.Confidence)
default:
return result, fmt.Errorf("invalid image source %s", clean.Log(src))
}
if err != nil {
return result, err
}
result = mergeLabels(result, labels)
}
} else {
return result, errors.New("missing labels model")
return result, errors.New("invalid labels model configuration")
}
} else {
return result, errors.New("missing labels model")
}
sort.Sort(result)

View File

@@ -7,6 +7,7 @@ import (
"github.com/photoprism/photoprism/internal/ai/classify"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/media"
)
func TestLabels(t *testing.T) {
@@ -14,7 +15,7 @@ func TestLabels(t *testing.T) {
var examplesPath = assetsPath + "/examples"
t.Run("Success", func(t *testing.T) {
result, err := Labels([]string{examplesPath + "/chameleon_lime.jpg"})
result, err := Labels(Files{examplesPath + "/chameleon_lime.jpg"}, media.SrcLocal)
assert.NoError(t, err)
assert.IsType(t, classify.Labels{}, result)
@@ -26,7 +27,7 @@ func TestLabels(t *testing.T) {
assert.Equal(t, 7, result[0].Uncertainty)
})
t.Run("Cats", func(t *testing.T) {
result, err := Labels([]string{examplesPath + "/cat_720.jpeg"})
result, err := Labels(Files{examplesPath + "/cat_720.jpeg"}, media.SrcLocal)
assert.NoError(t, err)
assert.IsType(t, classify.Labels{}, result)
@@ -39,7 +40,7 @@ func TestLabels(t *testing.T) {
assert.InDelta(t, float32(0.4), result[0].Confidence(), 0.01)
})
t.Run("InvalidFile", func(t *testing.T) {
_, err := Labels([]string{examplesPath + "/notexisting.jpg"})
_, err := Labels(Files{examplesPath + "/notexisting.jpg"}, media.SrcLocal)
assert.Error(t, err)
})
}

View File

@@ -1,11 +1,14 @@
package vision
import (
"fmt"
"net/http"
"path"
"path/filepath"
"sync"
"github.com/photoprism/photoprism/internal/ai/classify"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/pkg/clean"
)
@@ -13,6 +16,7 @@ var modelMutex = sync.Mutex{}
// Model represents a computer vision model configuration.
type Model struct {
Type ModelType `yaml:"Type,omitempty" json:"type,omitempty"`
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
@@ -20,18 +24,19 @@ type Model struct {
Key string `yaml:"Key,omitempty" json:"-"`
Method string `yaml:"Method,omitempty" json:"-"`
Path string `yaml:"Path,omitempty" json:"-"`
Format string `yaml:"Format,omitempty" json:"-"`
Tags []string `yaml:"Tags,omitempty" json:"-"`
Disabled bool `yaml:"Disabled,omitempty" json:"-"`
classifyModel *classify.Model
faceModel *face.Model
nsfwModel *nsfw.Model
}
// Models represents a set of computer vision models.
type Models []*Model
// Endpoint returns the remote service request method and endpoint URL, if any.
func (m *Model) Endpoint(name string) (method, uri string) {
if m.Uri == "" && ServiceUri == "" {
func (m *Model) Endpoint() (uri, method string) {
if m.Uri == "" && ServiceUri == "" || m.Type == "" {
return "", ""
}
@@ -44,7 +49,7 @@ func (m *Model) Endpoint(name string) (method, uri string) {
if m.Uri != "" {
return m.Uri, method
} else {
return path.Join(ServiceUri, name), method
return fmt.Sprintf("%s/%s", ServiceUri, clean.TypeLowerUnderscore(m.Type)), method
}
}
@@ -114,3 +119,115 @@ func (m *Model) ClassifyModel() *classify.Model {
return m.classifyModel
}
// FaceModel returns the matching face model instance, if any.
func (m *Model) FaceModel() *face.Model {
// Use mutex to prevent models from being loaded and
// initialized twice by different indexing workers.
modelMutex.Lock()
defer modelMutex.Unlock()
// Return the existing model instance if it has already been created.
if m.faceModel != nil {
return m.faceModel
}
switch m.Name {
case "":
log.Warnf("vision: missing name, model instance cannot be created")
return nil
case FacenetModel.Name, "facenet":
// Load and initialize the Nasnet image classification model.
if model := face.NewModel(FaceNetModelPath, CachePath, m.Resolution, m.Tags, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path)
return nil
} else {
m.faceModel = model
}
default:
// Set model path from model name if no path is configured.
if m.Path == "" {
m.Path = clean.TypeLowerUnderscore(m.Name)
}
// Set default thumbnail resolution if no tags are configured.
if m.Resolution <= 0 {
m.Resolution = DefaultResolution
}
// Set default tag if no tags are configured.
if len(m.Tags) == 0 {
m.Tags = []string{"serve"}
}
// Try to load custom model based on the configuration values.
if model := face.NewModel(filepath.Join(AssetsPath, m.Path), CachePath, m.Resolution, m.Tags, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path)
return nil
} else {
m.faceModel = model
}
}
return m.faceModel
}
// NsfwModel returns the matching nsfw model instance, if any.
func (m *Model) NsfwModel() *nsfw.Model {
// Use mutex to prevent models from being loaded and
// initialized twice by different indexing workers.
modelMutex.Lock()
defer modelMutex.Unlock()
// Return the existing model instance if it has already been created.
if m.nsfwModel != nil {
return m.nsfwModel
}
switch m.Name {
case "":
log.Warnf("vision: missing name, model instance cannot be created")
return nil
case NsfwModel.Name, "nsfw":
// Load and initialize the Nasnet image classification model.
if model := nsfw.NewModel(NsfwModelPath, m.Resolution, m.Tags, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path)
return nil
} else {
m.nsfwModel = model
}
default:
// Set model path from model name if no path is configured.
if m.Path == "" {
m.Path = clean.TypeLowerUnderscore(m.Name)
}
// Set default thumbnail resolution if no tags are configured.
if m.Resolution <= 0 {
m.Resolution = DefaultResolution
}
// Set default tag if no tags are configured.
if len(m.Tags) == 0 {
m.Tags = []string{"serve"}
}
// Try to load custom model based on the configuration values.
if model := nsfw.NewModel(filepath.Join(AssetsPath, m.Path), m.Resolution, m.Tags, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path)
return nil
} else {
m.nsfwModel = model
}
}
return m.nsfwModel
}

View File

@@ -0,0 +1,21 @@
package vision
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
func TestModel(t *testing.T) {
t.Run("Nasnet", func(t *testing.T) {
ServiceUri = "https://app.localssl.dev/api/v1/vision"
uri, method := NasnetModel.Endpoint()
ServiceUri = ""
assert.Equal(t, "https://app.localssl.dev/api/v1/vision/labels", uri)
assert.Equal(t, http.MethodPost, method)
uri, method = NasnetModel.Endpoint()
assert.Equal(t, "", uri)
assert.Equal(t, "", method)
})
}

View File

@@ -1,20 +1,39 @@
package vision
import (
"github.com/photoprism/photoprism/pkg/fs"
"net/http"
)
type ModelType = string
// Default computer vision model configuration.
var (
AssetsPath = fs.Abs("../../../assets")
ServiceUri = ""
ServiceKey = ""
DownloadUrl = ""
DefaultResolution = 224
)
// NasnetModel is a standard TensorFlow model used for label generation.
var (
NasnetModel = &Model{Name: "Nasnet", Version: "Mobile", Resolution: 224, Tags: []string{"photoprism"}}
NasnetModel = &Model{
Type: ModelTypeLabels,
Name: "NASNet",
Version: "Mobile",
Resolution: 224,
Tags: []string{"photoprism"},
}
NsfwModel = &Model{
Type: ModelTypeNsfw,
Name: "Nsfw",
Version: "",
Resolution: 224,
Tags: []string{"serve"},
}
FacenetModel = &Model{
Type: ModelTypeFaceEmbeddings,
Name: "FaceNet",
Version: "",
Resolution: 160,
Tags: []string{"serve"},
}
CaptionModel = &Model{
Type: ModelTypeCaption,
Name: "Caption",
Uri: "http://photoprism-vision/api/v1/vision/describe",
Method: http.MethodPost,
Resolution: 720,
}
DefaultModels = Models{NasnetModel, NsfwModel, FacenetModel, CaptionModel}
DefaultThresholds = Thresholds{Confidence: 10}
)

103
internal/ai/vision/nsfw.go Normal file
View File

@@ -0,0 +1,103 @@
package vision
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/media/http/header"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
// Nsfw checks the specified images for inappropriate content.
func Nsfw(images Files, src media.Src) (result []nsfw.Result, err error) {
// Return if no thumbnail filenames were given.
if len(images) == 0 {
return result, errors.New("missing image filenames")
}
result = make([]nsfw.Result, len(images))
// Return if there is no configuration or no image classification models are configured.
if Config == nil {
return result, errors.New("missing configuration")
} else if model := Config.Model(ModelTypeNsfw); model != nil {
// Use remote service API if a server endpoint has been configured.
if uri, method := model.Endpoint(); uri != "" && method != "" {
apiRequest, apiRequestErr := NewClientRequest(images, scheme.Data)
if apiRequestErr != nil {
return result, apiRequestErr
}
if model.Name != "" {
apiRequest.Model = model.Name
}
data, jsonErr := apiRequest.MarshalJSON()
if jsonErr != nil {
return result, jsonErr
}
// Create HTTP client and authenticated service API request.
client := http.Client{}
req, reqErr := http.NewRequest(method, uri, bytes.NewReader(data))
header.SetAuthorization(req, model.EndpointKey())
if reqErr != nil {
return result, reqErr
}
// Perform API request.
clientResp, clientErr := client.Do(req)
if clientErr != nil {
return result, clientErr
}
apiResponse := &ApiResponse{}
// Unmarshal response and add labels, if returned.
if apiJson, apiErr := io.ReadAll(clientResp.Body); apiErr != nil {
return result, apiErr
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
return result, apiErr
}
result = apiResponse.Result.Nsfw
} else if tf := model.NsfwModel(); tf != nil {
// Predict labels with local TensorFlow model.
for i := range images {
var labels nsfw.Result
switch src {
case media.SrcLocal:
labels, err = tf.File(images[i])
case media.SrcRemote:
labels, err = tf.Url(images[i])
default:
return result, fmt.Errorf("invalid image source %s", clean.Log(src))
}
if err != nil {
log.Errorf("nsfw: %s", err)
}
result[i] = labels
}
} else {
return result, errors.New("invalid nsfw model configuration")
}
} else {
return result, errors.New("missing nsfw model")
}
return result, nil
}

View File

@@ -15,21 +15,15 @@ var Config = NewOptions()
// Options represents a computer vision configuration for the supported Model types.
type Options struct {
Caption Models `yaml:"Caption,omitempty" json:"caption,omitempty"`
Faces Models `yaml:"Faces,omitempty" json:"faces,omitempty"`
Labels Models `yaml:"Labels,omitempty" json:"labels,omitempty"`
Nsfw Models `yaml:"Nsfw,omitempty" json:"nsfw,omitempty"`
Models Models `yaml:"Models,omitempty" json:"models,omitempty"`
Thresholds Thresholds `yaml:"Thresholds" json:"thresholds"`
}
// NewOptions returns a new computer vision config with defaults.
func NewOptions() *Options {
return &Options{
Caption: Models{},
Faces: Models{},
Labels: Models{NasnetModel},
Nsfw: Models{},
Thresholds: Thresholds{Confidence: 10},
Models: DefaultModels,
Thresholds: DefaultThresholds,
}
}
@@ -72,3 +66,14 @@ func (c *Options) Save(fileName string) error {
return nil
}
// Model returns the first enabled model with the matching type from the configuration.
func (c *Options) Model(t ModelType) *Model {
for _, m := range c.Models {
if m.Type == t && !m.Disabled {
return m
}
}
return nil
}

View File

@@ -0,0 +1,26 @@
package vision
import (
"os"
"testing"
"github.com/sirupsen/logrus"
"github.com/photoprism/photoprism/internal/event"
)
func TestMain(m *testing.M) {
// Init test logger.
log = logrus.StandardLogger()
log.SetLevel(logrus.TraceLevel)
event.AuditLog = log
// Set test config values.
DownloadUrl = "https://app.localssl.dev/api/v1/dl"
ServiceUri = ""
// Run unit tests.
code := m.Run()
os.Exit(code)
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/dustin/go-humanize/english"
"github.com/gin-gonic/gin"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/auth/acl"
"github.com/photoprism/photoprism/internal/entity/query"
"github.com/photoprism/photoprism/internal/event"
@@ -21,6 +22,7 @@ import (
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/i18n"
"github.com/photoprism/photoprism/pkg/media"
)
// UploadUserFiles adds files to the user upload folder, from where they can be moved and indexed.
@@ -184,19 +186,18 @@ func UploadUserFiles(router *gin.RouterGroup) {
// Check if the uploaded file may contain inappropriate content.
if len(uploads) > 0 && !conf.UploadNSFW() {
nd := get.NsfwDetector()
containsNSFW := false
for _, filename := range uploads {
labels, nsfwErr := nd.File(filename)
labels, nsfwErr := vision.Nsfw([]string{filename}, media.SrcLocal)
if nsfwErr != nil {
log.Debug(nsfwErr)
continue
}
if labels.IsSafe() {
} else if len(labels) < 1 {
log.Errorf("nsfw: model returned no result")
continue
} else if labels[0].IsSafe() {
continue
}

View File

@@ -5,24 +5,27 @@ import (
"github.com/gin-gonic/gin"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/auth/acl"
"github.com/photoprism/photoprism/internal/photoprism/get"
"github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/media/http/header"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
// PostVisionFaces returns the positions and embeddings of detected faces.
// PostVisionFaceEmbeddings returns the embeddings of detected faces.
//
// @Summary returns the positions and embeddings of detected faces
// @Id PostVisionFaces
// @Id PostVisionFaceEmbeddings
// @Tags Vision
// @Produce json
// @Success 200 {object} vision.ApiResponse
// @Failure 401,403,429,501 {object} i18n.Response
// @Param images body vision.ApiRequest true "list of image file urls"
// @Router /api/v1/vision/faces [post]
func PostVisionFaces(router *gin.RouterGroup) {
router.POST("/vision/faces", func(c *gin.Context) {
// @Router /api/v1/vision/face/embeddings [post]
func PostVisionFaceEmbeddings(router *gin.RouterGroup) {
router.POST("/vision/face/embeddings", func(c *gin.Context) {
s := Auth(c, acl.ResourceVision, acl.AccessAll)
// Abort if permission is not granted.
@@ -51,13 +54,27 @@ func PostVisionFaces(router *gin.RouterGroup) {
return
}
// Run inference to find matching labels.
results := make([]face.Embeddings, len(request.Images))
for i := range request.Images {
if data, err := media.ReadUrl(request.Images[i], scheme.HttpsData); err != nil {
results[i] = face.Embeddings{}
log.Errorf("vision: %s (read face embedding from url)", err)
} else if result, faceErr := vision.FaceEmbeddings(data); faceErr != nil {
results[i] = face.Embeddings{}
log.Errorf("vision: %s (generate face embedding)", faceErr)
} else {
results[i] = result
}
}
// Generate Vision API service response.
response := vision.ApiResponse{
Id: request.GetId(),
Code: http.StatusNotImplemented,
Error: http.StatusText(http.StatusNotImplemented),
Model: &vision.Model{Name: "Faces"},
Result: &vision.ApiResult{},
Code: http.StatusOK,
Model: &vision.Model{Name: vision.FacenetModel.Name},
Result: &vision.ApiResult{Embeddings: results},
}
c.JSON(http.StatusNotImplemented, response)

View File

@@ -8,6 +8,7 @@ import (
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/auth/acl"
"github.com/photoprism/photoprism/internal/photoprism/get"
"github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/media/http/header"
)
@@ -53,7 +54,7 @@ func PostVisionLabels(router *gin.RouterGroup) {
}
// Run inference to find matching labels.
labels, err := vision.Labels(request.Images)
labels, err := vision.Labels(request.Images, media.SrcLocal)
if err != nil {
c.JSON(http.StatusBadRequest, vision.NewApiError(request.GetId(), http.StatusBadRequest))

View File

@@ -0,0 +1,74 @@
package api
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/auth/acl"
"github.com/photoprism/photoprism/internal/photoprism/get"
"github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/media/http/header"
)
// PostVisionNsfw checks the specified images for inappropriate content.
//
// @Summary checks the specified images for inappropriate content
// @Id PostVisionNsfw
// @Tags Vision
// @Accept json
// @Produce json
// @Success 200 {object} vision.ApiResponse
// @Failure 401,403,429 {object} i18n.Response
// @Param images body vision.ApiRequest true "list of image file urls"
// @Router /api/v1/vision/nsfw [post]
func PostVisionNsfw(router *gin.RouterGroup) {
router.POST("/vision/nsfw", func(c *gin.Context) {
s := Auth(c, acl.ResourceVision, acl.AccessAll)
// Abort if permission is not granted.
if s.Abort(c) {
return
}
var request vision.ApiRequest
// File uploads are not currently supported for this API endpoint.
if header.HasContentType(&c.Request.Header, header.ContentTypeMultipart) {
c.JSON(http.StatusBadRequest, vision.NewApiError(request.GetId(), http.StatusBadRequest))
return
}
// Assign and validate request form values.
if err := c.BindJSON(&request); err != nil {
c.JSON(http.StatusBadRequest, vision.NewApiError(request.GetId(), http.StatusBadRequest))
return
}
// Check if the Computer Vision API is enabled, otherwise abort with an error.
if !get.Config().VisionApi() {
AbortFeatureDisabled(c)
c.JSON(http.StatusForbidden, vision.NewApiError(request.GetId(), http.StatusForbidden))
return
}
// Run inference to check the specified images for inappropriate content.
results, err := vision.Nsfw(request.Images, media.SrcRemote)
if err != nil {
c.JSON(http.StatusBadRequest, vision.NewApiError(request.GetId(), http.StatusBadRequest))
return
}
// Generate Vision API service response.
response := vision.ApiResponse{
Id: request.GetId(),
Code: http.StatusOK,
Model: &vision.Model{Name: vision.NsfwModel.Name},
Result: &vision.ApiResult{Nsfw: results},
}
c.JSON(http.StatusOK, response)
})
}

View File

@@ -282,6 +282,9 @@ func (c *Config) Propagate() {
// Configure computer vision package.
vision.AssetsPath = c.AssetsPath()
vision.FaceNetModelPath = c.FaceNetModelPath()
vision.NsfwModelPath = c.NSFWModelPath()
vision.CachePath = c.CachePath()
vision.ServiceUri = c.VisionUri()
vision.ServiceKey = c.VisionKey()
vision.DownloadUrl = c.DownloadUrl()

View File

@@ -1,19 +0,0 @@
package get
import (
"sync"
"github.com/photoprism/photoprism/internal/ai/classify"
)
var onceClassify sync.Once
func initClassify() {
services.Classify = classify.NewNasnet(Config().AssetsPath(), Config().DisableClassification())
}
func Classify() *classify.Model {
onceClassify.Do(initClassify)
return services.Classify
}

View File

@@ -1,19 +0,0 @@
package get
import (
"sync"
"github.com/photoprism/photoprism/internal/ai/face"
)
var onceFaceNet sync.Once
func initFaceNet() {
services.FaceNet = face.NewModel(conf.FaceNetModelPath(), "", conf.DisableFaces())
}
func FaceNet() *face.Model {
onceFaceNet.Do(initFaceNet)
return services.FaceNet
}

View File

@@ -9,7 +9,7 @@ import (
var onceIndex sync.Once
func initIndex() {
services.Index = photoprism.NewIndex(Config(), NsfwDetector(), FaceNet(), Convert(), Files(), Photos())
services.Index = photoprism.NewIndex(Config(), Convert(), Files(), Photos())
}
func Index() *photoprism.Index {

View File

@@ -1,19 +0,0 @@
package get
import (
"sync"
"github.com/photoprism/photoprism/internal/ai/nsfw"
)
var onceNsfwDetector sync.Once
func initNsfwDetector() {
services.Nsfw = nsfw.NewModel(conf.NSFWModelPath())
}
func NsfwDetector() *nsfw.Model {
onceNsfwDetector.Do(initNsfwDetector)
return services.Nsfw
}

View File

@@ -27,9 +27,6 @@ package get
import (
gc "github.com/patrickmn/go-cache"
"github.com/photoprism/photoprism/internal/ai/classify"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/internal/auth/oidc"
"github.com/photoprism/photoprism/internal/auth/session"
"github.com/photoprism/photoprism/internal/config"
@@ -43,7 +40,6 @@ var services struct {
FolderCache *gc.Cache
CoverCache *gc.Cache
ThumbCache *gc.Cache
Classify *classify.Model
Convert *photoprism.Convert
Files *photoprism.Files
Photos *photoprism.Photos
@@ -54,8 +50,6 @@ var services struct {
Places *photoprism.Places
Purge *photoprism.Purge
CleanUp *photoprism.CleanUp
Nsfw *nsfw.Model
FaceNet *face.Model
Query *query.Query
Thumbs *photoprism.Thumbs
Session *session.Session

View File

@@ -6,8 +6,6 @@ import (
gc "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/ai/classify"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/internal/auth/oidc"
"github.com/photoprism/photoprism/internal/auth/session"
"github.com/photoprism/photoprism/internal/entity/query"
@@ -30,10 +28,6 @@ func TestThumbCache(t *testing.T) {
assert.IsType(t, &gc.Cache{}, ThumbCache())
}
func TestClassify(t *testing.T) {
assert.IsType(t, &classify.Model{}, Classify())
}
func TestConvert(t *testing.T) {
assert.IsType(t, &photoprism.Convert{}, Convert())
}
@@ -58,10 +52,6 @@ func TestCleanUp(t *testing.T) {
assert.IsType(t, &photoprism.CleanUp{}, CleanUp())
}
func TestNsfwDetector(t *testing.T) {
assert.IsType(t, &nsfw.Model{}, NsfwDetector())
}
func TestQuery(t *testing.T) {
assert.IsType(t, &query.Query{}, Query())
}

View File

@@ -5,40 +5,34 @@ import (
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/internal/config"
)
func TestNewImport(t *testing.T) {
conf := config.TestConfig()
cfg := config.TestConfig()
nd := nsfw.NewModel(conf.NSFWModelPath())
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
convert := NewConvert(conf)
convert := NewConvert(cfg)
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
imp := NewImport(conf, ind, convert)
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
imp := NewImport(cfg, ind, convert)
assert.IsType(t, &Import{}, imp)
}
func TestImport_DestinationFilename(t *testing.T) {
conf := config.TestConfig()
cfg := config.TestConfig()
if err := conf.InitializeTestData(); err != nil {
if err := cfg.InitializeTestData(); err != nil {
t.Fatal(err)
}
nd := nsfw.NewModel(conf.NSFWModelPath())
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
convert := NewConvert(conf)
convert := NewConvert(cfg)
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
imp := NewImport(conf, ind, convert)
imp := NewImport(cfg, ind, convert)
rawFile, err := NewMediaFile(conf.ImportPath() + "/raw/IMG_2567.CR2")
rawFile, err := NewMediaFile(cfg.ImportPath() + "/raw/IMG_2567.CR2")
if err != nil {
t.Fatal(err)
@@ -51,7 +45,7 @@ func TestImport_DestinationFilename(t *testing.T) {
t.Fatal(err)
}
assert.Equal(t, conf.OriginalsPath()+"/2019/07/20190705_153230_C167C6FD.cr2", fileName)
assert.Equal(t, cfg.OriginalsPath()+"/2019/07/20190705_153230_C167C6FD.cr2", fileName)
})
t.Run("WithBasePath", func(t *testing.T) {
@@ -61,7 +55,7 @@ func TestImport_DestinationFilename(t *testing.T) {
t.Fatal(err)
}
assert.Equal(t, conf.OriginalsPath()+"/users/guest/2019/07/20190705_153230_C167C6FD.cr2", fileName)
assert.Equal(t, cfg.OriginalsPath()+"/users/guest/2019/07/20190705_153230_C167C6FD.cr2", fileName)
})
}
@@ -70,19 +64,17 @@ func TestImport_Start(t *testing.T) {
t.Skip("skipping test in short mode.")
}
conf := config.TestConfig()
cfg := config.TestConfig()
conf.InitializeTestData()
cfg.InitializeTestData()
nd := nsfw.NewModel(conf.NSFWModelPath())
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
convert := NewConvert(conf)
convert := NewConvert(cfg)
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
imp := NewImport(conf, ind, convert)
imp := NewImport(cfg, ind, convert)
opt := ImportOptionsMove(conf.ImportPath(), "")
opt := ImportOptionsMove(cfg.ImportPath(), "")
imp.Start(opt)
}

View File

@@ -5,36 +5,32 @@ import (
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/entity"
)
func TestImportWorker_OriginalFileNames(t *testing.T) {
c := config.TestConfig()
cfg := config.TestConfig()
if err := c.InitializeTestData(); err != nil {
if err := cfg.InitializeTestData(); err != nil {
t.Fatal(err)
}
nd := nsfw.NewModel(c.NSFWModelPath())
fn := face.NewModel(c.FaceNetModelPath(), "", c.DisableTensorFlow())
convert := NewConvert(c)
ind := NewIndex(c, nd, fn, convert, NewFiles(), NewPhotos())
imp := &Import{c, ind, convert, c.ImportAllow()}
convert := NewConvert(cfg)
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
imp := &Import{cfg, ind, convert, cfg.ImportAllow()}
mediaFileName := c.ExamplesPath() + "/beach_sand.jpg"
mediaFileName := cfg.ExamplesPath() + "/beach_sand.jpg"
mediaFile, err := NewMediaFile(mediaFileName)
if err != nil {
t.Fatal(err)
}
mediaFileName2 := c.ExamplesPath() + "/beach_wood.jpg"
mediaFileName2 := cfg.ExamplesPath() + "/beach_wood.jpg"
mediaFile2, err2 := NewMediaFile(mediaFileName2)
if err2 != nil {
t.Fatal(err2)
}
mediaFileName3 := c.ExamplesPath() + "/beach_colorfilter.jpg"
mediaFileName3 := cfg.ExamplesPath() + "/beach_colorfilter.jpg"
mediaFile3, err3 := NewMediaFile(mediaFileName3)
if err3 != nil {
t.Fatal(err3)
@@ -56,7 +52,7 @@ func TestImportWorker_OriginalFileNames(t *testing.T) {
FileName: mediaFile.FileName(),
Related: relatedFiles,
IndexOpt: IndexOptionsAll(),
ImportOpt: ImportOptionsCopy(c.ImportPath(), c.ImportDest()),
ImportOpt: ImportOptionsCopy(cfg.ImportPath(), cfg.ImportDest()),
Imp: imp,
}

View File

@@ -11,8 +11,6 @@ import (
"github.com/karrick/godirwalk"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/internal/event"
@@ -26,8 +24,6 @@ import (
// Index represents an indexer that indexes files in the originals directory.
type Index struct {
conf *config.Config
nsfwDetector *nsfw.Model
faceNet *face.Model
convert *Convert
files *Files
photos *Photos
@@ -38,7 +34,7 @@ type Index struct {
}
// NewIndex returns a new indexer and expects its dependencies as arguments.
func NewIndex(conf *config.Config, nsfwDetector *nsfw.Model, faceNet *face.Model, convert *Convert, files *Files, photos *Photos) *Index {
func NewIndex(conf *config.Config, convert *Convert, files *Files, photos *Photos) *Index {
if conf == nil {
log.Errorf("index: config is not set")
return nil
@@ -46,8 +42,6 @@ func NewIndex(conf *config.Config, nsfwDetector *nsfw.Model, faceNet *face.Model
i := &Index{
conf: conf,
nsfwDetector: nsfwDetector,
faceNet: faceNet,
convert: convert,
files: files,
photos: photos,

View File

@@ -6,6 +6,7 @@ import (
"github.com/dustin/go-humanize/english"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/thumb"
"github.com/photoprism/photoprism/pkg/clean"
)
@@ -39,7 +40,7 @@ func (ind *Index) Faces(jpeg *MediaFile, expected int) face.Faces {
start := time.Now()
faces, err := ind.faceNet.Detect(thumbName, Config().FaceSize(), true, expected)
faces, err := vision.Faces(thumbName, Config().FaceSize(), true, expected)
if err != nil {
log.Debugf("%s in %s", err, clean.Log(jpeg.BaseName()))

View File

@@ -9,6 +9,7 @@ import (
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/thumb"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media"
)
// Labels classifies a JPEG image and returns matching labels.
@@ -41,7 +42,7 @@ func (ind *Index) Labels(file *MediaFile) (labels classify.Labels) {
}
// Get matching labels from computer vision model.
if labels, err = vision.Labels(thumbnails); err != nil {
if labels, err = vision.Labels(thumbnails, media.SrcLocal); err != nil {
log.Debugf("labels: %s in %s", err, clean.Log(file.BaseName()))
return labels
}

View File

@@ -805,7 +805,7 @@ func (ind *Index) UserMediaFile(m *MediaFile, o IndexOptions, originalName, phot
}
if !photoExists && Config().Settings().Features.Private && Config().DetectNSFW() {
photo.PhotoPrivate = ind.NSFW(m)
photo.PhotoPrivate = ind.IsNsfw(m)
}
}

View File

@@ -5,8 +5,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/entity"
)
@@ -21,11 +19,9 @@ func TestIndex_MediaFile(t *testing.T) {
cfg.InitializeTestData()
nd := nsfw.NewModel(cfg.NSFWModelPath())
fn := face.NewModel(cfg.FaceNetModelPath(), "", cfg.DisableTensorFlow())
convert := NewConvert(cfg)
ind := NewIndex(cfg, nd, fn, convert, NewFiles(), NewPhotos())
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
indexOpt := IndexOptionsAll()
mediaFile, err := NewMediaFile("testdata/flash.jpg")
@@ -57,11 +53,9 @@ func TestIndex_MediaFile(t *testing.T) {
cfg.InitializeTestData()
nd := nsfw.NewModel(cfg.NSFWModelPath())
fn := face.NewModel(cfg.FaceNetModelPath(), "", cfg.DisableTensorFlow())
convert := NewConvert(cfg)
ind := NewIndex(cfg, nd, fn, convert, NewFiles(), NewPhotos())
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
indexOpt := IndexOptionsAll()
mediaFile, err := NewMediaFile(cfg.ExamplesPath() + "/blue-go-video.mp4")
if err != nil {
@@ -79,11 +73,9 @@ func TestIndex_MediaFile(t *testing.T) {
cfg.InitializeTestData()
nd := nsfw.NewModel(cfg.NSFWModelPath())
fn := face.NewModel(cfg.FaceNetModelPath(), "", cfg.DisableTensorFlow())
convert := NewConvert(cfg)
ind := NewIndex(cfg, nd, fn, convert, NewFiles(), NewPhotos())
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
indexOpt := IndexOptionsAll()
result := ind.MediaFile(nil, indexOpt, "blue-go-video.mp4", "")

View File

@@ -2,12 +2,14 @@ package photoprism
import (
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/thumb"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media"
)
// NSFW returns true if media file might be offensive and detection is enabled.
func (ind *Index) NSFW(m *MediaFile) bool {
// IsNsfw returns true if media file might be offensive and detection is enabled.
func (ind *Index) IsNsfw(m *MediaFile) bool {
filename, err := m.Thumbnail(Config().ThumbCachePath(), thumb.Fit720)
if err != nil {
@@ -15,15 +17,16 @@ func (ind *Index) NSFW(m *MediaFile) bool {
return false
}
if nsfwLabels, err := ind.nsfwDetector.File(filename); err != nil {
log.Errorf("index: %s in %s (detect nsfw)", err, m.RootRelName())
if results, modelErr := vision.Nsfw([]string{filename}, media.SrcLocal); modelErr != nil {
log.Errorf("index: %s in %s (detect nsfw)", modelErr, m.RootRelName())
return false
} else {
if nsfwLabels.NSFW(nsfw.ThresholdHigh) {
} else if len(results) < 1 {
log.Errorf("index: nsfw model returned no result for %s", m.RootRelName())
return false
} else if results[0].NSFW(nsfw.ThresholdHigh) {
log.Warnf("index: %s might contain offensive content", clean.Log(m.RelName(Config().OriginalsPath())))
return true
}
}
return false
}

View File

@@ -6,8 +6,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/entity/query"
"github.com/photoprism/photoprism/pkg/rnd"
@@ -15,7 +13,7 @@ import (
func TestIndexRelated(t *testing.T) {
t.Run("2018-04-12 19_24_49.gif", func(t *testing.T) {
conf := config.TestConfig()
cfg := config.TestConfig()
testFile, err := NewMediaFile("testdata/2018-04-12 19_24_49.gif")
@@ -30,7 +28,7 @@ func TestIndexRelated(t *testing.T) {
}
testToken := rnd.Base36(8)
testPath := filepath.Join(conf.OriginalsPath(), testToken)
testPath := filepath.Join(cfg.OriginalsPath(), testToken)
for _, f := range testRelated.Files {
dest := filepath.Join(testPath, f.BaseName())
@@ -52,11 +50,8 @@ func TestIndexRelated(t *testing.T) {
t.Fatal(err)
}
nd := nsfw.NewModel(conf.NSFWModelPath())
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
convert := NewConvert(conf)
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
convert := NewConvert(cfg)
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
opt := IndexOptionsAll()
result := IndexRelated(related, ind, opt)
@@ -75,7 +70,7 @@ func TestIndexRelated(t *testing.T) {
})
t.Run("apple-test-2.jpg", func(t *testing.T) {
conf := config.TestConfig()
cfg := config.TestConfig()
testFile, err := NewMediaFile("testdata/apple-test-2.jpg")
@@ -90,7 +85,7 @@ func TestIndexRelated(t *testing.T) {
}
testToken := rnd.Base36(8)
testPath := filepath.Join(conf.OriginalsPath(), testToken)
testPath := filepath.Join(cfg.OriginalsPath(), testToken)
for _, f := range testRelated.Files {
dest := filepath.Join(testPath, f.BaseName())
@@ -112,11 +107,8 @@ func TestIndexRelated(t *testing.T) {
t.Fatal(err)
}
nd := nsfw.NewModel(conf.NSFWModelPath())
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
convert := NewConvert(conf)
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
convert := NewConvert(cfg)
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
opt := IndexOptionsAll()
result := IndexRelated(related, ind, opt)

View File

@@ -7,8 +7,6 @@ import (
"github.com/dustin/go-humanize/english"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/internal/config"
)
@@ -17,17 +15,13 @@ func TestIndex_Start(t *testing.T) {
t.Skip("skipping test in short mode.")
}
conf := config.TestConfig()
cfg := config.TestConfig()
cfg.InitializeTestData()
conf.InitializeTestData()
nd := nsfw.NewModel(conf.NSFWModelPath())
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
convert := NewConvert(conf)
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
imp := NewImport(conf, ind, convert)
opt := ImportOptionsMove(conf.ImportPath(), "")
convert := NewConvert(cfg)
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
imp := NewImport(cfg, ind, convert)
opt := ImportOptionsMove(cfg.ImportPath(), "")
imp.Start(opt)
@@ -65,15 +59,11 @@ func TestIndex_File(t *testing.T) {
t.Skip("skipping test in short mode.")
}
conf := config.TestConfig()
cfg := config.TestConfig()
cfg.InitializeTestData()
conf.InitializeTestData()
nd := nsfw.NewModel(conf.NSFWModelPath())
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
convert := NewConvert(conf)
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
convert := NewConvert(cfg)
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
err := ind.FileName("xxx", IndexOptionsAll())

View File

@@ -9,8 +9,6 @@ import (
"github.com/disintegration/imaging"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/internal/thumb"
@@ -21,26 +19,23 @@ func TestResample_Start(t *testing.T) {
t.Skip("skipping test in short mode.")
}
conf := config.TestConfig()
cfg := config.TestConfig()
if err := conf.CreateDirectories(); err != nil {
if err := cfg.CreateDirectories(); err != nil {
t.Fatal(err)
}
conf.InitializeTestData()
cfg.InitializeTestData()
nd := nsfw.NewModel(conf.NSFWModelPath())
fn := face.NewModel(conf.FaceNetModelPath(), "", conf.DisableTensorFlow())
convert := NewConvert(conf)
convert := NewConvert(cfg)
ind := NewIndex(cfg, convert, NewFiles(), NewPhotos())
ind := NewIndex(conf, nd, fn, convert, NewFiles(), NewPhotos())
imp := NewImport(conf, ind, convert)
opt := ImportOptionsMove(conf.ImportPath(), "")
imp := NewImport(cfg, ind, convert)
opt := ImportOptionsMove(cfg.ImportPath(), "")
imp.Start(opt)
rs := NewThumbs(conf)
rs := NewThumbs(cfg)
err := rs.Start("", true, false)

View File

@@ -164,9 +164,10 @@ func registerRoutes(router *gin.Engine, conf *config.Config) {
api.FolderCover(APIv1)
// Computer Vision.
api.PostVisionCaption(APIv1)
api.PostVisionFaces(APIv1)
api.PostVisionLabels(APIv1)
api.PostVisionNsfw(APIv1)
api.PostVisionFaceEmbeddings(APIv1)
api.PostVisionCaption(APIv1)
// People.
api.SearchSubjects(APIv1)

View File

@@ -39,7 +39,7 @@ var thumbFileSizes = []thumb.Size{
}
// ImageFromThumb returns a cropped area from an existing thumbnail image.
func ImageFromThumb(thumbName string, area Area, size Size, cache bool) (img image.Image, err error) {
func ImageFromThumb(thumbName string, area Area, size Size, cache bool) (img image.Image, cropName string, err error) {
// Use same folder for caching if "cache" is true.
filePath := filepath.Dir(thumbName)
@@ -48,12 +48,12 @@ func ImageFromThumb(thumbName string, area Area, size Size, cache bool) (img ima
// Resolve symlinks.
if thumbName, err = fs.Resolve(thumbName); err != nil {
return nil, err
return nil, "", err
}
// Compose cached crop image file name.
cropBase := fmt.Sprintf("%s_%dx%d_crop_%s%s", hash, size.Width, size.Height, area.String(), fs.ExtJpeg)
cropName := filepath.Join(filePath, cropBase)
cropName = filepath.Join(filePath, cropBase)
// Cached?
if !fs.FileExists(cropName) {
@@ -61,14 +61,14 @@ func ImageFromThumb(thumbName string, area Area, size Size, cache bool) (img ima
} else if cropImg, cropErr := imaging.Open(cropName); cropErr != nil {
log.Errorf("crop: failed loading %s", filepath.Base(cropName))
} else {
return cropImg, nil
return cropImg, cropName, nil
}
// Open thumb image file.
img, err = openIdealThumbFile(thumbName, hash, area, size)
if err != nil {
return img, err
return img, "", err
}
// Get absolute crop coordinates and dimension.
@@ -93,7 +93,7 @@ func ImageFromThumb(thumbName string, area Area, size Size, cache bool) (img ima
}
}
return img, nil
return img, cropName, nil
}
// ThumbFileName returns the ideal thumb file name.

View File

@@ -1,69 +1,98 @@
package media
import (
"bytes"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"slices"
"strings"
"github.com/gabriel-vasile/mimetype"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media/http/header"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
// DataUrl returns a data URL representing the binary buffer data.
func DataUrl(buf *bytes.Buffer) string {
encoded := EncodeBase64(buf.Bytes())
// DataUrl generates a data URL of the binary data from the specified io.Reader.
func DataUrl(r io.Reader) string {
// Read binary data.
data, err := io.ReadAll(r)
if encoded == "" {
if err != nil || len(data) == 0 {
return ""
}
// Return as string if it already appears to be a data URL.
if string(data[0:4]) == "data:" {
return string(data)
}
// Detect mime type.
var mime *mimetype.MIME
var mimeType string
mime, err := mimetype.DetectReader(buf)
if err != nil {
mimeType = "application/octet-stream"
if mime = mimetype.Detect(data); mime == nil {
mimeType = header.ContentTypeBinary
} else {
mimeType = mime.String()
}
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded)
// Generate data URL.
return fmt.Sprintf("data:%s;base64,%s", mimeType, EncodeBase64(data))
}
// ReadUrl reads binary data from a regular file path,
// fetches its data from a remote http or https URL,
// or decodes a base64 data URL as created by DataUrl.
func ReadUrl(file string) (data []byte, err error) {
u, err := url.Parse(file)
if err != nil {
log.Fatal(err)
func ReadUrl(fileUrl string, schemes []string) (data []byte, err error) {
if fileUrl == "" {
return data, errors.New("missing url")
}
// Also supports http, https, and data URLs instead of a file name for remote processing.
if u.Scheme == "http" || u.Scheme == "https" {
resp, httpErr := http.Get(file)
// Parse file URL.
var u *url.URL
if u, err = url.Parse(fileUrl); err != nil {
return data, fmt.Errorf("invalid url (%s)", err)
}
// Reject it if it is not absolute, i.e. it does not contain a scheme.
if !u.IsAbs() {
return data, fmt.Errorf("url %s requires a scheme", clean.Log(fileUrl))
} else if !slices.Contains(schemes, u.Scheme) {
return data, fmt.Errorf("invalid url scheme %s", clean.Log(u.Scheme))
}
// Fetch the file data from the specified URL, depending on its scheme.
switch u.Scheme {
case scheme.Https, scheme.Http, scheme.Unix, scheme.HttpUnix:
resp, httpErr := http.Get(fileUrl)
if httpErr != nil {
return nil, httpErr
return data, fmt.Errorf("invalid %s url (%s)", u.Scheme, httpErr)
}
defer resp.Body.Close()
if data, err = io.ReadAll(resp.Body); err != nil {
return nil, err
return data, err
}
} else if u.Scheme == "data" {
case scheme.Data:
if _, binaryData, found := strings.Cut(u.Opaque, ";base64,"); !found || len(binaryData) == 0 {
return nil, fmt.Errorf("invalid data URL")
return data, fmt.Errorf("invalid %s url", u.Scheme)
} else {
return DecodeBase64(binaryData)
}
} else if data, err = os.ReadFile(file); err != nil {
return nil, err
case scheme.File:
if data, err = os.ReadFile(fileUrl); err != nil {
return data, fmt.Errorf("invalid %s url (%s)", u.Scheme, err)
}
default:
return data, fmt.Errorf("unsupported url scheme %s", clean.Log(u.Scheme))
}
return data, err

View File

@@ -1,7 +1,6 @@
package media
import (
"bytes"
"io"
"strings"
"testing"
@@ -16,17 +15,14 @@ func gopherPng() io.Reader { return ReadBase64(strings.NewReader(gopher)) }
func TestDataUrl(t *testing.T) {
t.Run("Gopher", func(t *testing.T) {
buf := new(bytes.Buffer)
_, bufErr := buf.ReadFrom(gopherPng())
assert.NoError(t, bufErr)
assert.Equal(t, "data:image/png;base64,"+gopher, DataUrl(buf))
assert.Equal(t, "data:image/png;base64,"+gopher, DataUrl(gopherPng()))
})
}
func TestReadUrl(t *testing.T) {
t.Run("Gopher", func(t *testing.T) {
dataUrl := "data:image/png;base64," + gopher
if data, err := ReadUrl(dataUrl); err != nil {
if data, err := ReadUrl(dataUrl, []string{"https", "data"}); err != nil {
t.Fatal(err)
} else {
expected, _ := DecodeBase64(gopher)

View File

@@ -0,0 +1,17 @@
package scheme
const (
File = "file"
Data = "data"
Http = "http"
Https = "https"
HttpUnix = Http + "+" + Unix
Websocket = "wss"
Unix = "unix"
Unixgram = "unixgram"
Unixpacket = "unixpacket"
)
var (
HttpsData = []string{Https, Data}
)

View File

@@ -1,8 +0,0 @@
package scheme
const (
Http = "http"
Https = "https"
HttpUnix = Http + "+" + Unix
Websocket = "wss"
)

View File

@@ -1,7 +0,0 @@
package scheme
const (
Unix = "unix"
Unixgram = "unixgram"
Unixpacket = "unixpacket"
)

9
pkg/media/source.go Normal file
View File

@@ -0,0 +1,9 @@
package media
type Src = string
// Data source types.
const (
SrcLocal Src = "local"
SrcRemote Src = "remote"
)