mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
@@ -7,53 +7,11 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/api/download"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/header"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
"github.com/photoprism/photoprism/pkg/rnd"
|
||||
)
|
||||
|
||||
// NewApiRequest returns a new Vision API request with the specified file payload and scheme.
|
||||
func NewApiRequest(images Files, fileScheme string) (*ApiRequest, error) {
|
||||
imageUrls := make(Files, len(images))
|
||||
|
||||
if fileScheme == scheme.Https && !strings.HasPrefix(DownloadUrl, "https://") {
|
||||
log.Tracef("vision: file request scheme changed from https to data because https is not configured")
|
||||
fileScheme = scheme.Data
|
||||
}
|
||||
|
||||
for i := range images {
|
||||
switch fileScheme {
|
||||
case scheme.Https:
|
||||
fileUuid := rnd.UUID()
|
||||
if err := download.Register(fileUuid, images[i]); err != nil {
|
||||
return nil, fmt.Errorf("%s (create download url)", err)
|
||||
} else {
|
||||
imageUrls[i] = fmt.Sprintf("%s/%s", DownloadUrl, fileUuid)
|
||||
}
|
||||
case scheme.Data:
|
||||
if file, err := os.Open(images[i]); err != nil {
|
||||
return nil, fmt.Errorf("%s (create data url)", err)
|
||||
} else {
|
||||
imageUrls[i] = media.DataUrl(file)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid file scheme %s", clean.Log(fileScheme))
|
||||
}
|
||||
}
|
||||
|
||||
return &ApiRequest{
|
||||
Id: rnd.UUID(),
|
||||
Model: "",
|
||||
Images: imageUrls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PerformApiRequest performs a Vision API request and returns the result.
|
||||
func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResponse *ApiResponse, err error) {
|
||||
if apiRequest == nil {
|
||||
@@ -89,9 +47,10 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
|
||||
return apiResponse, clientErr
|
||||
}
|
||||
|
||||
// Parse and return response, or an error if the request failed.
|
||||
switch apiRequest.GetResponseFormat() {
|
||||
case ApiFormatVision:
|
||||
apiResponse = &ApiResponse{}
|
||||
|
||||
// Unmarshal response and add labels, if returned.
|
||||
if apiJson, apiErr := io.ReadAll(clientResp.Body); apiErr != nil {
|
||||
return apiResponse, apiErr
|
||||
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
|
||||
@@ -99,6 +58,9 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
|
||||
} else if clientResp.StatusCode >= 300 {
|
||||
log.Debugf("vision: %s (status code %d)", apiJson, clientResp.StatusCode)
|
||||
}
|
||||
default:
|
||||
return apiResponse, fmt.Errorf("unsupported response format %s", clean.Log(apiRequest.responseFormat))
|
||||
}
|
||||
|
||||
return apiResponse, nil
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ func TestNewApiRequest(t *testing.T) {
|
||||
|
||||
t.Run("Data", func(t *testing.T) {
|
||||
thumbnails := Files{examplesPath + "/chameleon_lime.jpg"}
|
||||
result, err := NewApiRequest(thumbnails, scheme.Data)
|
||||
result, err := NewApiRequestImages(thumbnails, scheme.Data)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
@@ -30,7 +30,7 @@ func TestNewApiRequest(t *testing.T) {
|
||||
})
|
||||
t.Run("Https", func(t *testing.T) {
|
||||
thumbnails := Files{examplesPath + "/chameleon_lime.jpg"}
|
||||
result, err := NewApiRequest(thumbnails, scheme.Https)
|
||||
result, err := NewApiRequestImages(thumbnails, scheme.Https)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
|
||||
9
internal/ai/vision/api_format.go
Normal file
9
internal/ai/vision/api_format.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package vision
|
||||
|
||||
type ApiFormat = string
|
||||
|
||||
const (
|
||||
ApiFormatUrl ApiFormat = "url"
|
||||
ApiFormatImages ApiFormat = "images"
|
||||
ApiFormatVision ApiFormat = "vision"
|
||||
)
|
||||
@@ -2,7 +2,18 @@ package vision
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/api/download"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
"github.com/photoprism/photoprism/pkg/rnd"
|
||||
)
|
||||
|
||||
@@ -14,6 +25,100 @@ type ApiRequest struct {
|
||||
Model string `form:"model" yaml:"Model,omitempty" json:"model,omitempty"`
|
||||
Url string `form:"url" yaml:"Url,omitempty" json:"url,omitempty"`
|
||||
Images Files `form:"images" yaml:"Images,omitempty" json:"images,omitempty"`
|
||||
responseFormat ApiFormat `form:"-"`
|
||||
}
|
||||
|
||||
// NewApiRequest returns a new service API request with the specified format and payload.
|
||||
func NewApiRequest(requestFormat ApiFormat, files Files, fileScheme scheme.Type) (result *ApiRequest, err error) {
|
||||
if len(files) == 0 {
|
||||
return result, errors.New("missing files")
|
||||
}
|
||||
|
||||
switch requestFormat {
|
||||
case ApiFormatUrl:
|
||||
return NewApiRequestUrl(files[0], fileScheme)
|
||||
case ApiFormatImages, ApiFormatVision:
|
||||
return NewApiRequestImages(files, fileScheme)
|
||||
default:
|
||||
return result, errors.New("invalid request format")
|
||||
}
|
||||
}
|
||||
|
||||
// NewApiRequestUrl returns a new Vision API request with the specified image Url as payload.
|
||||
func NewApiRequestUrl(fileName string, fileScheme scheme.Type) (result *ApiRequest, err error) {
|
||||
var imgUrl string
|
||||
|
||||
switch fileScheme {
|
||||
case scheme.Https:
|
||||
// Return if no thumbnail filenames were given.
|
||||
if !fs.FileExistsNotEmpty(fileName) {
|
||||
return result, errors.New("invalid image file name")
|
||||
}
|
||||
|
||||
// Generate a random token for the remote service to download the file.
|
||||
fileUuid := rnd.UUID()
|
||||
|
||||
if err = download.Register(fileUuid, fileName); err != nil {
|
||||
return result, fmt.Errorf("%s (create download url)", err)
|
||||
}
|
||||
|
||||
imgUrl = fmt.Sprintf("%s/%s", DownloadUrl, fileUuid)
|
||||
case scheme.Data:
|
||||
var u *url.URL
|
||||
if u, err = url.Parse(fileName); err != nil {
|
||||
return result, fmt.Errorf("%s (invalid image url)", err)
|
||||
} else if !slices.Contains(scheme.HttpsHttp, u.Scheme) {
|
||||
return nil, fmt.Errorf("unsupported image url scheme %s", clean.Log(u.Scheme))
|
||||
} else {
|
||||
imgUrl = u.String()
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported file scheme %s", clean.Log(fileScheme))
|
||||
}
|
||||
|
||||
return &ApiRequest{
|
||||
Id: rnd.UUID(),
|
||||
Model: "",
|
||||
Url: imgUrl,
|
||||
responseFormat: ApiFormatVision,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewApiRequestImages returns a new Vision API request with the specified images as payload.
|
||||
func NewApiRequestImages(images Files, fileScheme scheme.Type) (*ApiRequest, error) {
|
||||
imageUrls := make(Files, len(images))
|
||||
|
||||
if fileScheme == scheme.Https && !strings.HasPrefix(DownloadUrl, "https://") {
|
||||
log.Tracef("vision: file request scheme changed from https to data because https is not configured")
|
||||
fileScheme = scheme.Data
|
||||
}
|
||||
|
||||
for i := range images {
|
||||
switch fileScheme {
|
||||
case scheme.Https:
|
||||
fileUuid := rnd.UUID()
|
||||
if err := download.Register(fileUuid, images[i]); err != nil {
|
||||
return nil, fmt.Errorf("%s (create download url)", err)
|
||||
} else {
|
||||
imageUrls[i] = fmt.Sprintf("%s/%s", DownloadUrl, fileUuid)
|
||||
}
|
||||
case scheme.Data:
|
||||
if file, err := os.Open(images[i]); err != nil {
|
||||
return nil, fmt.Errorf("%s (create data url)", err)
|
||||
} else {
|
||||
imageUrls[i] = media.DataUrl(file)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported file scheme %s", clean.Log(fileScheme))
|
||||
}
|
||||
}
|
||||
|
||||
return &ApiRequest{
|
||||
Id: rnd.UUID(),
|
||||
Model: "",
|
||||
Images: imageUrls,
|
||||
responseFormat: ApiFormatVision,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetId returns the request ID string and generates a random ID if none was set.
|
||||
@@ -25,6 +130,15 @@ func (r *ApiRequest) GetId() string {
|
||||
return r.Id
|
||||
}
|
||||
|
||||
// GetResponseFormat returns the expected response format type.
|
||||
func (r *ApiRequest) GetResponseFormat() ApiFormat {
|
||||
if r.responseFormat == "" {
|
||||
return ApiFormatVision
|
||||
}
|
||||
|
||||
return r.responseFormat
|
||||
}
|
||||
|
||||
// JSON returns the request data as JSON-encoded bytes.
|
||||
func (r *ApiRequest) JSON() ([]byte, error) {
|
||||
return json.Marshal(*r)
|
||||
|
||||
@@ -2,17 +2,9 @@ package vision
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"slices"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/api/download"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
"github.com/photoprism/photoprism/pkg/rnd"
|
||||
)
|
||||
|
||||
// Caption returns generated captions for the specified images.
|
||||
@@ -23,56 +15,23 @@ func Caption(imgName string, src media.Src) (result CaptionResult, err error) {
|
||||
} else if model := Config.Model(ModelTypeCaption); model != nil {
|
||||
// Use remote service API if a server endpoint has been configured.
|
||||
if uri, method := model.Endpoint(); uri != "" && method != "" {
|
||||
var imgUrl string
|
||||
var apiRequest *ApiRequest
|
||||
var apiResponse *ApiResponse
|
||||
|
||||
switch src {
|
||||
case media.SrcLocal:
|
||||
// Return if no thumbnail filenames were given.
|
||||
if !fs.FileExistsNotEmpty(imgName) {
|
||||
return result, errors.New("invalid image file name")
|
||||
if apiRequest, err = NewApiRequest(model.EndpointRequestFormat(), Files{imgName}, model.EndpointFileScheme()); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
/* TODO: Add support for data URLs to the service.
|
||||
if file, fileErr := os.Open(imgName); fileErr != nil {
|
||||
return result, fmt.Errorf("%s (open image file)", err)
|
||||
} else {
|
||||
imgUrl = media.DataUrl(file)
|
||||
} */
|
||||
|
||||
fileUuid := rnd.UUID()
|
||||
|
||||
if dlErr := download.Register(imgName, fileUuid); dlErr != nil {
|
||||
return result, fmt.Errorf("%s (create download url)", err)
|
||||
}
|
||||
|
||||
imgUrl = fmt.Sprintf("%s/%s", DownloadUrl, fileUuid)
|
||||
case media.SrcRemote:
|
||||
var u *url.URL
|
||||
if u, err = url.Parse(imgName); err != nil {
|
||||
return result, fmt.Errorf("%s (invalid image url)", err)
|
||||
} else if !slices.Contains(scheme.HttpsHttp, u.Scheme) {
|
||||
return result, fmt.Errorf("unsupported image url scheme %s", clean.Log(u.Scheme))
|
||||
} else {
|
||||
imgUrl = u.String()
|
||||
}
|
||||
default:
|
||||
return result, fmt.Errorf("unsupported media source type %s", clean.Log(src))
|
||||
}
|
||||
|
||||
apiRequest := &ApiRequest{
|
||||
Id: rnd.UUID(),
|
||||
Model: model.Name,
|
||||
Url: imgUrl,
|
||||
if model.Name != "" {
|
||||
apiRequest.Model = model.Name
|
||||
}
|
||||
|
||||
/* if json, _ := apiRequest.JSON(); len(json) > 0 {
|
||||
log.Debugf("request: %s", json)
|
||||
} */
|
||||
|
||||
apiResponse, apiErr := PerformApiRequest(apiRequest, uri, method, model.EndpointKey())
|
||||
|
||||
if apiErr != nil {
|
||||
return result, apiErr
|
||||
if apiResponse, err = PerformApiRequest(apiRequest, uri, method, model.EndpointKey()); err != nil {
|
||||
return result, err
|
||||
} else if apiResponse.Result.Caption == nil {
|
||||
return result, errors.New("invalid caption model response")
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -11,9 +13,13 @@ var (
|
||||
FaceNetModelPath = fs.Abs("../../../assets/facenet")
|
||||
NsfwModelPath = fs.Abs("../../../assets/nsfw")
|
||||
CachePath = fs.Abs("../../../storage/cache")
|
||||
DownloadUrl = ""
|
||||
ServiceUri = ""
|
||||
ServiceKey = ""
|
||||
ServiceTimeout = time.Minute
|
||||
DownloadUrl = ""
|
||||
ServiceMethod = http.MethodPost
|
||||
ServiceFileScheme = scheme.Data
|
||||
ServiceRequestFormat = ApiFormatVision
|
||||
ServiceResponseFormat = ApiFormatVision
|
||||
DefaultResolution = 224
|
||||
)
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/thumb/crop"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
// Faces detects faces in the specified image and generates embeddings from them.
|
||||
@@ -30,7 +29,11 @@ func Faces(fileName string, minSize int, cacheCrop bool, expected int) (result f
|
||||
}
|
||||
|
||||
if uri, method := model.Endpoint(); uri != "" && method != "" {
|
||||
faceCrops := make([]string, len(result))
|
||||
var faceCrops []string
|
||||
var apiRequest *ApiRequest
|
||||
var apiResponse *ApiResponse
|
||||
|
||||
faceCrops = make([]string, len(result))
|
||||
|
||||
for i, f := range result {
|
||||
if f.Area.Col == 0 && f.Area.Row == 0 {
|
||||
@@ -46,20 +49,16 @@ func Faces(fileName string, minSize int, cacheCrop bool, expected int) (result f
|
||||
}
|
||||
}
|
||||
|
||||
apiRequest, apiRequestErr := NewApiRequest(faceCrops, scheme.Data)
|
||||
|
||||
if apiRequestErr != nil {
|
||||
return result, apiRequestErr
|
||||
if apiRequest, err = NewApiRequest(model.EndpointRequestFormat(), faceCrops, model.EndpointFileScheme()); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
if model.Name != "" {
|
||||
apiRequest.Model = model.Name
|
||||
}
|
||||
|
||||
apiResponse, apiErr := PerformApiRequest(apiRequest, uri, method, model.EndpointKey())
|
||||
|
||||
if apiErr != nil {
|
||||
return result, apiErr
|
||||
if apiResponse, err = PerformApiRequest(apiRequest, uri, method, model.EndpointKey()); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
for i := range result {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/photoprism/photoprism/internal/ai/classify"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
// Labels finds matching labels for the specified image.
|
||||
@@ -24,20 +23,19 @@ func Labels(images Files, src media.Src) (result classify.Labels, err error) {
|
||||
} else if model := Config.Model(ModelTypeLabels); model != nil {
|
||||
// Use remote service API if a server endpoint has been configured.
|
||||
if uri, method := model.Endpoint(); uri != "" && method != "" {
|
||||
apiRequest, apiRequestErr := NewApiRequest(images, scheme.Data)
|
||||
var apiRequest *ApiRequest
|
||||
var apiResponse *ApiResponse
|
||||
|
||||
if apiRequestErr != nil {
|
||||
return result, apiRequestErr
|
||||
if apiRequest, err = NewApiRequest(model.EndpointRequestFormat(), images, model.EndpointFileScheme()); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
if model.Name != "" {
|
||||
apiRequest.Model = model.Name
|
||||
}
|
||||
|
||||
apiResponse, apiErr := PerformApiRequest(apiRequest, uri, method, model.EndpointKey())
|
||||
|
||||
if apiErr != nil {
|
||||
return result, apiErr
|
||||
if apiResponse, err = PerformApiRequest(apiRequest, uri, method, model.EndpointKey()); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
for _, label := range apiResponse.Result.Labels {
|
||||
|
||||
@@ -2,7 +2,6 @@ package vision
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
@@ -10,6 +9,7 @@ import (
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
var modelMutex = sync.Mutex{}
|
||||
@@ -20,12 +20,10 @@ type Model struct {
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
|
||||
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
|
||||
Uri string `yaml:"Uri,omitempty" json:"-"`
|
||||
Key string `yaml:"Key,omitempty" json:"-"`
|
||||
Method string `yaml:"Method,omitempty" json:"-"`
|
||||
Service Service `yaml:"Service,omitempty" json:"Service,omitempty"`
|
||||
Path string `yaml:"Path,omitempty" json:"-"`
|
||||
Tags []string `yaml:"Tags,omitempty" json:"-"`
|
||||
Disabled bool `yaml:"Disabled,omitempty" json:"-"`
|
||||
Disabled bool `yaml:"Disabled,omitempty" json:"disabled,omitempty"`
|
||||
classifyModel *classify.Model
|
||||
faceModel *face.Model
|
||||
nsfwModel *nsfw.Model
|
||||
@@ -36,32 +34,51 @@ type Models []*Model
|
||||
|
||||
// Endpoint returns the remote service request method and endpoint URL, if any.
|
||||
func (m *Model) Endpoint() (uri, method string) {
|
||||
if m.Uri == "" && ServiceUri == "" || m.Type == "" {
|
||||
if uri, method = m.Service.Endpoint(); uri != "" && method != "" {
|
||||
return uri, method
|
||||
} else if ServiceUri == "" {
|
||||
return "", ""
|
||||
} else if serviceType := clean.TypeLowerUnderscore(m.Type); serviceType == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
if m.Method != "" {
|
||||
method = m.Method
|
||||
} else {
|
||||
method = http.MethodPost
|
||||
}
|
||||
|
||||
if m.Uri != "" {
|
||||
return m.Uri, method
|
||||
} else {
|
||||
return fmt.Sprintf("%s/%s", ServiceUri, clean.TypeLowerUnderscore(m.Type)), method
|
||||
return fmt.Sprintf("%s/%s", ServiceUri, serviceType), ServiceMethod
|
||||
}
|
||||
}
|
||||
|
||||
// EndpointKey returns the access token belonging to the remote service endpoint, if any.
|
||||
func (m *Model) EndpointKey() string {
|
||||
if m.Key != "" {
|
||||
return m.Key
|
||||
} else if ServiceKey != "" {
|
||||
func (m *Model) EndpointKey() (key string) {
|
||||
if key = m.Service.EndpointKey(); key != "" {
|
||||
return key
|
||||
} else {
|
||||
return ServiceKey
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
// EndpointFileScheme returns the endpoint API request file scheme type.
|
||||
func (m *Model) EndpointFileScheme() (fileScheme scheme.Type) {
|
||||
if fileScheme = m.Service.EndpointFileScheme(); fileScheme != "" {
|
||||
return fileScheme
|
||||
}
|
||||
|
||||
return ServiceFileScheme
|
||||
}
|
||||
|
||||
// EndpointRequestFormat returns the endpoint API request format.
|
||||
func (m *Model) EndpointRequestFormat() (format ApiFormat) {
|
||||
if format = m.Service.EndpointRequestFormat(); format != "" {
|
||||
return format
|
||||
}
|
||||
|
||||
return ServiceRequestFormat
|
||||
}
|
||||
|
||||
// EndpointResponseFormat returns the endpoint API response format.
|
||||
func (m *Model) EndpointResponseFormat() (format ApiFormat) {
|
||||
if format = m.Service.EndpointResponseFormat(); format != "" {
|
||||
return format
|
||||
}
|
||||
|
||||
return ServiceResponseFormat
|
||||
}
|
||||
|
||||
// ClassifyModel returns the matching classify model instance, if any.
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
// Default computer vision model configuration.
|
||||
var (
|
||||
NasnetModel = &Model{
|
||||
@@ -26,7 +30,12 @@ var (
|
||||
CaptionModel = &Model{
|
||||
Type: ModelTypeCaption,
|
||||
Resolution: 224,
|
||||
Service: Service{
|
||||
Uri: "http://photoprism-vision:5000/api/v1/vision/caption",
|
||||
FileScheme: scheme.Https,
|
||||
RequestFormat: ApiFormatUrl,
|
||||
ResponseFormat: ApiFormatVision,
|
||||
},
|
||||
}
|
||||
DefaultModels = Models{NasnetModel, NsfwModel, FacenetModel, CaptionModel}
|
||||
DefaultThresholds = Thresholds{Confidence: 10}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
// Nsfw checks the specified images for inappropriate content.
|
||||
@@ -25,20 +24,19 @@ func Nsfw(images Files, src media.Src) (result []nsfw.Result, err error) {
|
||||
} else if model := Config.Model(ModelTypeNsfw); model != nil {
|
||||
// Use remote service API if a server endpoint has been configured.
|
||||
if uri, method := model.Endpoint(); uri != "" && method != "" {
|
||||
apiRequest, apiRequestErr := NewApiRequest(images, scheme.Data)
|
||||
var apiRequest *ApiRequest
|
||||
var apiResponse *ApiResponse
|
||||
|
||||
if apiRequestErr != nil {
|
||||
return result, apiRequestErr
|
||||
if apiRequest, err = NewApiRequest(model.EndpointRequestFormat(), images, model.EndpointFileScheme()); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
if model.Name != "" {
|
||||
apiRequest.Model = model.Name
|
||||
}
|
||||
|
||||
apiResponse, apiErr := PerformApiRequest(apiRequest, uri, method, model.EndpointKey())
|
||||
|
||||
if apiErr != nil {
|
||||
return result, apiErr
|
||||
if apiResponse, err = PerformApiRequest(apiRequest, uri, method, model.EndpointKey()); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
result = apiResponse.Result.Nsfw
|
||||
|
||||
73
internal/ai/vision/service.go
Normal file
73
internal/ai/vision/service.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
// Service represents a remote computer vision service configuration.
|
||||
type Service struct {
|
||||
Uri string `yaml:"Uri,omitempty" json:"uri"`
|
||||
Method string `yaml:"Method,omitempty" json:"method"`
|
||||
Key string `yaml:"Key,omitempty" json:"-"`
|
||||
FileScheme string `yaml:"FileScheme,omitempty" json:"fileScheme,omitempty"`
|
||||
RequestFormat ApiFormat `yaml:"RequestFormat,omitempty" json:"requestFormat,omitempty"`
|
||||
ResponseFormat ApiFormat `yaml:"ResponseFormat,omitempty" json:"responseFormat,omitempty"`
|
||||
Disabled bool `yaml:"Disabled,omitempty" json:"disabled,omitempty"`
|
||||
}
|
||||
|
||||
// Endpoint returns the remote service request method and endpoint URL, if any.
|
||||
func (m *Service) Endpoint() (uri, method string) {
|
||||
if m.Disabled || m.Uri == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
if m.Method != "" {
|
||||
method = m.Method
|
||||
} else {
|
||||
method = ServiceMethod
|
||||
}
|
||||
|
||||
return m.Uri, method
|
||||
}
|
||||
|
||||
// EndpointKey returns the access token belonging to the remote service endpoint, if any.
|
||||
func (m *Service) EndpointKey() string {
|
||||
if m.Disabled {
|
||||
return ""
|
||||
}
|
||||
|
||||
return m.Key
|
||||
}
|
||||
|
||||
// EndpointFileScheme returns the endpoint API file scheme type.
|
||||
func (m *Service) EndpointFileScheme() scheme.Type {
|
||||
if m.Disabled {
|
||||
return ""
|
||||
} else if m.FileScheme == "" {
|
||||
return ServiceFileScheme
|
||||
}
|
||||
|
||||
return m.FileScheme
|
||||
}
|
||||
|
||||
// EndpointRequestFormat returns the endpoint API request format.
|
||||
func (m *Service) EndpointRequestFormat() ApiFormat {
|
||||
if m.Disabled {
|
||||
return ""
|
||||
} else if m.RequestFormat == "" {
|
||||
return ApiFormatVision
|
||||
}
|
||||
|
||||
return m.RequestFormat
|
||||
}
|
||||
|
||||
// EndpointResponseFormat returns the endpoint API response format.
|
||||
func (m *Service) EndpointResponseFormat() ApiFormat {
|
||||
if m.Disabled {
|
||||
return ""
|
||||
} else if m.ResponseFormat == "" {
|
||||
return ApiFormatVision
|
||||
}
|
||||
|
||||
return m.ResponseFormat
|
||||
}
|
||||
4
internal/ai/vision/testdata/vision.yml
vendored
4
internal/ai/vision/testdata/vision.yml
vendored
@@ -17,6 +17,10 @@ Models:
|
||||
- serve
|
||||
- Type: caption
|
||||
Resolution: 224
|
||||
Service:
|
||||
Uri: http://photoprism-vision:5000/api/v1/vision/caption
|
||||
FileScheme: https
|
||||
RequestFormat: url
|
||||
ResponseFormat: vision
|
||||
Thresholds:
|
||||
Confidence: 10
|
||||
|
||||
@@ -13,19 +13,19 @@ import (
|
||||
// for download until the cache expires, or the server is restarted.
|
||||
func Register(fileUuid, fileName string) error {
|
||||
if !rnd.IsUUID(fileUuid) {
|
||||
event.AuditWarn([]string{"api", "create download token", "%s", authn.Failed}, fileName)
|
||||
event.AuditWarn([]string{"api", "download", "create temporary token for %s", authn.Failed}, fileName)
|
||||
return errors.New("invalid file uuid")
|
||||
}
|
||||
|
||||
if fileName = fs.Abs(fileName); !fs.FileExists(fileName) {
|
||||
event.AuditWarn([]string{"api", "create download token", "%s", authn.Failed}, fileName)
|
||||
event.AuditWarn([]string{"api", "download", "create temporary token for %s", authn.Failed}, fileName)
|
||||
return errors.New("file not found")
|
||||
} else if Deny(fileName) {
|
||||
event.AuditErr([]string{"api", "create download token", "%s", authn.Denied}, fileName)
|
||||
event.AuditErr([]string{"api", "download", "create temporary token for %s", authn.Denied}, fileName)
|
||||
return errors.New("forbidden file path")
|
||||
}
|
||||
|
||||
event.AuditInfo([]string{"api", "create download token", "%s", authn.Succeeded}, fileName, expires.String())
|
||||
event.AuditInfo([]string{"api", "download", "create temporary token for %s", authn.Succeeded}, fileName)
|
||||
|
||||
cache.SetDefault(fileUuid, fileName)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
// Generate a random token for the remote service to download the file.
|
||||
fileUuid := rnd.UUID()
|
||||
fileName := fs.Abs("./testdata/image.jpg")
|
||||
err := Register(fileUuid, fileName)
|
||||
@@ -30,6 +31,7 @@ func TestRegister(t *testing.T) {
|
||||
assert.Equal(t, "", findName)
|
||||
})
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
// Generate a random token for the remote service to download the file.
|
||||
fileUuid := rnd.UUID()
|
||||
fileName := fs.Abs("./testdata/invalid.jpg")
|
||||
err := Register(fileUuid, fileName)
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestPostVisionFace(t *testing.T) {
|
||||
fs.Abs("./testdata/face_160x160.jpg"),
|
||||
}
|
||||
|
||||
req, err := vision.NewApiRequest(files, scheme.Data)
|
||||
req, err := vision.NewApiRequestImages(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -65,7 +65,7 @@ func TestPostVisionFace(t *testing.T) {
|
||||
fs.Abs("./testdata/london_160x160.jpg"),
|
||||
}
|
||||
|
||||
req, err := vision.NewApiRequest(files, scheme.Data)
|
||||
req, err := vision.NewApiRequestImages(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -101,7 +101,7 @@ func TestPostVisionFace(t *testing.T) {
|
||||
fs.Abs("./testdata/face_320x320.jpg"),
|
||||
}
|
||||
|
||||
req, err := vision.NewApiRequest(files, scheme.Data)
|
||||
req, err := vision.NewApiRequestImages(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -142,7 +142,7 @@ func TestPostVisionFace(t *testing.T) {
|
||||
|
||||
files := vision.Files{}
|
||||
|
||||
req, err := vision.NewApiRequest(files, scheme.Data)
|
||||
req, err := vision.NewApiRequestImages(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestPostVisionLabels(t *testing.T) {
|
||||
fs.Abs("./testdata/cat_224x224.jpg"),
|
||||
}
|
||||
|
||||
req, err := vision.NewApiRequest(files, scheme.Data)
|
||||
req, err := vision.NewApiRequestImages(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -59,7 +59,7 @@ func TestPostVisionLabels(t *testing.T) {
|
||||
fs.Abs("./testdata/green_224x224.jpg"),
|
||||
}
|
||||
|
||||
req, err := vision.NewApiRequest(files, scheme.Data)
|
||||
req, err := vision.NewApiRequestImages(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -93,7 +93,7 @@ func TestPostVisionLabels(t *testing.T) {
|
||||
|
||||
files := vision.Files{}
|
||||
|
||||
req, err := vision.NewApiRequest(files, scheme.Data)
|
||||
req, err := vision.NewApiRequestImages(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestPostVisionNsfw(t *testing.T) {
|
||||
fs.Abs("./testdata/nsfw_224x224.jpg"),
|
||||
}
|
||||
|
||||
req, err := vision.NewApiRequest(files, scheme.Data)
|
||||
req, err := vision.NewApiRequestImages(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -74,7 +74,7 @@ func TestPostVisionNsfw(t *testing.T) {
|
||||
fs.Abs("./testdata/green_224x224.jpg"),
|
||||
}
|
||||
|
||||
req, err := vision.NewApiRequest(files, scheme.Data)
|
||||
req, err := vision.NewApiRequestImages(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -108,7 +108,7 @@ func TestPostVisionNsfw(t *testing.T) {
|
||||
|
||||
files := vision.Files{}
|
||||
|
||||
req, err := vision.NewApiRequest(files, scheme.Data)
|
||||
req, err := vision.NewApiRequestImages(files, scheme.Data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -19,8 +19,7 @@ var VisionRunCommand = &cli.Command{
|
||||
&cli.StringFlag{
|
||||
Name: "models",
|
||||
Aliases: []string{"m"},
|
||||
// TODO: Add captions to the list once the service can be used from the CLI.
|
||||
Usage: "model types (labels, nsfw)",
|
||||
Usage: "model types (labels, nsfw, caption)",
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "force",
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
package scheme
|
||||
|
||||
// Type represents a URL scheme type.
|
||||
type Type = string
|
||||
|
||||
const (
|
||||
File = "file"
|
||||
Data = "data"
|
||||
Http = "http"
|
||||
Https = "https"
|
||||
HttpUnix = Http + "+" + Unix
|
||||
Websocket = "wss"
|
||||
Unix = "unix"
|
||||
Unixgram = "unixgram"
|
||||
Unixpacket = "unixpacket"
|
||||
File Type = "file"
|
||||
Data Type = "data"
|
||||
Http Type = "http"
|
||||
Https Type = "https"
|
||||
Websocket Type = "wss"
|
||||
Unix Type = "unix"
|
||||
HttpUnix Type = "http+unix"
|
||||
Unixgram Type = "unixgram"
|
||||
Unixpacket Type = "unixpacket"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
Reference in New Issue
Block a user