mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
Modified classify to add custom models
Vision input parameters have also been changed to support the new parameters needed for the models.
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
tf "github.com/wamuir/graft/tensorflow"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/tensorflow"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
@@ -24,21 +25,45 @@ type Model struct {
|
||||
model *tf.SavedModel
|
||||
modelPath string
|
||||
assetsPath string
|
||||
resolution int
|
||||
modelTags []string
|
||||
labels []string
|
||||
disabled bool
|
||||
meta *tensorflow.ModelInfo
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewModel returns new TensorFlow classification model instance.
|
||||
func NewModel(assetsPath, modelPath string, resolution int, modelTags []string, disabled bool) *Model {
|
||||
return &Model{assetsPath: assetsPath, modelPath: modelPath, resolution: resolution, modelTags: modelTags, disabled: disabled}
|
||||
func NewModel(assetsPath, modelPath string, meta *tensorflow.ModelInfo, disabled bool) *Model {
|
||||
if meta == nil {
|
||||
meta = new(tensorflow.ModelInfo)
|
||||
}
|
||||
|
||||
return &Model{
|
||||
modelPath: modelPath,
|
||||
assetsPath: assetsPath,
|
||||
meta: meta,
|
||||
disabled: disabled,
|
||||
}
|
||||
}
|
||||
|
||||
// NewNasnet returns new Nasnet TensorFlow classification model instance.
|
||||
func NewNasnet(assetsPath string, disabled bool) *Model {
|
||||
return NewModel(assetsPath, "nasnet", 224, []string{"photoprism"}, disabled)
|
||||
return NewModel(assetsPath, "nasnet", &tensorflow.ModelInfo{
|
||||
TFVersion: "1.12.0",
|
||||
Tags: []string{"photoprism"},
|
||||
Input: &tensorflow.PhotoInput{
|
||||
Name: "input_1",
|
||||
Height: 224,
|
||||
Width: 224,
|
||||
Channels: 3,
|
||||
OutputIndex: 0,
|
||||
},
|
||||
Output: &tensorflow.ModelOutput{
|
||||
Name: "predictions/Softmax",
|
||||
NumOutputs: 1000,
|
||||
OutputIndex: 0,
|
||||
OutputsLogits: false,
|
||||
},
|
||||
}, disabled)
|
||||
}
|
||||
|
||||
// Init initialises tensorflow models if not disabled
|
||||
@@ -106,10 +131,10 @@ func (m *Model) Run(img []byte, confidenceThreshold int) (result Labels, err err
|
||||
// Run inference.
|
||||
output, err := m.model.Session.Run(
|
||||
map[tf.Output]*tf.Tensor{
|
||||
m.model.Graph.Operation("input_1").Output(0): tensor,
|
||||
m.model.Graph.Operation(m.meta.Input.Name).Output(m.meta.Input.OutputIndex): tensor,
|
||||
},
|
||||
[]tf.Output{
|
||||
m.model.Graph.Operation("predictions/Softmax").Output(0),
|
||||
m.model.Graph.Operation(m.meta.Output.Name).Output(m.meta.Output.OutputIndex),
|
||||
},
|
||||
nil)
|
||||
|
||||
@@ -155,7 +180,45 @@ func (m *Model) loadModel() (err error) {
|
||||
|
||||
modelPath := path.Join(m.assetsPath, m.modelPath)
|
||||
|
||||
m.model, err = tensorflow.SavedModel(modelPath, m.modelTags)
|
||||
if len(m.meta.Tags) == 0 {
|
||||
infos, err := tensorflow.GetModelInfo(modelPath)
|
||||
if err != nil {
|
||||
log.Errorf("classify: could not get the model info at %s: %v", clean.Log(modelPath), err)
|
||||
} else if len(infos) == 1 {
|
||||
log.Debugf("classify: model info: %+v", infos[0])
|
||||
m.meta.Merge(&infos[0])
|
||||
} else {
|
||||
log.Warnf("classify: found %d metagraphs... thats too many", len(infos))
|
||||
}
|
||||
}
|
||||
|
||||
m.model, err = tensorflow.SavedModel(modelPath, m.meta.Tags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !m.meta.IsComplete() {
|
||||
input, output, err := tensorflow.GetInputAndOutputFromSavedModel(m.model)
|
||||
if err != nil {
|
||||
log.Errorf("classify: could not get info from signatures: %v", err)
|
||||
input, output, err = tensorflow.GuessInputAndOutput(m.model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("classify: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.meta.Merge(&tensorflow.ModelInfo{
|
||||
Input: input,
|
||||
Output: output,
|
||||
})
|
||||
}
|
||||
|
||||
if m.meta.Output.OutputsLogits {
|
||||
_, err = tensorflow.AddSoftmax(m.model.Graph, m.meta)
|
||||
if err != nil {
|
||||
return fmt.Errorf("classify: could not add softmax: %w")
|
||||
}
|
||||
}
|
||||
|
||||
return m.loadLabels(modelPath)
|
||||
}
|
||||
@@ -215,9 +278,9 @@ func (m *Model) createTensor(image []byte) (*tf.Tensor, error) {
|
||||
}
|
||||
|
||||
// Resize the image only if its resolution does not match the model.
|
||||
if img.Bounds().Dx() != m.resolution || img.Bounds().Dy() != m.resolution {
|
||||
img = imaging.Fill(img, m.resolution, m.resolution, imaging.Center, imaging.Lanczos)
|
||||
if img.Bounds().Dx() != m.meta.Input.Resolution() || img.Bounds().Dy() != m.meta.Input.Resolution() {
|
||||
img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos)
|
||||
}
|
||||
|
||||
return tensorflow.Image(img, m.resolution)
|
||||
return tensorflow.Image(img, m.meta.Input.Resolution())
|
||||
}
|
||||
|
||||
@@ -13,44 +13,110 @@ import (
|
||||
|
||||
// Input description for a photo input for a model
|
||||
type PhotoInput struct {
|
||||
Name string
|
||||
OutputIndex int
|
||||
Height int64
|
||||
Width int64
|
||||
Channels int64
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
||||
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
|
||||
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
|
||||
Channels int64 `yaml:"Channels,omitempty" json:"channels,omitempty"`
|
||||
}
|
||||
|
||||
// When dimensions are not defined, it means the model accepts any size of
|
||||
// photo
|
||||
func (f PhotoInput) IsDynamic() bool {
|
||||
return f.Height == -1 && f.Width == -1
|
||||
func (p PhotoInput) IsDynamic() bool {
|
||||
return p.Height == -1 && p.Width == -1
|
||||
}
|
||||
|
||||
// Get the resolution
|
||||
func (f PhotoInput) Resolution() int {
|
||||
return int(f.Height)
|
||||
func (p PhotoInput) Resolution() int {
|
||||
return int(p.Height)
|
||||
}
|
||||
|
||||
// Set the resolution: same height and width
|
||||
func (f *PhotoInput) SetResolution(resolution int) {
|
||||
f.Height = int64(resolution)
|
||||
f.Width = int64(resolution)
|
||||
func (p *PhotoInput) SetResolution(resolution int) {
|
||||
p.Height = int64(resolution)
|
||||
p.Width = int64(resolution)
|
||||
}
|
||||
|
||||
// Merge other input with this.
|
||||
func (p *PhotoInput) Merge(other *PhotoInput) {
|
||||
if p.Name == "" {
|
||||
p.Name = other.Name
|
||||
}
|
||||
|
||||
if p.OutputIndex == 0 {
|
||||
p.OutputIndex = other.OutputIndex
|
||||
}
|
||||
|
||||
if p.Height == 0 {
|
||||
p.Height = other.Height
|
||||
}
|
||||
|
||||
if p.Width == 0 {
|
||||
p.Width = other.Width
|
||||
}
|
||||
|
||||
if p.Channels == 0 {
|
||||
p.Channels = other.Channels
|
||||
}
|
||||
}
|
||||
|
||||
// The output expected for a model
|
||||
type ModelOutput struct {
|
||||
Name string
|
||||
OutputIndex int
|
||||
NumOutputs int64
|
||||
OutputsLogits bool
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
||||
NumOutputs int64 `yaml:"Outputs,omitempty" json:"outputs,omitempty"`
|
||||
OutputsLogits bool `yaml:"Logits,omitempty" json:"logits,omitempty"`
|
||||
}
|
||||
|
||||
// Merge other output with this
|
||||
func (m *ModelOutput) Merge(other *ModelOutput) {
|
||||
if m.Name == "" {
|
||||
m.Name = other.Name
|
||||
}
|
||||
|
||||
if m.OutputIndex == 0 {
|
||||
m.OutputIndex = other.OutputIndex
|
||||
}
|
||||
|
||||
if m.NumOutputs == 0 {
|
||||
m.NumOutputs = other.NumOutputs
|
||||
}
|
||||
|
||||
if !m.OutputsLogits {
|
||||
m.OutputsLogits = other.OutputsLogits
|
||||
}
|
||||
}
|
||||
|
||||
// The meta information for the model
|
||||
type ModelInfo struct {
|
||||
TFVersion string
|
||||
Tags []string
|
||||
Input *PhotoInput
|
||||
Output *ModelOutput
|
||||
TFVersion string `yaml:"-" json:"-"`
|
||||
Tags []string `yaml:"Tags" json:"tags"`
|
||||
Input *PhotoInput `yaml:"Input" json:"input"`
|
||||
Output *ModelOutput `yaml:"Output" json:"output"`
|
||||
}
|
||||
|
||||
// Merge other model info. In case of having information
|
||||
// for a field, the current model will keep its current value
|
||||
func (m *ModelInfo) Merge(other *ModelInfo) {
|
||||
if m.TFVersion == "" {
|
||||
m.TFVersion = other.TFVersion
|
||||
}
|
||||
|
||||
if len(m.Tags) == 0 {
|
||||
m.Tags = other.Tags
|
||||
}
|
||||
|
||||
if m.Input == nil {
|
||||
m.Input = other.Input
|
||||
} else if other.Input != nil {
|
||||
m.Input.Merge(other.Input)
|
||||
}
|
||||
|
||||
if m.Output == nil {
|
||||
m.Output = other.Output
|
||||
} else if other.Output != nil {
|
||||
m.Output.Merge(other.Output)
|
||||
}
|
||||
}
|
||||
|
||||
// We consider a model complete if we know its inputs and outputs
|
||||
@@ -154,7 +220,7 @@ func GetModelInfo(path string) ([]ModelInfo, error) {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Could not get the inputs and outputs from signatures. (TF Version %s): %w", newModel.TFVersion, err)
|
||||
log.Errorf("Could not get the inputs and outputs from signatures. (TF Version %s): %w", newModel.TFVersion, err)
|
||||
}
|
||||
|
||||
models = append(models, newModel)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/photoprism/photoprism/internal/ai/classify"
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/ai/tensorflow"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
)
|
||||
|
||||
@@ -16,16 +17,16 @@ var modelMutex = sync.Mutex{}
|
||||
|
||||
// Model represents a computer vision model configuration.
|
||||
type Model struct {
|
||||
Type ModelType `yaml:"Type,omitempty" json:"type,omitempty"`
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
|
||||
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
|
||||
Uri string `yaml:"Uri,omitempty" json:"-"`
|
||||
Key string `yaml:"Key,omitempty" json:"-"`
|
||||
Method string `yaml:"Method,omitempty" json:"-"`
|
||||
Path string `yaml:"Path,omitempty" json:"-"`
|
||||
Tags []string `yaml:"Tags,omitempty" json:"-"`
|
||||
Disabled bool `yaml:"Disabled,omitempty" json:"-"`
|
||||
Type ModelType `yaml:"Type,omitempty" json:"type,omitempty"`
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
|
||||
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
|
||||
Meta *tensorflow.ModelInfo `yaml:"Meta,omitempty" json:"meta,omitempty"`
|
||||
Uri string `yaml:"Uri,omitempty" json:"-"`
|
||||
Key string `yaml:"Key,omitempty" json:"-"`
|
||||
Method string `yaml:"Method,omitempty" json:"-"`
|
||||
Path string `yaml:"Path,omitempty" json:"-"`
|
||||
Disabled bool `yaml:"Disabled,omitempty" json:"-"`
|
||||
classifyModel *classify.Model
|
||||
faceModel *face.Model
|
||||
nsfwModel *nsfw.Model
|
||||
@@ -96,18 +97,24 @@ func (m *Model) ClassifyModel() *classify.Model {
|
||||
m.Path = clean.TypeLowerUnderscore(m.Name)
|
||||
}
|
||||
|
||||
if m.Meta == nil {
|
||||
m.Meta = &tensorflow.ModelInfo{}
|
||||
}
|
||||
|
||||
// Set default thumbnail resolution if no tags are configured.
|
||||
if m.Resolution <= 0 {
|
||||
m.Resolution = DefaultResolution
|
||||
}
|
||||
} else {
|
||||
if m.Meta.Input == nil {
|
||||
m.Meta.Input = new(tensorflow.PhotoInput)
|
||||
}
|
||||
|
||||
// Set default tag if no tags are configured.
|
||||
if len(m.Tags) == 0 {
|
||||
m.Tags = []string{"serve"}
|
||||
m.Meta.Input.SetResolution(m.Resolution)
|
||||
m.Meta.Input.Channels = 3
|
||||
}
|
||||
|
||||
// Try to load custom model based on the configuration values.
|
||||
if model := classify.NewModel(AssetsPath, m.Path, m.Resolution, m.Tags, m.Disabled); model == nil {
|
||||
if model := classify.NewModel(AssetsPath, m.Path, m.Meta, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
@@ -138,7 +145,7 @@ func (m *Model) FaceModel() *face.Model {
|
||||
return nil
|
||||
case FacenetModel.Name, "facenet":
|
||||
// Load and initialize the Nasnet image classification model.
|
||||
if model := face.NewModel(FaceNetModelPath, CachePath, m.Resolution, m.Tags, m.Disabled); model == nil {
|
||||
if model := face.NewModel(FaceNetModelPath, CachePath, m.Resolution, m.Meta.Tags, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
@@ -157,13 +164,17 @@ func (m *Model) FaceModel() *face.Model {
|
||||
m.Resolution = DefaultResolution
|
||||
}
|
||||
|
||||
if m.Meta == nil {
|
||||
m.Meta = &tensorflow.ModelInfo{}
|
||||
}
|
||||
|
||||
// Set default tag if no tags are configured.
|
||||
if len(m.Tags) == 0 {
|
||||
m.Tags = []string{"serve"}
|
||||
if len(m.Meta.Tags) == 0 {
|
||||
m.Meta.Tags = []string{"serve"}
|
||||
}
|
||||
|
||||
// Try to load custom model based on the configuration values.
|
||||
if model := face.NewModel(filepath.Join(AssetsPath, m.Path), CachePath, m.Resolution, m.Tags, m.Disabled); model == nil {
|
||||
if model := face.NewModel(filepath.Join(AssetsPath, m.Path), CachePath, m.Resolution, m.Meta.Tags, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
@@ -194,7 +205,7 @@ func (m *Model) NsfwModel() *nsfw.Model {
|
||||
return nil
|
||||
case NsfwModel.Name, "nsfw":
|
||||
// Load and initialize the Nasnet image classification model.
|
||||
if model := nsfw.NewModel(NsfwModelPath, m.Resolution, m.Tags, m.Disabled); model == nil {
|
||||
if model := nsfw.NewModel(NsfwModelPath, m.Resolution, m.Meta.Tags, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
@@ -213,13 +224,17 @@ func (m *Model) NsfwModel() *nsfw.Model {
|
||||
m.Resolution = DefaultResolution
|
||||
}
|
||||
|
||||
if m.Meta == nil {
|
||||
m.Meta = &tensorflow.ModelInfo{}
|
||||
}
|
||||
|
||||
// Set default tag if no tags are configured.
|
||||
if len(m.Tags) == 0 {
|
||||
m.Tags = []string{"serve"}
|
||||
if len(m.Meta.Tags) == 0 {
|
||||
m.Meta.Tags = []string{"serve"}
|
||||
}
|
||||
|
||||
// Try to load custom model based on the configuration values.
|
||||
if model := nsfw.NewModel(filepath.Join(AssetsPath, m.Path), m.Resolution, m.Tags, m.Disabled); model == nil {
|
||||
if model := nsfw.NewModel(filepath.Join(AssetsPath, m.Path), m.Resolution, m.Meta.Tags, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package vision
|
||||
|
||||
import "github.com/photoprism/photoprism/internal/ai/tensorflow"
|
||||
|
||||
// Default computer vision model configuration.
|
||||
var (
|
||||
NasnetModel = &Model{
|
||||
@@ -7,21 +9,69 @@ var (
|
||||
Name: "NASNet",
|
||||
Version: "Mobile",
|
||||
Resolution: 224,
|
||||
Tags: []string{"photoprism"},
|
||||
Meta: &tensorflow.ModelInfo{
|
||||
TFVersion: "1.12.0",
|
||||
Tags: []string{"photoprism"},
|
||||
Input: &tensorflow.PhotoInput{
|
||||
Name: "input_1",
|
||||
Height: 224,
|
||||
Width: 224,
|
||||
Channels: 3,
|
||||
OutputIndex: 0,
|
||||
},
|
||||
Output: &tensorflow.ModelOutput{
|
||||
Name: "predictions/Softmax",
|
||||
NumOutputs: 1000,
|
||||
OutputIndex: 0,
|
||||
OutputsLogits: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
NsfwModel = &Model{
|
||||
Type: ModelTypeNsfw,
|
||||
Name: "Nsfw",
|
||||
Version: "",
|
||||
Resolution: 224,
|
||||
Tags: []string{"serve"},
|
||||
Meta: &tensorflow.ModelInfo{
|
||||
TFVersion: "1.12.0",
|
||||
Tags: []string{"serve"},
|
||||
Input: &tensorflow.PhotoInput{
|
||||
Name: "input_tensor",
|
||||
Height: 224,
|
||||
Width: 224,
|
||||
Channels: 3,
|
||||
OutputIndex: 0,
|
||||
},
|
||||
Output: &tensorflow.ModelOutput{
|
||||
Name: "nsfw_cls_model/final_prediction",
|
||||
NumOutputs: 5,
|
||||
OutputIndex: 0,
|
||||
OutputsLogits: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
FacenetModel = &Model{
|
||||
Type: ModelTypeFace,
|
||||
Name: "FaceNet",
|
||||
Version: "",
|
||||
Resolution: 160,
|
||||
Tags: []string{"serve"},
|
||||
Meta: &tensorflow.ModelInfo{
|
||||
TFVersion: "1.7.1",
|
||||
Tags: []string{"serve"},
|
||||
Input: &tensorflow.PhotoInput{
|
||||
Name: "input",
|
||||
Height: 160,
|
||||
Width: 160,
|
||||
Channels: 3,
|
||||
OutputIndex: 0,
|
||||
},
|
||||
Output: &tensorflow.ModelOutput{
|
||||
Name: "embeddings",
|
||||
NumOutputs: 512,
|
||||
OutputIndex: 0,
|
||||
OutputsLogits: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
CaptionModel = &Model{
|
||||
Type: ModelTypeCaption,
|
||||
|
||||
@@ -45,7 +45,7 @@ func visionListAction(ctx *cli.Context) error {
|
||||
model.Version,
|
||||
fmt.Sprintf("%d", model.Resolution),
|
||||
modelUri,
|
||||
strings.Join(model.Tags, ", "),
|
||||
strings.Join(model.Meta.Tags, ", "),
|
||||
report.Bool(model.Disabled, report.Yes, report.No),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user