Files
photoprism/internal/ai/vision/nsfw.go
2025-11-24 14:41:13 +01:00

101 lines
2.6 KiB
Go

package vision
import (
"errors"
"fmt"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media"
)
var nsfwFunc = nsfwInternal
// SetNSFWFunc overrides the Vision NSFW detector. Intended for tests.
func SetNSFWFunc(fn func(Files, media.Src) ([]nsfw.Result, error)) {
if fn == nil {
nsfwFunc = nsfwInternal
return
}
nsfwFunc = fn
}
// DetectNSFW checks images for inappropriate content and generates probability scores grouped by category.
func DetectNSFW(images Files, mediaSrc media.Src) (result []nsfw.Result, err error) {
return nsfwFunc(images, mediaSrc)
}
func nsfwInternal(images Files, mediaSrc media.Src) (result []nsfw.Result, err error) {
// Return if no thumbnail filenames were given.
if len(images) == 0 {
return result, errors.New("at least one image required")
}
result = make([]nsfw.Result, len(images))
// Return if there is no configuration or no image classification models are configured.
if Config == nil {
return result, errors.New("vision service is not configured")
} else if model := Config.Model(ModelTypeNsfw); model != nil {
// Use remote service API if a server endpoint has been configured.
if uri, method := model.Endpoint(); uri != "" && method != "" {
var apiRequest *ApiRequest
var apiResponse *ApiResponse
if apiRequest, err = NewApiRequest(model.EndpointRequestFormat(), images, model.EndpointFileScheme()); err != nil {
return result, err
}
if apiRequest.Model == "" {
apiRequest.Model, _, apiRequest.Version = model.GetModel()
}
model.ApplyService(apiRequest)
if model.System != "" {
apiRequest.System = model.System
}
if model.Prompt != "" {
apiRequest.Prompt = model.Prompt
}
// Log JSON request data in trace mode.
apiRequest.WriteLog()
if apiResponse, err = PerformApiRequest(apiRequest, uri, method, model.EndpointKey()); err != nil {
return result, err
}
result = apiResponse.Result.Nsfw
} else if tf := model.NsfwModel(); tf != nil {
// Predict labels with local TensorFlow model.
for i := range images {
var labels nsfw.Result
switch mediaSrc {
case media.SrcLocal:
labels, err = tf.File(images[i])
case media.SrcRemote:
labels, err = tf.Url(images[i])
default:
return result, fmt.Errorf("invalid media source %s", clean.Log(mediaSrc))
}
if err != nil {
log.Errorf("nsfw: %s", err)
}
result[i] = labels
}
} else {
return result, errors.New("invalid nsfw model configuration")
}
} else {
return result, errors.New("missing nsfw model")
}
return result, nil
}