AI: Added some intelligence to label loading.

Now when loading labels internal/ai/tensorflow package will try to look
for all the files that match the glob label*.txt and will return the
labels that match the expected number. Some models add a first label
called background, which is a bias.
Also, a new parameter has been added to models to allow a second path to
look for the label files. This path is set to nasnet asset on
internal/ai/vision.
This commit is contained in:
raystlin
2025-04-13 14:49:54 +00:00
parent 0895a085a1
commit d993eb2a85
4 changed files with 78 additions and 40 deletions

View File

@@ -25,6 +25,7 @@ type Model struct {
model *tf.SavedModel model *tf.SavedModel
modelPath string modelPath string
assetsPath string assetsPath string
defaultLabelsPath string
labels []string labels []string
disabled bool disabled bool
meta *tensorflow.ModelInfo meta *tensorflow.ModelInfo
@@ -32,7 +33,7 @@ type Model struct {
} }
// NewModel returns new TensorFlow classification model instance. // NewModel returns new TensorFlow classification model instance.
func NewModel(assetsPath, modelPath string, meta *tensorflow.ModelInfo, disabled bool) *Model { func NewModel(assetsPath, modelPath, defaultLabelsPath string, meta *tensorflow.ModelInfo, disabled bool) *Model {
if meta == nil { if meta == nil {
meta = new(tensorflow.ModelInfo) meta = new(tensorflow.ModelInfo)
} }
@@ -40,6 +41,7 @@ func NewModel(assetsPath, modelPath string, meta *tensorflow.ModelInfo, disabled
return &Model{ return &Model{
modelPath: modelPath, modelPath: modelPath,
assetsPath: assetsPath, assetsPath: assetsPath,
defaultLabelsPath: defaultLabelsPath,
meta: meta, meta: meta,
disabled: disabled, disabled: disabled,
} }
@@ -47,7 +49,7 @@ func NewModel(assetsPath, modelPath string, meta *tensorflow.ModelInfo, disabled
// NewNasnet returns new Nasnet TensorFlow classification model instance. // NewNasnet returns new Nasnet TensorFlow classification model instance.
func NewNasnet(assetsPath string, disabled bool) *Model { func NewNasnet(assetsPath string, disabled bool) *Model {
return NewModel(assetsPath, "nasnet", &tensorflow.ModelInfo{ return NewModel(assetsPath, "nasnet", "", &tensorflow.ModelInfo{
TFVersion: "1.12.0", TFVersion: "1.12.0",
Tags: []string{"photoprism"}, Tags: []string{"photoprism"},
Input: &tensorflow.PhotoInput{ Input: &tensorflow.PhotoInput{
@@ -159,7 +161,13 @@ func (m *Model) Run(img []byte, confidenceThreshold int) (result Labels, err err
} }
func (m *Model) loadLabels(modelPath string) (err error) { func (m *Model) loadLabels(modelPath string) (err error) {
m.labels, err = tensorflow.LoadLabels(modelPath) numLabels := int(m.meta.Output.NumOutputs)
m.labels, err = tensorflow.LoadLabels(modelPath, numLabels)
if os.IsNotExist(err) {
log.Infof("Model does not seem to have tags at %s, trying %s", modelPath, m.defaultLabelsPath)
m.labels, err = tensorflow.LoadLabels(m.defaultLabelsPath, numLabels)
}
return err return err
} }
@@ -216,7 +224,7 @@ func (m *Model) loadModel() (err error) {
if m.meta.Output.OutputsLogits { if m.meta.Output.OutputsLogits {
_, err = tensorflow.AddSoftmax(m.model.Graph, m.meta) _, err = tensorflow.AddSoftmax(m.model.Graph, m.meta)
if err != nil { if err != nil {
return fmt.Errorf("classify: could not add softmax: %w") return fmt.Errorf("classify: could not add softmax: %w", nil)
} }
} }

View File

@@ -170,7 +170,7 @@ func (m *Model) loadModel() error {
if m.meta.Output.OutputsLogits { if m.meta.Output.OutputsLogits {
_, err = tensorflow.AddSoftmax(m.model.Graph, m.meta) _, err = tensorflow.AddSoftmax(m.model.Graph, m.meta)
if err != nil { if err != nil {
return fmt.Errorf("nsfw: could not add softmax: %w") return fmt.Errorf("nsfw: could not add softmax: %w", nil)
} }
} }
@@ -178,7 +178,7 @@ func (m *Model) loadModel() error {
} }
func (m *Model) loadLabels(modelPath string) (err error) { func (m *Model) loadLabels(modelPath string) (err error) {
m.labels, err = tensorflow.LoadLabels(modelPath) m.labels, err = tensorflow.LoadLabels(modelPath, int(m.meta.Output.NumOutputs))
return nil return nil
} }

View File

@@ -2,19 +2,17 @@ package tensorflow
import ( import (
"bufio" "bufio"
"io/fs"
"os" "os"
"path/filepath"
) )
// LoadLabels loads the labels of classification models from the specified path and returns them. func loadLabelsFromPath(path string) (labels []string, err error) {
func LoadLabels(modelPath string) (labels []string, err error) { log.Infof("tensorflow: loading model labels from %s", path)
modelLabels := modelPath + "/labels.txt"
log.Infof("tensorflow: loading model labels from labels.txt")
f, err := os.Open(modelLabels)
f, err := os.Open(path)
if err != nil { if err != nil {
return labels, err return nil, err
} }
defer f.Close() defer f.Close()
@@ -30,3 +28,34 @@ func LoadLabels(modelPath string) (labels []string, err error) {
return labels, err return labels, err
} }
// LoadLabels loads the labels of classification models from the specified path and returns them.
func LoadLabels(modelPath string, expectedLabels int) (labels []string, err error) {
dir := os.DirFS(modelPath)
matches, err := fs.Glob(dir, "labels*.txt")
if err != nil {
return nil, err
}
for i := range matches {
labels, err := loadLabelsFromPath(filepath.Join(modelPath, matches[i]))
if err != nil {
return nil, err
}
switch expectedLabels - len(labels) {
case 0:
log.Infof("Found a valid labels file: %s", matches[i])
return labels, nil
case 1:
log.Infof("Found a valid labels file %s but we have to add bias", matches[i])
return append([]string{"background"}, labels...), nil
default:
log.Infof("File not valid. Expected %d labels and have %d",
expectedLabels, len(labels))
}
}
return nil, os.ErrNotExist
}

View File

@@ -104,17 +104,18 @@ func (m *Model) ClassifyModel() *classify.Model {
// Set default thumbnail resolution if no tags are configured. // Set default thumbnail resolution if no tags are configured.
if m.Resolution <= 0 { if m.Resolution <= 0 {
m.Resolution = DefaultResolution m.Resolution = DefaultResolution
} else { }
if m.Meta.Input == nil { if m.Meta.Input == nil {
m.Meta.Input = new(tensorflow.PhotoInput) m.Meta.Input = new(tensorflow.PhotoInput)
} }
m.Meta.Input.SetResolution(m.Resolution) m.Meta.Input.SetResolution(m.Resolution)
m.Meta.Input.Channels = 3 m.Meta.Input.Channels = 3
}
// Try to load custom model based on the configuration values. // Try to load custom model based on the configuration values.
if model := classify.NewModel(AssetsPath, m.Path, m.Meta, m.Disabled); model == nil { defaultPath := filepath.Join(AssetsPath, "nasnet")
if model := classify.NewModel(AssetsPath, m.Path, defaultPath, m.Meta, m.Disabled); model == nil {
return nil return nil
} else if err := model.Init(); err != nil { } else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path) log.Errorf("vision: %s (init %s)", err, m.Path)
@@ -222,14 +223,14 @@ func (m *Model) NsfwModel() *nsfw.Model {
// Set default thumbnail resolution if no tags are configured. // Set default thumbnail resolution if no tags are configured.
if m.Resolution <= 0 { if m.Resolution <= 0 {
m.Resolution = DefaultResolution m.Resolution = DefaultResolution
} else { }
if m.Meta.Input == nil { if m.Meta.Input == nil {
m.Meta.Input = new(tensorflow.PhotoInput) m.Meta.Input = new(tensorflow.PhotoInput)
} }
m.Meta.Input.SetResolution(m.Resolution) m.Meta.Input.SetResolution(m.Resolution)
m.Meta.Input.Channels = 3 m.Meta.Input.Channels = 3
}
if m.Meta == nil { if m.Meta == nil {
m.Meta = &tensorflow.ModelInfo{} m.Meta = &tensorflow.ModelInfo{}