mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
Added labels download to model_external_test
If a model needs to have its labels downloaded from another source, it can now be added to the test information.
This commit is contained in:
@@ -27,56 +27,77 @@ var baseUrl = "https://dl.photoprism.app/tensorflow/vision"
|
|||||||
//To avoid downloading everything again and again...
|
//To avoid downloading everything again and again...
|
||||||
//var baseUrl = "http://host.docker.internal:8000"
|
//var baseUrl = "http://host.docker.internal:8000"
|
||||||
|
|
||||||
var modelsInfo = map[string]*tensorflow.ModelInfo{
|
type ModelTestCase struct {
|
||||||
|
Info *tensorflow.ModelInfo
|
||||||
|
Labels string
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelsInfo = map[string]*ModelTestCase{
|
||||||
"efficientnet-v2-tensorflow2-imagenet1k-b0-classification-v2.tar.gz": {
|
"efficientnet-v2-tensorflow2-imagenet1k-b0-classification-v2.tar.gz": {
|
||||||
Output: &tensorflow.ModelOutput{
|
Info: &tensorflow.ModelInfo{
|
||||||
OutputsLogits: true,
|
Output: &tensorflow.ModelOutput{
|
||||||
|
OutputsLogits: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"efficientnet-v2-tensorflow2-imagenet1k-m-classification-v2.tar.gz": {
|
"efficientnet-v2-tensorflow2-imagenet1k-m-classification-v2.tar.gz": {
|
||||||
Input: &tensorflow.PhotoInput{
|
Info: &tensorflow.ModelInfo{
|
||||||
Height: 480,
|
|
||||||
Width: 480,
|
Input: &tensorflow.PhotoInput{
|
||||||
},
|
Height: 480,
|
||||||
Output: &tensorflow.ModelOutput{
|
Width: 480,
|
||||||
OutputsLogits: true,
|
},
|
||||||
|
Output: &tensorflow.ModelOutput{
|
||||||
|
OutputsLogits: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"efficientnet-v2-tensorflow2-imagenet21k-b0-classification-v1.tar.gz": {
|
"efficientnet-v2-tensorflow2-imagenet21k-b0-classification-v1.tar.gz": {
|
||||||
Output: &tensorflow.ModelOutput{
|
Info: &tensorflow.ModelInfo{
|
||||||
OutputsLogits: true,
|
Output: &tensorflow.ModelOutput{
|
||||||
|
OutputsLogits: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
Labels: "labels-imagenet21k.txt",
|
||||||
},
|
},
|
||||||
"inception-v3-tensorflow2-classification-v2.tar.gz": {
|
"inception-v3-tensorflow2-classification-v2.tar.gz": {
|
||||||
Input: &tensorflow.PhotoInput{
|
Info: &tensorflow.ModelInfo{
|
||||||
Height: 299,
|
Input: &tensorflow.PhotoInput{
|
||||||
Width: 299,
|
Height: 299,
|
||||||
},
|
Width: 299,
|
||||||
Output: &tensorflow.ModelOutput{
|
},
|
||||||
OutputsLogits: true,
|
Output: &tensorflow.ModelOutput{
|
||||||
|
OutputsLogits: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"resnet-v2-tensorflow2-101-classification-v2.tar.gz": {
|
"resnet-v2-tensorflow2-101-classification-v2.tar.gz": {
|
||||||
Output: &tensorflow.ModelOutput{
|
Info: &tensorflow.ModelInfo{
|
||||||
OutputsLogits: true,
|
Output: &tensorflow.ModelOutput{
|
||||||
|
OutputsLogits: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"resnet-v2-tensorflow2-152-classification-v2.tar.gz": {
|
"resnet-v2-tensorflow2-152-classification-v2.tar.gz": {
|
||||||
Output: &tensorflow.ModelOutput{
|
Info: &tensorflow.ModelInfo{
|
||||||
OutputsLogits: true,
|
Output: &tensorflow.ModelOutput{
|
||||||
|
OutputsLogits: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"vision-transformer-tensorflow2-vit-b16-classification-v1.tar.gz": {
|
"vision-transformer-tensorflow2-vit-b16-classification-v1.tar.gz": {
|
||||||
Input: &tensorflow.PhotoInput{
|
Info: &tensorflow.ModelInfo{
|
||||||
Intervals: []tensorflow.Interval{
|
Input: &tensorflow.PhotoInput{
|
||||||
{
|
Intervals: []tensorflow.Interval{
|
||||||
Start: -1.0,
|
{
|
||||||
End: 1.0,
|
Start: -1.0,
|
||||||
|
End: 1.0,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
Output: &tensorflow.ModelOutput{
|
||||||
Output: &tensorflow.ModelOutput{
|
OutputsLogits: true,
|
||||||
OutputsLogits: true,
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -113,7 +134,14 @@ func TestExternalModel_AllModels(t *testing.T) {
|
|||||||
downloadedModel := downloadRemoteModel(t, fmt.Sprintf("%s/%s", baseUrl, k), tmpPath)
|
downloadedModel := downloadRemoteModel(t, fmt.Sprintf("%s/%s", baseUrl, k), tmpPath)
|
||||||
log.Infof("Model downloaded to %s", downloadedModel)
|
log.Infof("Model downloaded to %s", downloadedModel)
|
||||||
|
|
||||||
model := NewModel(tmpPath, downloadedModel, modelPath, v, false)
|
if v.Labels != "" {
|
||||||
|
modelPath := filepath.Join(tmpPath, downloadedModel)
|
||||||
|
|
||||||
|
t.Logf("Model path: %s", modelPath)
|
||||||
|
downloadLabels(t, fmt.Sprintf("%s/%s", baseUrl, v.Labels), modelPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
model := NewModel(tmpPath, downloadedModel, modelPath, v.Info, false)
|
||||||
if err := model.loadModel(); err != nil {
|
if err := model.loadModel(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -128,6 +156,25 @@ func TestExternalModel_AllModels(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func downloadLabels(t *testing.T, url, dst string) {
|
||||||
|
resp, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
output, err := os.Create(filepath.Join(dst, "labels.txt"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer output.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(output, resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func downloadRemoteModel(t *testing.T, url, tmpPath string) (model string) {
|
func downloadRemoteModel(t *testing.T, url, tmpPath string) (model string) {
|
||||||
t.Logf("Downloading %s to %s", url, tmpPath)
|
t.Logf("Downloading %s to %s", url, tmpPath)
|
||||||
|
|
||||||
@@ -248,7 +295,6 @@ func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) {
|
|||||||
|
|
||||||
if len(result) > 0 {
|
if len(result) > 0 {
|
||||||
assertContainsAny(t, result[0].Name, []string{"cat", "kitty"})
|
assertContainsAny(t, result[0].Name, []string{"cat", "kitty"})
|
||||||
//assert.Equal(t, 59, result[0].Uncertainty)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run(testName("cat_720.jpeg"), func(t *testing.T) {
|
t.Run(testName("cat_720.jpeg"), func(t *testing.T) {
|
||||||
@@ -268,7 +314,6 @@ func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) {
|
|||||||
|
|
||||||
if len(result) > 0 {
|
if len(result) > 0 {
|
||||||
assertContainsAny(t, result[0].Name, []string{"cat", "kitty"})
|
assertContainsAny(t, result[0].Name, []string{"cat", "kitty"})
|
||||||
//assert.Equal(t, 60, result[0].Uncertainty)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run(testName("green.jpg"), func(t *testing.T) {
|
t.Run(testName("green.jpg"), func(t *testing.T) {
|
||||||
@@ -287,8 +332,6 @@ func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) {
|
|||||||
|
|
||||||
if len(result) > 0 {
|
if len(result) > 0 {
|
||||||
assert.Equal(t, "outdoor", result[0].Name)
|
assert.Equal(t, "outdoor", result[0].Name)
|
||||||
|
|
||||||
//assert.Equal(t, 70, result[0].Uncertainty)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run(testName("not existing file"), func(t *testing.T) {
|
t.Run(testName("not existing file"), func(t *testing.T) {
|
||||||
@@ -348,7 +391,6 @@ func testModel_Run(t *testing.T, tensorFlow *Model) {
|
|||||||
|
|
||||||
if len(result) > 0 {
|
if len(result) > 0 {
|
||||||
assert.Contains(t, result[0].Name, "chameleon")
|
assert.Contains(t, result[0].Name, "chameleon")
|
||||||
//assert.Equal(t, 100-93, result[0].Uncertainty)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -376,7 +418,6 @@ func testModel_Run(t *testing.T, tensorFlow *Model) {
|
|||||||
|
|
||||||
if len(result) > 0 {
|
if len(result) > 0 {
|
||||||
assertContainsAny(t, result[0].Name, []string{"dog", "corgi"})
|
assertContainsAny(t, result[0].Name, []string{"dog", "corgi"})
|
||||||
//assert.Equal(t, 34, result[0].Uncertainty)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user