mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
New parameters have been added to define the input of the models: * ResizeOperation: by default center-crop was being performed, now it is configurable. * InputOrder: by default RGB was being used as the order for the array values of the input tensor, now it can be configured. * InputInterval has been changed to InputIntervals (an slice). This means that every channel can have its own interval conversion. * InputInterval can define now stddev and mean, because sometimes instead of adjusting the interval, the stddev and mean of the training data should be use.
196 lines
3.7 KiB
Go
196 lines
3.7 KiB
Go
package tensorflow
|
|
|
|
import (
|
|
"encoding/json"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"gopkg.in/yaml.v2"
|
|
)
|
|
|
|
// Resize Operation Tests
|
|
var allOperations = []ResizeOperation{
|
|
UndefinedResizeOperation,
|
|
ResizeBreakAspectRatio,
|
|
CenterCrop,
|
|
Padding,
|
|
}
|
|
|
|
func TestResizeOperations(t *testing.T) {
|
|
for i := range allOperations {
|
|
text := allOperations[i].String()
|
|
|
|
op, err := NewResizeOperation(text)
|
|
if err != nil {
|
|
t.Fatalf("Invalid operation %s: %v", text, err)
|
|
}
|
|
|
|
assert.Equal(t, op, allOperations[i])
|
|
}
|
|
}
|
|
|
|
const exampleOperationJSON = `"CenterCrop"`
|
|
|
|
func TestResizeOperationJSON(t *testing.T) {
|
|
var op ResizeOperation
|
|
|
|
err := json.Unmarshal(
|
|
[]byte(exampleOperationJSON), &op)
|
|
|
|
if err != nil {
|
|
t.Fatal("Could not unmarshal the example operation")
|
|
}
|
|
|
|
for i := range allOperations {
|
|
serialized, err := json.Marshal(allOperations[i])
|
|
if err != nil {
|
|
t.Fatalf("Could not marshal %v: %v",
|
|
allOperations[i], err)
|
|
}
|
|
|
|
err = json.Unmarshal(serialized, &op)
|
|
if err != nil {
|
|
t.Fatalf("Could not unmarshal %s: %v",
|
|
string(serialized), err)
|
|
}
|
|
|
|
assert.Equal(t, op, allOperations[i])
|
|
}
|
|
}
|
|
|
|
const exampleOperationYAML = "CenterCrop"
|
|
|
|
func TestResizeOperationYAML(t *testing.T) {
|
|
var op ResizeOperation
|
|
|
|
err := yaml.Unmarshal(
|
|
[]byte(exampleOperationYAML), &op)
|
|
|
|
if err != nil {
|
|
t.Fatal("Could not unmarshal the example operation")
|
|
}
|
|
|
|
for i := range allOperations {
|
|
serialized, err := yaml.Marshal(allOperations[i])
|
|
if err != nil {
|
|
t.Fatalf("Could not marshal %v: %v",
|
|
allOperations[i], err)
|
|
}
|
|
|
|
err = yaml.Unmarshal(serialized, &op)
|
|
if err != nil {
|
|
t.Fatalf("Could not unmarshal %s: %v",
|
|
string(serialized), err)
|
|
}
|
|
|
|
assert.Equal(t, op, allOperations[i])
|
|
}
|
|
}
|
|
|
|
// Resize Operation Tests
|
|
var allInputOrders = []InputOrder{
|
|
RGB,
|
|
RBG,
|
|
GRB,
|
|
GBR,
|
|
BRG,
|
|
BGR,
|
|
}
|
|
|
|
func TestInputOrders(t *testing.T) {
|
|
for i := range allInputOrders {
|
|
text := allInputOrders[i].String()
|
|
|
|
order, err := NewInputOrder(text)
|
|
if err != nil {
|
|
t.Fatalf("Invalid order %s: %v", text, err)
|
|
}
|
|
|
|
assert.Equal(t, order, allInputOrders[i])
|
|
}
|
|
}
|
|
|
|
const exampleOrderJSON = `"RGB"`
|
|
|
|
func TestInputOrderJSON(t *testing.T) {
|
|
var order InputOrder
|
|
|
|
err := json.Unmarshal(
|
|
[]byte(exampleOrderJSON), &order)
|
|
|
|
if err != nil {
|
|
t.Fatal("Could not unmarshal the example operation")
|
|
}
|
|
|
|
for i := range allInputOrders {
|
|
serialized, err := json.Marshal(allInputOrders[i])
|
|
if err != nil {
|
|
t.Fatalf("Could not marshal %v: %v",
|
|
allInputOrders[i], err)
|
|
}
|
|
|
|
err = json.Unmarshal(serialized, &order)
|
|
if err != nil {
|
|
t.Fatalf("Could not unmarshal %s: %v",
|
|
string(serialized), err)
|
|
}
|
|
|
|
assert.Equal(t, order, allInputOrders[i])
|
|
}
|
|
}
|
|
|
|
const exampleOrderYAML = "RGB"
|
|
|
|
func TestInputOrderYAML(t *testing.T) {
|
|
var order InputOrder
|
|
|
|
err := yaml.Unmarshal(
|
|
[]byte(exampleOrderYAML), &order)
|
|
|
|
if err != nil {
|
|
t.Fatal("Could not unmarshal the example operation")
|
|
}
|
|
|
|
for i := range allInputOrders {
|
|
serialized, err := yaml.Marshal(allInputOrders[i])
|
|
if err != nil {
|
|
t.Fatalf("Could not marshal %v: %v",
|
|
allInputOrders[i], err)
|
|
}
|
|
|
|
err = yaml.Unmarshal(serialized, &order)
|
|
if err != nil {
|
|
t.Fatalf("Could not unmarshal %s: %v",
|
|
string(serialized), err)
|
|
}
|
|
|
|
assert.Equal(t, order, allInputOrders[i])
|
|
}
|
|
}
|
|
|
|
func TestOrderIndices(t *testing.T) {
|
|
r, g, b := UndefinedOrder.Indices()
|
|
|
|
assert.Equal(t, r, 0)
|
|
assert.Equal(t, g, 1)
|
|
assert.Equal(t, b, 2)
|
|
|
|
powerFx := func(i int) int {
|
|
switch i {
|
|
case 0:
|
|
return 100
|
|
case 1:
|
|
return 10
|
|
case 2:
|
|
return 1
|
|
default:
|
|
return -1
|
|
}
|
|
}
|
|
|
|
for i := range allInputOrders {
|
|
r, g, b = allInputOrders[i].Indices()
|
|
assert.Equal(t, powerFx(r)+2*powerFx(g)+3*powerFx(b), int(allInputOrders[i]))
|
|
}
|
|
}
|