mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
CI: Apply Go more linter recommendations to "ai/classify" package #5330
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
//go:build ignore
|
||||
// +build ignore
|
||||
|
||||
// This generates stopwords.go by running "go generate"
|
||||
package main
|
||||
@@ -88,6 +87,7 @@ package classify
|
||||
|
||||
// Generated code, do not edit.
|
||||
|
||||
// Rules contains the generated label classification rules from rules.yml.
|
||||
var Rules = LabelRules{
|
||||
{{- range $key, $value := .Rules }}
|
||||
{{ printf "%q" $key }}: {
|
||||
|
||||
@@ -78,7 +78,7 @@ func NewNasnet(modelsPath string, disabled bool) *Model {
|
||||
}, disabled)
|
||||
}
|
||||
|
||||
// Init initialises tensorflow models if not disabled
|
||||
// Init initializes tensorflow models if not disabled.
|
||||
func (m *Model) Init() (err error) {
|
||||
if m.disabled {
|
||||
return nil
|
||||
@@ -95,7 +95,7 @@ func (m *Model) File(fileName string, confidenceThreshold int) (result Labels, e
|
||||
|
||||
var data []byte
|
||||
|
||||
if data, err = os.ReadFile(fileName); err != nil {
|
||||
if data, err = os.ReadFile(fileName); err != nil { //nolint:gosec // fileName is provided by trusted callers; reading arbitrary local files is expected behavior
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -203,13 +203,17 @@ func (m *Model) loadModel() (err error) {
|
||||
|
||||
if len(m.meta.Tags) == 0 {
|
||||
infos, modelErr := tensorflow.GetModelTagsInfo(modelPath)
|
||||
if modelErr != nil {
|
||||
|
||||
switch {
|
||||
case modelErr != nil:
|
||||
log.Errorf("classify: could not get info from model in %s (%s)", clean.Log(modelPath), clean.Error(modelErr))
|
||||
} else if len(infos) == 1 {
|
||||
case len(infos) == 1:
|
||||
log.Debugf("classify: model info: %+v", infos[0])
|
||||
m.meta.Merge(&infos[0])
|
||||
} else {
|
||||
case len(infos) > 1:
|
||||
log.Warnf("classify: found %d metagraphs, which is too many", len(infos))
|
||||
default:
|
||||
log.Warnf("classify: no metagraphs found in %s", clean.Log(modelPath))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,11 +15,13 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/tensorflow"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultResolution = 224
|
||||
ExternalModelsTestLabel = "PHOTOPRISM_TEST_EXTERNAL_MODELS"
|
||||
maxArchiveFileSize = 2 * 1024 * 1024 * 1024 // 2 GiB limit to avoid decompression bombs in tests
|
||||
)
|
||||
|
||||
var baseUrl = "https://dl.photoprism.app/tensorflow/models"
|
||||
@@ -111,16 +113,43 @@ var modelsInfo = map[string]*ModelTestCase{
|
||||
*/
|
||||
}
|
||||
|
||||
func isSafePath(target, baseDir string) bool {
|
||||
func safeArchivePath(baseDir, name string) (string, error) {
|
||||
cleanName := filepath.Clean(name)
|
||||
|
||||
// Resolve the absolute path of the target
|
||||
absTarget := filepath.Join(baseDir, target)
|
||||
absBase, err := filepath.Abs(baseDir)
|
||||
if err != nil {
|
||||
return false
|
||||
if cleanName == "" || cleanName == "." {
|
||||
return "", fmt.Errorf("empty archive path")
|
||||
}
|
||||
|
||||
return strings.HasPrefix(absTarget, absBase)
|
||||
if filepath.IsAbs(cleanName) || filepath.VolumeName(cleanName) != "" {
|
||||
return "", fmt.Errorf("absolute paths are not allowed")
|
||||
}
|
||||
|
||||
if cleanName == ".." || strings.HasPrefix(cleanName, ".."+string(os.PathSeparator)) {
|
||||
return "", fmt.Errorf("path traversal detected")
|
||||
}
|
||||
|
||||
target := filepath.Join(baseDir, cleanName) //nolint:gosec // target is validated below
|
||||
|
||||
absBase, err := filepath.Abs(baseDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
absTarget, err := filepath.Abs(target)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(absBase, absTarget)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
|
||||
return "", fmt.Errorf("path escapes base directory")
|
||||
}
|
||||
|
||||
return absTarget, nil
|
||||
}
|
||||
|
||||
func TestExternalModel_AllModels(t *testing.T) {
|
||||
@@ -159,20 +188,20 @@ func TestExternalModel_AllModels(t *testing.T) {
|
||||
model.meta.Input.SetResolution(DefaultResolution)
|
||||
}
|
||||
|
||||
testModel_LabelsFromFile(t, model)
|
||||
testModel_Run(t, model)
|
||||
testModelLabelsFromFile(t, model)
|
||||
testModelRun(t, model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func downloadLabels(t *testing.T, url, dst string) {
|
||||
resp, err := http.Get(url)
|
||||
resp, err := http.Get(url) //nolint:gosec // test downloads from trusted PhotoPrism asset host
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
output, err := os.Create(filepath.Join(dst, "labels.txt"))
|
||||
output, err := os.Create(filepath.Join(dst, "labels.txt")) //nolint:gosec // destination is within a controlled temporary test directory
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -189,9 +218,11 @@ func downloadRemoteModel(t *testing.T, url, tmpPath string) (model string) {
|
||||
|
||||
modelPath := strings.TrimSuffix(path.Base(url), ".tar.gz")
|
||||
tmpPath = filepath.Join(tmpPath, modelPath)
|
||||
os.MkdirAll(tmpPath, 0755)
|
||||
if err := os.MkdirAll(tmpPath, fs.ModeDir); err != nil { //nolint:gosec // fs.ModeDir is the project default for directories
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := http.Get(url)
|
||||
resp, err := http.Get(url) //nolint:gosec // test downloads from trusted PhotoPrism asset host
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -207,7 +238,7 @@ func downloadRemoteModel(t *testing.T, url, tmpPath string) (model string) {
|
||||
}
|
||||
|
||||
tarReader := tar.NewReader(uncompressedBody)
|
||||
for true {
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
@@ -217,30 +248,41 @@ func downloadRemoteModel(t *testing.T, url, tmpPath string) (model string) {
|
||||
t.Fatalf("could not extract the file: %v", err)
|
||||
}
|
||||
|
||||
target := filepath.Join(tmpPath, header.Name)
|
||||
if !isSafePath(target, tmpPath) {
|
||||
t.Fatalf("The model file contains an invalid path: %s", header.Name)
|
||||
if strings.HasPrefix(header.Name, "__MACOSX") {
|
||||
continue
|
||||
}
|
||||
|
||||
target, err := safeArchivePath(tmpPath, header.Name)
|
||||
if err != nil {
|
||||
t.Fatalf("The model file contains an invalid path %s: %v", header.Name, err)
|
||||
}
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.Mkdir(target, 0755); err != nil {
|
||||
if err := os.Mkdir(target, fs.ModeDir); err != nil { //nolint:gosec // fs.ModeDir is intentional for extracted model directories
|
||||
t.Fatalf("could not make the dir %s: %v", header.Name, err)
|
||||
}
|
||||
case tar.TypeReg:
|
||||
outFile, err := os.Create(target)
|
||||
outFile, err := os.Create(target) //nolint:gosec // target path validated by isSafePath and confined to tmpPath
|
||||
if err != nil {
|
||||
t.Fatalf("could not create file %s: %v", header.Name, err)
|
||||
}
|
||||
if _, err := io.Copy(outFile, tarReader); err != nil {
|
||||
limitedReader := &io.LimitedReader{R: tarReader, N: maxArchiveFileSize}
|
||||
|
||||
if _, err := io.Copy(outFile, limitedReader); err != nil {
|
||||
t.Fatalf("could not copy file %s: %v", header.Name, err)
|
||||
}
|
||||
if limitedReader.N == 0 {
|
||||
t.Fatalf("file %s exceeds maximum allowed size of %d bytes", header.Name, maxArchiveFileSize)
|
||||
}
|
||||
|
||||
rootPath, fileName := filepath.Split(header.Name)
|
||||
if fileName == "saved_model.pb" {
|
||||
model = filepath.Join(modelPath, rootPath)
|
||||
}
|
||||
outFile.Close()
|
||||
if err := outFile.Close(); err != nil {
|
||||
t.Fatalf("could not close file %s: %v", header.Name, err)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("could not extract file. Unknown type %v in %s",
|
||||
header.Typeflag,
|
||||
@@ -266,7 +308,7 @@ func assertContainsAny(t *testing.T, s string, substrings []string) {
|
||||
s, substrings)
|
||||
}
|
||||
|
||||
func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) {
|
||||
func testModelLabelsFromFile(t *testing.T, tensorFlow *Model) {
|
||||
testName := func(name string) string {
|
||||
return fmt.Sprintf("%s/%s", tensorFlow.name, name)
|
||||
}
|
||||
@@ -367,7 +409,7 @@ func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) {
|
||||
})
|
||||
}
|
||||
|
||||
func testModel_Run(t *testing.T, tensorFlow *Model) {
|
||||
func testModelRun(t *testing.T, tensorFlow *Model) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
}
|
||||
|
||||
@@ -181,7 +181,7 @@ func TestModel_Run(t *testing.T) {
|
||||
t.Run("ChameleonLimeJpg", func(t *testing.T) {
|
||||
tensorFlow := NewModelTest(t)
|
||||
|
||||
if imageBuffer, err := os.ReadFile(examplesPath + "/chameleon_lime.jpg"); err != nil {
|
||||
if imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "chameleon_lime.jpg")); err != nil { //nolint:gosec // reading bundled test fixture
|
||||
t.Error(err)
|
||||
} else {
|
||||
result, err := tensorFlow.Run(imageBuffer, 10)
|
||||
@@ -206,7 +206,7 @@ func TestModel_Run(t *testing.T) {
|
||||
t.Run("DogOrangeJpg", func(t *testing.T) {
|
||||
tensorFlow := NewModelTest(t)
|
||||
|
||||
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
|
||||
if imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "dog_orange.jpg")); err != nil { //nolint:gosec // reading bundled test fixture
|
||||
t.Error(err)
|
||||
} else {
|
||||
result, err := tensorFlow.Run(imageBuffer, 10)
|
||||
@@ -231,7 +231,7 @@ func TestModel_Run(t *testing.T) {
|
||||
t.Run("RandomDocx", func(t *testing.T) {
|
||||
tensorFlow := NewModelTest(t)
|
||||
|
||||
if imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx"); err != nil {
|
||||
if imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "Random.docx")); err != nil { //nolint:gosec // reading bundled test fixture
|
||||
t.Error(err)
|
||||
} else {
|
||||
result, err := tensorFlow.Run(imageBuffer, 10)
|
||||
@@ -242,7 +242,7 @@ func TestModel_Run(t *testing.T) {
|
||||
t.Run("Num6720PxWhiteJpg", func(t *testing.T) {
|
||||
tensorFlow := NewModelTest(t)
|
||||
|
||||
if imageBuffer, err := os.ReadFile(examplesPath + "/6720px_white.jpg"); err != nil {
|
||||
if imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "6720px_white.jpg")); err != nil { //nolint:gosec // reading bundled test fixture
|
||||
t.Error(err)
|
||||
} else {
|
||||
result, err := tensorFlow.Run(imageBuffer, 10)
|
||||
@@ -257,7 +257,7 @@ func TestModel_Run(t *testing.T) {
|
||||
t.Run("Disabled", func(t *testing.T) {
|
||||
tensorFlow := NewNasnet(modelsPath, true)
|
||||
|
||||
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
|
||||
if imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "dog_orange.jpg")); err != nil { //nolint:gosec // reading bundled test fixture
|
||||
t.Error(err)
|
||||
} else {
|
||||
result, err := tensorFlow.Run(imageBuffer, 10)
|
||||
@@ -328,7 +328,7 @@ func BenchmarkModel_BestLabelWithOptimization(b *testing.B) {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg")
|
||||
imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "dog_orange.jpg")) //nolint:gosec // reading bundled test fixture
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@@ -349,7 +349,7 @@ func BenchmarkModel_BestLabelsNoOptimization(b *testing.B) {
|
||||
}
|
||||
model.builder = nil
|
||||
|
||||
imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg")
|
||||
imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "dog_orange.jpg")) //nolint:gosec // reading bundled test fixture
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package classify
|
||||
|
||||
// Generated code, do not edit.
|
||||
|
||||
// Rules contains the generated label classification rules from rules.yml.
|
||||
var Rules = LabelRules{
|
||||
"abacus": {
|
||||
Label: "",
|
||||
|
||||
Reference in New Issue
Block a user