mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
Added new parameters to model input.
New parameters have been added to define the input of the models: * ResizeOperation: by default center-crop was being performed, now it is configurable. * InputOrder: by default RGB was being used as the order for the array values of the input tensor, now it can be configured. * InputInterval has been changed to InputIntervals (an slice). This means that every channel can have its own interval conversion. * InputInterval can define now stddev and mean, because sometimes instead of adjusting the interval, the stddev and mean of the training data should be use.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -18,15 +19,26 @@ const ExpectedChannels = 3
|
||||
|
||||
// Interval of allowed values
|
||||
type Interval struct {
|
||||
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
|
||||
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
|
||||
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
|
||||
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
|
||||
Mean *float32 `yaml:"Mean,omitempty" json:"mean,omitempty"`
|
||||
StdDev *float32 `yaml:"StdDev,omitempty" json:"stdDev,omitempty"`
|
||||
}
|
||||
|
||||
// The size of the interval
|
||||
// The size/mean of the interval
|
||||
func (i Interval) Size() float32 {
|
||||
return i.End - i.Start
|
||||
}
|
||||
|
||||
// The offset of the interval
|
||||
func (i Interval) Offset() float32 {
|
||||
if i.StdDev == nil {
|
||||
return i.Start
|
||||
} else {
|
||||
return *i.StdDev
|
||||
}
|
||||
}
|
||||
|
||||
// The standard interval returned by decodeImage is [0, 1]
|
||||
func StandardInterval() *Interval {
|
||||
return &Interval{
|
||||
@@ -35,13 +47,230 @@ func StandardInterval() *Interval {
|
||||
}
|
||||
}
|
||||
|
||||
// How should we resize the images
|
||||
// JSON and YAML functions are given to make it
|
||||
// user friendly from the configuration files
|
||||
type ResizeOperation int
|
||||
|
||||
const (
|
||||
UndefinedResizeOperation ResizeOperation = iota
|
||||
ResizeBreakAspectRatio
|
||||
CenterCrop
|
||||
Padding
|
||||
)
|
||||
|
||||
func (o ResizeOperation) String() string {
|
||||
switch o {
|
||||
case UndefinedResizeOperation:
|
||||
return "Undefined"
|
||||
case ResizeBreakAspectRatio:
|
||||
return "ResizeBreakAspectRatio"
|
||||
case CenterCrop:
|
||||
return "CenterCrop"
|
||||
case Padding:
|
||||
return "Padding"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func NewResizeOperation(s string) (ResizeOperation, error) {
|
||||
switch s {
|
||||
case "Undefined":
|
||||
return UndefinedResizeOperation, nil
|
||||
case "ResizeBreakAspectRatio":
|
||||
return ResizeBreakAspectRatio, nil
|
||||
case "CenterCrop":
|
||||
return CenterCrop, nil
|
||||
case "Padding":
|
||||
return Padding, nil
|
||||
default:
|
||||
return UndefinedResizeOperation, fmt.Errorf("Invalid operation %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func (o ResizeOperation) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(o.String())
|
||||
}
|
||||
|
||||
func (o *ResizeOperation) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err := NewResizeOperation(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*o = val
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o ResizeOperation) MarshalYAML() (any, error) {
|
||||
return o.String(), nil
|
||||
}
|
||||
|
||||
func (o *ResizeOperation) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var s string
|
||||
if err := unmarshal(&s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err := NewResizeOperation(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*o = val
|
||||
return nil
|
||||
}
|
||||
|
||||
// How should we order the input vectors
|
||||
// JSON and YAML functions are given to make it
|
||||
// user friendly from the configuration files
|
||||
type InputOrder int
|
||||
|
||||
const (
|
||||
UndefinedOrder InputOrder = 0
|
||||
RGB = 123
|
||||
RBG = 132
|
||||
GRB = 213
|
||||
GBR = 231
|
||||
BRG = 312
|
||||
BGR = 321
|
||||
)
|
||||
|
||||
func (o InputOrder) Indices() (r, g, b int) {
|
||||
i := int(o)
|
||||
|
||||
if i == 0 {
|
||||
i = 123
|
||||
}
|
||||
|
||||
for idx := 0; i > 0 && idx < 3; idx += 1 {
|
||||
remainder := i % 10
|
||||
i /= 10
|
||||
|
||||
switch remainder {
|
||||
case 1:
|
||||
r = 2 - idx
|
||||
case 2:
|
||||
g = 2 - idx
|
||||
case 3:
|
||||
b = 2 - idx
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (o InputOrder) String() string {
|
||||
value := int(o)
|
||||
|
||||
if value == 0 {
|
||||
value = 123
|
||||
}
|
||||
|
||||
convert := func(remainder int) string {
|
||||
switch remainder {
|
||||
case 1:
|
||||
return "R"
|
||||
case 2:
|
||||
return "G"
|
||||
case 3:
|
||||
return "B"
|
||||
default:
|
||||
return "?"
|
||||
}
|
||||
}
|
||||
|
||||
result := ""
|
||||
for value > 0 {
|
||||
remainder := value % 10
|
||||
value /= 10
|
||||
|
||||
result = convert(remainder) + result
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func NewInputOrder(val string) (InputOrder, error) {
|
||||
if len(val) != 3 {
|
||||
return UndefinedOrder, fmt.Errorf("Invalid length, expected 3")
|
||||
}
|
||||
|
||||
convert := func(c rune) int {
|
||||
switch c {
|
||||
case 'R':
|
||||
return 1
|
||||
case 'G':
|
||||
return 2
|
||||
case 'B':
|
||||
return 3
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
result := 0
|
||||
for _, c := range val {
|
||||
index := convert(c)
|
||||
if index == 0 {
|
||||
return UndefinedOrder, fmt.Errorf("Invalid val %c", c)
|
||||
}
|
||||
result = result*10 + index
|
||||
}
|
||||
return InputOrder(result), nil
|
||||
}
|
||||
|
||||
func (o InputOrder) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(o.String())
|
||||
}
|
||||
|
||||
func (o *InputOrder) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err := NewInputOrder(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*o = val
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o InputOrder) MarshalYAML() (any, error) {
|
||||
return o.String(), nil
|
||||
}
|
||||
|
||||
func (o *InputOrder) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var s string
|
||||
if err := unmarshal(&s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err := NewInputOrder(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*o = val
|
||||
return nil
|
||||
}
|
||||
|
||||
// Input description for a photo input for a model
|
||||
type PhotoInput struct {
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
Interval *Interval `yaml:"Interval,omitempty" json:"interval,omitempty"`
|
||||
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
||||
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
|
||||
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
Intervals []Interval `yaml:"Intervals,omitempty" json:"intervals,omitempty"`
|
||||
ResizeOperation ResizeOperation `yaml:"ResizeOperation,omitempty" json:"resizeOperation,omitemty"`
|
||||
InputOrder InputOrder `yaml:"InputOrder,omitempty" json:"inputOrder,omitempty"`
|
||||
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
||||
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
|
||||
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
|
||||
}
|
||||
|
||||
// When dimensions are not defined, it means the model accepts any size of
|
||||
@@ -61,12 +290,18 @@ func (p *PhotoInput) SetResolution(resolution int) {
|
||||
p.Width = int64(resolution)
|
||||
}
|
||||
|
||||
// Get the interval or the default one
|
||||
func (p PhotoInput) GetInterval() *Interval {
|
||||
if p.Interval == nil {
|
||||
// Get the interval or the default one.
|
||||
// If just one interval has been fixed, then we assume
|
||||
// it is the same for every channel. If no intervals
|
||||
// have been defined, the default [0, 1] is returned
|
||||
func (p PhotoInput) GetInterval(channel int) *Interval {
|
||||
if len(p.Intervals) <= channel {
|
||||
if len(p.Intervals) == 1 {
|
||||
return &p.Intervals[0]
|
||||
}
|
||||
return StandardInterval()
|
||||
} else {
|
||||
return p.Interval
|
||||
return &p.Intervals[channel]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,8 +311,8 @@ func (p *PhotoInput) Merge(other *PhotoInput) {
|
||||
p.Name = other.Name
|
||||
}
|
||||
|
||||
if p.Interval == nil && other.Interval != nil {
|
||||
p.Interval = other.Interval
|
||||
if p.Intervals == nil && other.Intervals != nil {
|
||||
p.Intervals = other.Intervals
|
||||
}
|
||||
|
||||
if p.OutputIndex == 0 {
|
||||
@@ -91,6 +326,14 @@ func (p *PhotoInput) Merge(other *PhotoInput) {
|
||||
if p.Width == 0 {
|
||||
p.Width = other.Width
|
||||
}
|
||||
|
||||
if p.ResizeOperation == UndefinedResizeOperation {
|
||||
p.ResizeOperation = other.ResizeOperation
|
||||
}
|
||||
|
||||
if p.InputOrder == UndefinedOrder {
|
||||
p.InputOrder = other.InputOrder
|
||||
}
|
||||
}
|
||||
|
||||
// The output expected for a model
|
||||
|
||||
Reference in New Issue
Block a user