AI: Add TensorFlow utility package and improve model loading #127 #1090

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-04-07 05:26:45 +02:00
parent 35e9294d87
commit bfdb839d01
17 changed files with 421 additions and 261 deletions

1
.gitignore vendored
View File

@@ -50,6 +50,7 @@ frontend/coverage/
/assets/nsfw /assets/nsfw
/assets/static/build/ /assets/static/build/
/assets/*net /assets/*net
/assets/vision
/pro /pro
/plus /plus

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 98 KiB

BIN
assets/examples/green.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

View File

@@ -1,14 +1,11 @@
package classify package classify
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"image"
"math" "math"
"os" "os"
"path" "path"
"path/filepath"
"runtime/debug" "runtime/debug"
"sort" "sort"
"strings" "strings"
@@ -17,7 +14,7 @@ import (
"github.com/disintegration/imaging" "github.com/disintegration/imaging"
tf "github.com/wamuir/graft/tensorflow" tf "github.com/wamuir/graft/tensorflow"
"github.com/photoprism/photoprism/pkg/clean" "github.com/photoprism/photoprism/internal/ai/tensorflow"
) )
// Model represents a TensorFlow classification model. // Model represents a TensorFlow classification model.
@@ -82,7 +79,7 @@ func (m *Model) Labels(img []byte, confidenceThreshold int) (result Labels, err
return nil, loadErr return nil, loadErr
} }
// Create tensor from image. // Create input tensor from image.
tensor, err := m.createTensor(img) tensor, err := m.createTensor(img)
if err != nil { if err != nil {
@@ -112,45 +109,26 @@ func (m *Model) Labels(img []byte, confidenceThreshold int) (result Labels, err
if len(result) > 0 { if len(result) > 0 {
log.Tracef("classify: image classified as %+v", result) log.Tracef("classify: image classified as %+v", result)
} else {
result = Labels{}
} }
return result, nil return result, nil
} }
func (m *Model) loadLabels(path string) error { func (m *Model) loadLabels(modelPath string) (err error) {
modelLabels := path + "/labels.txt" m.labels, err = tensorflow.LoadLabels(modelPath)
log.Infof("classify: loading labels from labels.txt")
// Load labels
f, err := os.Open(modelLabels)
if err != nil {
return err return err
} }
defer f.Close()
scanner := bufio.NewScanner(f)
// Labels are separated by newlines
for scanner.Scan() {
m.labels = append(m.labels, scanner.Text())
}
if err := scanner.Err(); err != nil {
return err
}
return nil
}
// ModelLoaded tests if the TensorFlow model is loaded. // ModelLoaded tests if the TensorFlow model is loaded.
func (m *Model) ModelLoaded() bool { func (m *Model) ModelLoaded() bool {
return m.model != nil return m.model != nil
} }
func (m *Model) loadModel() error { func (m *Model) loadModel() (err error) {
// Use mutex to prevent the model from being loaded and
// initialized twice by different indexing workers.
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -160,16 +138,7 @@ func (m *Model) loadModel() error {
modelPath := path.Join(m.assetsPath, m.modelPath) modelPath := path.Join(m.assetsPath, m.modelPath)
log.Infof("classify: loading %s", clean.Log(filepath.Base(modelPath))) m.model, err = tensorflow.SavedModel(modelPath, m.modelTags)
// Load model
model, err := tf.LoadSavedModel(modelPath, m.modelTags, nil)
if err != nil {
return err
}
m.model = model
return m.loadLabels(modelPath) return m.loadLabels(modelPath)
} }
@@ -184,8 +153,10 @@ func (m *Model) bestLabels(probabilities []float32, confidenceThreshold int) Lab
break break
} }
confidence := int(math.Round(float64(p * 100)))
// discard labels with low probabilities // discard labels with low probabilities
if p < 0.1 { if confidence < confidenceThreshold {
continue continue
} }
@@ -204,13 +175,8 @@ func (m *Model) bestLabels(probabilities []float32, confidenceThreshold int) Lab
} }
labelText = strings.TrimSpace(labelText) labelText = strings.TrimSpace(labelText)
confidence := int(math.Round(float64(p * 100)))
if confidence >= confidenceThreshold {
result = append(result, Label{Name: labelText, Source: SrcImage, Uncertainty: 100 - confidence, Priority: rule.Priority, Categories: rule.Categories}) result = append(result, Label{Name: labelText, Source: SrcImage, Uncertainty: 100 - confidence, Priority: rule.Priority, Categories: rule.Categories})
} }
}
// Sort by probability // Sort by probability
sort.Sort(result) sort.Sort(result)
@@ -231,42 +197,7 @@ func (m *Model) createTensor(image []byte) (*tf.Tensor, error) {
return nil, err return nil, err
} }
width, height := m.resolution, m.resolution img = imaging.Fill(img, m.resolution, m.resolution, imaging.Center, imaging.Lanczos)
img = imaging.Fill(img, width, height, imaging.Center, imaging.Lanczos) return tensorflow.Image(img, m.resolution)
return imageToTensor(img, width, height)
}
func imageToTensor(img image.Image, imageHeight, imageWidth int) (tfTensor *tf.Tensor, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("classify: %s (panic)\nstack: %s", r, debug.Stack())
}
}()
if imageHeight <= 0 || imageWidth <= 0 {
return tfTensor, fmt.Errorf("classify: image width and height must be > 0")
}
var tfImage [1][][][3]float32
for j := 0; j < imageHeight; j++ {
tfImage[0] = append(tfImage[0], make([][3]float32, imageWidth))
}
for i := 0; i < imageWidth; i++ {
for j := 0; j < imageHeight; j++ {
r, g, b, _ := img.At(i, j).RGBA()
tfImage[0][j][i][0] = convertValue(r)
tfImage[0][j][i][1] = convertValue(g)
tfImage[0][j][i][2] = convertValue(b)
}
}
return tf.NewTensor(tfImage)
}
func convertValue(value uint32) float32 {
return (float32(value>>8) - float32(127.5)) / float32(127.5)
} }

View File

@@ -6,7 +6,6 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/wamuir/graft/tensorflow"
"github.com/photoprism/photoprism/pkg/fs" "github.com/photoprism/photoprism/pkg/fs"
) )
@@ -31,24 +30,66 @@ func NewModelTest(t *testing.T) *Model {
func TestModel_LabelsFromFile(t *testing.T) { func TestModel_LabelsFromFile(t *testing.T) {
t.Run("chameleon_lime.jpg", func(t *testing.T) { t.Run("chameleon_lime.jpg", func(t *testing.T) {
tensorFlow := NewModelTest(t) tensorFlow := NewModelTest(t)
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10) result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
assert.Nil(t, err) assert.NoError(t, err)
if err != nil {
t.Fatal(err)
}
assert.NotNil(t, result) assert.NotNil(t, result)
assert.IsType(t, Labels{}, result) assert.IsType(t, Labels{}, result)
assert.Equal(t, 1, len(result)) assert.Equal(t, 1, len(result))
t.Log(result) if len(result) > 0 {
t.Logf("result: %#v", result[0])
assert.Equal(t, "chameleon", result[0].Name) assert.Equal(t, "chameleon", result[0].Name)
assert.Equal(t, 7, result[0].Uncertainty) assert.Equal(t, 7, result[0].Uncertainty)
}
})
t.Run("cat_224.jpeg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
result, err := tensorFlow.File(examplesPath+"/cat_224.jpeg", 10)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 1, len(result))
if len(result) > 0 {
assert.Equal(t, "cat", result[0].Name)
assert.Equal(t, 59, result[0].Uncertainty)
}
})
t.Run("cat_720.jpeg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
result, err := tensorFlow.File(examplesPath+"/cat_720.jpeg", 10)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 3, len(result))
// t.Logf("labels: %#v", result)
if len(result) > 0 {
assert.Equal(t, "cat", result[0].Name)
assert.Equal(t, 60, result[0].Uncertainty)
}
})
t.Run("green.jpg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
result, err := tensorFlow.File(examplesPath+"/green.jpg", 10)
t.Logf("labels: %#v", result)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 1, len(result))
if len(result) > 0 {
assert.Equal(t, "outdoor", result[0].Name)
assert.Equal(t, 70, result[0].Uncertainty)
}
}) })
t.Run("not existing file", func(t *testing.T) { t.Run("not existing file", func(t *testing.T) {
tensorFlow := NewModelTest(t) tensorFlow := NewModelTest(t)
@@ -180,11 +221,13 @@ func TestModel_LoadModel(t *testing.T) {
}) })
t.Run("model path does not exist", func(t *testing.T) { t.Run("model path does not exist", func(t *testing.T) {
tensorFlow := NewNasnet(assetsPath+"foo", false) tensorFlow := NewNasnet(assetsPath+"foo", false)
if err := tensorFlow.loadModel(); err != nil { err := tensorFlow.loadModel()
assert.Contains(t, err.Error(), "Could not find SavedModel")
} else { if err != nil {
t.Fatal("err should NOT be nil") assert.Contains(t, err.Error(), "no such file or directory")
} }
assert.Error(t, err)
}) })
} }
@@ -218,35 +261,3 @@ func TestModel_BestLabels(t *testing.T) {
t.Log(result) t.Log(result)
}) })
} }
func TestModel_MakeTensor(t *testing.T) {
t.Run("cat_brown.jpg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
imageBuffer, err := os.ReadFile(examplesPath + "/cat_brown.jpg")
if err != nil {
t.Fatal(err)
}
result, err := tensorFlow.createTensor(imageBuffer)
assert.Equal(t, tensorflow.DataType(0x1), result.DataType())
assert.Equal(t, int64(1), result.Shape()[0])
assert.Equal(t, int64(224), result.Shape()[2])
})
t.Run("Random.docx", func(t *testing.T) {
tensorFlow := NewModelTest(t)
imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx")
assert.Nil(t, err)
result, err := tensorFlow.createTensor(imageBuffer)
assert.Empty(t, result)
assert.EqualError(t, err, "image: unknown format")
})
}
func Test_convertValue(t *testing.T) {
result := convertValue(uint32(98765432))
assert.Equal(t, float32(3024.898), result)
}

View File

@@ -32,7 +32,7 @@ func NewModel(modelPath, cachePath string, disabled bool) *Model {
} }
// Detect runs the detection and facenet algorithms over the provided source image. // Detect runs the detection and facenet algorithms over the provided source image.
func (t *Model) Detect(fileName string, minSize int, cacheCrop bool, expected int) (faces Faces, err error) { func (m *Model) Detect(fileName string, minSize int, cacheCrop bool, expected int) (faces Faces, err error) {
faces, err = Detect(fileName, false, minSize) faces, err = Detect(fileName, false, minSize)
if err != nil { if err != nil {
@@ -40,13 +40,13 @@ func (t *Model) Detect(fileName string, minSize int, cacheCrop bool, expected in
} }
// Skip FaceNet? // Skip FaceNet?
if t.disabled { if m.disabled {
return faces, nil return faces, nil
} else if c := len(faces); c == 0 || expected > 0 && c == expected { } else if c := len(faces); c == 0 || expected > 0 && c == expected {
return faces, nil return faces, nil
} }
err = t.loadModel() err = m.loadModel()
if err != nil { if err != nil {
return faces, err return faces, err
@@ -59,7 +59,7 @@ func (t *Model) Detect(fileName string, minSize int, cacheCrop bool, expected in
if img, imgErr := crop.ImageFromThumb(fileName, f.CropArea(), CropSize, cacheCrop); imgErr != nil { if img, imgErr := crop.ImageFromThumb(fileName, f.CropArea(), CropSize, cacheCrop); imgErr != nil {
log.Errorf("faces: failed to decode image: %s", imgErr) log.Errorf("faces: failed to decode image: %s", imgErr)
} else if embeddings := t.getEmbeddings(img); !embeddings.Empty() { } else if embeddings := m.getEmbeddings(img); !embeddings.Empty() {
faces[i].Embeddings = embeddings faces[i].Embeddings = embeddings
} }
} }
@@ -68,38 +68,40 @@ func (t *Model) Detect(fileName string, minSize int, cacheCrop bool, expected in
} }
// ModelLoaded tests if the TensorFlow model is loaded. // ModelLoaded tests if the TensorFlow model is loaded.
func (t *Model) ModelLoaded() bool { func (m *Model) ModelLoaded() bool {
return t.model != nil return m.model != nil
} }
// loadModel loads the TensorFlow model. // loadModel loads the TensorFlow model.
func (t *Model) loadModel() error { func (m *Model) loadModel() error {
t.mutex.Lock() // Use mutex to prevent the model from being loaded and
defer t.mutex.Unlock() // initialized twice by different indexing workers.
m.mutex.Lock()
defer m.mutex.Unlock()
if t.ModelLoaded() { if m.ModelLoaded() {
return nil return nil
} }
modelPath := path.Join(t.modelPath) modelPath := path.Join(m.modelPath)
log.Infof("faces: loading %s", clean.Log(filepath.Base(modelPath))) log.Infof("faces: loading %s", clean.Log(filepath.Base(modelPath)))
// Load model // Load model
model, err := tf.LoadSavedModel(modelPath, t.modelTags, nil) model, err := tf.LoadSavedModel(modelPath, m.modelTags, nil)
if err != nil { if err != nil {
return err return err
} }
t.model = model m.model = model
return nil return nil
} }
// getEmbeddings returns the face embeddings for an image. // getEmbeddings returns the face embeddings for an image.
func (t *Model) getEmbeddings(img image.Image) Embeddings { func (m *Model) getEmbeddings(img image.Image) Embeddings {
tensor, err := imageToTensor(img, t.resolution) tensor, err := imageToTensor(img, m.resolution)
if err != nil { if err != nil {
log.Errorf("faces: failed to convert image to tensor: %s", err) log.Errorf("faces: failed to convert image to tensor: %s", err)
@@ -109,13 +111,13 @@ func (t *Model) getEmbeddings(img image.Image) Embeddings {
trainPhaseBoolTensor, err := tf.NewTensor(false) trainPhaseBoolTensor, err := tf.NewTensor(false)
output, err := t.model.Session.Run( output, err := m.model.Session.Run(
map[tf.Output]*tf.Tensor{ map[tf.Output]*tf.Tensor{
t.model.Graph.Operation("input").Output(0): tensor, m.model.Graph.Operation("input").Output(0): tensor,
t.model.Graph.Operation("phase_train").Output(0): trainPhaseBoolTensor, m.model.Graph.Operation("phase_train").Output(0): trainPhaseBoolTensor,
}, },
[]tf.Output{ []tf.Output{
t.model.Graph.Operation("embeddings").Output(0), m.model.Graph.Operation("embeddings").Output(0),
}, },
nil) nil)

View File

@@ -1,25 +1,19 @@
package nsfw package nsfw
import ( import (
"bufio"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
tf "github.com/wamuir/graft/tensorflow" tf "github.com/wamuir/graft/tensorflow"
"github.com/wamuir/graft/tensorflow/op"
"github.com/photoprism/photoprism/internal/ai/tensorflow"
"github.com/photoprism/photoprism/pkg/clean" "github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/fs" "github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/media/http/header" "github.com/photoprism/photoprism/pkg/media/http/header"
) )
const (
Mean = float32(117)
Scale = float32(1)
)
// Model uses TensorFlow to label drawing, hentai, neutral, porn and sexy images. // Model uses TensorFlow to label drawing, hentai, neutral, porn and sexy images.
type Model struct { type Model struct {
model *tf.SavedModel model *tf.SavedModel
@@ -36,7 +30,7 @@ func NewModel(modelPath string) *Model {
} }
// File returns matching labels for a jpeg media file. // File returns matching labels for a jpeg media file.
func (t *Model) File(filename string) (result Labels, err error) { func (m *Model) File(filename string) (result Labels, err error) {
if fs.MimeType(filename) != header.ContentTypeJpeg { if fs.MimeType(filename) != header.ContentTypeJpeg {
return result, fmt.Errorf("nsfw: %s is not a jpeg file", clean.Log(filepath.Base(filename))) return result, fmt.Errorf("nsfw: %s is not a jpeg file", clean.Log(filepath.Base(filename)))
} }
@@ -47,29 +41,29 @@ func (t *Model) File(filename string) (result Labels, err error) {
return result, err return result, err
} }
return t.Labels(imageBuffer) return m.Labels(imageBuffer)
} }
// Labels returns matching labels for a jpeg media string. // Labels returns matching labels for a jpeg media string.
func (t *Model) Labels(img []byte) (result Labels, err error) { func (m *Model) Labels(img []byte) (result Labels, err error) {
if loadErr := t.loadModel(); loadErr != nil { if loadErr := m.loadModel(); loadErr != nil {
return result, loadErr return result, loadErr
} }
// Make tensor // Create input tensor from image.
tensor, err := createTensorFromImage(img, "jpeg", t.resolution) input, err := tensorflow.ImageTransform(img, fs.ImageJpeg, m.resolution)
if err != nil { if err != nil {
return result, fmt.Errorf("nsfw: %s", err) return result, fmt.Errorf("nsfw: %s", err)
} }
// Run inference // Run inference.
output, err := t.model.Session.Run( output, err := m.model.Session.Run(
map[tf.Output]*tf.Tensor{ map[tf.Output]*tf.Tensor{
t.model.Graph.Operation("input_tensor").Output(0): tensor, m.model.Graph.Operation("input_tensor").Output(0): input,
}, },
[]tf.Output{ []tf.Output{
t.model.Graph.Operation("nsfw_cls_model/final_prediction").Output(0), m.model.Graph.Operation("nsfw_cls_model/final_prediction").Output(0),
}, },
nil) nil)
@@ -81,66 +75,45 @@ func (t *Model) Labels(img []byte) (result Labels, err error) {
return result, fmt.Errorf("nsfw: inference failed, no output") return result, fmt.Errorf("nsfw: inference failed, no output")
} }
// Return best labels // Return best labels.
result = t.getLabels(output[0].Value().([][]float32)[0]) result = m.getLabels(output[0].Value().([][]float32)[0])
log.Tracef("nsfw: image classified as %+v", result) log.Tracef("nsfw: image classified as %+v", result)
return result, nil return result, nil
} }
func (t *Model) loadLabels(path string) error { func (m *Model) loadLabels(modelPath string) (err error) {
modelLabels := path + "/labels.txt" m.labels, err = tensorflow.LoadLabels(modelPath)
log.Infof("nsfw: loading labels from labels.txt")
// Load labels
f, err := os.Open(modelLabels)
if err != nil {
return err
}
defer f.Close()
scanner := bufio.NewScanner(f)
// Labels are separated by newlines
for scanner.Scan() {
t.labels = append(t.labels, scanner.Text())
}
if err := scanner.Err(); err != nil {
return err
}
return nil return nil
} }
func (t *Model) loadModel() error { func (m *Model) loadModel() error {
t.mutex.Lock() // Use mutex to prevent the model from being loaded and
defer t.mutex.Unlock() // initialized twice by different indexing workers.
m.mutex.Lock()
defer m.mutex.Unlock()
if t.model != nil { if m.model != nil {
// Already loaded // Already loaded
return nil return nil
} }
log.Infof("nsfw: loading %s", clean.Log(filepath.Base(t.modelPath))) log.Infof("nsfw: loading %s", clean.Log(filepath.Base(m.modelPath)))
// Load model // Load saved TensorFlow model from the specified path.
model, err := tf.LoadSavedModel(t.modelPath, t.modelTags, nil) model, err := tensorflow.SavedModel(m.modelPath, m.modelTags)
if err != nil { if err != nil {
return err return err
} }
t.model = model m.model = model
return t.loadLabels(t.modelPath) return m.loadLabels(m.modelPath)
} }
func (t *Model) getLabels(p []float32) Labels { func (m *Model) getLabels(p []float32) Labels {
return Labels{ return Labels{
Drawing: p[0], Drawing: p[0],
Hentai: p[1], Hentai: p[1],
@@ -149,56 +122,3 @@ func (t *Model) getLabels(p []float32) Labels {
Sexy: p[4], Sexy: p[4],
} }
} }
func transformImageGraph(imageFormat string, resolution int) (graph *tf.Graph, input, output tf.Output, err error) {
var H, W = int32(resolution), int32(resolution)
s := op.NewScope()
input = op.Placeholder(s, tf.String)
// Decode PNG or JPEG
var decode tf.Output
if imageFormat == "png" {
decode = op.DecodePng(s, input, op.DecodePngChannels(3))
} else {
decode = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3))
}
// Div and Sub perform (value-Mean)/Scale for each pixel
output = op.Div(s,
op.Sub(s,
// Resize to 224x224 with bilinear interpolation
op.ResizeBilinear(s,
// Create a batch containing a single image
op.ExpandDims(s,
// Use decoded pixel values
op.Cast(s, decode, tf.Float),
op.Const(s.SubScope("make_batch"), int32(0))),
op.Const(s.SubScope("size"), []int32{H, W})),
op.Const(s.SubScope("mean"), Mean)),
op.Const(s.SubScope("scale"), Scale))
graph, err = s.Finalize()
return graph, input, output, err
}
func createTensorFromImage(image []byte, imageFormat string, resolution int) (*tf.Tensor, error) {
tensor, err := tf.NewTensor(string(image))
if err != nil {
return nil, err
}
graph, input, output, err := transformImageGraph(imageFormat, resolution)
if err != nil {
return nil, err
}
session, err := tf.NewSession(graph, nil)
if err != nil {
return nil, err
}
defer session.Close()
normalized, err := session.Run(
map[tf.Output]*tf.Tensor{input: tensor},
[]tf.Output{output},
nil)
if err != nil {
return nil, err
}
return normalized[0], nil
}

View File

@@ -0,0 +1,145 @@
package tensorflow
import (
"bytes"
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"os"
"runtime/debug"
tf "github.com/wamuir/graft/tensorflow"
"github.com/wamuir/graft/tensorflow/op"
"github.com/photoprism/photoprism/pkg/fs"
)
const (
Mean = float32(117)
Scale = float32(1)
)
func ImageFromFile(fileName string, resolution int) (*tf.Tensor, error) {
if img, err := OpenImage(fileName); err != nil {
return nil, err
} else {
return Image(img, resolution)
}
}
func OpenImage(fileName string) (image.Image, error) {
f, err := os.Open(fileName)
if err != nil {
return nil, err
}
defer f.Close()
img, _, err := image.Decode(f)
return img, err
}
func ImageFromBytes(b []byte, resolution int) (*tf.Tensor, error) {
img, _, imgErr := image.Decode(bytes.NewReader(b))
if imgErr != nil {
return nil, imgErr
}
return Image(img, resolution)
}
func Image(img image.Image, resolution int) (tfTensor *tf.Tensor, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("tensorflow: %s (panic)\nstack: %s", r, debug.Stack())
}
}()
if resolution <= 0 {
return tfTensor, fmt.Errorf("tensorflow: resolution must be larger 0")
}
var tfImage [1][][][3]float32
for j := 0; j < resolution; j++ {
tfImage[0] = append(tfImage[0], make([][3]float32, resolution))
}
for i := 0; i < resolution; i++ {
for j := 0; j < resolution; j++ {
r, g, b, _ := img.At(i, j).RGBA()
tfImage[0][j][i][0] = convertValue(r, 127.5)
tfImage[0][j][i][1] = convertValue(g, 127.5)
tfImage[0][j][i][2] = convertValue(b, 127.5)
}
}
return tf.NewTensor(tfImage)
}
// ImageTransform transforms the given image into a *tf.Tensor and returns it.
func ImageTransform(image []byte, imageFormat fs.Type, resolution int) (*tf.Tensor, error) {
tensor, err := tf.NewTensor(string(image))
if err != nil {
return nil, err
}
graph, input, output, err := transformImageGraph(imageFormat, resolution)
if err != nil {
return nil, err
}
session, err := tf.NewSession(graph, nil)
if err != nil {
return nil, err
}
defer session.Close()
normalized, err := session.Run(
map[tf.Output]*tf.Tensor{input: tensor},
[]tf.Output{output},
nil)
if err != nil {
return nil, err
}
return normalized[0], nil
}
func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph, input, output tf.Output, err error) {
s := op.NewScope()
input = op.Placeholder(s, tf.String)
// Assume the image is a JPEG, or a PNG if explicitly specified.
var decodedImage tf.Output
switch imageFormat {
case fs.ImagePng:
decodedImage = op.DecodePng(s, input, op.DecodePngChannels(3))
default:
decodedImage = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3))
}
output = op.Div(s,
op.Sub(s,
op.ResizeBilinear(s,
op.ExpandDims(s,
op.Cast(s, decodedImage, tf.Float),
op.Const(s.SubScope("make_batch"), int32(0))),
op.Const(s.SubScope("size"), []int32{int32(resolution), int32(resolution)})),
op.Const(s.SubScope("mean"), Mean)),
op.Const(s.SubScope("scale"), Scale))
graph, err = s.Finalize()
return graph, input, output, err
}
func convertValue(value uint32, mean float32) float32 {
if mean == 0 {
mean = 127.5
}
return (float32(value>>8) - mean) / mean
}

View File

@@ -0,0 +1,42 @@
package tensorflow
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/wamuir/graft/tensorflow"
"github.com/photoprism/photoprism/pkg/fs"
)
func TestConvertValue(t *testing.T) {
result := convertValue(uint32(98765432), 127.5)
assert.Equal(t, float32(3024.898), result)
}
func TestImageFromBytes(t *testing.T) {
var assetsPath = fs.Abs("../../../assets")
var examplesPath = assetsPath + "/examples"
t.Run("CatJpeg", func(t *testing.T) {
imageBuffer, err := os.ReadFile(examplesPath + "/cat_brown.jpg")
if err != nil {
t.Fatal(err)
}
result, err := ImageFromBytes(imageBuffer, 224)
assert.Equal(t, tensorflow.DataType(0x1), result.DataType())
assert.Equal(t, int64(1), result.Shape()[0])
assert.Equal(t, int64(224), result.Shape()[2])
})
t.Run("Document", func(t *testing.T) {
imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx")
assert.Nil(t, err)
result, err := ImageFromBytes(imageBuffer, 224)
assert.Empty(t, result)
assert.EqualError(t, err, "image: unknown format")
})
}

View File

@@ -0,0 +1,32 @@
package tensorflow
import (
"bufio"
"os"
)
// LoadLabels loads the labels of classification models from the specified path and returns them.
func LoadLabels(modelPath string) (labels []string, err error) {
modelLabels := modelPath + "/labels.txt"
log.Infof("tensorflow: loading model labels from labels.txt")
f, err := os.Open(modelLabels)
if err != nil {
return labels, err
}
defer f.Close()
scanner := bufio.NewScanner(f)
// Labels are separated by newlines
for scanner.Scan() {
labels = append(labels, scanner.Text())
}
err = scanner.Err()
return labels, err
}

View File

@@ -0,0 +1,20 @@
package tensorflow
import (
"path/filepath"
tf "github.com/wamuir/graft/tensorflow"
"github.com/photoprism/photoprism/pkg/clean"
)
// SavedModel loads a saved TensorFlow model from the specified path.
func SavedModel(modelPath string, tags []string) (model *tf.SavedModel, err error) {
log.Infof("tensorflow: loading %s", clean.Log(filepath.Base(modelPath)))
if len(tags) == 0 {
tags = []string{"serve"}
}
return tf.LoadSavedModel(modelPath, tags, nil)
}

View File

@@ -0,0 +1,31 @@
/*
Package tensorflow provides TensorFlow utility functions.
Copyright (c) 2018 - 2025 PhotoPrism UG. All rights reserved.
This program is free software: you can redistribute it and/or modify
it under Version 3 of the GNU Affero General Public License (the "AGPL"):
<https://docs.photoprism.app/license/agpl>
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
The AGPL is supplemented by our Trademark and Brand Guidelines,
which describe how our Brand Assets may be used:
<https://www.photoprism.app/trademark>
Feel free to send an email to hello@photoprism.app if you have questions,
want to support our work, or just want to say hello.
Additional information can be found in our Developer Guide:
<https://docs.photoprism.app/developer-guide/>
*/
package tensorflow
import (
"github.com/photoprism/photoprism/internal/event"
)
var log = event.Log

View File

@@ -51,7 +51,7 @@ func Labels(thumbnails []string) (result classify.Labels, err error) {
} }
if !found { if !found {
result = append(result, labels...) result = append(result, labels[j])
} }
} }
} }

View File

@@ -25,6 +25,19 @@ func TestLabels(t *testing.T) {
assert.Equal(t, "chameleon", result[0].Name) assert.Equal(t, "chameleon", result[0].Name)
assert.Equal(t, 7, result[0].Uncertainty) assert.Equal(t, 7, result[0].Uncertainty)
}) })
t.Run("Cats", func(t *testing.T) {
result, err := Labels([]string{examplesPath + "/cat_720.jpeg"})
assert.NoError(t, err)
assert.IsType(t, classify.Labels{}, result)
assert.Equal(t, 1, len(result))
t.Log(result)
assert.Equal(t, "cat", result[0].Name)
assert.Equal(t, 60, result[0].Uncertainty)
assert.Equal(t, 40, result[0].Confidence())
})
t.Run("InvalidFile", func(t *testing.T) { t.Run("InvalidFile", func(t *testing.T) {
_, err := Labels([]string{examplesPath + "/notexisting.jpg"}) _, err := Labels([]string{examplesPath + "/notexisting.jpg"})
assert.Error(t, err) assert.Error(t, err)

View File

@@ -28,9 +28,12 @@ type Models []*Model
// ClassifyModel returns the matching classify model instance, if any. // ClassifyModel returns the matching classify model instance, if any.
func (m *Model) ClassifyModel() *classify.Model { func (m *Model) ClassifyModel() *classify.Model {
// Use mutex to prevent models from being loaded and
// initialized twice by different indexing workers.
modelMutex.Lock() modelMutex.Lock()
defer modelMutex.Unlock() defer modelMutex.Unlock()
// Return the existing model instance if it has already been created.
if m.classifyModel != nil { if m.classifyModel != nil {
return m.classifyModel return m.classifyModel
} }
@@ -40,6 +43,7 @@ func (m *Model) ClassifyModel() *classify.Model {
log.Warnf("vision: missing name, model instance cannot be created") log.Warnf("vision: missing name, model instance cannot be created")
return nil return nil
case NasnetModel.Name, "nasnet": case NasnetModel.Name, "nasnet":
// Load and initialize the Nasnet image classification model.
if model := classify.NewNasnet(AssetsPath, m.Disabled); model == nil { if model := classify.NewNasnet(AssetsPath, m.Disabled); model == nil {
return nil return nil
} else if err := model.Init(); err != nil { } else if err := model.Init(); err != nil {
@@ -49,14 +53,22 @@ func (m *Model) ClassifyModel() *classify.Model {
m.classifyModel = model m.classifyModel = model
} }
default: default:
// Set model path from model name if no path is configured.
if m.Path == "" { if m.Path == "" {
m.Path = clean.TypeLowerUnderscore(m.Name) m.Path = clean.TypeLowerUnderscore(m.Name)
} }
// Set default thumbnail resolution if no tags are configured.
if m.Resolution <= 0 { if m.Resolution <= 0 {
m.Resolution = DefaultResolution m.Resolution = DefaultResolution
} }
// Set default tag if no tags are configured.
if len(m.Tags) == 0 {
m.Tags = []string{"serve"}
}
// Try to load custom model based on the configuration values.
if model := classify.NewModel(AssetsPath, m.Path, m.Resolution, m.Tags, m.Disabled); model == nil { if model := classify.NewModel(AssetsPath, m.Path, m.Resolution, m.Tags, m.Disabled); model == nil {
return nil return nil
} else if err := model.Init(); err != nil { } else if err := model.Init(); err != nil {

View File

@@ -12,5 +12,5 @@ var DefaultResolution = 224
// NasnetModel is a standard TensorFlow model used for label generation. // NasnetModel is a standard TensorFlow model used for label generation.
var ( var (
NasnetModel = &Model{Name: "Nasnet", Resolution: 224, Tags: []string{"photoprism"}} NasnetModel = &Model{Name: "Nasnet", Version: "Mobile", Resolution: 224, Tags: []string{"photoprism"}}
) )