Added new parameters to model input.

New parameters have been added to define the input of the models:
* ResizeOperation: by default center-crop was being performed, now it is
  configurable.
* InputOrder: by default RGB was being used as the order for the array
  values of the input tensor, now it can be configured.
* InputInterval has been changed to InputIntervals (an slice). This
  means that every channel can have its own interval conversion.
* InputInterval can define now stddev and mean, because sometimes
  instead of adjusting the interval, the stddev and mean of the training
data should be use.
This commit is contained in:
raystlin
2025-07-15 13:31:31 +00:00
parent c682a94a07
commit adc4dc0f74
10 changed files with 800 additions and 40 deletions

View File

@@ -0,0 +1,195 @@
package tensorflow
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)
// Resize Operation Tests
var allOperations = []ResizeOperation{
UndefinedResizeOperation,
ResizeBreakAspectRatio,
CenterCrop,
Padding,
}
func TestResizeOperations(t *testing.T) {
for i := range allOperations {
text := allOperations[i].String()
op, err := NewResizeOperation(text)
if err != nil {
t.Fatalf("Invalid operation %s: %v", text, err)
}
assert.Equal(t, op, allOperations[i])
}
}
const exampleOperationJSON = `"CenterCrop"`
func TestResizeOperationJSON(t *testing.T) {
var op ResizeOperation
err := json.Unmarshal(
[]byte(exampleOperationJSON), &op)
if err != nil {
t.Fatal("Could not unmarshal the example operation")
}
for i := range allOperations {
serialized, err := json.Marshal(allOperations[i])
if err != nil {
t.Fatalf("Could not marshal %v: %v",
allOperations[i], err)
}
err = json.Unmarshal(serialized, &op)
if err != nil {
t.Fatalf("Could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, op, allOperations[i])
}
}
const exampleOperationYAML = "CenterCrop"
func TestResizeOperationYAML(t *testing.T) {
var op ResizeOperation
err := yaml.Unmarshal(
[]byte(exampleOperationYAML), &op)
if err != nil {
t.Fatal("Could not unmarshal the example operation")
}
for i := range allOperations {
serialized, err := yaml.Marshal(allOperations[i])
if err != nil {
t.Fatalf("Could not marshal %v: %v",
allOperations[i], err)
}
err = yaml.Unmarshal(serialized, &op)
if err != nil {
t.Fatalf("Could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, op, allOperations[i])
}
}
// Resize Operation Tests
var allInputOrders = []InputOrder{
RGB,
RBG,
GRB,
GBR,
BRG,
BGR,
}
func TestInputOrders(t *testing.T) {
for i := range allInputOrders {
text := allInputOrders[i].String()
order, err := NewInputOrder(text)
if err != nil {
t.Fatalf("Invalid order %s: %v", text, err)
}
assert.Equal(t, order, allInputOrders[i])
}
}
const exampleOrderJSON = `"RGB"`
func TestInputOrderJSON(t *testing.T) {
var order InputOrder
err := json.Unmarshal(
[]byte(exampleOrderJSON), &order)
if err != nil {
t.Fatal("Could not unmarshal the example operation")
}
for i := range allInputOrders {
serialized, err := json.Marshal(allInputOrders[i])
if err != nil {
t.Fatalf("Could not marshal %v: %v",
allInputOrders[i], err)
}
err = json.Unmarshal(serialized, &order)
if err != nil {
t.Fatalf("Could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, order, allInputOrders[i])
}
}
const exampleOrderYAML = "RGB"
func TestInputOrderYAML(t *testing.T) {
var order InputOrder
err := yaml.Unmarshal(
[]byte(exampleOrderYAML), &order)
if err != nil {
t.Fatal("Could not unmarshal the example operation")
}
for i := range allInputOrders {
serialized, err := yaml.Marshal(allInputOrders[i])
if err != nil {
t.Fatalf("Could not marshal %v: %v",
allInputOrders[i], err)
}
err = yaml.Unmarshal(serialized, &order)
if err != nil {
t.Fatalf("Could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, order, allInputOrders[i])
}
}
func TestOrderIndices(t *testing.T) {
r, g, b := UndefinedOrder.Indices()
assert.Equal(t, r, 0)
assert.Equal(t, g, 1)
assert.Equal(t, b, 2)
powerFx := func(i int) int {
switch i {
case 0:
return 100
case 1:
return 10
case 2:
return 1
default:
return -1
}
}
for i := range allInputOrders {
r, g, b = allInputOrders[i].Indices()
assert.Equal(t, powerFx(r)+2*powerFx(g)+3*powerFx(b), int(allInputOrders[i]))
}
}