mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-11 16:24:11 +01:00
98 lines
2.0 KiB
Go
98 lines
2.0 KiB
Go
package tensorflow
|
|
|
|
import (
|
|
"path/filepath"
|
|
"slices"
|
|
"testing"
|
|
|
|
"github.com/photoprism/photoprism/pkg/fs"
|
|
)
|
|
|
|
var assetsPath = fs.Abs("../../../assets")
|
|
var testDataPath = fs.Abs("testdata")
|
|
|
|
func TestTF1ModelLoad(t *testing.T) {
|
|
model, err := SavedModel(
|
|
filepath.Join(assetsPath, "models", "nasnet"),
|
|
[]string{"photoprism"})
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, _, err = GetInputAndOutputFromSavedModel(model)
|
|
if err == nil {
|
|
t.Fatalf("TF1 does not have signatures, but GetInput worked")
|
|
}
|
|
|
|
input, output, err := GuessInputAndOutput(model)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
switch {
|
|
case input == nil:
|
|
t.Fatal("Could not get the input")
|
|
case output == nil:
|
|
t.Fatal("Could not get the output")
|
|
case input.Shape == nil:
|
|
t.Fatal("Could not get the shape")
|
|
default:
|
|
t.Logf("Shape: %v", input.Shape)
|
|
}
|
|
}
|
|
|
|
func TestTF2ModelLoad(t *testing.T) {
|
|
model, err := SavedModel(
|
|
filepath.Join(testDataPath, "tf2"),
|
|
[]string{"serve"})
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
input, output, err := GetInputAndOutputFromSavedModel(model)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
switch {
|
|
case input == nil:
|
|
t.Fatal("Could not get the input")
|
|
case output == nil:
|
|
t.Fatal("Could not get the output")
|
|
case input.Shape == nil:
|
|
t.Fatal("Could not get the shape")
|
|
case !slices.Equal(input.Shape, DefaultPhotoInputShape()):
|
|
t.Fatalf("Invalid shape calculated. Expected BHWC, got %v", input.Shape)
|
|
}
|
|
}
|
|
|
|
func TestTF2ModelBCHWLoad(t *testing.T) {
|
|
model, err := SavedModel(
|
|
filepath.Join(testDataPath, "tf2_bchw"),
|
|
[]string{"serve"})
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
input, output, err := GetInputAndOutputFromSavedModel(model)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
switch {
|
|
case input == nil:
|
|
t.Fatal("Could not get the input")
|
|
case output == nil:
|
|
t.Fatal("Could not get the output")
|
|
case input.Shape == nil:
|
|
t.Fatal("Could not get the shape")
|
|
case !slices.Equal(input.Shape, []ShapeComponent{
|
|
ShapeBatch, ShapeColor, ShapeHeight, ShapeWidth,
|
|
}):
|
|
t.Fatalf("Invalid shape calculated. Expected BCHW, got %v", input.Shape)
|
|
}
|
|
}
|