AI: Added some intelligence to label loading.

Now when loading labels internal/ai/tensorflow package will try to look
for all the files that match the glob label*.txt and will return the
labels that match the expected number. Some models add a first label
called background, which is a bias.
Also, a new parameter has been added to models to allow a second path to
look for the label files. This path is set to nasnet asset on
internal/ai/vision.
This commit is contained in:
raystlin
2025-04-13 14:49:54 +00:00
parent 0895a085a1
commit d993eb2a85
4 changed files with 78 additions and 40 deletions

View File

@@ -2,19 +2,17 @@ package tensorflow
import (
"bufio"
"io/fs"
"os"
"path/filepath"
)
// LoadLabels loads the labels of classification models from the specified path and returns them.
func LoadLabels(modelPath string) (labels []string, err error) {
modelLabels := modelPath + "/labels.txt"
log.Infof("tensorflow: loading model labels from labels.txt")
f, err := os.Open(modelLabels)
func loadLabelsFromPath(path string) (labels []string, err error) {
log.Infof("tensorflow: loading model labels from %s", path)
f, err := os.Open(path)
if err != nil {
return labels, err
return nil, err
}
defer f.Close()
@@ -30,3 +28,34 @@ func LoadLabels(modelPath string) (labels []string, err error) {
return labels, err
}
// LoadLabels loads the labels of classification models from the specified path and returns them.
func LoadLabels(modelPath string, expectedLabels int) (labels []string, err error) {
dir := os.DirFS(modelPath)
matches, err := fs.Glob(dir, "labels*.txt")
if err != nil {
return nil, err
}
for i := range matches {
labels, err := loadLabelsFromPath(filepath.Join(modelPath, matches[i]))
if err != nil {
return nil, err
}
switch expectedLabels - len(labels) {
case 0:
log.Infof("Found a valid labels file: %s", matches[i])
return labels, nil
case 1:
log.Infof("Found a valid labels file %s but we have to add bias", matches[i])
return append([]string{"background"}, labels...), nil
default:
log.Infof("File not valid. Expected %d labels and have %d",
expectedLabels, len(labels))
}
}
return nil, os.ErrNotExist
}