Files
photoprism/internal/ai/tensorflow/info_test.go
2025-11-22 11:47:17 +01:00

279 lines
5.6 KiB
Go

package tensorflow
import (
"encoding/json"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)
// Resize Operation Tests
var allOperations = []ResizeOperation{
UndefinedResizeOperation,
ResizeBreakAspectRatio,
CenterCrop,
Padding,
}
func TestGetModelTagsInfo(t *testing.T) {
info, err := GetModelTagsInfo(
filepath.Join(assetsPath, "models", "nasnet"))
if err != nil {
t.Fatal(err)
}
switch {
case len(info) != 1:
t.Fatalf("Expected 1 info but got %d", len(info))
case len(info[0].Tags) != 1:
t.Fatalf("Expected 1 tag, but got %d", len(info[0].Tags))
case info[0].Tags[0] != "photoprism":
t.Fatalf("Expected tag photoprism, but have %s", info[0].Tags[0])
}
}
func TestResizeOperations(t *testing.T) {
for i := range allOperations {
text := allOperations[i].String()
op, err := NewResizeOperation(text)
if err != nil {
t.Fatalf("Invalid operation %s: %v", text, err)
}
assert.Equal(t, op, allOperations[i])
}
}
const exampleOperationJSON = `"CenterCrop"`
func TestResizeOperationJSON(t *testing.T) {
var op ResizeOperation
err := json.Unmarshal(
[]byte(exampleOperationJSON), &op)
if err != nil {
t.Fatal("could not unmarshal the example operation")
}
for i := range allOperations {
serialized, err := json.Marshal(allOperations[i])
if err != nil {
t.Fatalf("could not marshal %v: %v",
allOperations[i], err)
}
err = json.Unmarshal(serialized, &op)
if err != nil {
t.Fatalf("could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, op, allOperations[i])
}
}
const exampleOperationYAML = "CenterCrop"
func TestResizeOperationYAML(t *testing.T) {
var op ResizeOperation
err := yaml.Unmarshal(
[]byte(exampleOperationYAML), &op)
if err != nil {
t.Fatal("could not unmarshal the example operation")
}
for i := range allOperations {
serialized, err := yaml.Marshal(allOperations[i])
if err != nil {
t.Fatalf("could not marshal %v: %v",
allOperations[i], err)
}
err = yaml.Unmarshal(serialized, &op)
if err != nil {
t.Fatalf("could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, op, allOperations[i])
}
}
// Resize Operation Tests
var allColorChannelOrders = []ColorChannelOrder{
RGB,
RBG,
GRB,
GBR,
BRG,
BGR,
}
func TestColorChannelOrders(t *testing.T) {
for i := range allColorChannelOrders {
text := allColorChannelOrders[i].String()
order, err := NewColorChannelOrder(text)
if err != nil {
t.Fatalf("Invalid order %s: %v", text, err)
}
assert.Equal(t, order, allColorChannelOrders[i])
}
}
const exampleOrderJSON = `"RGB"`
func TestColorChannelOrderJSON(t *testing.T) {
var order ColorChannelOrder
err := json.Unmarshal(
[]byte(exampleOrderJSON), &order)
if err != nil {
t.Fatal("could not unmarshal the example color order")
}
for i := range allColorChannelOrders {
serialized, err := json.Marshal(allColorChannelOrders[i])
if err != nil {
t.Fatalf("could not marshal %v: %v",
allColorChannelOrders[i], err)
}
err = json.Unmarshal(serialized, &order)
if err != nil {
t.Fatalf("could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, order, allColorChannelOrders[i])
}
}
const exampleOrderYAML = "RGB"
func TestColorChannelOrderYAML(t *testing.T) {
var order ColorChannelOrder
err := yaml.Unmarshal(
[]byte(exampleOrderYAML), &order)
if err != nil {
t.Fatal("could not unmarshal the example color order")
}
for i := range allColorChannelOrders {
serialized, err := yaml.Marshal(allColorChannelOrders[i])
if err != nil {
t.Fatalf("could not marshal %v: %v",
allColorChannelOrders[i], err)
}
err = yaml.Unmarshal(serialized, &order)
if err != nil {
t.Fatalf("could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, order, allColorChannelOrders[i])
}
}
func TestOrderIndices(t *testing.T) {
r, g, b := UndefinedOrder.Indices()
assert.Equal(t, r, 0)
assert.Equal(t, g, 1)
assert.Equal(t, b, 2)
powerFx := func(i int) int {
switch i {
case 0:
return 100
case 1:
return 10
case 2:
return 1
default:
return -1
}
}
for i := range allColorChannelOrders {
r, g, b = allColorChannelOrders[i].Indices()
assert.Equal(t, powerFx(r)+2*powerFx(g)+3*powerFx(b), int(allColorChannelOrders[i]))
}
}
var allShapeComponents = []ShapeComponent{
ShapeBatch,
ShapeWidth,
ShapeHeight,
ShapeColor,
}
const exampleShapeComponentJSON = `"Batch"`
func TestShapeComponentJSON(t *testing.T) {
var comp ShapeComponent
err := json.Unmarshal(
[]byte(exampleShapeComponentJSON), &comp)
if err != nil {
t.Fatal("could not unmarshal the example shape component")
}
for i := range allShapeComponents {
serialized, err := json.Marshal(allShapeComponents[i])
if err != nil {
t.Fatalf("could not marshal %v: %v",
allShapeComponents[i], err)
}
err = json.Unmarshal(serialized, &comp)
if err != nil {
t.Fatalf("could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, comp, allShapeComponents[i])
}
}
const exampleShapeComponentYAML = "Batch"
func TestShapeComponentYAML(t *testing.T) {
var comp ShapeComponent
err := yaml.Unmarshal(
[]byte(exampleShapeComponentYAML), &comp)
if err != nil {
t.Fatal("could not unmarshal the example operation")
}
for i := range allShapeComponents {
serialized, err := yaml.Marshal(allShapeComponents[i])
if err != nil {
t.Fatalf("could not marshal %v: %v",
allShapeComponents[i], err)
}
err = yaml.Unmarshal(serialized, &comp)
if err != nil {
t.Fatalf("could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, comp, allShapeComponents[i])
}
}