Modified classify to add custom models

Vision input parameters have also been changed to support the new
parameters needed for the models.
This commit is contained in:
raystlin
2025-04-11 20:01:36 +00:00
parent 8bc7121394
commit 88508679b0
5 changed files with 253 additions and 59 deletions

View File

@@ -15,6 +15,7 @@ import (
tf "github.com/wamuir/graft/tensorflow"
"github.com/photoprism/photoprism/internal/ai/tensorflow"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
@@ -24,21 +25,45 @@ type Model struct {
model *tf.SavedModel
modelPath string
assetsPath string
resolution int
modelTags []string
labels []string
disabled bool
meta *tensorflow.ModelInfo
mutex sync.Mutex
}
// NewModel returns new TensorFlow classification model instance.
func NewModel(assetsPath, modelPath string, resolution int, modelTags []string, disabled bool) *Model {
return &Model{assetsPath: assetsPath, modelPath: modelPath, resolution: resolution, modelTags: modelTags, disabled: disabled}
func NewModel(assetsPath, modelPath string, meta *tensorflow.ModelInfo, disabled bool) *Model {
if meta == nil {
meta = new(tensorflow.ModelInfo)
}
return &Model{
modelPath: modelPath,
assetsPath: assetsPath,
meta: meta,
disabled: disabled,
}
}
// NewNasnet returns new Nasnet TensorFlow classification model instance.
func NewNasnet(assetsPath string, disabled bool) *Model {
return NewModel(assetsPath, "nasnet", 224, []string{"photoprism"}, disabled)
return NewModel(assetsPath, "nasnet", &tensorflow.ModelInfo{
TFVersion: "1.12.0",
Tags: []string{"photoprism"},
Input: &tensorflow.PhotoInput{
Name: "input_1",
Height: 224,
Width: 224,
Channels: 3,
OutputIndex: 0,
},
Output: &tensorflow.ModelOutput{
Name: "predictions/Softmax",
NumOutputs: 1000,
OutputIndex: 0,
OutputsLogits: false,
},
}, disabled)
}
// Init initialises tensorflow models if not disabled
@@ -106,10 +131,10 @@ func (m *Model) Run(img []byte, confidenceThreshold int) (result Labels, err err
// Run inference.
output, err := m.model.Session.Run(
map[tf.Output]*tf.Tensor{
m.model.Graph.Operation("input_1").Output(0): tensor,
m.model.Graph.Operation(m.meta.Input.Name).Output(m.meta.Input.OutputIndex): tensor,
},
[]tf.Output{
m.model.Graph.Operation("predictions/Softmax").Output(0),
m.model.Graph.Operation(m.meta.Output.Name).Output(m.meta.Output.OutputIndex),
},
nil)
@@ -155,7 +180,45 @@ func (m *Model) loadModel() (err error) {
modelPath := path.Join(m.assetsPath, m.modelPath)
m.model, err = tensorflow.SavedModel(modelPath, m.modelTags)
if len(m.meta.Tags) == 0 {
infos, err := tensorflow.GetModelInfo(modelPath)
if err != nil {
log.Errorf("classify: could not get the model info at %s: %v", clean.Log(modelPath), err)
} else if len(infos) == 1 {
log.Debugf("classify: model info: %+v", infos[0])
m.meta.Merge(&infos[0])
} else {
log.Warnf("classify: found %d metagraphs... thats too many", len(infos))
}
}
m.model, err = tensorflow.SavedModel(modelPath, m.meta.Tags)
if err != nil {
return err
}
if !m.meta.IsComplete() {
input, output, err := tensorflow.GetInputAndOutputFromSavedModel(m.model)
if err != nil {
log.Errorf("classify: could not get info from signatures: %v", err)
input, output, err = tensorflow.GuessInputAndOutput(m.model)
if err != nil {
return fmt.Errorf("classify: %w", err)
}
}
m.meta.Merge(&tensorflow.ModelInfo{
Input: input,
Output: output,
})
}
if m.meta.Output.OutputsLogits {
_, err = tensorflow.AddSoftmax(m.model.Graph, m.meta)
if err != nil {
return fmt.Errorf("classify: could not add softmax: %w")
}
}
return m.loadLabels(modelPath)
}
@@ -215,9 +278,9 @@ func (m *Model) createTensor(image []byte) (*tf.Tensor, error) {
}
// Resize the image only if its resolution does not match the model.
if img.Bounds().Dx() != m.resolution || img.Bounds().Dy() != m.resolution {
img = imaging.Fill(img, m.resolution, m.resolution, imaging.Center, imaging.Lanczos)
if img.Bounds().Dx() != m.meta.Input.Resolution() || img.Bounds().Dy() != m.meta.Input.Resolution() {
img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos)
}
return tensorflow.Image(img, m.resolution)
return tensorflow.Image(img, m.meta.Input.Resolution())
}

View File

@@ -13,44 +13,110 @@ import (
// Input description for a photo input for a model
type PhotoInput struct {
Name string
OutputIndex int
Height int64
Width int64
Channels int64
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
Channels int64 `yaml:"Channels,omitempty" json:"channels,omitempty"`
}
// When dimensions are not defined, it means the model accepts any size of
// photo
func (f PhotoInput) IsDynamic() bool {
return f.Height == -1 && f.Width == -1
func (p PhotoInput) IsDynamic() bool {
return p.Height == -1 && p.Width == -1
}
// Get the resolution
func (f PhotoInput) Resolution() int {
return int(f.Height)
func (p PhotoInput) Resolution() int {
return int(p.Height)
}
// Set the resolution: same height and width
func (f *PhotoInput) SetResolution(resolution int) {
f.Height = int64(resolution)
f.Width = int64(resolution)
func (p *PhotoInput) SetResolution(resolution int) {
p.Height = int64(resolution)
p.Width = int64(resolution)
}
// Merge other input with this.
func (p *PhotoInput) Merge(other *PhotoInput) {
if p.Name == "" {
p.Name = other.Name
}
if p.OutputIndex == 0 {
p.OutputIndex = other.OutputIndex
}
if p.Height == 0 {
p.Height = other.Height
}
if p.Width == 0 {
p.Width = other.Width
}
if p.Channels == 0 {
p.Channels = other.Channels
}
}
// The output expected for a model
type ModelOutput struct {
Name string
OutputIndex int
NumOutputs int64
OutputsLogits bool
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
NumOutputs int64 `yaml:"Outputs,omitempty" json:"outputs,omitempty"`
OutputsLogits bool `yaml:"Logits,omitempty" json:"logits,omitempty"`
}
// Merge other output with this
func (m *ModelOutput) Merge(other *ModelOutput) {
if m.Name == "" {
m.Name = other.Name
}
if m.OutputIndex == 0 {
m.OutputIndex = other.OutputIndex
}
if m.NumOutputs == 0 {
m.NumOutputs = other.NumOutputs
}
if !m.OutputsLogits {
m.OutputsLogits = other.OutputsLogits
}
}
// The meta information for the model
type ModelInfo struct {
TFVersion string
Tags []string
Input *PhotoInput
Output *ModelOutput
TFVersion string `yaml:"-" json:"-"`
Tags []string `yaml:"Tags" json:"tags"`
Input *PhotoInput `yaml:"Input" json:"input"`
Output *ModelOutput `yaml:"Output" json:"output"`
}
// Merge other model info. In case of having information
// for a field, the current model will keep its current value
func (m *ModelInfo) Merge(other *ModelInfo) {
if m.TFVersion == "" {
m.TFVersion = other.TFVersion
}
if len(m.Tags) == 0 {
m.Tags = other.Tags
}
if m.Input == nil {
m.Input = other.Input
} else if other.Input != nil {
m.Input.Merge(other.Input)
}
if m.Output == nil {
m.Output = other.Output
} else if other.Output != nil {
m.Output.Merge(other.Output)
}
}
// We consider a model complete if we know its inputs and outputs
@@ -154,7 +220,7 @@ func GetModelInfo(path string) ([]ModelInfo, error) {
}
if err != nil {
log.Printf("Could not get the inputs and outputs from signatures. (TF Version %s): %w", newModel.TFVersion, err)
log.Errorf("Could not get the inputs and outputs from signatures. (TF Version %s): %w", newModel.TFVersion, err)
}
models = append(models, newModel)

View File

@@ -9,6 +9,7 @@ import (
"github.com/photoprism/photoprism/internal/ai/classify"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/internal/ai/tensorflow"
"github.com/photoprism/photoprism/pkg/clean"
)
@@ -16,16 +17,16 @@ var modelMutex = sync.Mutex{}
// Model represents a computer vision model configuration.
type Model struct {
Type ModelType `yaml:"Type,omitempty" json:"type,omitempty"`
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
Uri string `yaml:"Uri,omitempty" json:"-"`
Key string `yaml:"Key,omitempty" json:"-"`
Method string `yaml:"Method,omitempty" json:"-"`
Path string `yaml:"Path,omitempty" json:"-"`
Tags []string `yaml:"Tags,omitempty" json:"-"`
Disabled bool `yaml:"Disabled,omitempty" json:"-"`
Type ModelType `yaml:"Type,omitempty" json:"type,omitempty"`
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
Meta *tensorflow.ModelInfo `yaml:"Meta,omitempty" json:"meta,omitempty"`
Uri string `yaml:"Uri,omitempty" json:"-"`
Key string `yaml:"Key,omitempty" json:"-"`
Method string `yaml:"Method,omitempty" json:"-"`
Path string `yaml:"Path,omitempty" json:"-"`
Disabled bool `yaml:"Disabled,omitempty" json:"-"`
classifyModel *classify.Model
faceModel *face.Model
nsfwModel *nsfw.Model
@@ -96,18 +97,24 @@ func (m *Model) ClassifyModel() *classify.Model {
m.Path = clean.TypeLowerUnderscore(m.Name)
}
if m.Meta == nil {
m.Meta = &tensorflow.ModelInfo{}
}
// Set default thumbnail resolution if no tags are configured.
if m.Resolution <= 0 {
m.Resolution = DefaultResolution
}
} else {
if m.Meta.Input == nil {
m.Meta.Input = new(tensorflow.PhotoInput)
}
// Set default tag if no tags are configured.
if len(m.Tags) == 0 {
m.Tags = []string{"serve"}
m.Meta.Input.SetResolution(m.Resolution)
m.Meta.Input.Channels = 3
}
// Try to load custom model based on the configuration values.
if model := classify.NewModel(AssetsPath, m.Path, m.Resolution, m.Tags, m.Disabled); model == nil {
if model := classify.NewModel(AssetsPath, m.Path, m.Meta, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path)
@@ -138,7 +145,7 @@ func (m *Model) FaceModel() *face.Model {
return nil
case FacenetModel.Name, "facenet":
// Load and initialize the Nasnet image classification model.
if model := face.NewModel(FaceNetModelPath, CachePath, m.Resolution, m.Tags, m.Disabled); model == nil {
if model := face.NewModel(FaceNetModelPath, CachePath, m.Resolution, m.Meta.Tags, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path)
@@ -157,13 +164,17 @@ func (m *Model) FaceModel() *face.Model {
m.Resolution = DefaultResolution
}
if m.Meta == nil {
m.Meta = &tensorflow.ModelInfo{}
}
// Set default tag if no tags are configured.
if len(m.Tags) == 0 {
m.Tags = []string{"serve"}
if len(m.Meta.Tags) == 0 {
m.Meta.Tags = []string{"serve"}
}
// Try to load custom model based on the configuration values.
if model := face.NewModel(filepath.Join(AssetsPath, m.Path), CachePath, m.Resolution, m.Tags, m.Disabled); model == nil {
if model := face.NewModel(filepath.Join(AssetsPath, m.Path), CachePath, m.Resolution, m.Meta.Tags, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path)
@@ -194,7 +205,7 @@ func (m *Model) NsfwModel() *nsfw.Model {
return nil
case NsfwModel.Name, "nsfw":
// Load and initialize the Nasnet image classification model.
if model := nsfw.NewModel(NsfwModelPath, m.Resolution, m.Tags, m.Disabled); model == nil {
if model := nsfw.NewModel(NsfwModelPath, m.Resolution, m.Meta.Tags, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path)
@@ -213,13 +224,17 @@ func (m *Model) NsfwModel() *nsfw.Model {
m.Resolution = DefaultResolution
}
if m.Meta == nil {
m.Meta = &tensorflow.ModelInfo{}
}
// Set default tag if no tags are configured.
if len(m.Tags) == 0 {
m.Tags = []string{"serve"}
if len(m.Meta.Tags) == 0 {
m.Meta.Tags = []string{"serve"}
}
// Try to load custom model based on the configuration values.
if model := nsfw.NewModel(filepath.Join(AssetsPath, m.Path), m.Resolution, m.Tags, m.Disabled); model == nil {
if model := nsfw.NewModel(filepath.Join(AssetsPath, m.Path), m.Resolution, m.Meta.Tags, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path)

View File

@@ -1,5 +1,7 @@
package vision
import "github.com/photoprism/photoprism/internal/ai/tensorflow"
// Default computer vision model configuration.
var (
NasnetModel = &Model{
@@ -7,21 +9,69 @@ var (
Name: "NASNet",
Version: "Mobile",
Resolution: 224,
Tags: []string{"photoprism"},
Meta: &tensorflow.ModelInfo{
TFVersion: "1.12.0",
Tags: []string{"photoprism"},
Input: &tensorflow.PhotoInput{
Name: "input_1",
Height: 224,
Width: 224,
Channels: 3,
OutputIndex: 0,
},
Output: &tensorflow.ModelOutput{
Name: "predictions/Softmax",
NumOutputs: 1000,
OutputIndex: 0,
OutputsLogits: false,
},
},
}
NsfwModel = &Model{
Type: ModelTypeNsfw,
Name: "Nsfw",
Version: "",
Resolution: 224,
Tags: []string{"serve"},
Meta: &tensorflow.ModelInfo{
TFVersion: "1.12.0",
Tags: []string{"serve"},
Input: &tensorflow.PhotoInput{
Name: "input_tensor",
Height: 224,
Width: 224,
Channels: 3,
OutputIndex: 0,
},
Output: &tensorflow.ModelOutput{
Name: "nsfw_cls_model/final_prediction",
NumOutputs: 5,
OutputIndex: 0,
OutputsLogits: false,
},
},
}
FacenetModel = &Model{
Type: ModelTypeFace,
Name: "FaceNet",
Version: "",
Resolution: 160,
Tags: []string{"serve"},
Meta: &tensorflow.ModelInfo{
TFVersion: "1.7.1",
Tags: []string{"serve"},
Input: &tensorflow.PhotoInput{
Name: "input",
Height: 160,
Width: 160,
Channels: 3,
OutputIndex: 0,
},
Output: &tensorflow.ModelOutput{
Name: "embeddings",
NumOutputs: 512,
OutputIndex: 0,
OutputsLogits: false,
},
},
}
CaptionModel = &Model{
Type: ModelTypeCaption,

View File

@@ -45,7 +45,7 @@ func visionListAction(ctx *cli.Context) error {
model.Version,
fmt.Sprintf("%d", model.Resolution),
modelUri,
strings.Join(model.Tags, ", "),
strings.Join(model.Meta.Tags, ", "),
report.Bool(model.Disabled, report.Yes, report.No),
}
}