mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 08:44:04 +01:00
If a model needs to have its labels downloaded from another source, it can now be added to the test information.
464 lines
11 KiB
Go
464 lines
11 KiB
Go
package classify
|
|
|
|
import (
|
|
"archive/tar"
|
|
"compress/gzip"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/photoprism/photoprism/internal/ai/tensorflow"
|
|
)
|
|
|
|
const (
|
|
DefaultResolution = 224
|
|
ExternalModelsTestLabel = "PHOTOPRISM_TEST_EXTERNAL_MODELS"
|
|
)
|
|
|
|
var baseUrl = "https://dl.photoprism.app/tensorflow/vision"
|
|
|
|
//To avoid downloading everything again and again...
|
|
//var baseUrl = "http://host.docker.internal:8000"
|
|
|
|
type ModelTestCase struct {
|
|
Info *tensorflow.ModelInfo
|
|
Labels string
|
|
}
|
|
|
|
var modelsInfo = map[string]*ModelTestCase{
|
|
"efficientnet-v2-tensorflow2-imagenet1k-b0-classification-v2.tar.gz": {
|
|
Info: &tensorflow.ModelInfo{
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
},
|
|
"efficientnet-v2-tensorflow2-imagenet1k-m-classification-v2.tar.gz": {
|
|
Info: &tensorflow.ModelInfo{
|
|
|
|
Input: &tensorflow.PhotoInput{
|
|
Height: 480,
|
|
Width: 480,
|
|
},
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
},
|
|
"efficientnet-v2-tensorflow2-imagenet21k-b0-classification-v1.tar.gz": {
|
|
Info: &tensorflow.ModelInfo{
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
Labels: "labels-imagenet21k.txt",
|
|
},
|
|
"inception-v3-tensorflow2-classification-v2.tar.gz": {
|
|
Info: &tensorflow.ModelInfo{
|
|
Input: &tensorflow.PhotoInput{
|
|
Height: 299,
|
|
Width: 299,
|
|
},
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
},
|
|
"resnet-v2-tensorflow2-101-classification-v2.tar.gz": {
|
|
Info: &tensorflow.ModelInfo{
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
},
|
|
"resnet-v2-tensorflow2-152-classification-v2.tar.gz": {
|
|
Info: &tensorflow.ModelInfo{
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
},
|
|
"vision-transformer-tensorflow2-vit-b16-classification-v1.tar.gz": {
|
|
Info: &tensorflow.ModelInfo{
|
|
Input: &tensorflow.PhotoInput{
|
|
Intervals: []tensorflow.Interval{
|
|
{
|
|
Start: -1.0,
|
|
End: 1.0,
|
|
},
|
|
},
|
|
},
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
func isSafePath(target, baseDir string) bool {
|
|
|
|
// Resolve the absolute path of the target
|
|
absTarget := filepath.Join(baseDir, target)
|
|
absBase, err := filepath.Abs(baseDir)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
return strings.HasPrefix(absTarget, absBase)
|
|
}
|
|
|
|
func TestExternalModel_AllModels(t *testing.T) {
|
|
|
|
if os.Getenv(ExternalModelsTestLabel) == "" {
|
|
t.Skipf("Skipping external model tests. To test them add set env var %s=true",
|
|
ExternalModelsTestLabel)
|
|
}
|
|
|
|
tmpPath, err := os.MkdirTemp("", "*-photoprism")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer os.RemoveAll(tmpPath)
|
|
|
|
for k, v := range modelsInfo {
|
|
t.Run(k, func(*testing.T) {
|
|
log.Infof("Testing model %s", k)
|
|
|
|
downloadedModel := downloadRemoteModel(t, fmt.Sprintf("%s/%s", baseUrl, k), tmpPath)
|
|
log.Infof("Model downloaded to %s", downloadedModel)
|
|
|
|
if v.Labels != "" {
|
|
modelPath := filepath.Join(tmpPath, downloadedModel)
|
|
|
|
t.Logf("Model path: %s", modelPath)
|
|
downloadLabels(t, fmt.Sprintf("%s/%s", baseUrl, v.Labels), modelPath)
|
|
}
|
|
|
|
model := NewModel(tmpPath, downloadedModel, modelPath, v.Info, false)
|
|
if err := model.loadModel(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if model.meta.Input.IsDynamic() {
|
|
model.meta.Input.SetResolution(DefaultResolution)
|
|
}
|
|
|
|
testModel_LabelsFromFile(t, model)
|
|
testModel_Run(t, model)
|
|
})
|
|
}
|
|
}
|
|
|
|
func downloadLabels(t *testing.T, url, dst string) {
|
|
resp, err := http.Get(url)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
output, err := os.Create(filepath.Join(dst, "labels.txt"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer output.Close()
|
|
|
|
_, err = io.Copy(output, resp.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func downloadRemoteModel(t *testing.T, url, tmpPath string) (model string) {
|
|
t.Logf("Downloading %s to %s", url, tmpPath)
|
|
|
|
modelPath := strings.TrimSuffix(path.Base(url), ".tar.gz")
|
|
tmpPath = filepath.Join(tmpPath, modelPath)
|
|
os.MkdirAll(tmpPath, 0755)
|
|
|
|
resp, err := http.Get(url)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
|
t.Fatalf("Invalid status code for url %s: %d", url, resp.StatusCode)
|
|
}
|
|
|
|
uncompressedBody, err := gzip.NewReader(resp.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
tarReader := tar.NewReader(uncompressedBody)
|
|
for true {
|
|
header, err := tarReader.Next()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
|
|
if err != nil {
|
|
t.Fatalf("Could not extract the file: %v", err)
|
|
}
|
|
|
|
target := filepath.Join(tmpPath, header.Name)
|
|
if !isSafePath(target, tmpPath) {
|
|
t.Fatalf("The model file contains an invalid path: %s", header.Name)
|
|
}
|
|
|
|
switch header.Typeflag {
|
|
case tar.TypeDir:
|
|
if err := os.Mkdir(target, 0755); err != nil {
|
|
t.Fatalf("Could not make the dir %s: %v", header.Name, err)
|
|
}
|
|
case tar.TypeReg:
|
|
outFile, err := os.Create(target)
|
|
if err != nil {
|
|
t.Fatalf("Could not create file %s: %v", header.Name, err)
|
|
}
|
|
if _, err := io.Copy(outFile, tarReader); err != nil {
|
|
t.Fatalf("Could not copy file %s: %v", header.Name, err)
|
|
}
|
|
|
|
rootPath, fileName := filepath.Split(header.Name)
|
|
if fileName == "saved_model.pb" {
|
|
model = filepath.Join(modelPath, rootPath)
|
|
}
|
|
outFile.Close()
|
|
default:
|
|
t.Fatalf("Could not extract file. Unknown type %v in %s",
|
|
header.Typeflag,
|
|
header.Name)
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func containsAny(s string, substrings []string) bool {
|
|
for i := range substrings {
|
|
if strings.Contains(s, substrings[i]) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func assertContainsAny(t *testing.T, s string, substrings []string) {
|
|
assert.Truef(t, containsAny(s, substrings),
|
|
"The result [%s] does not contain any of %v",
|
|
s, substrings)
|
|
}
|
|
|
|
func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) {
|
|
testName := func(name string) string {
|
|
return fmt.Sprintf("%s/%s", tensorFlow.modelPath, name)
|
|
}
|
|
|
|
t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) {
|
|
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
if len(result) != 1 {
|
|
t.Logf("Expected 1 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assert.Contains(t, result[0].Name, "chameleon")
|
|
//assert.Equal(t, 7, result[0].Uncertainty)
|
|
}
|
|
})
|
|
t.Run(testName("cat_224.jpeg"), func(t *testing.T) {
|
|
result, err := tensorFlow.File(examplesPath+"/cat_224.jpeg", 10)
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
if len(result) != 1 {
|
|
t.Logf("Expected 1 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assertContainsAny(t, result[0].Name, []string{"cat", "kitty"})
|
|
}
|
|
})
|
|
t.Run(testName("cat_720.jpeg"), func(t *testing.T) {
|
|
result, err := tensorFlow.File(examplesPath+"/cat_720.jpeg", 10)
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
assert.IsType(t, Labels{}, result)
|
|
//assert.Equal(t, 3, len(result))
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
// t.Logf("labels: %#v", result)
|
|
if len(result) != 3 {
|
|
t.Logf("Expected 3 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assertContainsAny(t, result[0].Name, []string{"cat", "kitty"})
|
|
}
|
|
})
|
|
t.Run(testName("green.jpg"), func(t *testing.T) {
|
|
result, err := tensorFlow.File(examplesPath+"/green.jpg", 10)
|
|
|
|
t.Logf("labels: %#v", result)
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
if len(result) != 1 {
|
|
t.Logf("Expected 1 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assert.Equal(t, "outdoor", result[0].Name)
|
|
}
|
|
})
|
|
t.Run(testName("not existing file"), func(t *testing.T) {
|
|
result, err := tensorFlow.File(examplesPath+"/notexisting.jpg", 10)
|
|
assert.Contains(t, err.Error(), "no such file or directory")
|
|
assert.Empty(t, result)
|
|
})
|
|
t.Run(testName("disabled true"), func(t *testing.T) {
|
|
tensorFlow.disabled = true
|
|
defer func() { tensorFlow.disabled = false }()
|
|
|
|
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
|
|
assert.Nil(t, err)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
assert.Nil(t, result)
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.Equal(t, 0, len(result))
|
|
|
|
t.Log(result)
|
|
})
|
|
}
|
|
|
|
func testModel_Run(t *testing.T, tensorFlow *Model) {
|
|
if testing.Short() {
|
|
t.Skip("skipping test in short mode.")
|
|
}
|
|
|
|
testName := func(name string) string {
|
|
return fmt.Sprintf("%s/%s", tensorFlow.modelPath, name)
|
|
}
|
|
|
|
t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) {
|
|
if imageBuffer, err := os.ReadFile(examplesPath + "/chameleon_lime.jpg"); err != nil {
|
|
t.Error(err)
|
|
} else {
|
|
result, err := tensorFlow.Run(imageBuffer, 10)
|
|
|
|
t.Log(result)
|
|
|
|
assert.NotNil(t, result)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
if len(result) != 1 {
|
|
t.Logf("Expected 1 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assert.Contains(t, result[0].Name, "chameleon")
|
|
}
|
|
}
|
|
})
|
|
t.Run(testName("dog_orange.jpg"), func(t *testing.T) {
|
|
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
|
|
t.Error(err)
|
|
} else {
|
|
result, err := tensorFlow.Run(imageBuffer, 10)
|
|
|
|
t.Log(result)
|
|
|
|
assert.NotNil(t, result)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
if len(result) != 1 {
|
|
t.Logf("Expected 1 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assertContainsAny(t, result[0].Name, []string{"dog", "corgi"})
|
|
}
|
|
}
|
|
})
|
|
t.Run(testName("Random.docx"), func(t *testing.T) {
|
|
if imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx"); err != nil {
|
|
t.Error(err)
|
|
} else {
|
|
result, err := tensorFlow.Run(imageBuffer, 10)
|
|
assert.Empty(t, result)
|
|
assert.Error(t, err)
|
|
}
|
|
})
|
|
t.Run(testName("6720px_white.jpg"), func(t *testing.T) {
|
|
if imageBuffer, err := os.ReadFile(examplesPath + "/6720px_white.jpg"); err != nil {
|
|
t.Error(err)
|
|
} else {
|
|
result, err := tensorFlow.Run(imageBuffer, 10)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
assert.Empty(t, result)
|
|
}
|
|
})
|
|
t.Run(testName("disabled true"), func(t *testing.T) {
|
|
tensorFlow.disabled = true
|
|
defer func() { tensorFlow.disabled = false }()
|
|
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
|
|
t.Error(err)
|
|
} else {
|
|
result, err := tensorFlow.Run(imageBuffer, 10)
|
|
|
|
t.Log(result)
|
|
|
|
assert.Nil(t, result)
|
|
|
|
assert.Nil(t, err)
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.Equal(t, 0, len(result))
|
|
}
|
|
})
|
|
}
|