Files
photoprism/internal/ai/tensorflow/model_test.go
2025-11-22 11:47:17 +01:00

98 lines
2.0 KiB
Go

package tensorflow
import (
"path/filepath"
"slices"
"testing"
"github.com/photoprism/photoprism/pkg/fs"
)
var assetsPath = fs.Abs("../../../assets")
var testDataPath = fs.Abs("testdata")
func TestTF1ModelLoad(t *testing.T) {
model, err := SavedModel(
filepath.Join(assetsPath, "models", "nasnet"),
[]string{"photoprism"})
if err != nil {
t.Fatal(err)
}
_, _, err = GetInputAndOutputFromSavedModel(model)
if err == nil {
t.Fatalf("TF1 does not have signatures, but GetInput worked")
}
input, output, err := GuessInputAndOutput(model)
if err != nil {
t.Fatal(err)
}
switch {
case input == nil:
t.Fatal("Could not get the input")
case output == nil:
t.Fatal("Could not get the output")
case input.Shape == nil:
t.Fatal("Could not get the shape")
default:
t.Logf("Shape: %v", input.Shape)
}
}
func TestTF2ModelLoad(t *testing.T) {
model, err := SavedModel(
filepath.Join(testDataPath, "tf2"),
[]string{"serve"})
if err != nil {
t.Fatal(err)
}
input, output, err := GetInputAndOutputFromSavedModel(model)
if err != nil {
t.Fatal(err)
}
switch {
case input == nil:
t.Fatal("Could not get the input")
case output == nil:
t.Fatal("Could not get the output")
case input.Shape == nil:
t.Fatal("Could not get the shape")
case !slices.Equal(input.Shape, DefaultPhotoInputShape()):
t.Fatalf("Invalid shape calculated. Expected BHWC, got %v", input.Shape)
}
}
func TestTF2ModelBCHWLoad(t *testing.T) {
model, err := SavedModel(
filepath.Join(testDataPath, "tf2_bchw"),
[]string{"serve"})
if err != nil {
t.Fatal(err)
}
input, output, err := GetInputAndOutputFromSavedModel(model)
if err != nil {
t.Fatal(err)
}
switch {
case input == nil:
t.Fatal("Could not get the input")
case output == nil:
t.Fatal("Could not get the output")
case input.Shape == nil:
t.Fatal("Could not get the shape")
case !slices.Equal(input.Shape, []ShapeComponent{
ShapeBatch, ShapeColor, ShapeHeight, ShapeWidth,
}):
t.Fatalf("Invalid shape calculated. Expected BCHW, got %v", input.Shape)
}
}