AI: Add additional vision service API configuration options #127 #1090

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-04-13 10:58:15 +02:00
parent e5916b98b9
commit 8189503a69
20 changed files with 342 additions and 190 deletions

View File

@@ -7,53 +7,11 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"os"
"strings"
"github.com/photoprism/photoprism/internal/api/download"
"github.com/photoprism/photoprism/pkg/clean" "github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/media/http/header" "github.com/photoprism/photoprism/pkg/media/http/header"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
"github.com/photoprism/photoprism/pkg/rnd"
) )
// NewApiRequest returns a new Vision API request with the specified file payload and scheme.
func NewApiRequest(images Files, fileScheme string) (*ApiRequest, error) {
imageUrls := make(Files, len(images))
if fileScheme == scheme.Https && !strings.HasPrefix(DownloadUrl, "https://") {
log.Tracef("vision: file request scheme changed from https to data because https is not configured")
fileScheme = scheme.Data
}
for i := range images {
switch fileScheme {
case scheme.Https:
fileUuid := rnd.UUID()
if err := download.Register(fileUuid, images[i]); err != nil {
return nil, fmt.Errorf("%s (create download url)", err)
} else {
imageUrls[i] = fmt.Sprintf("%s/%s", DownloadUrl, fileUuid)
}
case scheme.Data:
if file, err := os.Open(images[i]); err != nil {
return nil, fmt.Errorf("%s (create data url)", err)
} else {
imageUrls[i] = media.DataUrl(file)
}
default:
return nil, fmt.Errorf("invalid file scheme %s", clean.Log(fileScheme))
}
}
return &ApiRequest{
Id: rnd.UUID(),
Model: "",
Images: imageUrls,
}, nil
}
// PerformApiRequest performs a Vision API request and returns the result. // PerformApiRequest performs a Vision API request and returns the result.
func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResponse *ApiResponse, err error) { func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResponse *ApiResponse, err error) {
if apiRequest == nil { if apiRequest == nil {
@@ -89,15 +47,19 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
return apiResponse, clientErr return apiResponse, clientErr
} }
apiResponse = &ApiResponse{} // Parse and return response, or an error if the request failed.
switch apiRequest.GetResponseFormat() {
// Unmarshal response and add labels, if returned. case ApiFormatVision:
if apiJson, apiErr := io.ReadAll(clientResp.Body); apiErr != nil { apiResponse = &ApiResponse{}
return apiResponse, apiErr if apiJson, apiErr := io.ReadAll(clientResp.Body); apiErr != nil {
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil { return apiResponse, apiErr
return apiResponse, apiErr } else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
} else if clientResp.StatusCode >= 300 { return apiResponse, apiErr
log.Debugf("vision: %s (status code %d)", apiJson, clientResp.StatusCode) } else if clientResp.StatusCode >= 300 {
log.Debugf("vision: %s (status code %d)", apiJson, clientResp.StatusCode)
}
default:
return apiResponse, fmt.Errorf("unsupported response format %s", clean.Log(apiRequest.responseFormat))
} }
return apiResponse, nil return apiResponse, nil

View File

@@ -15,7 +15,7 @@ func TestNewApiRequest(t *testing.T) {
t.Run("Data", func(t *testing.T) { t.Run("Data", func(t *testing.T) {
thumbnails := Files{examplesPath + "/chameleon_lime.jpg"} thumbnails := Files{examplesPath + "/chameleon_lime.jpg"}
result, err := NewApiRequest(thumbnails, scheme.Data) result, err := NewApiRequestImages(thumbnails, scheme.Data)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, result) assert.NotNil(t, result)
@@ -30,7 +30,7 @@ func TestNewApiRequest(t *testing.T) {
}) })
t.Run("Https", func(t *testing.T) { t.Run("Https", func(t *testing.T) {
thumbnails := Files{examplesPath + "/chameleon_lime.jpg"} thumbnails := Files{examplesPath + "/chameleon_lime.jpg"}
result, err := NewApiRequest(thumbnails, scheme.Https) result, err := NewApiRequestImages(thumbnails, scheme.Https)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, result) assert.NotNil(t, result)

View File

@@ -0,0 +1,9 @@
package vision
type ApiFormat = string
const (
ApiFormatUrl ApiFormat = "url"
ApiFormatImages ApiFormat = "images"
ApiFormatVision ApiFormat = "vision"
)

View File

@@ -2,7 +2,18 @@ package vision
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt"
"net/url"
"os"
"slices"
"strings"
"github.com/photoprism/photoprism/internal/api/download"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
"github.com/photoprism/photoprism/pkg/rnd" "github.com/photoprism/photoprism/pkg/rnd"
) )
@@ -10,10 +21,104 @@ type Files = []string
// ApiRequest represents a Vision API service request. // ApiRequest represents a Vision API service request.
type ApiRequest struct { type ApiRequest struct {
Id string `form:"id" yaml:"Id,omitempty" json:"id,omitempty"` Id string `form:"id" yaml:"Id,omitempty" json:"id,omitempty"`
Model string `form:"model" yaml:"Model,omitempty" json:"model,omitempty"` Model string `form:"model" yaml:"Model,omitempty" json:"model,omitempty"`
Url string `form:"url" yaml:"Url,omitempty" json:"url,omitempty"` Url string `form:"url" yaml:"Url,omitempty" json:"url,omitempty"`
Images Files `form:"images" yaml:"Images,omitempty" json:"images,omitempty"` Images Files `form:"images" yaml:"Images,omitempty" json:"images,omitempty"`
responseFormat ApiFormat `form:"-"`
}
// NewApiRequest returns a new service API request with the specified format and payload.
func NewApiRequest(requestFormat ApiFormat, files Files, fileScheme scheme.Type) (result *ApiRequest, err error) {
if len(files) == 0 {
return result, errors.New("missing files")
}
switch requestFormat {
case ApiFormatUrl:
return NewApiRequestUrl(files[0], fileScheme)
case ApiFormatImages, ApiFormatVision:
return NewApiRequestImages(files, fileScheme)
default:
return result, errors.New("invalid request format")
}
}
// NewApiRequestUrl returns a new Vision API request with the specified image Url as payload.
func NewApiRequestUrl(fileName string, fileScheme scheme.Type) (result *ApiRequest, err error) {
var imgUrl string
switch fileScheme {
case scheme.Https:
// Return if no thumbnail filenames were given.
if !fs.FileExistsNotEmpty(fileName) {
return result, errors.New("invalid image file name")
}
// Generate a random token for the remote service to download the file.
fileUuid := rnd.UUID()
if err = download.Register(fileUuid, fileName); err != nil {
return result, fmt.Errorf("%s (create download url)", err)
}
imgUrl = fmt.Sprintf("%s/%s", DownloadUrl, fileUuid)
case scheme.Data:
var u *url.URL
if u, err = url.Parse(fileName); err != nil {
return result, fmt.Errorf("%s (invalid image url)", err)
} else if !slices.Contains(scheme.HttpsHttp, u.Scheme) {
return nil, fmt.Errorf("unsupported image url scheme %s", clean.Log(u.Scheme))
} else {
imgUrl = u.String()
}
default:
return nil, fmt.Errorf("unsupported file scheme %s", clean.Log(fileScheme))
}
return &ApiRequest{
Id: rnd.UUID(),
Model: "",
Url: imgUrl,
responseFormat: ApiFormatVision,
}, nil
}
// NewApiRequestImages returns a new Vision API request with the specified images as payload.
func NewApiRequestImages(images Files, fileScheme scheme.Type) (*ApiRequest, error) {
imageUrls := make(Files, len(images))
if fileScheme == scheme.Https && !strings.HasPrefix(DownloadUrl, "https://") {
log.Tracef("vision: file request scheme changed from https to data because https is not configured")
fileScheme = scheme.Data
}
for i := range images {
switch fileScheme {
case scheme.Https:
fileUuid := rnd.UUID()
if err := download.Register(fileUuid, images[i]); err != nil {
return nil, fmt.Errorf("%s (create download url)", err)
} else {
imageUrls[i] = fmt.Sprintf("%s/%s", DownloadUrl, fileUuid)
}
case scheme.Data:
if file, err := os.Open(images[i]); err != nil {
return nil, fmt.Errorf("%s (create data url)", err)
} else {
imageUrls[i] = media.DataUrl(file)
}
default:
return nil, fmt.Errorf("unsupported file scheme %s", clean.Log(fileScheme))
}
}
return &ApiRequest{
Id: rnd.UUID(),
Model: "",
Images: imageUrls,
responseFormat: ApiFormatVision,
}, nil
} }
// GetId returns the request ID string and generates a random ID if none was set. // GetId returns the request ID string and generates a random ID if none was set.
@@ -25,6 +130,15 @@ func (r *ApiRequest) GetId() string {
return r.Id return r.Id
} }
// GetResponseFormat returns the expected response format type.
func (r *ApiRequest) GetResponseFormat() ApiFormat {
if r.responseFormat == "" {
return ApiFormatVision
}
return r.responseFormat
}
// JSON returns the request data as JSON-encoded bytes. // JSON returns the request data as JSON-encoded bytes.
func (r *ApiRequest) JSON() ([]byte, error) { func (r *ApiRequest) JSON() ([]byte, error) {
return json.Marshal(*r) return json.Marshal(*r)

View File

@@ -2,17 +2,9 @@ package vision
import ( import (
"errors" "errors"
"fmt"
"net/url"
"slices"
"github.com/photoprism/photoprism/internal/api/download"
"github.com/photoprism/photoprism/internal/entity" "github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/media" "github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
"github.com/photoprism/photoprism/pkg/rnd"
) )
// Caption returns generated captions for the specified images. // Caption returns generated captions for the specified images.
@@ -23,56 +15,23 @@ func Caption(imgName string, src media.Src) (result CaptionResult, err error) {
} else if model := Config.Model(ModelTypeCaption); model != nil { } else if model := Config.Model(ModelTypeCaption); model != nil {
// Use remote service API if a server endpoint has been configured. // Use remote service API if a server endpoint has been configured.
if uri, method := model.Endpoint(); uri != "" && method != "" { if uri, method := model.Endpoint(); uri != "" && method != "" {
var imgUrl string var apiRequest *ApiRequest
var apiResponse *ApiResponse
switch src { if apiRequest, err = NewApiRequest(model.EndpointRequestFormat(), Files{imgName}, model.EndpointFileScheme()); err != nil {
case media.SrcLocal: return result, err
// Return if no thumbnail filenames were given.
if !fs.FileExistsNotEmpty(imgName) {
return result, errors.New("invalid image file name")
}
/* TODO: Add support for data URLs to the service.
if file, fileErr := os.Open(imgName); fileErr != nil {
return result, fmt.Errorf("%s (open image file)", err)
} else {
imgUrl = media.DataUrl(file)
} */
fileUuid := rnd.UUID()
if dlErr := download.Register(imgName, fileUuid); dlErr != nil {
return result, fmt.Errorf("%s (create download url)", err)
}
imgUrl = fmt.Sprintf("%s/%s", DownloadUrl, fileUuid)
case media.SrcRemote:
var u *url.URL
if u, err = url.Parse(imgName); err != nil {
return result, fmt.Errorf("%s (invalid image url)", err)
} else if !slices.Contains(scheme.HttpsHttp, u.Scheme) {
return result, fmt.Errorf("unsupported image url scheme %s", clean.Log(u.Scheme))
} else {
imgUrl = u.String()
}
default:
return result, fmt.Errorf("unsupported media source type %s", clean.Log(src))
} }
apiRequest := &ApiRequest{ if model.Name != "" {
Id: rnd.UUID(), apiRequest.Model = model.Name
Model: model.Name,
Url: imgUrl,
} }
/* if json, _ := apiRequest.JSON(); len(json) > 0 { /* if json, _ := apiRequest.JSON(); len(json) > 0 {
log.Debugf("request: %s", json) log.Debugf("request: %s", json)
} */ } */
apiResponse, apiErr := PerformApiRequest(apiRequest, uri, method, model.EndpointKey()) if apiResponse, err = PerformApiRequest(apiRequest, uri, method, model.EndpointKey()); err != nil {
return result, err
if apiErr != nil {
return result, apiErr
} else if apiResponse.Result.Caption == nil { } else if apiResponse.Result.Caption == nil {
return result, errors.New("invalid caption model response") return result, errors.New("invalid caption model response")
} }

View File

@@ -1,19 +1,25 @@
package vision package vision
import ( import (
"net/http"
"time" "time"
"github.com/photoprism/photoprism/pkg/fs" "github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
) )
var ( var (
AssetsPath = fs.Abs("../../../assets") AssetsPath = fs.Abs("../../../assets")
FaceNetModelPath = fs.Abs("../../../assets/facenet") FaceNetModelPath = fs.Abs("../../../assets/facenet")
NsfwModelPath = fs.Abs("../../../assets/nsfw") NsfwModelPath = fs.Abs("../../../assets/nsfw")
CachePath = fs.Abs("../../../storage/cache") CachePath = fs.Abs("../../../storage/cache")
ServiceUri = "" DownloadUrl = ""
ServiceKey = "" ServiceUri = ""
ServiceTimeout = time.Minute ServiceKey = ""
DownloadUrl = "" ServiceTimeout = time.Minute
DefaultResolution = 224 ServiceMethod = http.MethodPost
ServiceFileScheme = scheme.Data
ServiceRequestFormat = ApiFormatVision
ServiceResponseFormat = ApiFormatVision
DefaultResolution = 224
) )

View File

@@ -5,7 +5,6 @@ import (
"github.com/photoprism/photoprism/internal/ai/face" "github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/thumb/crop" "github.com/photoprism/photoprism/internal/thumb/crop"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
) )
// Faces detects faces in the specified image and generates embeddings from them. // Faces detects faces in the specified image and generates embeddings from them.
@@ -30,7 +29,11 @@ func Faces(fileName string, minSize int, cacheCrop bool, expected int) (result f
} }
if uri, method := model.Endpoint(); uri != "" && method != "" { if uri, method := model.Endpoint(); uri != "" && method != "" {
faceCrops := make([]string, len(result)) var faceCrops []string
var apiRequest *ApiRequest
var apiResponse *ApiResponse
faceCrops = make([]string, len(result))
for i, f := range result { for i, f := range result {
if f.Area.Col == 0 && f.Area.Row == 0 { if f.Area.Col == 0 && f.Area.Row == 0 {
@@ -46,20 +49,16 @@ func Faces(fileName string, minSize int, cacheCrop bool, expected int) (result f
} }
} }
apiRequest, apiRequestErr := NewApiRequest(faceCrops, scheme.Data) if apiRequest, err = NewApiRequest(model.EndpointRequestFormat(), faceCrops, model.EndpointFileScheme()); err != nil {
return result, err
if apiRequestErr != nil {
return result, apiRequestErr
} }
if model.Name != "" { if model.Name != "" {
apiRequest.Model = model.Name apiRequest.Model = model.Name
} }
apiResponse, apiErr := PerformApiRequest(apiRequest, uri, method, model.EndpointKey()) if apiResponse, err = PerformApiRequest(apiRequest, uri, method, model.EndpointKey()); err != nil {
return result, err
if apiErr != nil {
return result, apiErr
} }
for i := range result { for i := range result {

View File

@@ -8,7 +8,6 @@ import (
"github.com/photoprism/photoprism/internal/ai/classify" "github.com/photoprism/photoprism/internal/ai/classify"
"github.com/photoprism/photoprism/pkg/clean" "github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media" "github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
) )
// Labels finds matching labels for the specified image. // Labels finds matching labels for the specified image.
@@ -24,20 +23,19 @@ func Labels(images Files, src media.Src) (result classify.Labels, err error) {
} else if model := Config.Model(ModelTypeLabels); model != nil { } else if model := Config.Model(ModelTypeLabels); model != nil {
// Use remote service API if a server endpoint has been configured. // Use remote service API if a server endpoint has been configured.
if uri, method := model.Endpoint(); uri != "" && method != "" { if uri, method := model.Endpoint(); uri != "" && method != "" {
apiRequest, apiRequestErr := NewApiRequest(images, scheme.Data) var apiRequest *ApiRequest
var apiResponse *ApiResponse
if apiRequestErr != nil { if apiRequest, err = NewApiRequest(model.EndpointRequestFormat(), images, model.EndpointFileScheme()); err != nil {
return result, apiRequestErr return result, err
} }
if model.Name != "" { if model.Name != "" {
apiRequest.Model = model.Name apiRequest.Model = model.Name
} }
apiResponse, apiErr := PerformApiRequest(apiRequest, uri, method, model.EndpointKey()) if apiResponse, err = PerformApiRequest(apiRequest, uri, method, model.EndpointKey()); err != nil {
return result, err
if apiErr != nil {
return result, apiErr
} }
for _, label := range apiResponse.Result.Labels { for _, label := range apiResponse.Result.Labels {

View File

@@ -2,7 +2,6 @@ package vision
import ( import (
"fmt" "fmt"
"net/http"
"path/filepath" "path/filepath"
"sync" "sync"
@@ -10,6 +9,7 @@ import (
"github.com/photoprism/photoprism/internal/ai/face" "github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/ai/nsfw" "github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/pkg/clean" "github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
) )
var modelMutex = sync.Mutex{} var modelMutex = sync.Mutex{}
@@ -20,12 +20,10 @@ type Model struct {
Name string `yaml:"Name,omitempty" json:"name,omitempty"` Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Version string `yaml:"Version,omitempty" json:"version,omitempty"` Version string `yaml:"Version,omitempty" json:"version,omitempty"`
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"` Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
Uri string `yaml:"Uri,omitempty" json:"-"` Service Service `yaml:"Service,omitempty" json:"Service,omitempty"`
Key string `yaml:"Key,omitempty" json:"-"`
Method string `yaml:"Method,omitempty" json:"-"`
Path string `yaml:"Path,omitempty" json:"-"` Path string `yaml:"Path,omitempty" json:"-"`
Tags []string `yaml:"Tags,omitempty" json:"-"` Tags []string `yaml:"Tags,omitempty" json:"-"`
Disabled bool `yaml:"Disabled,omitempty" json:"-"` Disabled bool `yaml:"Disabled,omitempty" json:"disabled,omitempty"`
classifyModel *classify.Model classifyModel *classify.Model
faceModel *face.Model faceModel *face.Model
nsfwModel *nsfw.Model nsfwModel *nsfw.Model
@@ -36,32 +34,51 @@ type Models []*Model
// Endpoint returns the remote service request method and endpoint URL, if any. // Endpoint returns the remote service request method and endpoint URL, if any.
func (m *Model) Endpoint() (uri, method string) { func (m *Model) Endpoint() (uri, method string) {
if m.Uri == "" && ServiceUri == "" || m.Type == "" { if uri, method = m.Service.Endpoint(); uri != "" && method != "" {
return uri, method
} else if ServiceUri == "" {
return "", ""
} else if serviceType := clean.TypeLowerUnderscore(m.Type); serviceType == "" {
return "", "" return "", ""
}
if m.Method != "" {
method = m.Method
} else { } else {
method = http.MethodPost return fmt.Sprintf("%s/%s", ServiceUri, serviceType), ServiceMethod
}
if m.Uri != "" {
return m.Uri, method
} else {
return fmt.Sprintf("%s/%s", ServiceUri, clean.TypeLowerUnderscore(m.Type)), method
} }
} }
// EndpointKey returns the access token belonging to the remote service endpoint, if any. // EndpointKey returns the access token belonging to the remote service endpoint, if any.
func (m *Model) EndpointKey() string { func (m *Model) EndpointKey() (key string) {
if m.Key != "" { if key = m.Service.EndpointKey(); key != "" {
return m.Key return key
} else if ServiceKey != "" { } else {
return ServiceKey return ServiceKey
} }
}
return "" // EndpointFileScheme returns the endpoint API request file scheme type.
func (m *Model) EndpointFileScheme() (fileScheme scheme.Type) {
if fileScheme = m.Service.EndpointFileScheme(); fileScheme != "" {
return fileScheme
}
return ServiceFileScheme
}
// EndpointRequestFormat returns the endpoint API request format.
func (m *Model) EndpointRequestFormat() (format ApiFormat) {
if format = m.Service.EndpointRequestFormat(); format != "" {
return format
}
return ServiceRequestFormat
}
// EndpointResponseFormat returns the endpoint API response format.
func (m *Model) EndpointResponseFormat() (format ApiFormat) {
if format = m.Service.EndpointResponseFormat(); format != "" {
return format
}
return ServiceResponseFormat
} }
// ClassifyModel returns the matching classify model instance, if any. // ClassifyModel returns the matching classify model instance, if any.

View File

@@ -1,5 +1,9 @@
package vision package vision
import (
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
// Default computer vision model configuration. // Default computer vision model configuration.
var ( var (
NasnetModel = &Model{ NasnetModel = &Model{
@@ -26,7 +30,12 @@ var (
CaptionModel = &Model{ CaptionModel = &Model{
Type: ModelTypeCaption, Type: ModelTypeCaption,
Resolution: 224, Resolution: 224,
Uri: "http://photoprism-vision:5000/api/v1/vision/caption", Service: Service{
Uri: "http://photoprism-vision:5000/api/v1/vision/caption",
FileScheme: scheme.Https,
RequestFormat: ApiFormatUrl,
ResponseFormat: ApiFormatVision,
},
} }
DefaultModels = Models{NasnetModel, NsfwModel, FacenetModel, CaptionModel} DefaultModels = Models{NasnetModel, NsfwModel, FacenetModel, CaptionModel}
DefaultThresholds = Thresholds{Confidence: 10} DefaultThresholds = Thresholds{Confidence: 10}

View File

@@ -7,7 +7,6 @@ import (
"github.com/photoprism/photoprism/internal/ai/nsfw" "github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/pkg/clean" "github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media" "github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
) )
// Nsfw checks the specified images for inappropriate content. // Nsfw checks the specified images for inappropriate content.
@@ -25,20 +24,19 @@ func Nsfw(images Files, src media.Src) (result []nsfw.Result, err error) {
} else if model := Config.Model(ModelTypeNsfw); model != nil { } else if model := Config.Model(ModelTypeNsfw); model != nil {
// Use remote service API if a server endpoint has been configured. // Use remote service API if a server endpoint has been configured.
if uri, method := model.Endpoint(); uri != "" && method != "" { if uri, method := model.Endpoint(); uri != "" && method != "" {
apiRequest, apiRequestErr := NewApiRequest(images, scheme.Data) var apiRequest *ApiRequest
var apiResponse *ApiResponse
if apiRequestErr != nil { if apiRequest, err = NewApiRequest(model.EndpointRequestFormat(), images, model.EndpointFileScheme()); err != nil {
return result, apiRequestErr return result, err
} }
if model.Name != "" { if model.Name != "" {
apiRequest.Model = model.Name apiRequest.Model = model.Name
} }
apiResponse, apiErr := PerformApiRequest(apiRequest, uri, method, model.EndpointKey()) if apiResponse, err = PerformApiRequest(apiRequest, uri, method, model.EndpointKey()); err != nil {
return result, err
if apiErr != nil {
return result, apiErr
} }
result = apiResponse.Result.Nsfw result = apiResponse.Result.Nsfw

View File

@@ -0,0 +1,73 @@
package vision
import (
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
// Service represents a remote computer vision service configuration.
type Service struct {
Uri string `yaml:"Uri,omitempty" json:"uri"`
Method string `yaml:"Method,omitempty" json:"method"`
Key string `yaml:"Key,omitempty" json:"-"`
FileScheme string `yaml:"FileScheme,omitempty" json:"fileScheme,omitempty"`
RequestFormat ApiFormat `yaml:"RequestFormat,omitempty" json:"requestFormat,omitempty"`
ResponseFormat ApiFormat `yaml:"ResponseFormat,omitempty" json:"responseFormat,omitempty"`
Disabled bool `yaml:"Disabled,omitempty" json:"disabled,omitempty"`
}
// Endpoint returns the remote service request method and endpoint URL, if any.
func (m *Service) Endpoint() (uri, method string) {
if m.Disabled || m.Uri == "" {
return "", ""
}
if m.Method != "" {
method = m.Method
} else {
method = ServiceMethod
}
return m.Uri, method
}
// EndpointKey returns the access token belonging to the remote service endpoint, if any.
func (m *Service) EndpointKey() string {
if m.Disabled {
return ""
}
return m.Key
}
// EndpointFileScheme returns the endpoint API file scheme type.
func (m *Service) EndpointFileScheme() scheme.Type {
if m.Disabled {
return ""
} else if m.FileScheme == "" {
return ServiceFileScheme
}
return m.FileScheme
}
// EndpointRequestFormat returns the endpoint API request format.
func (m *Service) EndpointRequestFormat() ApiFormat {
if m.Disabled {
return ""
} else if m.RequestFormat == "" {
return ApiFormatVision
}
return m.RequestFormat
}
// EndpointResponseFormat returns the endpoint API response format.
func (m *Service) EndpointResponseFormat() ApiFormat {
if m.Disabled {
return ""
} else if m.ResponseFormat == "" {
return ApiFormatVision
}
return m.ResponseFormat
}

View File

@@ -17,6 +17,10 @@ Models:
- serve - serve
- Type: caption - Type: caption
Resolution: 224 Resolution: 224
Uri: http://photoprism-vision:5000/api/v1/vision/caption Service:
Uri: http://photoprism-vision:5000/api/v1/vision/caption
FileScheme: https
RequestFormat: url
ResponseFormat: vision
Thresholds: Thresholds:
Confidence: 10 Confidence: 10

View File

@@ -13,19 +13,19 @@ import (
// for download until the cache expires, or the server is restarted. // for download until the cache expires, or the server is restarted.
func Register(fileUuid, fileName string) error { func Register(fileUuid, fileName string) error {
if !rnd.IsUUID(fileUuid) { if !rnd.IsUUID(fileUuid) {
event.AuditWarn([]string{"api", "create download token", "%s", authn.Failed}, fileName) event.AuditWarn([]string{"api", "download", "create temporary token for %s", authn.Failed}, fileName)
return errors.New("invalid file uuid") return errors.New("invalid file uuid")
} }
if fileName = fs.Abs(fileName); !fs.FileExists(fileName) { if fileName = fs.Abs(fileName); !fs.FileExists(fileName) {
event.AuditWarn([]string{"api", "create download token", "%s", authn.Failed}, fileName) event.AuditWarn([]string{"api", "download", "create temporary token for %s", authn.Failed}, fileName)
return errors.New("file not found") return errors.New("file not found")
} else if Deny(fileName) { } else if Deny(fileName) {
event.AuditErr([]string{"api", "create download token", "%s", authn.Denied}, fileName) event.AuditErr([]string{"api", "download", "create temporary token for %s", authn.Denied}, fileName)
return errors.New("forbidden file path") return errors.New("forbidden file path")
} }
event.AuditInfo([]string{"api", "create download token", "%s", authn.Succeeded}, fileName, expires.String()) event.AuditInfo([]string{"api", "download", "create temporary token for %s", authn.Succeeded}, fileName)
cache.SetDefault(fileUuid, fileName) cache.SetDefault(fileUuid, fileName)

View File

@@ -11,6 +11,7 @@ import (
func TestRegister(t *testing.T) { func TestRegister(t *testing.T) {
t.Run("Success", func(t *testing.T) { t.Run("Success", func(t *testing.T) {
// Generate a random token for the remote service to download the file.
fileUuid := rnd.UUID() fileUuid := rnd.UUID()
fileName := fs.Abs("./testdata/image.jpg") fileName := fs.Abs("./testdata/image.jpg")
err := Register(fileUuid, fileName) err := Register(fileUuid, fileName)
@@ -30,6 +31,7 @@ func TestRegister(t *testing.T) {
assert.Equal(t, "", findName) assert.Equal(t, "", findName)
}) })
t.Run("NotFound", func(t *testing.T) { t.Run("NotFound", func(t *testing.T) {
// Generate a random token for the remote service to download the file.
fileUuid := rnd.UUID() fileUuid := rnd.UUID()
fileName := fs.Abs("./testdata/invalid.jpg") fileName := fs.Abs("./testdata/invalid.jpg")
err := Register(fileUuid, fileName) err := Register(fileUuid, fileName)

View File

@@ -22,7 +22,7 @@ func TestPostVisionFace(t *testing.T) {
fs.Abs("./testdata/face_160x160.jpg"), fs.Abs("./testdata/face_160x160.jpg"),
} }
req, err := vision.NewApiRequest(files, scheme.Data) req, err := vision.NewApiRequestImages(files, scheme.Data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -65,7 +65,7 @@ func TestPostVisionFace(t *testing.T) {
fs.Abs("./testdata/london_160x160.jpg"), fs.Abs("./testdata/london_160x160.jpg"),
} }
req, err := vision.NewApiRequest(files, scheme.Data) req, err := vision.NewApiRequestImages(files, scheme.Data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -101,7 +101,7 @@ func TestPostVisionFace(t *testing.T) {
fs.Abs("./testdata/face_320x320.jpg"), fs.Abs("./testdata/face_320x320.jpg"),
} }
req, err := vision.NewApiRequest(files, scheme.Data) req, err := vision.NewApiRequestImages(files, scheme.Data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -142,7 +142,7 @@ func TestPostVisionFace(t *testing.T) {
files := vision.Files{} files := vision.Files{}
req, err := vision.NewApiRequest(files, scheme.Data) req, err := vision.NewApiRequestImages(files, scheme.Data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@@ -22,7 +22,7 @@ func TestPostVisionLabels(t *testing.T) {
fs.Abs("./testdata/cat_224x224.jpg"), fs.Abs("./testdata/cat_224x224.jpg"),
} }
req, err := vision.NewApiRequest(files, scheme.Data) req, err := vision.NewApiRequestImages(files, scheme.Data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -59,7 +59,7 @@ func TestPostVisionLabels(t *testing.T) {
fs.Abs("./testdata/green_224x224.jpg"), fs.Abs("./testdata/green_224x224.jpg"),
} }
req, err := vision.NewApiRequest(files, scheme.Data) req, err := vision.NewApiRequestImages(files, scheme.Data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -93,7 +93,7 @@ func TestPostVisionLabels(t *testing.T) {
files := vision.Files{} files := vision.Files{}
req, err := vision.NewApiRequest(files, scheme.Data) req, err := vision.NewApiRequestImages(files, scheme.Data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@@ -22,7 +22,7 @@ func TestPostVisionNsfw(t *testing.T) {
fs.Abs("./testdata/nsfw_224x224.jpg"), fs.Abs("./testdata/nsfw_224x224.jpg"),
} }
req, err := vision.NewApiRequest(files, scheme.Data) req, err := vision.NewApiRequestImages(files, scheme.Data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -74,7 +74,7 @@ func TestPostVisionNsfw(t *testing.T) {
fs.Abs("./testdata/green_224x224.jpg"), fs.Abs("./testdata/green_224x224.jpg"),
} }
req, err := vision.NewApiRequest(files, scheme.Data) req, err := vision.NewApiRequestImages(files, scheme.Data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -108,7 +108,7 @@ func TestPostVisionNsfw(t *testing.T) {
files := vision.Files{} files := vision.Files{}
req, err := vision.NewApiRequest(files, scheme.Data) req, err := vision.NewApiRequestImages(files, scheme.Data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@@ -19,8 +19,7 @@ var VisionRunCommand = &cli.Command{
&cli.StringFlag{ &cli.StringFlag{
Name: "models", Name: "models",
Aliases: []string{"m"}, Aliases: []string{"m"},
// TODO: Add captions to the list once the service can be used from the CLI. Usage: "model types (labels, nsfw, caption)",
Usage: "model types (labels, nsfw)",
}, },
&cli.BoolFlag{ &cli.BoolFlag{
Name: "force", Name: "force",

View File

@@ -1,15 +1,18 @@
package scheme package scheme
// Type represents a URL scheme type.
type Type = string
const ( const (
File = "file" File Type = "file"
Data = "data" Data Type = "data"
Http = "http" Http Type = "http"
Https = "https" Https Type = "https"
HttpUnix = Http + "+" + Unix Websocket Type = "wss"
Websocket = "wss" Unix Type = "unix"
Unix = "unix" HttpUnix Type = "http+unix"
Unixgram = "unixgram" Unixgram Type = "unixgram"
Unixpacket = "unixpacket" Unixpacket Type = "unixpacket"
) )
var ( var (