mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 08:44:04 +01:00
189 lines
5.7 KiB
Go
189 lines
5.7 KiB
Go
package vision
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"net/http"
|
|
|
|
"github.com/photoprism/photoprism/internal/ai/classify"
|
|
"github.com/photoprism/photoprism/internal/ai/face"
|
|
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
|
"github.com/photoprism/photoprism/internal/entity"
|
|
"github.com/photoprism/photoprism/pkg/clean"
|
|
)
|
|
|
|
// ApiResponse represents a Vision API service response.
|
|
type ApiResponse struct {
|
|
Id string `yaml:"Id,omitempty" json:"id,omitempty"`
|
|
Code int `yaml:"Code,omitempty" json:"code,omitempty"`
|
|
Error string `yaml:"Error,omitempty" json:"error,omitempty"`
|
|
Model *Model `yaml:"Model,omitempty" json:"model,omitempty"`
|
|
Result ApiResult `yaml:"Result,omitempty" json:"result,omitempty"`
|
|
}
|
|
|
|
// Err returns an error if the request has failed.
|
|
func (r *ApiResponse) Err() error {
|
|
if r == nil {
|
|
return errors.New("response is nil")
|
|
}
|
|
|
|
if r.Code >= 400 {
|
|
if r.Error != "" {
|
|
return errors.New(r.Error)
|
|
}
|
|
|
|
return fmt.Errorf("error %d", r.Code)
|
|
} else if r.Result.IsEmpty() {
|
|
return errors.New("no result")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// HasResult checks if there is at least one result in the response data.
|
|
func (r *ApiResponse) HasResult() bool {
|
|
if r == nil {
|
|
return false
|
|
}
|
|
|
|
return !r.Result.IsEmpty()
|
|
}
|
|
|
|
// ApiResult represents the model response(s) to a Vision API service
|
|
// request and can optionally include data from multiple models.
|
|
type ApiResult struct {
|
|
Labels []LabelResult `yaml:"Labels,omitempty" json:"labels,omitempty"`
|
|
Nsfw []nsfw.Result `yaml:"Nsfw,omitempty" json:"nsfw,omitempty"`
|
|
Embeddings []face.Embeddings `yaml:"Embeddings,omitempty" json:"embeddings,omitempty"`
|
|
Caption *CaptionResult `yaml:"Caption,omitempty" json:"caption,omitempty"`
|
|
}
|
|
|
|
// IsEmpty checks if there is no result in the response data.
|
|
func (r *ApiResult) IsEmpty() bool {
|
|
if r == nil {
|
|
return false
|
|
}
|
|
|
|
return len(r.Labels) == 0 && len(r.Nsfw) == 0 && len(r.Embeddings) == 0 && r.Caption == nil
|
|
}
|
|
|
|
// CaptionResult represents the result generated by a caption generation model.
|
|
type CaptionResult struct {
|
|
Text string `yaml:"Text,omitempty" json:"text,omitempty"`
|
|
Source string `yaml:"Source,omitempty" json:"source,omitempty"`
|
|
Confidence float32 `yaml:"Confidence,omitempty" json:"confidence,omitempty"`
|
|
}
|
|
|
|
// LabelResult represents a label generated by an image classification model.
|
|
type LabelResult struct {
|
|
Name string `yaml:"Name,omitempty" json:"name"`
|
|
Source string `yaml:"Source,omitempty" json:"source"`
|
|
Priority int `yaml:"Priority,omitempty" json:"priority,omitempty"`
|
|
Confidence float32 `yaml:"Confidence,omitempty" json:"confidence,omitempty"`
|
|
Topicality float32 `yaml:"Topicality,omitempty" json:"topicality,omitempty"`
|
|
Categories []string `yaml:"Categories,omitempty" json:"categories,omitempty"`
|
|
NSFW bool `yaml:"Nsfw,omitempty" json:"nsfw,omitempty"`
|
|
NSFWConfidence float32 `yaml:"NsfwConfidence,omitempty" json:"nsfw_confidence,omitempty"`
|
|
}
|
|
|
|
// ToClassify returns the label results as classify.Label.
|
|
func (r LabelResult) ToClassify(labelSrc string) classify.Label {
|
|
// Calculate uncertainty from confidence or assume a default of 20%.
|
|
var uncertainty int
|
|
|
|
if r.Confidence <= 0 {
|
|
uncertainty = 20
|
|
} else {
|
|
uncertainty = int(math.RoundToEven(float64(100 - r.Confidence*100)))
|
|
}
|
|
|
|
// Default to "image" if no source name is provided.
|
|
switch {
|
|
case labelSrc != entity.SrcAuto:
|
|
labelSrc = clean.ShortTypeLower(labelSrc)
|
|
case r.Source != "":
|
|
labelSrc = clean.ShortTypeLower(r.Source)
|
|
default:
|
|
labelSrc = entity.SrcImage
|
|
}
|
|
|
|
topicality := int(math.RoundToEven(float64(r.Topicality * 100)))
|
|
if topicality < 0 {
|
|
topicality = 0
|
|
} else if topicality > 100 {
|
|
topicality = 100
|
|
}
|
|
|
|
// Return label.
|
|
confidenceScaled := int(math.RoundToEven(float64(r.NSFWConfidence * 100)))
|
|
if confidenceScaled < 0 {
|
|
confidenceScaled = 0
|
|
} else if confidenceScaled > 100 {
|
|
confidenceScaled = 100
|
|
}
|
|
if r.NSFW && confidenceScaled == 0 {
|
|
confidenceScaled = 100
|
|
}
|
|
|
|
return classify.Label{
|
|
Name: r.Name,
|
|
Source: labelSrc,
|
|
Priority: r.Priority,
|
|
Uncertainty: uncertainty,
|
|
Topicality: topicality,
|
|
Categories: r.Categories,
|
|
NSFW: r.NSFW,
|
|
NSFWConfidence: confidenceScaled,
|
|
}
|
|
}
|
|
|
|
// NewApiError generates a Vision API error response based on the specified HTTP status code.
|
|
func NewApiError(id string, code int) ApiResponse {
|
|
return ApiResponse{
|
|
Id: clean.Type(id),
|
|
Code: code,
|
|
Error: http.StatusText(code),
|
|
}
|
|
}
|
|
|
|
// NewLabelsResponse generates a new Vision API image classification service response.
|
|
func NewLabelsResponse(id string, model *Model, results classify.Labels) ApiResponse {
|
|
if model == nil {
|
|
model = NasnetModel
|
|
}
|
|
|
|
var labels = make([]LabelResult, 0, len(results))
|
|
|
|
for _, label := range results {
|
|
|
|
labels = append(labels, LabelResult{
|
|
Name: label.Name,
|
|
Source: label.Source,
|
|
Priority: label.Priority,
|
|
Confidence: label.Confidence(),
|
|
Topicality: float32(label.Topicality) / 100,
|
|
Categories: label.Categories,
|
|
NSFW: label.NSFW,
|
|
NSFWConfidence: float32(label.NSFWConfidence) / 100,
|
|
})
|
|
}
|
|
|
|
return ApiResponse{
|
|
Id: clean.Type(id),
|
|
Code: http.StatusOK,
|
|
Model: &Model{Type: ModelTypeLabels, Name: model.Name, Version: model.Version, Resolution: model.Resolution},
|
|
Result: ApiResult{Labels: labels},
|
|
}
|
|
}
|
|
|
|
// NewCaptionResponse generates a new Vision API image caption service response.
|
|
func NewCaptionResponse(id string, model *Model, result *CaptionResult) ApiResponse {
|
|
return ApiResponse{
|
|
Id: clean.Type(id),
|
|
Code: http.StatusOK,
|
|
Model: &Model{Type: ModelTypeLabels, Name: model.Name, Version: model.Version, Resolution: model.Resolution},
|
|
Result: ApiResult{Caption: result},
|
|
}
|
|
}
|