Added new parameters to model input.

New parameters have been added to define the input of the models:
* ResizeOperation: by default center-crop was being performed, now it is
  configurable.
* InputOrder: by default RGB was being used as the order for the array
  values of the input tensor, now it can be configured.
* InputInterval has been changed to InputIntervals (an slice). This
  means that every channel can have its own interval conversion.
* InputInterval can define now stddev and mean, because sometimes
  instead of adjusting the interval, the stddev and mean of the training
data should be use.
This commit is contained in:
raystlin
2025-07-15 13:31:31 +00:00
parent c682a94a07
commit adc4dc0f74
10 changed files with 800 additions and 40 deletions

View File

@@ -1,6 +1,7 @@
package tensorflow
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
@@ -18,15 +19,26 @@ const ExpectedChannels = 3
// Interval of allowed values
type Interval struct {
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
Mean *float32 `yaml:"Mean,omitempty" json:"mean,omitempty"`
StdDev *float32 `yaml:"StdDev,omitempty" json:"stdDev,omitempty"`
}
// The size of the interval
// The size/mean of the interval
func (i Interval) Size() float32 {
return i.End - i.Start
}
// The offset of the interval
func (i Interval) Offset() float32 {
if i.StdDev == nil {
return i.Start
} else {
return *i.StdDev
}
}
// The standard interval returned by decodeImage is [0, 1]
func StandardInterval() *Interval {
return &Interval{
@@ -35,13 +47,230 @@ func StandardInterval() *Interval {
}
}
// How should we resize the images
// JSON and YAML functions are given to make it
// user friendly from the configuration files
type ResizeOperation int
const (
UndefinedResizeOperation ResizeOperation = iota
ResizeBreakAspectRatio
CenterCrop
Padding
)
func (o ResizeOperation) String() string {
switch o {
case UndefinedResizeOperation:
return "Undefined"
case ResizeBreakAspectRatio:
return "ResizeBreakAspectRatio"
case CenterCrop:
return "CenterCrop"
case Padding:
return "Padding"
default:
return "Unknown"
}
}
func NewResizeOperation(s string) (ResizeOperation, error) {
switch s {
case "Undefined":
return UndefinedResizeOperation, nil
case "ResizeBreakAspectRatio":
return ResizeBreakAspectRatio, nil
case "CenterCrop":
return CenterCrop, nil
case "Padding":
return Padding, nil
default:
return UndefinedResizeOperation, fmt.Errorf("Invalid operation %s", s)
}
}
func (o ResizeOperation) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
func (o *ResizeOperation) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
val, err := NewResizeOperation(s)
if err != nil {
return err
}
*o = val
return nil
}
func (o ResizeOperation) MarshalYAML() (any, error) {
return o.String(), nil
}
func (o *ResizeOperation) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
if err := unmarshal(&s); err != nil {
return err
}
val, err := NewResizeOperation(s)
if err != nil {
return err
}
*o = val
return nil
}
// How should we order the input vectors
// JSON and YAML functions are given to make it
// user friendly from the configuration files
type InputOrder int
const (
UndefinedOrder InputOrder = 0
RGB = 123
RBG = 132
GRB = 213
GBR = 231
BRG = 312
BGR = 321
)
func (o InputOrder) Indices() (r, g, b int) {
i := int(o)
if i == 0 {
i = 123
}
for idx := 0; i > 0 && idx < 3; idx += 1 {
remainder := i % 10
i /= 10
switch remainder {
case 1:
r = 2 - idx
case 2:
g = 2 - idx
case 3:
b = 2 - idx
}
}
return
}
func (o InputOrder) String() string {
value := int(o)
if value == 0 {
value = 123
}
convert := func(remainder int) string {
switch remainder {
case 1:
return "R"
case 2:
return "G"
case 3:
return "B"
default:
return "?"
}
}
result := ""
for value > 0 {
remainder := value % 10
value /= 10
result = convert(remainder) + result
}
return result
}
func NewInputOrder(val string) (InputOrder, error) {
if len(val) != 3 {
return UndefinedOrder, fmt.Errorf("Invalid length, expected 3")
}
convert := func(c rune) int {
switch c {
case 'R':
return 1
case 'G':
return 2
case 'B':
return 3
default:
return 0
}
}
result := 0
for _, c := range val {
index := convert(c)
if index == 0 {
return UndefinedOrder, fmt.Errorf("Invalid val %c", c)
}
result = result*10 + index
}
return InputOrder(result), nil
}
func (o InputOrder) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
func (o *InputOrder) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
val, err := NewInputOrder(s)
if err != nil {
return err
}
*o = val
return nil
}
func (o InputOrder) MarshalYAML() (any, error) {
return o.String(), nil
}
func (o *InputOrder) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
if err := unmarshal(&s); err != nil {
return err
}
val, err := NewInputOrder(s)
if err != nil {
return err
}
*o = val
return nil
}
// Input description for a photo input for a model
type PhotoInput struct {
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Interval *Interval `yaml:"Interval,omitempty" json:"interval,omitempty"`
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Intervals []Interval `yaml:"Intervals,omitempty" json:"intervals,omitempty"`
ResizeOperation ResizeOperation `yaml:"ResizeOperation,omitempty" json:"resizeOperation,omitemty"`
InputOrder InputOrder `yaml:"InputOrder,omitempty" json:"inputOrder,omitempty"`
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
}
// When dimensions are not defined, it means the model accepts any size of
@@ -61,12 +290,18 @@ func (p *PhotoInput) SetResolution(resolution int) {
p.Width = int64(resolution)
}
// Get the interval or the default one
func (p PhotoInput) GetInterval() *Interval {
if p.Interval == nil {
// Get the interval or the default one.
// If just one interval has been fixed, then we assume
// it is the same for every channel. If no intervals
// have been defined, the default [0, 1] is returned
func (p PhotoInput) GetInterval(channel int) *Interval {
if len(p.Intervals) <= channel {
if len(p.Intervals) == 1 {
return &p.Intervals[0]
}
return StandardInterval()
} else {
return p.Interval
return &p.Intervals[channel]
}
}
@@ -76,8 +311,8 @@ func (p *PhotoInput) Merge(other *PhotoInput) {
p.Name = other.Name
}
if p.Interval == nil && other.Interval != nil {
p.Interval = other.Interval
if p.Intervals == nil && other.Intervals != nil {
p.Intervals = other.Intervals
}
if p.OutputIndex == 0 {
@@ -91,6 +326,14 @@ func (p *PhotoInput) Merge(other *PhotoInput) {
if p.Width == 0 {
p.Width = other.Width
}
if p.ResizeOperation == UndefinedResizeOperation {
p.ResizeOperation = other.ResizeOperation
}
if p.InputOrder == UndefinedOrder {
p.InputOrder = other.InputOrder
}
}
// The output expected for a model