CI: Apply Go more linter recommendations to "ai/classify" package #5330

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-11-22 11:30:58 +01:00
parent 4682791253
commit 780a870f5c
5 changed files with 83 additions and 36 deletions

View File

@@ -1,5 +1,4 @@
//go:build ignore
// +build ignore
// This generates stopwords.go by running "go generate"
package main
@@ -88,6 +87,7 @@ package classify
// Generated code, do not edit.
// Rules contains the generated label classification rules from rules.yml.
var Rules = LabelRules{
{{- range $key, $value := .Rules }}
{{ printf "%q" $key }}: {

View File

@@ -78,7 +78,7 @@ func NewNasnet(modelsPath string, disabled bool) *Model {
}, disabled)
}
// Init initialises tensorflow models if not disabled
// Init initializes tensorflow models if not disabled.
func (m *Model) Init() (err error) {
if m.disabled {
return nil
@@ -95,7 +95,7 @@ func (m *Model) File(fileName string, confidenceThreshold int) (result Labels, e
var data []byte
if data, err = os.ReadFile(fileName); err != nil {
if data, err = os.ReadFile(fileName); err != nil { //nolint:gosec // fileName is provided by trusted callers; reading arbitrary local files is expected behavior
return nil, err
}
@@ -203,13 +203,17 @@ func (m *Model) loadModel() (err error) {
if len(m.meta.Tags) == 0 {
infos, modelErr := tensorflow.GetModelTagsInfo(modelPath)
if modelErr != nil {
switch {
case modelErr != nil:
log.Errorf("classify: could not get info from model in %s (%s)", clean.Log(modelPath), clean.Error(modelErr))
} else if len(infos) == 1 {
case len(infos) == 1:
log.Debugf("classify: model info: %+v", infos[0])
m.meta.Merge(&infos[0])
} else {
case len(infos) > 1:
log.Warnf("classify: found %d metagraphs, which is too many", len(infos))
default:
log.Warnf("classify: no metagraphs found in %s", clean.Log(modelPath))
}
}

View File

@@ -15,11 +15,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/ai/tensorflow"
"github.com/photoprism/photoprism/pkg/fs"
)
const (
DefaultResolution = 224
ExternalModelsTestLabel = "PHOTOPRISM_TEST_EXTERNAL_MODELS"
maxArchiveFileSize = 2 * 1024 * 1024 * 1024 // 2 GiB limit to avoid decompression bombs in tests
)
var baseUrl = "https://dl.photoprism.app/tensorflow/models"
@@ -111,16 +113,43 @@ var modelsInfo = map[string]*ModelTestCase{
*/
}
func isSafePath(target, baseDir string) bool {
func safeArchivePath(baseDir, name string) (string, error) {
cleanName := filepath.Clean(name)
// Resolve the absolute path of the target
absTarget := filepath.Join(baseDir, target)
absBase, err := filepath.Abs(baseDir)
if err != nil {
return false
if cleanName == "" || cleanName == "." {
return "", fmt.Errorf("empty archive path")
}
return strings.HasPrefix(absTarget, absBase)
if filepath.IsAbs(cleanName) || filepath.VolumeName(cleanName) != "" {
return "", fmt.Errorf("absolute paths are not allowed")
}
if cleanName == ".." || strings.HasPrefix(cleanName, ".."+string(os.PathSeparator)) {
return "", fmt.Errorf("path traversal detected")
}
target := filepath.Join(baseDir, cleanName) //nolint:gosec // target is validated below
absBase, err := filepath.Abs(baseDir)
if err != nil {
return "", err
}
absTarget, err := filepath.Abs(target)
if err != nil {
return "", err
}
rel, err := filepath.Rel(absBase, absTarget)
if err != nil {
return "", err
}
if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
return "", fmt.Errorf("path escapes base directory")
}
return absTarget, nil
}
func TestExternalModel_AllModels(t *testing.T) {
@@ -159,20 +188,20 @@ func TestExternalModel_AllModels(t *testing.T) {
model.meta.Input.SetResolution(DefaultResolution)
}
testModel_LabelsFromFile(t, model)
testModel_Run(t, model)
testModelLabelsFromFile(t, model)
testModelRun(t, model)
})
}
}
func downloadLabels(t *testing.T, url, dst string) {
resp, err := http.Get(url)
resp, err := http.Get(url) //nolint:gosec // test downloads from trusted PhotoPrism asset host
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
output, err := os.Create(filepath.Join(dst, "labels.txt"))
output, err := os.Create(filepath.Join(dst, "labels.txt")) //nolint:gosec // destination is within a controlled temporary test directory
if err != nil {
t.Fatal(err)
}
@@ -189,9 +218,11 @@ func downloadRemoteModel(t *testing.T, url, tmpPath string) (model string) {
modelPath := strings.TrimSuffix(path.Base(url), ".tar.gz")
tmpPath = filepath.Join(tmpPath, modelPath)
os.MkdirAll(tmpPath, 0755)
if err := os.MkdirAll(tmpPath, fs.ModeDir); err != nil { //nolint:gosec // fs.ModeDir is the project default for directories
t.Fatal(err)
}
resp, err := http.Get(url)
resp, err := http.Get(url) //nolint:gosec // test downloads from trusted PhotoPrism asset host
if err != nil {
t.Fatal(err)
}
@@ -207,7 +238,7 @@ func downloadRemoteModel(t *testing.T, url, tmpPath string) (model string) {
}
tarReader := tar.NewReader(uncompressedBody)
for true {
for {
header, err := tarReader.Next()
if err == io.EOF {
break
@@ -217,30 +248,41 @@ func downloadRemoteModel(t *testing.T, url, tmpPath string) (model string) {
t.Fatalf("could not extract the file: %v", err)
}
target := filepath.Join(tmpPath, header.Name)
if !isSafePath(target, tmpPath) {
t.Fatalf("The model file contains an invalid path: %s", header.Name)
if strings.HasPrefix(header.Name, "__MACOSX") {
continue
}
target, err := safeArchivePath(tmpPath, header.Name)
if err != nil {
t.Fatalf("The model file contains an invalid path %s: %v", header.Name, err)
}
switch header.Typeflag {
case tar.TypeDir:
if err := os.Mkdir(target, 0755); err != nil {
if err := os.Mkdir(target, fs.ModeDir); err != nil { //nolint:gosec // fs.ModeDir is intentional for extracted model directories
t.Fatalf("could not make the dir %s: %v", header.Name, err)
}
case tar.TypeReg:
outFile, err := os.Create(target)
outFile, err := os.Create(target) //nolint:gosec // target path validated by isSafePath and confined to tmpPath
if err != nil {
t.Fatalf("could not create file %s: %v", header.Name, err)
}
if _, err := io.Copy(outFile, tarReader); err != nil {
limitedReader := &io.LimitedReader{R: tarReader, N: maxArchiveFileSize}
if _, err := io.Copy(outFile, limitedReader); err != nil {
t.Fatalf("could not copy file %s: %v", header.Name, err)
}
if limitedReader.N == 0 {
t.Fatalf("file %s exceeds maximum allowed size of %d bytes", header.Name, maxArchiveFileSize)
}
rootPath, fileName := filepath.Split(header.Name)
if fileName == "saved_model.pb" {
model = filepath.Join(modelPath, rootPath)
}
outFile.Close()
if err := outFile.Close(); err != nil {
t.Fatalf("could not close file %s: %v", header.Name, err)
}
default:
t.Fatalf("could not extract file. Unknown type %v in %s",
header.Typeflag,
@@ -266,7 +308,7 @@ func assertContainsAny(t *testing.T, s string, substrings []string) {
s, substrings)
}
func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) {
func testModelLabelsFromFile(t *testing.T, tensorFlow *Model) {
testName := func(name string) string {
return fmt.Sprintf("%s/%s", tensorFlow.name, name)
}
@@ -367,7 +409,7 @@ func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) {
})
}
func testModel_Run(t *testing.T, tensorFlow *Model) {
func testModelRun(t *testing.T, tensorFlow *Model) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}

View File

@@ -181,7 +181,7 @@ func TestModel_Run(t *testing.T) {
t.Run("ChameleonLimeJpg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
if imageBuffer, err := os.ReadFile(examplesPath + "/chameleon_lime.jpg"); err != nil {
if imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "chameleon_lime.jpg")); err != nil { //nolint:gosec // reading bundled test fixture
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
@@ -206,7 +206,7 @@ func TestModel_Run(t *testing.T) {
t.Run("DogOrangeJpg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
if imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "dog_orange.jpg")); err != nil { //nolint:gosec // reading bundled test fixture
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
@@ -231,7 +231,7 @@ func TestModel_Run(t *testing.T) {
t.Run("RandomDocx", func(t *testing.T) {
tensorFlow := NewModelTest(t)
if imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx"); err != nil {
if imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "Random.docx")); err != nil { //nolint:gosec // reading bundled test fixture
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
@@ -242,7 +242,7 @@ func TestModel_Run(t *testing.T) {
t.Run("Num6720PxWhiteJpg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
if imageBuffer, err := os.ReadFile(examplesPath + "/6720px_white.jpg"); err != nil {
if imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "6720px_white.jpg")); err != nil { //nolint:gosec // reading bundled test fixture
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
@@ -257,7 +257,7 @@ func TestModel_Run(t *testing.T) {
t.Run("Disabled", func(t *testing.T) {
tensorFlow := NewNasnet(modelsPath, true)
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
if imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "dog_orange.jpg")); err != nil { //nolint:gosec // reading bundled test fixture
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
@@ -328,7 +328,7 @@ func BenchmarkModel_BestLabelWithOptimization(b *testing.B) {
b.Fatal(err)
}
imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg")
imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "dog_orange.jpg")) //nolint:gosec // reading bundled test fixture
if err != nil {
b.Fatal(err)
}
@@ -349,7 +349,7 @@ func BenchmarkModel_BestLabelsNoOptimization(b *testing.B) {
}
model.builder = nil
imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg")
imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "dog_orange.jpg")) //nolint:gosec // reading bundled test fixture
if err != nil {
b.Fatal(err)
}

View File

@@ -2,6 +2,7 @@ package classify
// Generated code, do not edit.
// Rules contains the generated label classification rules from rules.yml.
var Rules = LabelRules{
"abacus": {
Label: "",