mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
Modified classify to add custom models
Vision input parameters have also been changed to support the new parameters needed for the models.
This commit is contained in:
@@ -13,44 +13,110 @@ import (
|
||||
|
||||
// Input description for a photo input for a model
|
||||
type PhotoInput struct {
|
||||
Name string
|
||||
OutputIndex int
|
||||
Height int64
|
||||
Width int64
|
||||
Channels int64
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
||||
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
|
||||
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
|
||||
Channels int64 `yaml:"Channels,omitempty" json:"channels,omitempty"`
|
||||
}
|
||||
|
||||
// When dimensions are not defined, it means the model accepts any size of
|
||||
// photo
|
||||
func (f PhotoInput) IsDynamic() bool {
|
||||
return f.Height == -1 && f.Width == -1
|
||||
func (p PhotoInput) IsDynamic() bool {
|
||||
return p.Height == -1 && p.Width == -1
|
||||
}
|
||||
|
||||
// Get the resolution
|
||||
func (f PhotoInput) Resolution() int {
|
||||
return int(f.Height)
|
||||
func (p PhotoInput) Resolution() int {
|
||||
return int(p.Height)
|
||||
}
|
||||
|
||||
// Set the resolution: same height and width
|
||||
func (f *PhotoInput) SetResolution(resolution int) {
|
||||
f.Height = int64(resolution)
|
||||
f.Width = int64(resolution)
|
||||
func (p *PhotoInput) SetResolution(resolution int) {
|
||||
p.Height = int64(resolution)
|
||||
p.Width = int64(resolution)
|
||||
}
|
||||
|
||||
// Merge other input with this.
|
||||
func (p *PhotoInput) Merge(other *PhotoInput) {
|
||||
if p.Name == "" {
|
||||
p.Name = other.Name
|
||||
}
|
||||
|
||||
if p.OutputIndex == 0 {
|
||||
p.OutputIndex = other.OutputIndex
|
||||
}
|
||||
|
||||
if p.Height == 0 {
|
||||
p.Height = other.Height
|
||||
}
|
||||
|
||||
if p.Width == 0 {
|
||||
p.Width = other.Width
|
||||
}
|
||||
|
||||
if p.Channels == 0 {
|
||||
p.Channels = other.Channels
|
||||
}
|
||||
}
|
||||
|
||||
// The output expected for a model
|
||||
type ModelOutput struct {
|
||||
Name string
|
||||
OutputIndex int
|
||||
NumOutputs int64
|
||||
OutputsLogits bool
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
||||
NumOutputs int64 `yaml:"Outputs,omitempty" json:"outputs,omitempty"`
|
||||
OutputsLogits bool `yaml:"Logits,omitempty" json:"logits,omitempty"`
|
||||
}
|
||||
|
||||
// Merge other output with this
|
||||
func (m *ModelOutput) Merge(other *ModelOutput) {
|
||||
if m.Name == "" {
|
||||
m.Name = other.Name
|
||||
}
|
||||
|
||||
if m.OutputIndex == 0 {
|
||||
m.OutputIndex = other.OutputIndex
|
||||
}
|
||||
|
||||
if m.NumOutputs == 0 {
|
||||
m.NumOutputs = other.NumOutputs
|
||||
}
|
||||
|
||||
if !m.OutputsLogits {
|
||||
m.OutputsLogits = other.OutputsLogits
|
||||
}
|
||||
}
|
||||
|
||||
// The meta information for the model
|
||||
type ModelInfo struct {
|
||||
TFVersion string
|
||||
Tags []string
|
||||
Input *PhotoInput
|
||||
Output *ModelOutput
|
||||
TFVersion string `yaml:"-" json:"-"`
|
||||
Tags []string `yaml:"Tags" json:"tags"`
|
||||
Input *PhotoInput `yaml:"Input" json:"input"`
|
||||
Output *ModelOutput `yaml:"Output" json:"output"`
|
||||
}
|
||||
|
||||
// Merge other model info. In case of having information
|
||||
// for a field, the current model will keep its current value
|
||||
func (m *ModelInfo) Merge(other *ModelInfo) {
|
||||
if m.TFVersion == "" {
|
||||
m.TFVersion = other.TFVersion
|
||||
}
|
||||
|
||||
if len(m.Tags) == 0 {
|
||||
m.Tags = other.Tags
|
||||
}
|
||||
|
||||
if m.Input == nil {
|
||||
m.Input = other.Input
|
||||
} else if other.Input != nil {
|
||||
m.Input.Merge(other.Input)
|
||||
}
|
||||
|
||||
if m.Output == nil {
|
||||
m.Output = other.Output
|
||||
} else if other.Output != nil {
|
||||
m.Output.Merge(other.Output)
|
||||
}
|
||||
}
|
||||
|
||||
// We consider a model complete if we know its inputs and outputs
|
||||
@@ -154,7 +220,7 @@ func GetModelInfo(path string) ([]ModelInfo, error) {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Could not get the inputs and outputs from signatures. (TF Version %s): %w", newModel.TFVersion, err)
|
||||
log.Errorf("Could not get the inputs and outputs from signatures. (TF Version %s): %w", newModel.TFVersion, err)
|
||||
}
|
||||
|
||||
models = append(models, newModel)
|
||||
|
||||
Reference in New Issue
Block a user