AI: Add "photoprism vision run" command and vision worker #127 #1090

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-04-11 05:15:14 +02:00
parent 5b8be2f5d4
commit f80acab4c2
17 changed files with 446 additions and 29 deletions

View File

@@ -70,6 +70,7 @@ func (r *ApiResult) IsEmpty() bool {
// CaptionResult represents the result generated by a caption generation model.
type CaptionResult struct {
Text string `yaml:"Text,omitempty" json:"text,omitempty"`
Source string `yaml:"Source,omitempty" json:"source,omitempty"`
Confidence float32 `yaml:"Confidence,omitempty" json:"confidence,omitempty"`
}

View File

@@ -7,6 +7,7 @@ import (
"slices"
"github.com/photoprism/photoprism/internal/api/download"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/media"
@@ -31,6 +32,13 @@ func Caption(imgName string, src media.Src) (result CaptionResult, err error) {
return result, errors.New("invalid image file name")
}
/* TODO: Add support for data URLs to the service.
if file, fileErr := os.Open(imgName); fileErr != nil {
return result, fmt.Errorf("%s (open image file)", err)
} else {
imgUrl = media.DataUrl(file)
} */
dlId, dlErr := download.Register(imgName)
if dlErr != nil {
@@ -57,9 +65,9 @@ func Caption(imgName string, src media.Src) (result CaptionResult, err error) {
Url: imgUrl,
}
if json, _ := apiRequest.MarshalJSON(); len(json) > 0 {
/* if json, _ := apiRequest.MarshalJSON(); len(json) > 0 {
log.Debugf("request: %s", json)
}
} */
apiResponse, apiErr := PerformApiRequest(apiRequest, uri, method, model.EndpointKey())
@@ -69,6 +77,11 @@ func Caption(imgName string, src media.Src) (result CaptionResult, err error) {
return result, errors.New("invalid caption model response")
}
// Set image as the default caption source.
if apiResponse.Result.Caption.Text != "" && apiResponse.Result.Caption.Source == "" {
apiResponse.Result.Caption.Source = entity.SrcImage
}
result = *apiResponse.Result.Caption
} else {
return result, errors.New("invalid caption model configuration")

View File

@@ -1,10 +0,0 @@
package vision
type ModelType = string
const (
ModelTypeLabels ModelType = "labels"
ModelTypeNsfw ModelType = "nsfw"
ModelTypeFaceEmbeddings ModelType = "face/embeddings"
ModelTypeCaption ModelType = "caption"
)

View File

@@ -19,3 +19,18 @@ func TestModel(t *testing.T) {
assert.Equal(t, "", method)
})
}
func TestParseTypes(t *testing.T) {
t.Run("Valid", func(t *testing.T) {
result := ParseTypes("nsfw, labels, Caption")
assert.Equal(t, ModelTypes{"nsfw", "labels", "caption"}, result)
})
t.Run("None", func(t *testing.T) {
result := ParseTypes("")
assert.Equal(t, ModelTypes{}, result)
})
t.Run("Invalid", func(t *testing.T) {
result := ParseTypes("foo, captions")
assert.Equal(t, ModelTypes{}, result)
})
}

View File

@@ -0,0 +1,38 @@
package vision
import (
"slices"
"strings"
)
type ModelType = string
type ModelTypes = []ModelType
const (
ModelTypeLabels ModelType = "labels"
ModelTypeNsfw ModelType = "nsfw"
ModelTypeFaceEmbeddings ModelType = "face/embeddings"
ModelTypeCaption ModelType = "caption"
)
// ParseTypes parses a model type string.
func ParseTypes(s string) (types ModelTypes) {
if s = strings.TrimSpace(s); s == "" {
return ModelTypes{}
}
s = strings.ToLower(s)
types = make(ModelTypes, 0, strings.Count(s, ","))
for _, t := range strings.Split(s, ",") {
t = strings.TrimSpace(t)
switch t {
case ModelTypeLabels, ModelTypeNsfw, ModelTypeFaceEmbeddings, ModelTypeCaption:
if !slices.Contains(types, t) {
types = append(types, t)
}
}
}
return types
}

View File

@@ -0,0 +1,24 @@
package vision
import (
"github.com/photoprism/photoprism/internal/thumb"
)
// Resolution returns the image resolution of the given model type.
func Resolution(modelType ModelType) int {
m := Config.Model(modelType)
if m == nil {
return DefaultResolution
} else if m.Resolution <= 0 {
return DefaultResolution
}
return m.Resolution
}
// Thumb returns the matching thumbnail size for the given model type.
func Thumb(modelType ModelType) (size thumb.Size) {
res := Resolution(modelType)
return thumb.Vision(res)
}

View File

@@ -0,0 +1,43 @@
package vision
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/thumb"
)
func TestResolution(t *testing.T) {
t.Run("Default", func(t *testing.T) {
result := Resolution("invalid")
assert.Equal(t, DefaultResolution, result)
})
t.Run("Facenet", func(t *testing.T) {
result := Resolution(ModelTypeFaceEmbeddings)
assert.Equal(t, FacenetModel.Resolution, result)
})
t.Run("Nasnet", func(t *testing.T) {
result := Resolution(ModelTypeLabels)
assert.Equal(t, 224, result)
})
}
func TestThumb(t *testing.T) {
t.Run("Default", func(t *testing.T) {
size := Thumb("invalid")
assert.Equal(t, thumb.SizeTile224, size)
})
t.Run("Facenet", func(t *testing.T) {
size := Thumb(ModelTypeFaceEmbeddings)
assert.Equal(t, thumb.SizeTile224, size)
})
t.Run("Nasnet", func(t *testing.T) {
size := Thumb(ModelTypeLabels)
assert.Equal(t, thumb.SizeTile224, size)
})
t.Run("Caption", func(t *testing.T) {
size := Thumb(ModelTypeCaption)
assert.Equal(t, thumb.SizeTile224, size)
})
}

View File

@@ -18,7 +18,7 @@ import (
var FindCommand = &cli.Command{
Name: "find",
Usage: "Searches the index for specific files",
ArgsUsage: "filter",
ArgsUsage: "[filter]",
Flags: append(report.CliFlags, &cli.UintFlag{
Name: "count",
Aliases: []string{"n"},

View File

@@ -1,24 +1,41 @@
package commands
import (
"strings"
"github.com/urfave/cli/v2"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/workers"
)
// VisionRunCommand configures the command name, flags, and action.
var VisionRunCommand = &cli.Command{
Name: "run",
Usage: "Runs a computer vision model",
ArgsUsage: "[type]",
Action: visionRunAction,
Hidden: true,
ArgsUsage: "[filter]",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "models",
Aliases: []string{"m"},
// TODO: Add captions to the list once the service can be used from the CLI.
Usage: "model types (labels, nsfw)",
},
&cli.BoolFlag{
Name: "force",
Aliases: []string{"f"},
Usage: "force existing metadata to be updated",
},
},
Action: visionRunAction,
Hidden: true,
}
// visionListAction displays existing user accounts.
func visionRunAction(ctx *cli.Context) error {
return CallWithDependencies(ctx, func(conf *config.Config) error {
log.Error("not implemented")
return nil
worker := workers.NewVision(conf)
return worker.Start(strings.TrimSpace(ctx.Args().First()), vision.ParseTypes(ctx.String("models")), ctx.Bool("force"))
})
}

View File

@@ -7,6 +7,7 @@ var (
BackupWorker = Activity{}
ShareWorker = Activity{}
MetaWorker = Activity{}
VisionWorker = Activity{}
FacesWorker = Activity{}
UpdatePeople = Activity{}
)
@@ -18,11 +19,12 @@ func CancelAll() {
BackupWorker.Cancel()
ShareWorker.Cancel()
MetaWorker.Cancel()
VisionWorker.Cancel()
FacesWorker.Cancel()
UpdatePeople.Cancel()
}
// WorkersRunning checks if a worker is currently running.
func WorkersRunning() bool {
return IndexWorker.Running() || SyncWorker.Running() || BackupWorker.Running() || ShareWorker.Running() || MetaWorker.Running() || FacesWorker.Running()
return IndexWorker.Running() || SyncWorker.Running() || BackupWorker.Running() || ShareWorker.Running() || MetaWorker.Running() || VisionWorker.Running() || FacesWorker.Running()
}

View File

@@ -0,0 +1,31 @@
package photoprism
import (
"time"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media"
)
// Caption returns generated caption for the specified media file.
func (ind *Index) Caption(file *MediaFile) (caption vision.CaptionResult, err error) {
start := time.Now()
size := vision.Thumb(vision.ModelTypeCaption)
// Get thumbnail filenames for the selected sizes.
fileName, fileErr := file.Thumbnail(Config().ThumbCachePath(), size.Name)
if fileErr != nil {
return caption, err
}
// Get matching labels from computer vision model.
if caption, err = vision.Caption(fileName, media.SrcLocal); err != nil {
} else if caption.Text != "" {
log.Infof("vision: generated caption for %s [%s]", clean.Log(file.BaseName()), time.Since(start))
}
return caption, err
}

View File

@@ -29,12 +29,12 @@ func (ind *Index) Faces(jpeg *MediaFile, expected int) face.Faces {
thumbName, err := jpeg.Thumbnail(Config().ThumbCachePath(), thumbSize)
if err != nil {
log.Debugf("index: %s in %s (faces)", err, clean.Log(jpeg.BaseName()))
log.Debugf("vision: %s in %s (detect faces)", err, clean.Log(jpeg.BaseName()))
return face.Faces{}
}
if thumbName == "" {
log.Debugf("index: thumb %s not found in %s (faces)", thumbSize, clean.Log(jpeg.BaseName()))
log.Debugf("vision: thumb %s not found in %s (detect faces)", thumbSize, clean.Log(jpeg.BaseName()))
return face.Faces{}
}
@@ -43,11 +43,11 @@ func (ind *Index) Faces(jpeg *MediaFile, expected int) face.Faces {
faces, err := vision.Faces(thumbName, Config().FaceSize(), true, expected)
if err != nil {
log.Debugf("%s in %s", err, clean.Log(jpeg.BaseName()))
log.Debugf("vision: %s in %s (detect faces)", err, clean.Log(jpeg.BaseName()))
}
if l := len(faces); l > 0 {
log.Infof("index: found %s in %s [%s]", english.Plural(l, "face", "faces"), clean.Log(jpeg.BaseName()), time.Since(start))
log.Infof("vision: found %s in %s [%s]", english.Plural(l, "face", "faces"), clean.Log(jpeg.BaseName()), time.Since(start))
}
return faces

View File

@@ -18,13 +18,13 @@ func (ind *Index) IsNsfw(m *MediaFile) bool {
}
if results, modelErr := vision.Nsfw([]string{filename}, media.SrcLocal); modelErr != nil {
log.Errorf("index: %s in %s (detect nsfw)", modelErr, m.RootRelName())
log.Errorf("vision: %s in %s (detect nsfw)", modelErr, m.RootRelName())
return false
} else if len(results) < 1 {
log.Errorf("index: nsfw model returned no result for %s", m.RootRelName())
log.Errorf("vision: nsfw model returned no result for %s", m.RootRelName())
return false
} else if results[0].IsNsfw(nsfw.ThresholdHigh) {
log.Warnf("index: %s might contain offensive content", clean.Log(m.RelName(Config().OriginalsPath())))
log.Warnf("vision: %s might contain offensive content", clean.Log(m.RelName(Config().OriginalsPath())))
return true
}

View File

@@ -1,6 +1,10 @@
package thumb
import "github.com/photoprism/photoprism/pkg/fs"
import (
"strings"
"github.com/photoprism/photoprism/pkg/fs"
)
// Name represents a thumbnail size name.
type Name string
@@ -73,3 +77,20 @@ func Find(pixels int) (name Name, size Size) {
return "", Size{}
}
// Vision returns a suitable tile size for computer vision applications.
func Vision(resolution int) (size Size) {
for _, size = range All {
if size.Height != size.Width {
continue
} else if !strings.HasPrefix(size.Name.String(), "tile_") {
continue
}
if size.Width >= resolution {
return size
}
}
return SizeTile224
}

View File

@@ -1,5 +1,9 @@
package thumb
import (
"slices"
)
// Default thumbnail size limits (cached and uncached).
var (
SizeCached = SizeFit1920.Width
@@ -34,6 +38,16 @@ func (m SizeMap) All() SizeList {
result = append(result, s)
}
slices.SortStableFunc(result, func(a, b Size) int {
if a.Width < b.Width {
return -1
} else if a.Width > b.Width {
return 1
} else {
return 0
}
})
return result
}
@@ -92,6 +106,9 @@ var Sizes = SizeMap{
Fit7680: SizeFit7680,
}
// All contains all thumbnail sizes sorted by width.
var All = Sizes.All()
func ParseSize(s string) Size {
return Sizes[Name(s)]
}

View File

@@ -81,7 +81,7 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
for _, photo := range photos {
if mutex.MetaWorker.Canceled() {
return errors.New("index: metadata optimization canceled")
return errors.New("index: metadata worker canceled")
}
if done[photo.PhotoUID] {
@@ -106,7 +106,7 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
}
if mutex.MetaWorker.Canceled() {
return errors.New("index: optimization canceled")
return errors.New("index: metadata worker canceled")
}
offset += limit

205
internal/workers/vision.go Normal file
View File

@@ -0,0 +1,205 @@
package workers
import (
"errors"
"fmt"
"path"
"runtime/debug"
"slices"
"strings"
"time"
"github.com/dustin/go-humanize/english"
"github.com/photoprism/photoprism/internal/ai/vision"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/internal/entity/query"
"github.com/photoprism/photoprism/internal/entity/search"
"github.com/photoprism/photoprism/internal/entity/sortby"
"github.com/photoprism/photoprism/internal/form"
"github.com/photoprism/photoprism/internal/mutex"
"github.com/photoprism/photoprism/internal/photoprism"
"github.com/photoprism/photoprism/internal/photoprism/get"
"github.com/photoprism/photoprism/pkg/clean"
)
// Vision represents a computer vision worker.
type Vision struct {
conf *config.Config
}
// NewVision returns a new Vision worker.
func NewVision(conf *config.Config) *Vision {
return &Vision{conf: conf}
}
// originalsPath returns the original media files path as string.
func (w *Vision) originalsPath() string {
return w.conf.OriginalsPath()
}
// Start runs the specified model types for the photos that match the search query.
func (w *Vision) Start(q string, models []string, force bool) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("vision: %s (worker panic)\nstack: %s", r, debug.Stack())
log.Error(err)
}
}()
if err = mutex.VisionWorker.Start(); err != nil {
return err
}
defer mutex.VisionWorker.Stop()
updateLabels := slices.Contains(models, vision.ModelTypeLabels)
updateNsfw := slices.Contains(models, vision.ModelTypeNsfw)
updateCaptions := slices.Contains(models, vision.ModelTypeCaption)
// Refresh index metadata.
if n := len(models); n == 0 {
log.Warnf("vision: no models were specified")
return nil
} else if n == 1 {
log.Infof("vision: running %s model", models[0])
} else {
log.Infof("vision: running %s models", strings.Join(models, " and "))
}
// Check time when worker was last executed.
updateIndex := false
start := time.Now()
done := make(map[string]bool)
limit := 1000
offset := 0
updated := 0
ind := get.Index()
for {
frm := form.SearchPhotos{
Query: strings.TrimSpace(q),
Primary: true,
Merged: false,
Count: limit,
Offset: offset,
Order: sortby.Oldest,
}
photos, _, queryErr := search.Photos(frm)
if queryErr != nil {
return queryErr
}
if len(photos) == 0 {
break
}
for _, photo := range photos {
if mutex.VisionWorker.Canceled() {
return errors.New("vision: worker canceled")
}
if done[photo.PhotoUID] {
continue
}
done[photo.PhotoUID] = true
photoName := path.Join(photo.PhotoPath, photo.PhotoName)
fileName := photoprism.FileName(photo.FileRoot, photo.FileName)
file, fileErr := photoprism.NewMediaFile(fileName)
if fileErr != nil {
log.Errorf("vision: failed to open %s (%s)", photoName, fileErr)
continue
}
m, loadErr := query.PhotoByUID(photo.PhotoUID)
if loadErr != nil {
log.Errorf("vision: failed to load %s (%s)", photoName, loadErr)
continue
}
changed := false
if updateLabels && (len(m.Labels) == 0 || force) {
if labels := ind.Labels(file); len(labels) > 0 {
m.AddLabels(labels)
changed = true
}
}
if updateNsfw && (!photo.PhotoPrivate || force) {
if isNsfw := ind.IsNsfw(file); photo.PhotoPrivate != isNsfw {
photo.PhotoPrivate = isNsfw
changed = true
log.Infof("vision: changed private flag of %s to %t", photoName, photo.PhotoPrivate)
}
}
if updateCaptions && (m.PhotoCaption == "" || force) {
if caption, captionErr := ind.Caption(file); captionErr != nil {
log.Warnf("vision: %s in %s (generate caption)", clean.Error(captionErr), photoName)
} else if caption.Text != "" {
if caption.Source == "" {
caption.Source = entity.SrcImage
}
if (entity.SrcPriority[caption.Source] > entity.SrcPriority[m.CaptionSrc]) || !m.HasCaption() {
m.SetCaption(caption.Text, caption.Source)
changed = true
log.Infof("vision: changed caption of %s to %t", photoName, clean.Log(m.PhotoCaption))
}
}
}
if changed {
if saveErr := m.GenerateAndSaveTitle(); saveErr != nil {
log.Infof("vision: failed to updated %s (%s)", photoName, clean.Error(saveErr))
} else {
updated++
log.Debugf("vision: updated %s", photoName)
}
}
}
if mutex.VisionWorker.Canceled() {
return errors.New("vision: worker canceled")
}
offset += limit
}
if updated > 0 {
log.Infof("vision: updated %s [%s]", english.Plural(updated, "photo", "photos"), time.Since(start))
updateIndex = true
}
// Only update index if photo metadata has changed or the force flag was used.
if updateIndex {
// Run moments worker.
if moments := photoprism.NewMoments(w.conf); moments == nil {
log.Errorf("vision: failed to update moments")
} else if err = moments.Start(); err != nil {
log.Warnf("moments: %s in optimization worker", err)
}
// Update precalculated photo and file counts.
if err = entity.UpdateCounts(); err != nil {
log.Warnf("vision: %s in optimization worker", err)
}
// Update album, subject, and label cover thumbs.
if err = query.UpdateCovers(); err != nil {
log.Warnf("vision: %s in optimization worker", err)
}
}
return nil
}