AI: Update log messages and tests in internal/ai/classify #127 #5011

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-08-04 10:31:39 +02:00
parent 523605f7d7
commit 3177a61f75
3 changed files with 38 additions and 37 deletions

View File

@@ -149,7 +149,7 @@ func (m *Model) Run(img []byte, confidenceThreshold int) (result Labels, err err
nil) nil)
if err != nil { if err != nil {
return result, fmt.Errorf("classify: %s (run inference)", err.Error()) return result, fmt.Errorf("classify: %s (run inference)", clean.Error(err))
} }
if len(output) < 1 { if len(output) < 1 {
@@ -173,7 +173,7 @@ func (m *Model) loadLabels(modelPath string) (err error) {
m.labels, err = tensorflow.LoadLabels(modelPath, numLabels) m.labels, err = tensorflow.LoadLabels(modelPath, numLabels)
if os.IsNotExist(err) { if os.IsNotExist(err) {
log.Infof("Model does not seem to have tags at %s, trying %s", modelPath, m.defaultLabelsPath) log.Infof("vision: model does not seem to have tags at %s, trying %s", clean.Log(modelPath), clean.Log(m.defaultLabelsPath))
m.labels, err = tensorflow.LoadLabels(m.defaultLabelsPath, numLabels) m.labels, err = tensorflow.LoadLabels(m.defaultLabelsPath, numLabels)
} }
return err return err
@@ -197,29 +197,30 @@ func (m *Model) loadModel() (err error) {
modelPath := path.Join(m.assetsPath, m.modelPath) modelPath := path.Join(m.assetsPath, m.modelPath)
if len(m.meta.Tags) == 0 { if len(m.meta.Tags) == 0 {
infos, err := tensorflow.GetModelInfo(modelPath) infos, modelErr := tensorflow.GetModelInfo(modelPath)
if err != nil { if modelErr != nil {
log.Errorf("classify: could not get the model info at %s: %v", clean.Log(modelPath), err) log.Errorf("classify: could not get info from model in %s (%s)", clean.Log(modelPath), clean.Error(modelErr))
} else if len(infos) == 1 { } else if len(infos) == 1 {
log.Debugf("classify: model info: %+v", infos[0]) log.Debugf("classify: model info: %+v", infos[0])
m.meta.Merge(&infos[0]) m.meta.Merge(&infos[0])
} else { } else {
log.Warnf("classify: found %d metagraphs... that's too many", len(infos)) log.Warnf("classify: found %d metagraphs, which is too many", len(infos))
} }
} }
m.model, err = tensorflow.SavedModel(modelPath, m.meta.Tags) m.model, err = tensorflow.SavedModel(modelPath, m.meta.Tags)
if err != nil { if err != nil {
return err return err
} }
if !m.meta.IsComplete() { if !m.meta.IsComplete() {
input, output, err := tensorflow.GetInputAndOutputFromSavedModel(m.model) input, output, modelErr := tensorflow.GetInputAndOutputFromSavedModel(m.model)
if err != nil { if modelErr != nil {
log.Errorf("classify: could not get info from signatures: %v", err) log.Errorf("classify: could not get info from signatures (%s)", clean.Error(modelErr))
input, output, err = tensorflow.GuessInputAndOutput(m.model) input, output, modelErr = tensorflow.GuessInputAndOutput(m.model)
if err != nil { if modelErr != nil {
return fmt.Errorf("classify: %w", err) return fmt.Errorf("classify: %w", modelErr)
} }
} }
@@ -232,7 +233,7 @@ func (m *Model) loadModel() (err error) {
if m.meta.Output.OutputsLogits { if m.meta.Output.OutputsLogits {
_, err = tensorflow.AddSoftmax(m.model.Graph, m.meta) _, err = tensorflow.AddSoftmax(m.model.Graph, m.meta)
if err != nil { if err != nil {
return fmt.Errorf("classify: could not add softmax: %w", nil) return fmt.Errorf("classify: could not add softmax (%s)", clean.Error(err))
} }
} }

View File

@@ -129,15 +129,15 @@ func TestExternalModel_AllModels(t *testing.T) {
for k, v := range modelsInfo { for k, v := range modelsInfo {
t.Run(k, func(*testing.T) { t.Run(k, func(*testing.T) {
log.Infof("Testing model %s", k) log.Infof("vision: testing model %s", k)
downloadedModel := downloadRemoteModel(t, fmt.Sprintf("%s/%s", baseUrl, k), tmpPath) downloadedModel := downloadRemoteModel(t, fmt.Sprintf("%s/%s", baseUrl, k), tmpPath)
log.Infof("Model downloaded to %s", downloadedModel) log.Infof("vision: model downloaded to %s", downloadedModel)
if v.Labels != "" { if v.Labels != "" {
modelPath := filepath.Join(tmpPath, downloadedModel) modelPath := filepath.Join(tmpPath, downloadedModel)
t.Logf("Model path: %s", modelPath) t.Logf("vision: model path is %s", modelPath)
downloadLabels(t, fmt.Sprintf("%s/%s", baseUrl, v.Labels), modelPath) downloadLabels(t, fmt.Sprintf("%s/%s", baseUrl, v.Labels), modelPath)
} }

View File

@@ -37,7 +37,7 @@ func TestModel_CenterCrop(t *testing.T) {
model.meta.Input.ResizeOperation = tensorflow.CenterCrop model.meta.Input.ResizeOperation = tensorflow.CenterCrop
t.Run("nasnet padding", func(t *testing.T) { t.Run("nasnet padding", func(t *testing.T) {
testModel_BasicLabels(t, model, 6) runBasicLabelsTest(t, model, 6)
}) })
} }
@@ -50,7 +50,7 @@ func TestModel_Padding(t *testing.T) {
model.meta.Input.ResizeOperation = tensorflow.Padding model.meta.Input.ResizeOperation = tensorflow.Padding
t.Run("nasnet padding", func(t *testing.T) { t.Run("nasnet padding", func(t *testing.T) {
testModel_BasicLabels(t, model, 6) runBasicLabelsTest(t, model, 6)
}) })
} }
@@ -63,11 +63,11 @@ func TestModel_ResizeBreakAspectRatio(t *testing.T) {
model.meta.Input.ResizeOperation = tensorflow.ResizeBreakAspectRatio model.meta.Input.ResizeOperation = tensorflow.ResizeBreakAspectRatio
t.Run("nasnet break aspect ratio", func(t *testing.T) { t.Run("nasnet break aspect ratio", func(t *testing.T) {
testModel_BasicLabels(t, model, 4) runBasicLabelsTest(t, model, 4)
}) })
} }
func testModel_BasicLabels(t *testing.T, model *Model, expectedUncertainty int) { func runBasicLabelsTest(t *testing.T, model *Model, expectedUncertainty int) {
result, err := model.File(examplesPath+"/zebra_green_brown.jpg", 10) result, err := model.File(examplesPath+"/zebra_green_brown.jpg", 10)
assert.NoError(t, err) assert.NoError(t, err)
@@ -146,14 +146,14 @@ func TestModel_LabelsFromFile(t *testing.T) {
assert.Equal(t, 70, result[0].Uncertainty) assert.Equal(t, 70, result[0].Uncertainty)
} }
}) })
t.Run("not existing file", func(t *testing.T) { t.Run("NotExistingFile", func(t *testing.T) {
tensorFlow := NewModelTest(t) tensorFlow := NewModelTest(t)
result, err := tensorFlow.File(examplesPath+"/notexisting.jpg", 10) result, err := tensorFlow.File(examplesPath+"/notexisting.jpg", 10)
assert.Contains(t, err.Error(), "no such file or directory") assert.Contains(t, err.Error(), "no such file or directory")
assert.Empty(t, result) assert.Empty(t, result)
}) })
t.Run("disabled true", func(t *testing.T) { t.Run("Disabled", func(t *testing.T) {
tensorFlow := NewNasnet(assetsPath, true) tensorFlow := NewNasnet(assetsPath, true)
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10) result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
@@ -252,7 +252,7 @@ func TestModel_Run(t *testing.T) {
assert.Empty(t, result) assert.Empty(t, result)
} }
}) })
t.Run("disabled true", func(t *testing.T) { t.Run("Disabled", func(t *testing.T) {
tensorFlow := NewNasnet(assetsPath, true) tensorFlow := NewNasnet(assetsPath, true)
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil { if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
@@ -272,16 +272,16 @@ func TestModel_Run(t *testing.T) {
} }
func TestModel_LoadModel(t *testing.T) { func TestModel_LoadModel(t *testing.T) {
t.Run("model loaded", func(t *testing.T) { t.Run("Success", func(t *testing.T) {
tf := NewModelTest(t) tf := NewModelTest(t)
assert.True(t, tf.ModelLoaded()) assert.True(t, tf.ModelLoaded())
}) })
t.Run("model path does not exist", func(t *testing.T) { t.Run("NotFound", func(t *testing.T) {
tensorFlow := NewNasnet(assetsPath+"foo", false) tensorFlow := NewNasnet(assetsPath+"foo", false)
err := tensorFlow.loadModel() err := tensorFlow.loadModel()
if err != nil { if err != nil {
assert.Contains(t, err.Error(), "could not find SavedModel") assert.Contains(t, err.Error(), "not find SavedModel")
} }
assert.Error(t, err) assert.Error(t, err)
@@ -289,17 +289,7 @@ func TestModel_LoadModel(t *testing.T) {
} }
func TestModel_BestLabels(t *testing.T) { func TestModel_BestLabels(t *testing.T) {
t.Run("labels not loaded", func(t *testing.T) { t.Run("Success", func(t *testing.T) {
tensorFlow := NewNasnet(assetsPath, false)
p := make([]float32, 1000)
p[666] = 0.5
result := tensorFlow.bestLabels(p, 10)
assert.Empty(t, result)
})
t.Run("labels loaded", func(t *testing.T) {
tensorFlow := NewNasnet(assetsPath, false) tensorFlow := NewNasnet(assetsPath, false)
if err := tensorFlow.loadLabels(modelPath); err != nil { if err := tensorFlow.loadLabels(modelPath); err != nil {
@@ -317,4 +307,14 @@ func TestModel_BestLabels(t *testing.T) {
assert.Equal(t, "image", result[0].Source) assert.Equal(t, "image", result[0].Source)
t.Log(result) t.Log(result)
}) })
t.Run("NotLoaded", func(t *testing.T) {
tensorFlow := NewNasnet(assetsPath, false)
p := make([]float32, 1000)
p[666] = 0.5
result := tensorFlow.bestLabels(p, 10)
assert.Empty(t, result)
})
} }