mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
@@ -149,7 +149,7 @@ func (m *Model) Run(img []byte, confidenceThreshold int) (result Labels, err err
|
||||
nil)
|
||||
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("classify: %s (run inference)", err.Error())
|
||||
return result, fmt.Errorf("classify: %s (run inference)", clean.Error(err))
|
||||
}
|
||||
|
||||
if len(output) < 1 {
|
||||
@@ -173,7 +173,7 @@ func (m *Model) loadLabels(modelPath string) (err error) {
|
||||
|
||||
m.labels, err = tensorflow.LoadLabels(modelPath, numLabels)
|
||||
if os.IsNotExist(err) {
|
||||
log.Infof("Model does not seem to have tags at %s, trying %s", modelPath, m.defaultLabelsPath)
|
||||
log.Infof("vision: model does not seem to have tags at %s, trying %s", clean.Log(modelPath), clean.Log(m.defaultLabelsPath))
|
||||
m.labels, err = tensorflow.LoadLabels(m.defaultLabelsPath, numLabels)
|
||||
}
|
||||
return err
|
||||
@@ -197,29 +197,30 @@ func (m *Model) loadModel() (err error) {
|
||||
modelPath := path.Join(m.assetsPath, m.modelPath)
|
||||
|
||||
if len(m.meta.Tags) == 0 {
|
||||
infos, err := tensorflow.GetModelInfo(modelPath)
|
||||
if err != nil {
|
||||
log.Errorf("classify: could not get the model info at %s: %v", clean.Log(modelPath), err)
|
||||
infos, modelErr := tensorflow.GetModelInfo(modelPath)
|
||||
if modelErr != nil {
|
||||
log.Errorf("classify: could not get info from model in %s (%s)", clean.Log(modelPath), clean.Error(modelErr))
|
||||
} else if len(infos) == 1 {
|
||||
log.Debugf("classify: model info: %+v", infos[0])
|
||||
m.meta.Merge(&infos[0])
|
||||
} else {
|
||||
log.Warnf("classify: found %d metagraphs... that's too many", len(infos))
|
||||
log.Warnf("classify: found %d metagraphs, which is too many", len(infos))
|
||||
}
|
||||
}
|
||||
|
||||
m.model, err = tensorflow.SavedModel(modelPath, m.meta.Tags)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !m.meta.IsComplete() {
|
||||
input, output, err := tensorflow.GetInputAndOutputFromSavedModel(m.model)
|
||||
if err != nil {
|
||||
log.Errorf("classify: could not get info from signatures: %v", err)
|
||||
input, output, err = tensorflow.GuessInputAndOutput(m.model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("classify: %w", err)
|
||||
input, output, modelErr := tensorflow.GetInputAndOutputFromSavedModel(m.model)
|
||||
if modelErr != nil {
|
||||
log.Errorf("classify: could not get info from signatures (%s)", clean.Error(modelErr))
|
||||
input, output, modelErr = tensorflow.GuessInputAndOutput(m.model)
|
||||
if modelErr != nil {
|
||||
return fmt.Errorf("classify: %w", modelErr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -232,7 +233,7 @@ func (m *Model) loadModel() (err error) {
|
||||
if m.meta.Output.OutputsLogits {
|
||||
_, err = tensorflow.AddSoftmax(m.model.Graph, m.meta)
|
||||
if err != nil {
|
||||
return fmt.Errorf("classify: could not add softmax: %w", nil)
|
||||
return fmt.Errorf("classify: could not add softmax (%s)", clean.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -129,15 +129,15 @@ func TestExternalModel_AllModels(t *testing.T) {
|
||||
|
||||
for k, v := range modelsInfo {
|
||||
t.Run(k, func(*testing.T) {
|
||||
log.Infof("Testing model %s", k)
|
||||
log.Infof("vision: testing model %s", k)
|
||||
|
||||
downloadedModel := downloadRemoteModel(t, fmt.Sprintf("%s/%s", baseUrl, k), tmpPath)
|
||||
log.Infof("Model downloaded to %s", downloadedModel)
|
||||
log.Infof("vision: model downloaded to %s", downloadedModel)
|
||||
|
||||
if v.Labels != "" {
|
||||
modelPath := filepath.Join(tmpPath, downloadedModel)
|
||||
|
||||
t.Logf("Model path: %s", modelPath)
|
||||
t.Logf("vision: model path is %s", modelPath)
|
||||
downloadLabels(t, fmt.Sprintf("%s/%s", baseUrl, v.Labels), modelPath)
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ func TestModel_CenterCrop(t *testing.T) {
|
||||
model.meta.Input.ResizeOperation = tensorflow.CenterCrop
|
||||
|
||||
t.Run("nasnet padding", func(t *testing.T) {
|
||||
testModel_BasicLabels(t, model, 6)
|
||||
runBasicLabelsTest(t, model, 6)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func TestModel_Padding(t *testing.T) {
|
||||
model.meta.Input.ResizeOperation = tensorflow.Padding
|
||||
|
||||
t.Run("nasnet padding", func(t *testing.T) {
|
||||
testModel_BasicLabels(t, model, 6)
|
||||
runBasicLabelsTest(t, model, 6)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -63,11 +63,11 @@ func TestModel_ResizeBreakAspectRatio(t *testing.T) {
|
||||
model.meta.Input.ResizeOperation = tensorflow.ResizeBreakAspectRatio
|
||||
|
||||
t.Run("nasnet break aspect ratio", func(t *testing.T) {
|
||||
testModel_BasicLabels(t, model, 4)
|
||||
runBasicLabelsTest(t, model, 4)
|
||||
})
|
||||
}
|
||||
|
||||
func testModel_BasicLabels(t *testing.T, model *Model, expectedUncertainty int) {
|
||||
func runBasicLabelsTest(t *testing.T, model *Model, expectedUncertainty int) {
|
||||
result, err := model.File(examplesPath+"/zebra_green_brown.jpg", 10)
|
||||
|
||||
assert.NoError(t, err)
|
||||
@@ -146,14 +146,14 @@ func TestModel_LabelsFromFile(t *testing.T) {
|
||||
assert.Equal(t, 70, result[0].Uncertainty)
|
||||
}
|
||||
})
|
||||
t.Run("not existing file", func(t *testing.T) {
|
||||
t.Run("NotExistingFile", func(t *testing.T) {
|
||||
tensorFlow := NewModelTest(t)
|
||||
|
||||
result, err := tensorFlow.File(examplesPath+"/notexisting.jpg", 10)
|
||||
assert.Contains(t, err.Error(), "no such file or directory")
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
t.Run("disabled true", func(t *testing.T) {
|
||||
t.Run("Disabled", func(t *testing.T) {
|
||||
tensorFlow := NewNasnet(assetsPath, true)
|
||||
|
||||
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
|
||||
@@ -252,7 +252,7 @@ func TestModel_Run(t *testing.T) {
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
})
|
||||
t.Run("disabled true", func(t *testing.T) {
|
||||
t.Run("Disabled", func(t *testing.T) {
|
||||
tensorFlow := NewNasnet(assetsPath, true)
|
||||
|
||||
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
|
||||
@@ -272,16 +272,16 @@ func TestModel_Run(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModel_LoadModel(t *testing.T) {
|
||||
t.Run("model loaded", func(t *testing.T) {
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
tf := NewModelTest(t)
|
||||
assert.True(t, tf.ModelLoaded())
|
||||
})
|
||||
t.Run("model path does not exist", func(t *testing.T) {
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
tensorFlow := NewNasnet(assetsPath+"foo", false)
|
||||
err := tensorFlow.loadModel()
|
||||
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), "could not find SavedModel")
|
||||
assert.Contains(t, err.Error(), "not find SavedModel")
|
||||
}
|
||||
|
||||
assert.Error(t, err)
|
||||
@@ -289,17 +289,7 @@ func TestModel_LoadModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModel_BestLabels(t *testing.T) {
|
||||
t.Run("labels not loaded", func(t *testing.T) {
|
||||
tensorFlow := NewNasnet(assetsPath, false)
|
||||
|
||||
p := make([]float32, 1000)
|
||||
|
||||
p[666] = 0.5
|
||||
|
||||
result := tensorFlow.bestLabels(p, 10)
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
t.Run("labels loaded", func(t *testing.T) {
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
tensorFlow := NewNasnet(assetsPath, false)
|
||||
|
||||
if err := tensorFlow.loadLabels(modelPath); err != nil {
|
||||
@@ -317,4 +307,14 @@ func TestModel_BestLabels(t *testing.T) {
|
||||
assert.Equal(t, "image", result[0].Source)
|
||||
t.Log(result)
|
||||
})
|
||||
t.Run("NotLoaded", func(t *testing.T) {
|
||||
tensorFlow := NewNasnet(assetsPath, false)
|
||||
|
||||
p := make([]float32, 1000)
|
||||
|
||||
p[666] = 0.5
|
||||
|
||||
result := tensorFlow.bestLabels(p, 10)
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user