mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
AI: Added some intelligence to label loading.
Now when loading labels internal/ai/tensorflow package will try to look for all the files that match the glob label*.txt and will return the labels that match the expected number. Some models add a first label called background, which is a bias. Also, a new parameter has been added to models to allow a second path to look for the label files. This path is set to nasnet asset on internal/ai/vision.
This commit is contained in:
@@ -22,32 +22,34 @@ import (
|
|||||||
|
|
||||||
// Model represents a TensorFlow classification model.
|
// Model represents a TensorFlow classification model.
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model *tf.SavedModel
|
model *tf.SavedModel
|
||||||
modelPath string
|
modelPath string
|
||||||
assetsPath string
|
assetsPath string
|
||||||
labels []string
|
defaultLabelsPath string
|
||||||
disabled bool
|
labels []string
|
||||||
meta *tensorflow.ModelInfo
|
disabled bool
|
||||||
mutex sync.Mutex
|
meta *tensorflow.ModelInfo
|
||||||
|
mutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewModel returns new TensorFlow classification model instance.
|
// NewModel returns new TensorFlow classification model instance.
|
||||||
func NewModel(assetsPath, modelPath string, meta *tensorflow.ModelInfo, disabled bool) *Model {
|
func NewModel(assetsPath, modelPath, defaultLabelsPath string, meta *tensorflow.ModelInfo, disabled bool) *Model {
|
||||||
if meta == nil {
|
if meta == nil {
|
||||||
meta = new(tensorflow.ModelInfo)
|
meta = new(tensorflow.ModelInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Model{
|
return &Model{
|
||||||
modelPath: modelPath,
|
modelPath: modelPath,
|
||||||
assetsPath: assetsPath,
|
assetsPath: assetsPath,
|
||||||
meta: meta,
|
defaultLabelsPath: defaultLabelsPath,
|
||||||
disabled: disabled,
|
meta: meta,
|
||||||
|
disabled: disabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNasnet returns new Nasnet TensorFlow classification model instance.
|
// NewNasnet returns new Nasnet TensorFlow classification model instance.
|
||||||
func NewNasnet(assetsPath string, disabled bool) *Model {
|
func NewNasnet(assetsPath string, disabled bool) *Model {
|
||||||
return NewModel(assetsPath, "nasnet", &tensorflow.ModelInfo{
|
return NewModel(assetsPath, "nasnet", "", &tensorflow.ModelInfo{
|
||||||
TFVersion: "1.12.0",
|
TFVersion: "1.12.0",
|
||||||
Tags: []string{"photoprism"},
|
Tags: []string{"photoprism"},
|
||||||
Input: &tensorflow.PhotoInput{
|
Input: &tensorflow.PhotoInput{
|
||||||
@@ -159,7 +161,13 @@ func (m *Model) Run(img []byte, confidenceThreshold int) (result Labels, err err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) loadLabels(modelPath string) (err error) {
|
func (m *Model) loadLabels(modelPath string) (err error) {
|
||||||
m.labels, err = tensorflow.LoadLabels(modelPath)
|
numLabels := int(m.meta.Output.NumOutputs)
|
||||||
|
|
||||||
|
m.labels, err = tensorflow.LoadLabels(modelPath, numLabels)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
log.Infof("Model does not seem to have tags at %s, trying %s", modelPath, m.defaultLabelsPath)
|
||||||
|
m.labels, err = tensorflow.LoadLabels(m.defaultLabelsPath, numLabels)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,7 +224,7 @@ func (m *Model) loadModel() (err error) {
|
|||||||
if m.meta.Output.OutputsLogits {
|
if m.meta.Output.OutputsLogits {
|
||||||
_, err = tensorflow.AddSoftmax(m.model.Graph, m.meta)
|
_, err = tensorflow.AddSoftmax(m.model.Graph, m.meta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("classify: could not add softmax: %w")
|
return fmt.Errorf("classify: could not add softmax: %w", nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -170,7 +170,7 @@ func (m *Model) loadModel() error {
|
|||||||
if m.meta.Output.OutputsLogits {
|
if m.meta.Output.OutputsLogits {
|
||||||
_, err = tensorflow.AddSoftmax(m.model.Graph, m.meta)
|
_, err = tensorflow.AddSoftmax(m.model.Graph, m.meta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nsfw: could not add softmax: %w")
|
return fmt.Errorf("nsfw: could not add softmax: %w", nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,7 +178,7 @@ func (m *Model) loadModel() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) loadLabels(modelPath string) (err error) {
|
func (m *Model) loadLabels(modelPath string) (err error) {
|
||||||
m.labels, err = tensorflow.LoadLabels(modelPath)
|
m.labels, err = tensorflow.LoadLabels(modelPath, int(m.meta.Output.NumOutputs))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,19 +2,17 @@ package tensorflow
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"io/fs"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LoadLabels loads the labels of classification models from the specified path and returns them.
|
func loadLabelsFromPath(path string) (labels []string, err error) {
|
||||||
func LoadLabels(modelPath string) (labels []string, err error) {
|
log.Infof("tensorflow: loading model labels from %s", path)
|
||||||
modelLabels := modelPath + "/labels.txt"
|
|
||||||
|
|
||||||
log.Infof("tensorflow: loading model labels from labels.txt")
|
|
||||||
|
|
||||||
f, err := os.Open(modelLabels)
|
|
||||||
|
|
||||||
|
f, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return labels, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
@@ -30,3 +28,34 @@ func LoadLabels(modelPath string) (labels []string, err error) {
|
|||||||
|
|
||||||
return labels, err
|
return labels, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadLabels loads the labels of classification models from the specified path and returns them.
|
||||||
|
func LoadLabels(modelPath string, expectedLabels int) (labels []string, err error) {
|
||||||
|
|
||||||
|
dir := os.DirFS(modelPath)
|
||||||
|
matches, err := fs.Glob(dir, "labels*.txt")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range matches {
|
||||||
|
labels, err := loadLabelsFromPath(filepath.Join(modelPath, matches[i]))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch expectedLabels - len(labels) {
|
||||||
|
case 0:
|
||||||
|
log.Infof("Found a valid labels file: %s", matches[i])
|
||||||
|
return labels, nil
|
||||||
|
case 1:
|
||||||
|
log.Infof("Found a valid labels file %s but we have to add bias", matches[i])
|
||||||
|
|
||||||
|
return append([]string{"background"}, labels...), nil
|
||||||
|
default:
|
||||||
|
log.Infof("File not valid. Expected %d labels and have %d",
|
||||||
|
expectedLabels, len(labels))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, os.ErrNotExist
|
||||||
|
}
|
||||||
|
|||||||
@@ -104,17 +104,18 @@ func (m *Model) ClassifyModel() *classify.Model {
|
|||||||
// Set default thumbnail resolution if no tags are configured.
|
// Set default thumbnail resolution if no tags are configured.
|
||||||
if m.Resolution <= 0 {
|
if m.Resolution <= 0 {
|
||||||
m.Resolution = DefaultResolution
|
m.Resolution = DefaultResolution
|
||||||
} else {
|
|
||||||
if m.Meta.Input == nil {
|
|
||||||
m.Meta.Input = new(tensorflow.PhotoInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.Meta.Input.SetResolution(m.Resolution)
|
|
||||||
m.Meta.Input.Channels = 3
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.Meta.Input == nil {
|
||||||
|
m.Meta.Input = new(tensorflow.PhotoInput)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Meta.Input.SetResolution(m.Resolution)
|
||||||
|
m.Meta.Input.Channels = 3
|
||||||
|
|
||||||
// Try to load custom model based on the configuration values.
|
// Try to load custom model based on the configuration values.
|
||||||
if model := classify.NewModel(AssetsPath, m.Path, m.Meta, m.Disabled); model == nil {
|
defaultPath := filepath.Join(AssetsPath, "nasnet")
|
||||||
|
if model := classify.NewModel(AssetsPath, m.Path, defaultPath, m.Meta, m.Disabled); model == nil {
|
||||||
return nil
|
return nil
|
||||||
} else if err := model.Init(); err != nil {
|
} else if err := model.Init(); err != nil {
|
||||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||||
@@ -222,15 +223,15 @@ func (m *Model) NsfwModel() *nsfw.Model {
|
|||||||
// Set default thumbnail resolution if no tags are configured.
|
// Set default thumbnail resolution if no tags are configured.
|
||||||
if m.Resolution <= 0 {
|
if m.Resolution <= 0 {
|
||||||
m.Resolution = DefaultResolution
|
m.Resolution = DefaultResolution
|
||||||
} else {
|
|
||||||
if m.Meta.Input == nil {
|
|
||||||
m.Meta.Input = new(tensorflow.PhotoInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.Meta.Input.SetResolution(m.Resolution)
|
|
||||||
m.Meta.Input.Channels = 3
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.Meta.Input == nil {
|
||||||
|
m.Meta.Input = new(tensorflow.PhotoInput)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Meta.Input.SetResolution(m.Resolution)
|
||||||
|
m.Meta.Input.Channels = 3
|
||||||
|
|
||||||
if m.Meta == nil {
|
if m.Meta == nil {
|
||||||
m.Meta = &tensorflow.ModelInfo{}
|
m.Meta = &tensorflow.ModelInfo{}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user