AI: Add TensorFlow model shape detection #127 #5164

* AI: Added support for non BHWC models

Tensorflow models use BHWC by default, however, if we are using
converted models, we can find that the expected input is BCHW. Now the
input is configurable (although the restriction of being dimesion 4 is
still there) via Shape parameter on the input definition. Also, the
model instrospection will try to deduce the input shape from the model
signature.

* AI: Added more tests for enum parsing

ShapeComponent was missing from the tests

* AI: Modified external tests to the new url

The path has been moved from tensorflow/vision to tensorflow/models

* AI: Moved the builder to the model to reuse it

It should reduce the amount of allocations done

* AI: fixed errors after merge

Mainly incorrect paths and duplicated variables
This commit is contained in:
raystlin
2025-08-16 15:55:59 +02:00
committed by GitHub
parent 2a7351ee9a
commit 519a6ab34a
15 changed files with 502 additions and 157 deletions

View File

@@ -2,6 +2,7 @@ package tensorflow
import (
"encoding/json"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
@@ -16,6 +17,22 @@ var allOperations = []ResizeOperation{
Padding,
}
func TestGetModelTagsInfo(t *testing.T) {
info, err := GetModelTagsInfo(
filepath.Join(assetsPath, "models", "nasnet"))
if err != nil {
t.Fatal(err)
}
if len(info) != 1 {
t.Fatalf("Expected 1 info but got %d", len(info))
} else if len(info[0].Tags) != 1 {
t.Fatalf("Expected 1 tag, but got %d", len(info[0].Tags))
} else if info[0].Tags[0] != "photoprism" {
t.Fatalf("Expected tag photoprism, but have %s", info[0].Tags[0])
}
}
func TestResizeOperations(t *testing.T) {
for i := range allOperations {
text := allOperations[i].String()
@@ -119,7 +136,7 @@ func TestColorChannelOrderJSON(t *testing.T) {
[]byte(exampleOrderJSON), &order)
if err != nil {
t.Fatal("could not unmarshal the example operation")
t.Fatal("could not unmarshal the example color order")
}
for i := range allColorChannelOrders {
@@ -148,7 +165,7 @@ func TestColorChannelOrderYAML(t *testing.T) {
[]byte(exampleOrderYAML), &order)
if err != nil {
t.Fatal("could not unmarshal the example operation")
t.Fatal("could not unmarshal the example color order")
}
for i := range allColorChannelOrders {
@@ -193,3 +210,68 @@ func TestOrderIndices(t *testing.T) {
assert.Equal(t, powerFx(r)+2*powerFx(g)+3*powerFx(b), int(allColorChannelOrders[i]))
}
}
var allShapeComponents = []ShapeComponent{
ShapeBatch,
ShapeWidth,
ShapeHeight,
ShapeColor,
}
const exampleShapeComponentJSON = `"Batch"`
func TestShapeComponentJSON(t *testing.T) {
var comp ShapeComponent
err := json.Unmarshal(
[]byte(exampleShapeComponentJSON), &comp)
if err != nil {
t.Fatal("could not unmarshal the example shape component")
}
for i := range allShapeComponents {
serialized, err := json.Marshal(allShapeComponents[i])
if err != nil {
t.Fatalf("could not marshal %v: %v",
allShapeComponents[i], err)
}
err = json.Unmarshal(serialized, &comp)
if err != nil {
t.Fatalf("could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, comp, allShapeComponents[i])
}
}
const exampleShapeComponentYAML = "Batch"
func TestShapeComponentYAML(t *testing.T) {
var comp ShapeComponent
err := yaml.Unmarshal(
[]byte(exampleShapeComponentYAML), &comp)
if err != nil {
t.Fatal("could not unmarshal the example operation")
}
for i := range allShapeComponents {
serialized, err := yaml.Marshal(allShapeComponents[i])
if err != nil {
t.Fatalf("could not marshal %v: %v",
allShapeComponents[i], err)
}
err = yaml.Unmarshal(serialized, &comp)
if err != nil {
t.Fatalf("could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, comp, allShapeComponents[i])
}
}