Files
photoprism/internal/ai/classify/model_test.go
raystlin 519a6ab34a AI: Add TensorFlow model shape detection #127 #5164
* AI: Added support for non BHWC models

Tensorflow models use BHWC by default, however, if we are using
converted models, we can find that the expected input is BCHW. Now the
input is configurable (although the restriction of being dimesion 4 is
still there) via Shape parameter on the input definition. Also, the
model instrospection will try to deduce the input shape from the model
signature.

* AI: Added more tests for enum parsing

ShapeComponent was missing from the tests

* AI: Modified external tests to the new url

The path has been moved from tensorflow/vision to tensorflow/models

* AI: Moved the builder to the model to reuse it

It should reduce the amount of allocations done

* AI: fixed errors after merge

Mainly incorrect paths and duplicated variables
2025-08-16 15:55:59 +02:00

364 lines
8.1 KiB
Go

package classify
import (
"os"
"path/filepath"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/ai/tensorflow"
"github.com/photoprism/photoprism/pkg/fs"
)
var assetsPath = fs.Abs("../../../assets")
var examplesPath = filepath.Join(assetsPath, "examples")
var modelsPath = filepath.Join(assetsPath, "models")
var modelPath = modelsPath + "/nasnet"
var once sync.Once
var testInstance *Model
func NewModelTest(t *testing.T) *Model {
once.Do(func() {
testInstance = NewNasnet(modelsPath, false)
if err := testInstance.loadModel(); err != nil {
t.Fatal(err)
}
})
return testInstance
}
func TestModel_CenterCrop(t *testing.T) {
model := NewNasnet(modelsPath, false)
if err := model.loadModel(); err != nil {
t.Fatal(err)
}
model.meta.Input.ResizeOperation = tensorflow.CenterCrop
t.Run("nasnet padding", func(t *testing.T) {
runBasicLabelsTest(t, model, 6)
})
}
func TestModel_Padding(t *testing.T) {
model := NewNasnet(modelsPath, false)
if err := model.loadModel(); err != nil {
t.Fatal(err)
}
model.meta.Input.ResizeOperation = tensorflow.Padding
t.Run("nasnet padding", func(t *testing.T) {
runBasicLabelsTest(t, model, 6)
})
}
func TestModel_ResizeBreakAspectRatio(t *testing.T) {
model := NewNasnet(modelsPath, false)
if err := model.loadModel(); err != nil {
t.Fatal(err)
}
model.meta.Input.ResizeOperation = tensorflow.ResizeBreakAspectRatio
t.Run("nasnet break aspect ratio", func(t *testing.T) {
runBasicLabelsTest(t, model, 4)
})
}
func runBasicLabelsTest(t *testing.T, model *Model, expectedUncertainty int) {
result, err := model.File(examplesPath+"/zebra_green_brown.jpg", 10)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 1, len(result))
if len(result) > 0 {
assert.Equal(t, "zebra", result[0].Name)
assert.Equal(t, expectedUncertainty, result[0].Uncertainty)
}
}
func TestModel_LabelsFromFile(t *testing.T) {
t.Run("chameleon_lime.jpg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 1, len(result))
if len(result) > 0 {
t.Logf("result: %#v", result[0])
assert.Equal(t, "chameleon", result[0].Name)
assert.Equal(t, 7, result[0].Uncertainty)
}
})
t.Run("cat_224.jpeg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
result, err := tensorFlow.File(examplesPath+"/cat_224.jpeg", 10)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 1, len(result))
if len(result) > 0 {
assert.Equal(t, "cat", result[0].Name)
assert.Equal(t, 59, result[0].Uncertainty)
}
})
t.Run("cat_720.jpeg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
result, err := tensorFlow.File(examplesPath+"/cat_720.jpeg", 10)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 3, len(result))
// t.Logf("labels: %#v", result)
if len(result) > 0 {
assert.Equal(t, "cat", result[0].Name)
assert.Equal(t, 60, result[0].Uncertainty)
}
})
t.Run("green.jpg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
result, err := tensorFlow.File(examplesPath+"/green.jpg", 10)
t.Logf("labels: %#v", result)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 1, len(result))
if len(result) > 0 {
assert.Equal(t, "outdoor", result[0].Name)
assert.Equal(t, 70, result[0].Uncertainty)
}
})
t.Run("NotExistingFile", func(t *testing.T) {
tensorFlow := NewModelTest(t)
result, err := tensorFlow.File(examplesPath+"/notexisting.jpg", 10)
assert.Contains(t, err.Error(), "no such file or directory")
assert.Empty(t, result)
})
t.Run("Disabled", func(t *testing.T) {
tensorFlow := NewNasnet(modelsPath, true)
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
assert.Nil(t, err)
if err != nil {
t.Fatal(err)
}
assert.Nil(t, result)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 0, len(result))
t.Log(result)
})
}
func TestModel_Run(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
t.Run("chameleon_lime.jpg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
if imageBuffer, err := os.ReadFile(examplesPath + "/chameleon_lime.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
t.Log(result)
assert.NotNil(t, result)
if err != nil {
t.Fatal(err)
}
assert.IsType(t, Labels{}, result)
assert.Equal(t, 1, len(result))
if len(result) > 0 {
assert.Equal(t, "chameleon", result[0].Name)
assert.Equal(t, 100-93, result[0].Uncertainty)
}
}
})
t.Run("dog_orange.jpg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
t.Log(result)
assert.NotNil(t, result)
if err != nil {
t.Fatal(err)
}
assert.IsType(t, Labels{}, result)
assert.Equal(t, 1, len(result))
if len(result) > 0 {
assert.Equal(t, "dog", result[0].Name)
assert.Equal(t, 34, result[0].Uncertainty)
}
}
})
t.Run("Random.docx", func(t *testing.T) {
tensorFlow := NewModelTest(t)
if imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
assert.Empty(t, result)
assert.Error(t, err)
}
})
t.Run("6720px_white.jpg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
if imageBuffer, err := os.ReadFile(examplesPath + "/6720px_white.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
if err != nil {
t.Fatal(err)
}
assert.Empty(t, result)
}
})
t.Run("Disabled", func(t *testing.T) {
tensorFlow := NewNasnet(modelsPath, true)
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
t.Log(result)
assert.Nil(t, result)
assert.Nil(t, err)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 0, len(result))
}
})
}
func TestModel_LoadModel(t *testing.T) {
t.Run("Success", func(t *testing.T) {
tf := NewModelTest(t)
assert.True(t, tf.ModelLoaded())
})
t.Run("NotFound", func(t *testing.T) {
tensorFlow := NewNasnet(modelsPath+"foo", false)
err := tensorFlow.loadModel()
if err != nil {
assert.Contains(t, err.Error(), "not find SavedModel")
}
assert.Error(t, err)
})
}
func TestModel_BestLabels(t *testing.T) {
t.Run("Success", func(t *testing.T) {
tensorFlow := NewNasnet(modelsPath, false)
if err := tensorFlow.loadLabels(modelPath); err != nil {
t.Fatal(err)
}
p := make([]float32, 1000)
p[8] = 0.7
p[1] = 0.5
result := tensorFlow.bestLabels(p, 10)
assert.Equal(t, "chicken", result[0].Name)
assert.Equal(t, "bird", result[0].Categories[0])
assert.Equal(t, "image", result[0].Source)
t.Log(result)
})
t.Run("NotLoaded", func(t *testing.T) {
tensorFlow := NewNasnet(modelsPath, false)
p := make([]float32, 1000)
p[666] = 0.5
result := tensorFlow.bestLabels(p, 10)
assert.Empty(t, result)
})
}
func BenchmarkModel_BestLabelWithOptimization(b *testing.B) {
model := NewNasnet(assetsPath, false)
err := model.loadModel()
if err != nil {
b.Fatal(err)
}
imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg")
if err != nil {
b.Fatal(err)
}
for b.Loop() {
_, err := model.Run(imageBuffer, 10)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkModel_BestLabelsNoOptimization(b *testing.B) {
model := NewNasnet(assetsPath, false)
err := model.loadModel()
if err != nil {
b.Fatal(err)
}
model.builder = nil
imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg")
if err != nil {
b.Fatal(err)
}
for b.Loop() {
_, err := model.Run(imageBuffer, 10)
if err != nil {
b.Fatal(err)
}
}
}