AI: Add VisionApi, VisionUri, and VisionKey config options #127 #1090

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-04-07 20:30:07 +02:00
parent 2a6b9fb237
commit d304509c0d
14 changed files with 176 additions and 71 deletions

View File

@@ -14,10 +14,11 @@ type Model struct {
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
Url string `yaml:"Url,omitempty" json:"-"`
Uri string `yaml:"Uri,omitempty" json:"-"`
Key string `yaml:"Key,omitempty" json:"-"`
Method string `yaml:"Method,omitempty" json:"-"`
Format string `yaml:"Format,omitempty" json:"-"`
Path string `yaml:"Path,omitempty" json:"-"`
Format string `yaml:"Format,omitempty" json:"-"`
Tags []string `yaml:"Tags,omitempty" json:"-"`
Disabled bool `yaml:"Disabled,omitempty" json:"-"`
classifyModel *classify.Model

View File

@@ -7,6 +7,7 @@ import (
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/auth/acl"
"github.com/photoprism/photoprism/internal/photoprism/get"
)
// PostVisionCaption returns a suitable caption for an image.
@@ -36,6 +37,13 @@ func PostVisionCaption(router *gin.RouterGroup) {
return
}
// Check if the Computer Vision API is enabled, otherwise abort with an error.
if !get.Config().VisionApi() {
AbortFeatureDisabled(c)
c.JSON(http.StatusForbidden, vision.NewApiError(request.GetId(), http.StatusForbidden))
return
}
// Generate Vision API service response.
response := vision.ApiResponse{
Id: request.GetId(),

View File

@@ -7,6 +7,7 @@ import (
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/auth/acl"
"github.com/photoprism/photoprism/internal/photoprism/get"
)
// PostVisionFaces returns the positions and embeddings of detected faces.
@@ -36,6 +37,13 @@ func PostVisionFaces(router *gin.RouterGroup) {
return
}
// Check if the Computer Vision API is enabled, otherwise abort with an error.
if !get.Config().VisionApi() {
AbortFeatureDisabled(c)
c.JSON(http.StatusForbidden, vision.NewApiError(request.GetId(), http.StatusForbidden))
return
}
// Generate Vision API service response.
response := vision.ApiResponse{
Id: request.GetId(),

View File

@@ -7,6 +7,7 @@ import (
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/auth/acl"
"github.com/photoprism/photoprism/internal/photoprism/get"
)
// PostVisionLabels returns suitable labels for an image.
@@ -37,6 +38,13 @@ func PostVisionLabels(router *gin.RouterGroup) {
return
}
// Check if the Computer Vision API is enabled, otherwise abort with an error.
if !get.Config().VisionApi() {
AbortFeatureDisabled(c)
c.JSON(http.StatusForbidden, vision.NewApiError(request.GetId(), http.StatusForbidden))
return
}
// Run inference to find matching labels.
labels, err := vision.Labels(request.Images)

View File

@@ -1,32 +0,0 @@
package config
import (
"path/filepath"
tf "github.com/wamuir/graft/tensorflow"
)
// TensorFlowVersion returns the TenorFlow framework version.
func (c *Config) TensorFlowVersion() string {
return tf.Version()
}
// NasnetModelPath returns the TensorFlow model path.
func (c *Config) NasnetModelPath() string {
return filepath.Join(c.AssetsPath(), "nasnet")
}
// FaceNetModelPath returns the FaceNet model path.
func (c *Config) FaceNetModelPath() string {
return filepath.Join(c.AssetsPath(), "facenet")
}
// NSFWModelPath returns the "not safe for work" TensorFlow model path.
func (c *Config) NSFWModelPath() string {
return filepath.Join(c.AssetsPath(), "nsfw")
}
// DetectNSFW checks if NSFW photos should be detected and flagged.
func (c *Config) DetectNSFW() bool {
return c.options.DetectNSFW
}

View File

@@ -257,15 +257,6 @@ func (c *Config) DefaultsYaml() string {
return fs.Abs(c.options.DefaultsYaml)
}
// VisionYaml returns the vision config YAML filename.
func (c *Config) VisionYaml() string {
if c.options.VisionYaml != "" {
return fs.Abs(c.options.VisionYaml)
} else {
return filepath.Join(c.ConfigPath(), "vision.yml")
}
}
// HubConfigFile returns the backend api config file name.
func (c *Config) HubConfigFile() string {
return filepath.Join(c.ConfigPath(), "hub.yml")

View File

@@ -412,11 +412,6 @@ func TestConfig_CreateDirectories2(t *testing.T) {
}
*/
func TestConfig_VisionYaml(t *testing.T) {
c := NewConfig(CliTestContext())
assert.Equal(t, "/go/src/github.com/photoprism/photoprism/storage/testdata/config/vision.yml", c.VisionYaml())
}
func TestConfig_PIDFilename2(t *testing.T) {
c := NewConfig(CliTestContext())
assert.Equal(t, "/go/src/github.com/photoprism/photoprism/storage/testdata/photoprism.pid", c.PIDFilename())

View File

@@ -0,0 +1,71 @@
package config
import (
"os"
"path/filepath"
tf "github.com/wamuir/graft/tensorflow"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/fs"
)
// VisionYaml returns the vision config YAML filename.
func (c *Config) VisionYaml() string {
if c.options.VisionYaml != "" {
return fs.Abs(c.options.VisionYaml)
} else {
return filepath.Join(c.ConfigPath(), "vision.yml")
}
}
// VisionApi checks whether the Computer Vision API endpoints should be enabled.
func (c *Config) VisionApi() bool {
return c.options.VisionApi
}
// VisionUri returns the remote computer vision endpoint URI, e.g. https://example.com/api/v1/vision.
func (c *Config) VisionUri() string {
return clean.Uri(c.options.VisionUri)
}
// VisionKey returns the remote computer vision endpoint access token.
func (c *Config) VisionKey() string {
// Try to read access token from file if c.options.VisionKey is not set.
if c.options.VisionKey != "" {
return clean.Password(c.options.VisionKey)
} else if fileName := FlagFilePath("VISION_KEY"); fileName == "" {
// No access token set, this is not an error.
return ""
} else if b, err := os.ReadFile(fileName); err != nil || len(b) == 0 {
log.Warnf("config: failed to read vision key from %s (%s)", fileName, err)
return ""
} else {
return clean.Password(string(b))
}
}
// TensorFlowVersion returns the TenorFlow framework version.
func (c *Config) TensorFlowVersion() string {
return tf.Version()
}
// NasnetModelPath returns the TensorFlow model path.
func (c *Config) NasnetModelPath() string {
return filepath.Join(c.AssetsPath(), "nasnet")
}
// FaceNetModelPath returns the FaceNet model path.
func (c *Config) FaceNetModelPath() string {
return filepath.Join(c.AssetsPath(), "facenet")
}
// NSFWModelPath returns the "not safe for work" TensorFlow model path.
func (c *Config) NSFWModelPath() string {
return filepath.Join(c.AssetsPath(), "nsfw")
}
// DetectNSFW checks if NSFW photos should be detected and flagged.
func (c *Config) DetectNSFW() bool {
return c.options.DetectNSFW
}

View File

@@ -6,11 +6,39 @@ import (
"github.com/stretchr/testify/assert"
)
func TestConfig_VisionYaml(t *testing.T) {
c := NewConfig(CliTestContext())
assert.Equal(t, "/go/src/github.com/photoprism/photoprism/storage/testdata/config/vision.yml", c.VisionYaml())
}
func TestConfig_VisionApi(t *testing.T) {
c := NewConfig(CliTestContext())
assert.True(t, c.VisionApi())
}
func TestConfig_VisionUri(t *testing.T) {
c := NewConfig(CliTestContext())
assert.Equal(t, "", c.VisionUri())
c.options.VisionUri = "https://www.example.com/api/v1/vision"
assert.Equal(t, "https://www.example.com/api/v1/vision", c.VisionUri())
c.options.VisionUri = ""
assert.Equal(t, "", c.VisionUri())
}
func TestConfig_VisionKey(t *testing.T) {
c := NewConfig(CliTestContext())
assert.Equal(t, "", c.VisionKey())
c.options.VisionKey = "SecretAccessToken!"
assert.Equal(t, "SecretAccessToken!", c.VisionKey())
c.options.VisionKey = ""
assert.Equal(t, "", c.VisionKey())
}
func TestConfig_TensorFlowVersion(t *testing.T) {
c := NewConfig(CliTestContext())
version := c.TensorFlowVersion()
assert.IsType(t, "1.15.0", version)
assert.IsType(t, "2.18.0", version)
}
func TestConfig_TensorFlowModelPath(t *testing.T) {

View File

@@ -196,13 +196,6 @@ var Flags = CliFlags{
EnvVars: EnvVars("DEFAULTS_YAML"),
TakesFile: true,
}}, {
Flag: &cli.StringFlag{
Name: "vision-yaml",
Usage: "load computer vision model configuration from `FILE`*optional*",
Value: "",
EnvVars: EnvVars("VISION_YAML"),
TakesFile: true,
}}, {
Flag: &cli.PathFlag{
Name: "originals-path",
Aliases: []string{"o"},
@@ -492,11 +485,6 @@ var Flags = CliFlags{
Usage: "always perform a brute-force search if no Exif headers were found",
EnvVars: EnvVars("EXIF_BRUTEFORCE"),
}}, {
Flag: &cli.BoolFlag{
Name: "detect-nsfw",
Usage: "flag newly added pictures as private if they might be offensive (requires TensorFlow)",
EnvVars: EnvVars("DETECT_NSFW"),
}}, {
Flag: &cli.StringFlag{
Name: "default-locale",
Aliases: []string{"lang"},
@@ -972,6 +960,35 @@ var Flags = CliFlags{
Value: 7680,
EnvVars: EnvVars("PNG_SIZE"),
}}, {
Flag: &cli.StringFlag{
Name: "vision-yaml",
Usage: "computer vision model configuration `FILE`*optional*",
Value: "",
EnvVars: EnvVars("VISION_YAML"),
TakesFile: true,
}}, {
Flag: &cli.BoolFlag{
Name: "vision-api",
Usage: "enable computer vision server API endpoints",
EnvVars: EnvVars("VISION_API"),
}}, {
Flag: &cli.StringFlag{
Name: "vision-uri",
Usage: "remote computer vision endpoint `URI`, e.g. https://example.com/api/v1/vision (leave blank to disable)",
Value: "",
EnvVars: EnvVars("VISION_URI"),
}}, {
Flag: &cli.StringFlag{
Name: "vision-key",
Usage: "remote computer vision endpoint access `TOKEN`*optional*",
Value: "",
EnvVars: EnvVars("VISION_KEY"),
}}, {
Flag: &cli.BoolFlag{
Name: "detect-nsfw",
Usage: "flag newly added pictures as private if they might be offensive (requires TensorFlow)",
EnvVars: EnvVars("DETECT_NSFW"),
}}, {
Flag: &cli.IntFlag{
Name: "face-size",
Usage: "minimum size of faces in `PIXELS` (20-10000)",

View File

@@ -59,7 +59,6 @@ type Options struct {
Sponsor bool `yaml:"-" json:"-" flag:"sponsor"`
ConfigPath string `yaml:"ConfigPath" json:"-" flag:"config-path"`
DefaultsYaml string `json:"-" yaml:"-" flag:"defaults-yaml"`
VisionYaml string `json:"-" yaml:"-" flag:"vision-yaml"`
OriginalsPath string `yaml:"OriginalsPath" json:"-" flag:"originals-path"`
OriginalsLimit int `yaml:"OriginalsLimit" json:"OriginalsLimit" flag:"originals-limit"`
ResolutionLimit int `yaml:"ResolutionLimit" json:"ResolutionLimit" flag:"resolution-limit"`
@@ -114,7 +113,6 @@ type Options struct {
DisableRaw bool `yaml:"DisableRaw" json:"DisableRaw" flag:"disable-raw"`
RawPresets bool `yaml:"RawPresets" json:"RawPresets" flag:"raw-presets"`
ExifBruteForce bool `yaml:"ExifBruteForce" json:"ExifBruteForce" flag:"exif-bruteforce"`
DetectNSFW bool `yaml:"DetectNSFW" json:"DetectNSFW" flag:"detect-nsfw"`
DefaultLocale string `yaml:"DefaultLocale" json:"DefaultLocale" flag:"default-locale"`
DefaultTimezone string `yaml:"DefaultTimezone" json:"DefaultTimezone" flag:"default-timezone"`
DefaultTheme string `yaml:"DefaultTheme" json:"DefaultTheme" flag:"default-theme"`
@@ -194,6 +192,11 @@ type Options struct {
JpegQuality int `yaml:"JpegQuality" json:"JpegQuality" flag:"jpeg-quality"`
JpegSize int `yaml:"JpegSize" json:"JpegSize" flag:"jpeg-size"`
PngSize int `yaml:"PngSize" json:"PngSize" flag:"png-size"`
VisionYaml string `yaml:"VisionYaml" json:"-" flag:"vision-yaml"`
VisionApi bool `yaml:"VisionApi" json:"-" flag:"vision-api"`
VisionUri string `yaml:"VisionUri" json:"-" flag:"vision-uri"`
VisionKey string `yaml:"VisionKey" json:"-" flag:"vision-key"`
DetectNSFW bool `yaml:"DetectNSFW" json:"DetectNSFW" flag:"detect-nsfw"`
FaceSize int `yaml:"-" json:"-" flag:"face-size"`
FaceScore float64 `yaml:"-" json:"-" flag:"face-score"`
FaceOverlap int `yaml:"-" json:"-" flag:"face-overlap"`

View File

@@ -51,7 +51,6 @@ func (c *Config) Report() (rows [][]string, cols []string) {
rows = append(rows, [][]string{
{"settings-yaml", c.SettingsYaml()},
{"vision-yaml", c.VisionYaml()},
// Originals.
{"originals-path", c.OriginalsPath()},
@@ -135,13 +134,6 @@ func (c *Config) Report() (rows [][]string, cols []string) {
{"raw-presets", fmt.Sprintf("%t", c.RawPresets())},
{"exif-bruteforce", fmt.Sprintf("%t", c.ExifBruteForce())},
// Computer Vision.
{"detect-nsfw", fmt.Sprintf("%t", c.DetectNSFW())},
{"nsfw-model-path", c.NSFWModelPath()},
{"nasnet-model-path", c.NasnetModelPath()},
{"facenet-model-path", c.FaceNetModelPath()},
{"tensorflow-version", c.TensorFlowVersion()},
// Customization.
{"default-locale", c.DefaultLocale()},
{"default-timezone", c.DefaultTimezone().String()},
@@ -249,6 +241,17 @@ func (c *Config) Report() (rows [][]string, cols []string) {
{"jpeg-size", fmt.Sprintf("%d", c.JpegSize())},
{"png-size", fmt.Sprintf("%d", c.PngSize())},
// Computer Vision.
{"vision-yaml", c.VisionYaml()},
{"vision-api", fmt.Sprintf("%t", c.VisionApi())},
{"vision-uri", c.VisionUri()},
{"vision-key", strings.Repeat("*", utf8.RuneCountInString(c.VisionKey()))},
{"tensorflow-version", c.TensorFlowVersion()},
{"nasnet-model-path", c.NasnetModelPath()},
{"facenet-model-path", c.FaceNetModelPath()},
{"nsfw-model-path", c.NSFWModelPath()},
{"detect-nsfw", fmt.Sprintf("%t", c.DetectNSFW())},
// Facial Recognition.
{"face-size", fmt.Sprintf("%d", c.FaceSize())},
{"face-score", fmt.Sprintf("%f", c.FaceScore())},

View File

@@ -32,6 +32,7 @@ var OptionsReportSections = []ReportSection{
{Start: "PHOTOPRISM_DOWNLOAD_TOKEN", Title: "Security Tokens"},
{Start: "PHOTOPRISM_THUMB_LIBRARY", Title: "Preview Images"},
{Start: "PHOTOPRISM_JPEG_QUALITY", Title: "Image Quality"},
{Start: "PHOTOPRISM_VISION_YAML", Title: "Computer Vision"},
{Start: "PHOTOPRISM_FACE_SIZE", Title: "Face Recognition",
Info: faceFlagsInfo},
{Start: "PHOTOPRISM_PID_FILENAME", Title: "Daemon Mode",
@@ -56,6 +57,7 @@ var YamlReportSections = []ReportSection{
{Start: "DownloadToken", Title: "Security Tokens"},
{Start: "ThumbLibrary", Title: "Preview Images"},
{Start: "JpegQuality", Title: "Image Quality"},
{Start: "VisionYaml", Title: "Computer Vision"},
{Start: "PIDFilename", Title: "Daemon Mode",
Info: "If you start the server as a *daemon* in the background, you can additionally specify a filename for the log and the process ID:"},
}

View File

@@ -254,6 +254,7 @@ func CliTestContext() *cli.Context {
globalSet.String("darktable-exclude", config.DarktableExclude, "doc")
globalSet.String("sips-exclude", config.SipsExclude, "doc")
globalSet.String("wakeup-interval", "1h34m9s", "doc")
globalSet.Bool("vision-api", config.VisionApi, "doc")
globalSet.Bool("detect-nsfw", config.DetectNSFW, "doc")
globalSet.Bool("debug", false, "doc")
globalSet.Bool("sponsor", true, "doc")
@@ -288,6 +289,7 @@ func CliTestContext() *cli.Context {
LogErr(c.Set("darktable-exclude", "raf, cr3"))
LogErr(c.Set("sips-exclude", "avif, avifs, thm"))
LogErr(c.Set("wakeup-interval", "1h34m9s"))
LogErr(c.Set("vision-api", "true"))
LogErr(c.Set("detect-nsfw", "true"))
LogErr(c.Set("debug", "false"))
LogErr(c.Set("sponsor", "true"))