mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
Added new parameters to model input.
New parameters have been added to define the input of the models: * ResizeOperation: by default center-crop was being performed, now it is configurable. * InputOrder: by default RGB was being used as the order for the array values of the input tensor, now it can be configured. * InputInterval has been changed to InputIntervals (an slice). This means that every channel can have its own interval conversion. * InputInterval can define now stddev and mean, because sometimes instead of adjusting the interval, the stddev and mean of the training data should be use.
This commit is contained in:
195
internal/ai/tensorflow/info_test.go
Normal file
195
internal/ai/tensorflow/info_test.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// Resize Operation Tests
|
||||
var allOperations = []ResizeOperation{
|
||||
UndefinedResizeOperation,
|
||||
ResizeBreakAspectRatio,
|
||||
CenterCrop,
|
||||
Padding,
|
||||
}
|
||||
|
||||
func TestResizeOperations(t *testing.T) {
|
||||
for i := range allOperations {
|
||||
text := allOperations[i].String()
|
||||
|
||||
op, err := NewResizeOperation(text)
|
||||
if err != nil {
|
||||
t.Fatalf("Invalid operation %s: %v", text, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, op, allOperations[i])
|
||||
}
|
||||
}
|
||||
|
||||
const exampleOperationJSON = `"CenterCrop"`
|
||||
|
||||
func TestResizeOperationJSON(t *testing.T) {
|
||||
var op ResizeOperation
|
||||
|
||||
err := json.Unmarshal(
|
||||
[]byte(exampleOperationJSON), &op)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("Could not unmarshal the example operation")
|
||||
}
|
||||
|
||||
for i := range allOperations {
|
||||
serialized, err := json.Marshal(allOperations[i])
|
||||
if err != nil {
|
||||
t.Fatalf("Could not marshal %v: %v",
|
||||
allOperations[i], err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(serialized, &op)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not unmarshal %s: %v",
|
||||
string(serialized), err)
|
||||
}
|
||||
|
||||
assert.Equal(t, op, allOperations[i])
|
||||
}
|
||||
}
|
||||
|
||||
const exampleOperationYAML = "CenterCrop"
|
||||
|
||||
func TestResizeOperationYAML(t *testing.T) {
|
||||
var op ResizeOperation
|
||||
|
||||
err := yaml.Unmarshal(
|
||||
[]byte(exampleOperationYAML), &op)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("Could not unmarshal the example operation")
|
||||
}
|
||||
|
||||
for i := range allOperations {
|
||||
serialized, err := yaml.Marshal(allOperations[i])
|
||||
if err != nil {
|
||||
t.Fatalf("Could not marshal %v: %v",
|
||||
allOperations[i], err)
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(serialized, &op)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not unmarshal %s: %v",
|
||||
string(serialized), err)
|
||||
}
|
||||
|
||||
assert.Equal(t, op, allOperations[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Resize Operation Tests
|
||||
var allInputOrders = []InputOrder{
|
||||
RGB,
|
||||
RBG,
|
||||
GRB,
|
||||
GBR,
|
||||
BRG,
|
||||
BGR,
|
||||
}
|
||||
|
||||
func TestInputOrders(t *testing.T) {
|
||||
for i := range allInputOrders {
|
||||
text := allInputOrders[i].String()
|
||||
|
||||
order, err := NewInputOrder(text)
|
||||
if err != nil {
|
||||
t.Fatalf("Invalid order %s: %v", text, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, order, allInputOrders[i])
|
||||
}
|
||||
}
|
||||
|
||||
const exampleOrderJSON = `"RGB"`
|
||||
|
||||
func TestInputOrderJSON(t *testing.T) {
|
||||
var order InputOrder
|
||||
|
||||
err := json.Unmarshal(
|
||||
[]byte(exampleOrderJSON), &order)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("Could not unmarshal the example operation")
|
||||
}
|
||||
|
||||
for i := range allInputOrders {
|
||||
serialized, err := json.Marshal(allInputOrders[i])
|
||||
if err != nil {
|
||||
t.Fatalf("Could not marshal %v: %v",
|
||||
allInputOrders[i], err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(serialized, &order)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not unmarshal %s: %v",
|
||||
string(serialized), err)
|
||||
}
|
||||
|
||||
assert.Equal(t, order, allInputOrders[i])
|
||||
}
|
||||
}
|
||||
|
||||
const exampleOrderYAML = "RGB"
|
||||
|
||||
func TestInputOrderYAML(t *testing.T) {
|
||||
var order InputOrder
|
||||
|
||||
err := yaml.Unmarshal(
|
||||
[]byte(exampleOrderYAML), &order)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("Could not unmarshal the example operation")
|
||||
}
|
||||
|
||||
for i := range allInputOrders {
|
||||
serialized, err := yaml.Marshal(allInputOrders[i])
|
||||
if err != nil {
|
||||
t.Fatalf("Could not marshal %v: %v",
|
||||
allInputOrders[i], err)
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(serialized, &order)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not unmarshal %s: %v",
|
||||
string(serialized), err)
|
||||
}
|
||||
|
||||
assert.Equal(t, order, allInputOrders[i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrderIndices(t *testing.T) {
|
||||
r, g, b := UndefinedOrder.Indices()
|
||||
|
||||
assert.Equal(t, r, 0)
|
||||
assert.Equal(t, g, 1)
|
||||
assert.Equal(t, b, 2)
|
||||
|
||||
powerFx := func(i int) int {
|
||||
switch i {
|
||||
case 0:
|
||||
return 100
|
||||
case 1:
|
||||
return 10
|
||||
case 2:
|
||||
return 1
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
for i := range allInputOrders {
|
||||
r, g, b = allInputOrders[i].Indices()
|
||||
assert.Equal(t, powerFx(r)+2*powerFx(g)+3*powerFx(b), int(allInputOrders[i]))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user