CI: Apply Go linter recommendations to "ai/tensorflow" package #5330

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-11-22 11:47:17 +01:00
parent b954de52e9
commit 75bc6d754c
10 changed files with 113 additions and 59 deletions

View File

@@ -4,8 +4,9 @@ import (
"bytes"
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
_ "image/jpeg" // register JPEG decoder
_ "image/png" // register PNG decoder
"math"
"os"
"runtime/debug"
@@ -16,10 +17,13 @@ import (
)
const (
Mean = float32(117)
// Mean is the default mean pixel value used during normalization.
Mean = float32(117)
// Scale is the default scale applied during normalization.
Scale = float32(1)
)
// ImageFromFile decodes an image from disk and converts it to a tensor for inference.
func ImageFromFile(fileName string, input *PhotoInput) (*tf.Tensor, error) {
if img, err := OpenImage(fileName); err != nil {
return nil, err
@@ -28,8 +32,9 @@ func ImageFromFile(fileName string, input *PhotoInput) (*tf.Tensor, error) {
}
}
// OpenImage opens an image file and decodes it using the registered decoders.
func OpenImage(fileName string) (image.Image, error) {
f, err := os.Open(fileName)
f, err := os.Open(fileName) //nolint:gosec // fileName supplied by trusted caller; reading local images is expected
if err != nil {
return nil, err
}
@@ -39,6 +44,7 @@ func OpenImage(fileName string) (image.Image, error) {
return img, err
}
// ImageFromBytes converts raw image bytes into a tensor using the provided input definition.
func ImageFromBytes(b []byte, input *PhotoInput, builder *ImageTensorBuilder) (*tf.Tensor, error) {
img, _, imgErr := image.Decode(bytes.NewReader(b))
@@ -49,6 +55,7 @@ func ImageFromBytes(b []byte, input *PhotoInput, builder *ImageTensorBuilder) (*
return Image(img, input, builder)
}
// Image converts a decoded image into a tensor matching the provided input description.
func Image(img image.Image, input *PhotoInput, builder *ImageTensorBuilder) (tfTensor *tf.Tensor, err error) {
defer func() {
if r := recover(); r != nil {
@@ -70,8 +77,8 @@ func Image(img image.Image, input *PhotoInput, builder *ImageTensorBuilder) (tfT
for i := 0; i < input.Resolution(); i++ {
for j := 0; j < input.Resolution(); j++ {
r, g, b, _ := img.At(i, j).RGBA()
//Although RGB can be disordered, we assume the input intervals are
//given in RGB order.
// Although RGB can be disordered, we assume the input intervals are
// given in RGB order.
builder.Set(i, j,
convertValue(r, input.GetInterval(0)),
convertValue(g, input.GetInterval(1)),
@@ -116,6 +123,10 @@ func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph,
s := op.NewScope()
input = op.Placeholder(s, tf.String)
if resolution <= 0 || resolution > math.MaxInt32 {
return nil, input, output, fmt.Errorf("tensorflow: resolution %d is out of bounds", resolution)
}
// Assume the image is a JPEG, or a PNG if explicitly specified.
var decodedImage tf.Output
switch imageFormat {
@@ -125,13 +136,15 @@ func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph,
decodedImage = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3))
}
size := int32(resolution) //nolint:gosec // resolution is validated to be within int32 range above
output = op.Div(s,
op.Sub(s,
op.ResizeBilinear(s,
op.ExpandDims(s,
op.Cast(s, decodedImage, tf.Float),
op.Const(s.SubScope("make_batch"), int32(0))),
op.Const(s.SubScope("size"), []int32{int32(resolution), int32(resolution)})),
op.Const(s.SubScope("size"), []int32{size, size})),
op.Const(s.SubScope("mean"), Mean)),
op.Const(s.SubScope("scale"), Scale))

View File

@@ -32,7 +32,7 @@ func TestConvertStdMean(t *testing.T) {
func TestImageFromBytes(t *testing.T) {
t.Run("CatJpeg", func(t *testing.T) {
imageBuffer, err := os.ReadFile(examplesPath + "/cat_brown.jpg")
imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "cat_brown.jpg")) //nolint:gosec // reading bundled test fixture
if err != nil {
t.Fatal(err)
@@ -48,7 +48,7 @@ func TestImageFromBytes(t *testing.T) {
assert.Equal(t, int64(224), result.Shape()[2])
})
t.Run("Document", func(t *testing.T) {
imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx")
imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "Random.docx")) //nolint:gosec // reading bundled test fixture
assert.Nil(t, err)
result, err := ImageFromBytes(imageBuffer, defaultImageInput, nil)

View File

@@ -17,7 +17,7 @@ import (
// defined for input images as "what decodeImage returns".
const ExpectedChannels = 3
// Interval of allowed values
// Interval of allowed values.
type Interval struct {
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
@@ -53,9 +53,13 @@ func StandardInterval() *Interval {
type ResizeOperation int
const (
// UndefinedResizeOperation indicates that no resize strategy was specified.
UndefinedResizeOperation ResizeOperation = iota
// ResizeBreakAspectRatio resizes without preserving aspect ratio.
ResizeBreakAspectRatio
// CenterCrop crops the center region after resizing to fill the target size.
CenterCrop
// Padding resizes while preserving aspect ratio and pads the rest.
Padding
)
@@ -74,6 +78,7 @@ func (o ResizeOperation) String() string {
}
}
// NewResizeOperation parses a string into a ResizeOperation.
func NewResizeOperation(s string) (ResizeOperation, error) {
switch s {
case "Undefined":
@@ -85,14 +90,16 @@ func NewResizeOperation(s string) (ResizeOperation, error) {
case "Padding":
return Padding, nil
default:
return UndefinedResizeOperation, fmt.Errorf("Invalid operation %s", s)
return UndefinedResizeOperation, fmt.Errorf("invalid operation %s", s)
}
}
// MarshalJSON encodes the resize operation as its string name.
func (o ResizeOperation) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
// UnmarshalJSON decodes a resize operation from its string representation.
func (o *ResizeOperation) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
@@ -108,10 +115,12 @@ func (o *ResizeOperation) UnmarshalJSON(data []byte) error {
return nil
}
// MarshalYAML encodes the resize operation for YAML output.
func (o ResizeOperation) MarshalYAML() (any, error) {
return o.String(), nil
}
// UnmarshalYAML decodes the resize operation from YAML input.
func (o *ResizeOperation) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
if err := unmarshal(&s); err != nil {
@@ -131,15 +140,23 @@ func (o *ResizeOperation) UnmarshalYAML(unmarshal func(interface{}) error) error
type ColorChannelOrder int
const (
// UndefinedOrder leaves channel order unspecified, defaulting to RGB.
UndefinedOrder ColorChannelOrder = 0
RGB = 123
RBG = 132
GRB = 213
GBR = 231
BRG = 312
BGR = 321
// RGB represents Red-Green-Blue channel order.
RGB = 123
// RBG represents Red-Blue-Green channel order.
RBG = 132
// GRB represents Green-Red-Blue channel order.
GRB = 213
// GBR represents Green-Blue-Red channel order.
GBR = 231
// BRG represents Blue-Red-Green channel order.
BRG = 312
// BGR represents Blue-Green-Red channel order.
BGR = 321
)
// Indices returns the zero-based indices of the R, G, and B channels.
func (o ColorChannelOrder) Indices() (r, g, b int) {
i := int(o)
@@ -147,7 +164,7 @@ func (o ColorChannelOrder) Indices() (r, g, b int) {
i = 123
}
for idx := 0; i > 0 && idx < 3; idx += 1 {
for idx := 0; i > 0 && idx < 3; idx++ {
remainder := i % 10
i /= 10
@@ -195,9 +212,10 @@ func (o ColorChannelOrder) String() string {
return result
}
// NewColorChannelOrder parses a string (e.g., "RGB") into a ColorChannelOrder.
func NewColorChannelOrder(val string) (ColorChannelOrder, error) {
if len(val) != 3 {
return UndefinedOrder, fmt.Errorf("Invalid length, expected 3")
return UndefinedOrder, fmt.Errorf("invalid length, expected 3")
}
convert := func(c rune) int {
@@ -217,17 +235,19 @@ func NewColorChannelOrder(val string) (ColorChannelOrder, error) {
for _, c := range val {
index := convert(c)
if index == 0 {
return UndefinedOrder, fmt.Errorf("Invalid val %c", c)
return UndefinedOrder, fmt.Errorf("invalid val %c", c)
}
result = result*10 + index
}
return ColorChannelOrder(result), nil
}
// MarshalJSON encodes the channel order as its string name.
func (o ColorChannelOrder) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
// UnmarshalJSON decodes a channel order from its string representation.
func (o *ColorChannelOrder) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
@@ -243,10 +263,12 @@ func (o *ColorChannelOrder) UnmarshalJSON(data []byte) error {
return nil
}
// MarshalYAML encodes the channel order for YAML output.
func (o ColorChannelOrder) MarshalYAML() (any, error) {
return o.String(), nil
}
// UnmarshalYAML decodes the channel order from YAML input.
func (o *ColorChannelOrder) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
if err := unmarshal(&s); err != nil {
@@ -261,17 +283,22 @@ func (o *ColorChannelOrder) UnmarshalYAML(unmarshal func(interface{}) error) err
return nil
}
// The expected shape for the input layer of a mode. Usually this shape is
// (batch, resolution, resolution, channels) but sometimes it is not.
// ShapeComponent describes a single dimension of a model input shape.
// Usually this shape is (batch, resolution, resolution, channels) but sometimes it is not.
type ShapeComponent string
const (
ShapeBatch ShapeComponent = "Batch"
ShapeWidth = "Width"
ShapeHeight = "Height"
ShapeColor = "Color"
// ShapeBatch represents the batch dimension.
ShapeBatch ShapeComponent = "Batch"
// ShapeWidth represents the width dimension.
ShapeWidth = "Width"
// ShapeHeight represents the height dimension.
ShapeHeight = "Height"
// ShapeColor represents the color/channel dimension.
ShapeColor = "Color"
)
// DefaultPhotoInputShape returns the standard BHWC input shape.
func DefaultPhotoInputShape() []ShapeComponent {
return []ShapeComponent{
ShapeBatch,
@@ -285,7 +312,7 @@ func DefaultPhotoInputShape() []ShapeComponent {
type PhotoInput struct {
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Intervals []Interval `yaml:"Intervals,omitempty" json:"intervals,omitempty"`
ResizeOperation ResizeOperation `yaml:"ResizeOperation,omitempty" json:"resizeOperation,omitemty"`
ResizeOperation ResizeOperation `yaml:"ResizeOperation,omitempty" json:"resizeOperation,omitempty"`
ColorChannelOrder ColorChannelOrder `yaml:"ColorChannelOrder,omitempty" json:"inputOrder,omitempty"`
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
@@ -427,10 +454,11 @@ func (m ModelInfo) IsComplete() bool {
return m.Input != nil && m.Output != nil && m.Input.Shape != nil
}
// GetModelTagsInfo reads a SavedModel and returns its available meta graph tags.
func GetModelTagsInfo(savedModelPath string) ([]ModelInfo, error) {
savedModel := filepath.Join(savedModelPath, "saved_model.pb")
data, err := os.ReadFile(savedModel)
data, err := os.ReadFile(savedModel) //nolint:gosec // savedModel path derived from trusted model directory
if err != nil {
return nil, fmt.Errorf("vision: failed to read %s (%s)", clean.Path(savedModel), clean.Error(err))

View File

@@ -24,11 +24,12 @@ func TestGetModelTagsInfo(t *testing.T) {
t.Fatal(err)
}
if len(info) != 1 {
switch {
case len(info) != 1:
t.Fatalf("Expected 1 info but got %d", len(info))
} else if len(info[0].Tags) != 1 {
case len(info[0].Tags) != 1:
t.Fatalf("Expected 1 tag, but got %d", len(info[0].Tags))
} else if info[0].Tags[0] != "photoprism" {
case info[0].Tags[0] != "photoprism":
t.Fatalf("Expected tag photoprism, but have %s", info[0].Tags[0])
}
}

View File

@@ -12,7 +12,7 @@ import (
func loadLabelsFromPath(path string) (labels []string, err error) {
log.Infof("vision: loading TensorFlow model labels from %s", path)
f, err := os.Open(path)
f, err := os.Open(path) //nolint:gosec // path originates from known model directory; reading labels is expected
if err != nil {
return nil, err
}

View File

@@ -23,7 +23,7 @@ func SavedModel(modelPath string, tags []string) (model *tf.SavedModel, err erro
}
// GuessInputAndOutput tries to inspect a loaded saved model to build the
// ModelInfo struct
// ModelInfo struct.
func GuessInputAndOutput(model *tf.SavedModel) (input *PhotoInput, output *ModelOutput, err error) {
if model == nil {
return nil, nil, fmt.Errorf("tensorflow: GuessInputAndOutput received a nil input")
@@ -39,9 +39,10 @@ func GuessInputAndOutput(model *tf.SavedModel) (input *PhotoInput, output *Model
shape := modelOps[i].Output(0).Shape()
var comps []ShapeComponent
if shape.Size(3) == ExpectedChannels {
switch {
case shape.Size(3) == ExpectedChannels:
comps = []ShapeComponent{ShapeBatch, ShapeHeight, ShapeWidth, ShapeColor}
} else if shape.Size(1) == ExpectedChannels { // check the channels are 3
case shape.Size(1) == ExpectedChannels: // check the channels are 3
comps = []ShapeComponent{ShapeBatch, ShapeColor, ShapeHeight, ShapeWidth, ShapeColor}
}
@@ -69,6 +70,7 @@ func GuessInputAndOutput(model *tf.SavedModel) (input *PhotoInput, output *Model
return nil, nil, fmt.Errorf("could not guess the inputs and outputs")
}
// GetInputAndOutputFromSavedModel reads signature definitions to derive input/output info.
func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelOutput, error) {
if model == nil {
return nil, nil, fmt.Errorf("GetInputAndOutputFromSavedModel: nil input")
@@ -86,18 +88,20 @@ func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelO
for _, inputTensor := range inputs {
if inputTensor.Shape.NumDimensions() == 4 {
var comps []ShapeComponent
if inputTensor.Shape.Size(3) == ExpectedChannels {
switch {
case inputTensor.Shape.Size(3) == ExpectedChannels:
comps = []ShapeComponent{ShapeBatch, ShapeHeight, ShapeWidth, ShapeColor}
} else if inputTensor.Shape.Size(1) == ExpectedChannels { // check the channels are 3
case inputTensor.Shape.Size(1) == ExpectedChannels: // check the channels are 3
comps = []ShapeComponent{ShapeBatch, ShapeColor, ShapeHeight, ShapeWidth}
} else {
default:
log.Debugf("tensorflow: shape %d", inputTensor.Shape.Size(1))
}
if comps == nil {
log.Warnf("tensorflow: skipping signature %v because we could not find the color component", k)
} else {
var inputIdx = 0
inputIdx := 0
var err error
inputName, inputIndex, found := strings.Cut(inputTensor.Name, ":")

View File

@@ -20,23 +20,24 @@ func TestTF1ModelLoad(t *testing.T) {
t.Fatal(err)
}
input, output, err := GetInputAndOutputFromSavedModel(model)
_, _, err = GetInputAndOutputFromSavedModel(model)
if err == nil {
t.Fatalf("TF1 does not have signatures, but GetInput worked")
}
input, output, err = GuessInputAndOutput(model)
input, output, err := GuessInputAndOutput(model)
if err != nil {
t.Fatal(err)
}
if input == nil {
switch {
case input == nil:
t.Fatal("Could not get the input")
} else if output == nil {
case output == nil:
t.Fatal("Could not get the output")
} else if input.Shape == nil {
case input.Shape == nil:
t.Fatal("Could not get the shape")
} else {
default:
t.Logf("Shape: %v", input.Shape)
}
}
@@ -55,15 +56,15 @@ func TestTF2ModelLoad(t *testing.T) {
t.Fatal(err)
}
if input == nil {
switch {
case input == nil:
t.Fatal("Could not get the input")
} else if output == nil {
case output == nil:
t.Fatal("Could not get the output")
} else if input.Shape == nil {
case input.Shape == nil:
t.Fatal("Could not get the shape")
} else if !slices.Equal(input.Shape, DefaultPhotoInputShape()) {
t.Fatalf("Invalid shape calculated. Expected BHWC, got %v",
input.Shape)
case !slices.Equal(input.Shape, DefaultPhotoInputShape()):
t.Fatalf("Invalid shape calculated. Expected BHWC, got %v", input.Shape)
}
}
@@ -81,16 +82,16 @@ func TestTF2ModelBCHWLoad(t *testing.T) {
t.Fatal(err)
}
if input == nil {
switch {
case input == nil:
t.Fatal("Could not get the input")
} else if output == nil {
case output == nil:
t.Fatal("Could not get the output")
} else if input.Shape == nil {
case input.Shape == nil:
t.Fatal("Could not get the shape")
} else if !slices.Equal(input.Shape, []ShapeComponent{
case !slices.Equal(input.Shape, []ShapeComponent{
ShapeBatch, ShapeColor, ShapeHeight, ShapeWidth,
}) {
t.Fatalf("Invalid shape calculated. Expected BCHW, got %v",
input.Shape)
}):
t.Fatalf("Invalid shape calculated. Expected BCHW, got %v", input.Shape)
}
}

View File

@@ -4,6 +4,7 @@ import (
tf "github.com/wamuir/graft/tensorflow"
)
// AddSoftmax appends a Softmax operation to the graph for the configured model output.
func AddSoftmax(graph *tf.Graph, info *ModelInfo) (*tf.Operation, error) {
randomName := randomString(10)

View File

@@ -7,6 +7,7 @@ import (
tf "github.com/wamuir/graft/tensorflow"
)
// ImageTensorBuilder incrementally constructs an image tensor in BHWC or BCHW order.
type ImageTensorBuilder struct {
data []float32
shape []ShapeComponent
@@ -29,6 +30,7 @@ func shapeLen(c ShapeComponent, res int) int {
}
}
// NewImageTensorBuilder creates a builder for the given photo input definition.
func NewImageTensorBuilder(input *PhotoInput) (*ImageTensorBuilder, error) {
if len(input.Shape) != 4 {
@@ -62,6 +64,7 @@ func NewImageTensorBuilder(input *PhotoInput) (*ImageTensorBuilder, error) {
}, nil
}
// Set assigns the normalized RGB values for the pixel at (x,y).
func (t *ImageTensorBuilder) Set(x, y int, r, g, b float32) {
t.data[t.flatIndex(x, y, t.rIndex)] = r
t.data[t.flatIndex(x, y, t.gIndex)] = g
@@ -93,6 +96,7 @@ func (t *ImageTensorBuilder) flatIndex(x, y, c int) int {
return idx
}
// BuildTensor materializes the underlying data into a TensorFlow tensor.
func (t *ImageTensorBuilder) BuildTensor() (*tf.Tensor, error) {
arr := make([][][][]float32, shapeLen(t.shape[0], t.resolution))

View File

@@ -6,11 +6,12 @@ func randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
result[i] = charset[rand.IntN(len(charset))]
result[i] = charset[rand.IntN(len(charset))] //nolint:gosec // pseudo-random is sufficient for non-cryptographic identifiers
}
return string(result)
}
// GetOne returns an arbitrary key-value pair from the map or nils when empty.
func GetOne[K comparable, V any](input map[K]V) (*K, *V) {
for k, v := range input {
return &k, &v
@@ -19,6 +20,7 @@ func GetOne[K comparable, V any](input map[K]V) (*K, *V) {
return nil, nil
}
// Deref returns the value of a pointer or a default when the pointer is nil.
func Deref[V any](input *V, defval V) V {
if input == nil {
return defval