mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
New parameters have been added to define the input of the models: * ResizeOperation: by default center-crop was being performed, now it is configurable. * InputOrder: by default RGB was being used as the order for the array values of the input tensor, now it can be configured. * InputInterval has been changed to InputIntervals (an slice). This means that every channel can have its own interval conversion. * InputInterval can define now stddev and mean, because sometimes instead of adjusting the interval, the stddev and mean of the training data should be use.
423 lines
10 KiB
Go
423 lines
10 KiB
Go
package classify
|
|
|
|
import (
|
|
"archive/tar"
|
|
"compress/gzip"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/photoprism/photoprism/internal/ai/tensorflow"
|
|
)
|
|
|
|
const (
|
|
DefaultResolution = 224
|
|
ExternalModelsTestLabel = "PHOTOPRISM_TEST_EXTERNAL_MODELS"
|
|
)
|
|
|
|
var baseUrl = "https://dl.photoprism.app/tensorflow/vision"
|
|
|
|
//To avoid downloading everything again and again...
|
|
//var baseUrl = "http://host.docker.internal:8000"
|
|
|
|
var modelsInfo = map[string]*tensorflow.ModelInfo{
|
|
"efficientnet-v2-tensorflow2-imagenet1k-b0-classification-v2.tar.gz": {
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
"efficientnet-v2-tensorflow2-imagenet1k-m-classification-v2.tar.gz": {
|
|
Input: &tensorflow.PhotoInput{
|
|
Height: 480,
|
|
Width: 480,
|
|
},
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
"efficientnet-v2-tensorflow2-imagenet21k-b0-classification-v1.tar.gz": {
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
"inception-v3-tensorflow2-classification-v2.tar.gz": {
|
|
Input: &tensorflow.PhotoInput{
|
|
Height: 299,
|
|
Width: 299,
|
|
},
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
"resnet-v2-tensorflow2-101-classification-v2.tar.gz": {
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
"resnet-v2-tensorflow2-152-classification-v2.tar.gz": {
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
"vision-transformer-tensorflow2-vit-b16-classification-v1.tar.gz": {
|
|
Input: &tensorflow.PhotoInput{
|
|
Intervals: []tensorflow.Interval{
|
|
{
|
|
Start: -1.0,
|
|
End: 1.0,
|
|
},
|
|
},
|
|
},
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
func isSafePath(target, baseDir string) bool {
|
|
|
|
// Resolve the absolute path of the target
|
|
absTarget := filepath.Join(baseDir, target)
|
|
absBase, err := filepath.Abs(baseDir)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
return strings.HasPrefix(absTarget, absBase)
|
|
}
|
|
|
|
func TestExternalModel_AllModels(t *testing.T) {
|
|
|
|
if os.Getenv(ExternalModelsTestLabel) == "" {
|
|
t.Skipf("Skipping external model tests. To test them add set env var %s=true",
|
|
ExternalModelsTestLabel)
|
|
}
|
|
|
|
tmpPath, err := os.MkdirTemp("", "*-photoprism")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer os.RemoveAll(tmpPath)
|
|
|
|
for k, v := range modelsInfo {
|
|
t.Run(k, func(*testing.T) {
|
|
log.Infof("Testing model %s", k)
|
|
|
|
downloadedModel := downloadRemoteModel(t, fmt.Sprintf("%s/%s", baseUrl, k), tmpPath)
|
|
log.Infof("Model downloaded to %s", downloadedModel)
|
|
|
|
model := NewModel(tmpPath, downloadedModel, modelPath, v, false)
|
|
if err := model.loadModel(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if model.meta.Input.IsDynamic() {
|
|
model.meta.Input.SetResolution(DefaultResolution)
|
|
}
|
|
|
|
testModel_LabelsFromFile(t, model)
|
|
testModel_Run(t, model)
|
|
})
|
|
}
|
|
}
|
|
|
|
func downloadRemoteModel(t *testing.T, url, tmpPath string) (model string) {
|
|
t.Logf("Downloading %s to %s", url, tmpPath)
|
|
|
|
modelPath := strings.TrimSuffix(path.Base(url), ".tar.gz")
|
|
tmpPath = filepath.Join(tmpPath, modelPath)
|
|
os.MkdirAll(tmpPath, 0755)
|
|
|
|
resp, err := http.Get(url)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
|
t.Fatalf("Invalid status code for url %s: %d", url, resp.StatusCode)
|
|
}
|
|
|
|
uncompressedBody, err := gzip.NewReader(resp.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
tarReader := tar.NewReader(uncompressedBody)
|
|
for true {
|
|
header, err := tarReader.Next()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
|
|
if err != nil {
|
|
t.Fatalf("Could not extract the file: %v", err)
|
|
}
|
|
|
|
target := filepath.Join(tmpPath, header.Name)
|
|
if !isSafePath(target, tmpPath) {
|
|
t.Fatalf("The model file contains an invalid path: %s", header.Name)
|
|
}
|
|
|
|
switch header.Typeflag {
|
|
case tar.TypeDir:
|
|
if err := os.Mkdir(target, 0755); err != nil {
|
|
t.Fatalf("Could not make the dir %s: %v", header.Name, err)
|
|
}
|
|
case tar.TypeReg:
|
|
outFile, err := os.Create(target)
|
|
if err != nil {
|
|
t.Fatalf("Could not create file %s: %v", header.Name, err)
|
|
}
|
|
if _, err := io.Copy(outFile, tarReader); err != nil {
|
|
t.Fatalf("Could not copy file %s: %v", header.Name, err)
|
|
}
|
|
|
|
rootPath, fileName := filepath.Split(header.Name)
|
|
if fileName == "saved_model.pb" {
|
|
model = filepath.Join(modelPath, rootPath)
|
|
}
|
|
outFile.Close()
|
|
default:
|
|
t.Fatalf("Could not extract file. Unknown type %v in %s",
|
|
header.Typeflag,
|
|
header.Name)
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func containsAny(s string, substrings []string) bool {
|
|
for i := range substrings {
|
|
if strings.Contains(s, substrings[i]) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func assertContainsAny(t *testing.T, s string, substrings []string) {
|
|
assert.Truef(t, containsAny(s, substrings),
|
|
"The result [%s] does not contain any of %v",
|
|
s, substrings)
|
|
}
|
|
|
|
func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) {
|
|
testName := func(name string) string {
|
|
return fmt.Sprintf("%s/%s", tensorFlow.modelPath, name)
|
|
}
|
|
|
|
t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) {
|
|
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
if len(result) != 1 {
|
|
t.Logf("Expected 1 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assert.Contains(t, result[0].Name, "chameleon")
|
|
//assert.Equal(t, 7, result[0].Uncertainty)
|
|
}
|
|
})
|
|
t.Run(testName("cat_224.jpeg"), func(t *testing.T) {
|
|
result, err := tensorFlow.File(examplesPath+"/cat_224.jpeg", 10)
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
if len(result) != 1 {
|
|
t.Logf("Expected 1 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assertContainsAny(t, result[0].Name, []string{"cat", "kitty"})
|
|
//assert.Equal(t, 59, result[0].Uncertainty)
|
|
}
|
|
})
|
|
t.Run(testName("cat_720.jpeg"), func(t *testing.T) {
|
|
result, err := tensorFlow.File(examplesPath+"/cat_720.jpeg", 10)
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
assert.IsType(t, Labels{}, result)
|
|
//assert.Equal(t, 3, len(result))
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
// t.Logf("labels: %#v", result)
|
|
if len(result) != 3 {
|
|
t.Logf("Expected 3 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assertContainsAny(t, result[0].Name, []string{"cat", "kitty"})
|
|
//assert.Equal(t, 60, result[0].Uncertainty)
|
|
}
|
|
})
|
|
t.Run(testName("green.jpg"), func(t *testing.T) {
|
|
result, err := tensorFlow.File(examplesPath+"/green.jpg", 10)
|
|
|
|
t.Logf("labels: %#v", result)
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
if len(result) != 1 {
|
|
t.Logf("Expected 1 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assert.Equal(t, "outdoor", result[0].Name)
|
|
|
|
//assert.Equal(t, 70, result[0].Uncertainty)
|
|
}
|
|
})
|
|
t.Run(testName("not existing file"), func(t *testing.T) {
|
|
result, err := tensorFlow.File(examplesPath+"/notexisting.jpg", 10)
|
|
assert.Contains(t, err.Error(), "no such file or directory")
|
|
assert.Empty(t, result)
|
|
})
|
|
t.Run(testName("disabled true"), func(t *testing.T) {
|
|
tensorFlow.disabled = true
|
|
defer func() { tensorFlow.disabled = false }()
|
|
|
|
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
|
|
assert.Nil(t, err)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
assert.Nil(t, result)
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.Equal(t, 0, len(result))
|
|
|
|
t.Log(result)
|
|
})
|
|
}
|
|
|
|
func testModel_Run(t *testing.T, tensorFlow *Model) {
|
|
if testing.Short() {
|
|
t.Skip("skipping test in short mode.")
|
|
}
|
|
|
|
testName := func(name string) string {
|
|
return fmt.Sprintf("%s/%s", tensorFlow.modelPath, name)
|
|
}
|
|
|
|
t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) {
|
|
if imageBuffer, err := os.ReadFile(examplesPath + "/chameleon_lime.jpg"); err != nil {
|
|
t.Error(err)
|
|
} else {
|
|
result, err := tensorFlow.Run(imageBuffer, 10)
|
|
|
|
t.Log(result)
|
|
|
|
assert.NotNil(t, result)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
if len(result) != 1 {
|
|
t.Logf("Expected 1 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assert.Contains(t, result[0].Name, "chameleon")
|
|
//assert.Equal(t, 100-93, result[0].Uncertainty)
|
|
}
|
|
}
|
|
})
|
|
t.Run(testName("dog_orange.jpg"), func(t *testing.T) {
|
|
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
|
|
t.Error(err)
|
|
} else {
|
|
result, err := tensorFlow.Run(imageBuffer, 10)
|
|
|
|
t.Log(result)
|
|
|
|
assert.NotNil(t, result)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
if len(result) != 1 {
|
|
t.Logf("Expected 1 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assertContainsAny(t, result[0].Name, []string{"dog", "corgi"})
|
|
//assert.Equal(t, 34, result[0].Uncertainty)
|
|
}
|
|
}
|
|
})
|
|
t.Run(testName("Random.docx"), func(t *testing.T) {
|
|
if imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx"); err != nil {
|
|
t.Error(err)
|
|
} else {
|
|
result, err := tensorFlow.Run(imageBuffer, 10)
|
|
assert.Empty(t, result)
|
|
assert.Error(t, err)
|
|
}
|
|
})
|
|
t.Run(testName("6720px_white.jpg"), func(t *testing.T) {
|
|
if imageBuffer, err := os.ReadFile(examplesPath + "/6720px_white.jpg"); err != nil {
|
|
t.Error(err)
|
|
} else {
|
|
result, err := tensorFlow.Run(imageBuffer, 10)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
assert.Empty(t, result)
|
|
}
|
|
})
|
|
t.Run(testName("disabled true"), func(t *testing.T) {
|
|
tensorFlow.disabled = true
|
|
defer func() { tensorFlow.disabled = false }()
|
|
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
|
|
t.Error(err)
|
|
} else {
|
|
result, err := tensorFlow.Run(imageBuffer, 10)
|
|
|
|
t.Log(result)
|
|
|
|
assert.Nil(t, result)
|
|
|
|
assert.Nil(t, err)
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.Equal(t, 0, len(result))
|
|
}
|
|
})
|
|
}
|