diff --git a/internal/ai/tensorflow/image.go b/internal/ai/tensorflow/image.go index 3aea2d186..196b78d61 100644 --- a/internal/ai/tensorflow/image.go +++ b/internal/ai/tensorflow/image.go @@ -4,8 +4,9 @@ import ( "bytes" "fmt" "image" - _ "image/jpeg" - _ "image/png" + _ "image/jpeg" // register JPEG decoder + _ "image/png" // register PNG decoder + "math" "os" "runtime/debug" @@ -16,10 +17,13 @@ import ( ) const ( - Mean = float32(117) + // Mean is the default mean pixel value used during normalization. + Mean = float32(117) + // Scale is the default scale applied during normalization. Scale = float32(1) ) +// ImageFromFile decodes an image from disk and converts it to a tensor for inference. func ImageFromFile(fileName string, input *PhotoInput) (*tf.Tensor, error) { if img, err := OpenImage(fileName); err != nil { return nil, err @@ -28,8 +32,9 @@ func ImageFromFile(fileName string, input *PhotoInput) (*tf.Tensor, error) { } } +// OpenImage opens an image file and decodes it using the registered decoders. func OpenImage(fileName string) (image.Image, error) { - f, err := os.Open(fileName) + f, err := os.Open(fileName) //nolint:gosec // fileName supplied by trusted caller; reading local images is expected if err != nil { return nil, err } @@ -39,6 +44,7 @@ func OpenImage(fileName string) (image.Image, error) { return img, err } +// ImageFromBytes converts raw image bytes into a tensor using the provided input definition. func ImageFromBytes(b []byte, input *PhotoInput, builder *ImageTensorBuilder) (*tf.Tensor, error) { img, _, imgErr := image.Decode(bytes.NewReader(b)) @@ -49,6 +55,7 @@ func ImageFromBytes(b []byte, input *PhotoInput, builder *ImageTensorBuilder) (* return Image(img, input, builder) } +// Image converts a decoded image into a tensor matching the provided input description. func Image(img image.Image, input *PhotoInput, builder *ImageTensorBuilder) (tfTensor *tf.Tensor, err error) { defer func() { if r := recover(); r != nil { @@ -70,8 +77,8 @@ func Image(img image.Image, input *PhotoInput, builder *ImageTensorBuilder) (tfT for i := 0; i < input.Resolution(); i++ { for j := 0; j < input.Resolution(); j++ { r, g, b, _ := img.At(i, j).RGBA() - //Although RGB can be disordered, we assume the input intervals are - //given in RGB order. + // Although RGB can be disordered, we assume the input intervals are + // given in RGB order. builder.Set(i, j, convertValue(r, input.GetInterval(0)), convertValue(g, input.GetInterval(1)), @@ -116,6 +123,10 @@ func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph, s := op.NewScope() input = op.Placeholder(s, tf.String) + if resolution <= 0 || resolution > math.MaxInt32 { + return nil, input, output, fmt.Errorf("tensorflow: resolution %d is out of bounds", resolution) + } + // Assume the image is a JPEG, or a PNG if explicitly specified. var decodedImage tf.Output switch imageFormat { @@ -125,13 +136,15 @@ func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph, decodedImage = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3)) } + size := int32(resolution) //nolint:gosec // resolution is validated to be within int32 range above + output = op.Div(s, op.Sub(s, op.ResizeBilinear(s, op.ExpandDims(s, op.Cast(s, decodedImage, tf.Float), op.Const(s.SubScope("make_batch"), int32(0))), - op.Const(s.SubScope("size"), []int32{int32(resolution), int32(resolution)})), + op.Const(s.SubScope("size"), []int32{size, size})), op.Const(s.SubScope("mean"), Mean)), op.Const(s.SubScope("scale"), Scale)) diff --git a/internal/ai/tensorflow/image_test.go b/internal/ai/tensorflow/image_test.go index 7dc857a1b..7bc0511ae 100644 --- a/internal/ai/tensorflow/image_test.go +++ b/internal/ai/tensorflow/image_test.go @@ -32,7 +32,7 @@ func TestConvertStdMean(t *testing.T) { func TestImageFromBytes(t *testing.T) { t.Run("CatJpeg", func(t *testing.T) { - imageBuffer, err := os.ReadFile(examplesPath + "/cat_brown.jpg") + imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "cat_brown.jpg")) //nolint:gosec // reading bundled test fixture if err != nil { t.Fatal(err) @@ -48,7 +48,7 @@ func TestImageFromBytes(t *testing.T) { assert.Equal(t, int64(224), result.Shape()[2]) }) t.Run("Document", func(t *testing.T) { - imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx") + imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "Random.docx")) //nolint:gosec // reading bundled test fixture assert.Nil(t, err) result, err := ImageFromBytes(imageBuffer, defaultImageInput, nil) diff --git a/internal/ai/tensorflow/info.go b/internal/ai/tensorflow/info.go index 3483d3ad7..581e45875 100644 --- a/internal/ai/tensorflow/info.go +++ b/internal/ai/tensorflow/info.go @@ -17,7 +17,7 @@ import ( // defined for input images as "what decodeImage returns". const ExpectedChannels = 3 -// Interval of allowed values +// Interval of allowed values. type Interval struct { Start float32 `yaml:"Start,omitempty" json:"start,omitempty"` End float32 `yaml:"End,omitempty" json:"end,omitempty"` @@ -53,9 +53,13 @@ func StandardInterval() *Interval { type ResizeOperation int const ( + // UndefinedResizeOperation indicates that no resize strategy was specified. UndefinedResizeOperation ResizeOperation = iota + // ResizeBreakAspectRatio resizes without preserving aspect ratio. ResizeBreakAspectRatio + // CenterCrop crops the center region after resizing to fill the target size. CenterCrop + // Padding resizes while preserving aspect ratio and pads the rest. Padding ) @@ -74,6 +78,7 @@ func (o ResizeOperation) String() string { } } +// NewResizeOperation parses a string into a ResizeOperation. func NewResizeOperation(s string) (ResizeOperation, error) { switch s { case "Undefined": @@ -85,14 +90,16 @@ func NewResizeOperation(s string) (ResizeOperation, error) { case "Padding": return Padding, nil default: - return UndefinedResizeOperation, fmt.Errorf("Invalid operation %s", s) + return UndefinedResizeOperation, fmt.Errorf("invalid operation %s", s) } } +// MarshalJSON encodes the resize operation as its string name. func (o ResizeOperation) MarshalJSON() ([]byte, error) { return json.Marshal(o.String()) } +// UnmarshalJSON decodes a resize operation from its string representation. func (o *ResizeOperation) UnmarshalJSON(data []byte) error { var s string if err := json.Unmarshal(data, &s); err != nil { @@ -108,10 +115,12 @@ func (o *ResizeOperation) UnmarshalJSON(data []byte) error { return nil } +// MarshalYAML encodes the resize operation for YAML output. func (o ResizeOperation) MarshalYAML() (any, error) { return o.String(), nil } +// UnmarshalYAML decodes the resize operation from YAML input. func (o *ResizeOperation) UnmarshalYAML(unmarshal func(interface{}) error) error { var s string if err := unmarshal(&s); err != nil { @@ -131,15 +140,23 @@ func (o *ResizeOperation) UnmarshalYAML(unmarshal func(interface{}) error) error type ColorChannelOrder int const ( + // UndefinedOrder leaves channel order unspecified, defaulting to RGB. UndefinedOrder ColorChannelOrder = 0 - RGB = 123 - RBG = 132 - GRB = 213 - GBR = 231 - BRG = 312 - BGR = 321 + // RGB represents Red-Green-Blue channel order. + RGB = 123 + // RBG represents Red-Blue-Green channel order. + RBG = 132 + // GRB represents Green-Red-Blue channel order. + GRB = 213 + // GBR represents Green-Blue-Red channel order. + GBR = 231 + // BRG represents Blue-Red-Green channel order. + BRG = 312 + // BGR represents Blue-Green-Red channel order. + BGR = 321 ) +// Indices returns the zero-based indices of the R, G, and B channels. func (o ColorChannelOrder) Indices() (r, g, b int) { i := int(o) @@ -147,7 +164,7 @@ func (o ColorChannelOrder) Indices() (r, g, b int) { i = 123 } - for idx := 0; i > 0 && idx < 3; idx += 1 { + for idx := 0; i > 0 && idx < 3; idx++ { remainder := i % 10 i /= 10 @@ -195,9 +212,10 @@ func (o ColorChannelOrder) String() string { return result } +// NewColorChannelOrder parses a string (e.g., "RGB") into a ColorChannelOrder. func NewColorChannelOrder(val string) (ColorChannelOrder, error) { if len(val) != 3 { - return UndefinedOrder, fmt.Errorf("Invalid length, expected 3") + return UndefinedOrder, fmt.Errorf("invalid length, expected 3") } convert := func(c rune) int { @@ -217,17 +235,19 @@ func NewColorChannelOrder(val string) (ColorChannelOrder, error) { for _, c := range val { index := convert(c) if index == 0 { - return UndefinedOrder, fmt.Errorf("Invalid val %c", c) + return UndefinedOrder, fmt.Errorf("invalid val %c", c) } result = result*10 + index } return ColorChannelOrder(result), nil } +// MarshalJSON encodes the channel order as its string name. func (o ColorChannelOrder) MarshalJSON() ([]byte, error) { return json.Marshal(o.String()) } +// UnmarshalJSON decodes a channel order from its string representation. func (o *ColorChannelOrder) UnmarshalJSON(data []byte) error { var s string if err := json.Unmarshal(data, &s); err != nil { @@ -243,10 +263,12 @@ func (o *ColorChannelOrder) UnmarshalJSON(data []byte) error { return nil } +// MarshalYAML encodes the channel order for YAML output. func (o ColorChannelOrder) MarshalYAML() (any, error) { return o.String(), nil } +// UnmarshalYAML decodes the channel order from YAML input. func (o *ColorChannelOrder) UnmarshalYAML(unmarshal func(interface{}) error) error { var s string if err := unmarshal(&s); err != nil { @@ -261,17 +283,22 @@ func (o *ColorChannelOrder) UnmarshalYAML(unmarshal func(interface{}) error) err return nil } -// The expected shape for the input layer of a mode. Usually this shape is -// (batch, resolution, resolution, channels) but sometimes it is not. +// ShapeComponent describes a single dimension of a model input shape. +// Usually this shape is (batch, resolution, resolution, channels) but sometimes it is not. type ShapeComponent string const ( - ShapeBatch ShapeComponent = "Batch" - ShapeWidth = "Width" - ShapeHeight = "Height" - ShapeColor = "Color" + // ShapeBatch represents the batch dimension. + ShapeBatch ShapeComponent = "Batch" + // ShapeWidth represents the width dimension. + ShapeWidth = "Width" + // ShapeHeight represents the height dimension. + ShapeHeight = "Height" + // ShapeColor represents the color/channel dimension. + ShapeColor = "Color" ) +// DefaultPhotoInputShape returns the standard BHWC input shape. func DefaultPhotoInputShape() []ShapeComponent { return []ShapeComponent{ ShapeBatch, @@ -285,7 +312,7 @@ func DefaultPhotoInputShape() []ShapeComponent { type PhotoInput struct { Name string `yaml:"Name,omitempty" json:"name,omitempty"` Intervals []Interval `yaml:"Intervals,omitempty" json:"intervals,omitempty"` - ResizeOperation ResizeOperation `yaml:"ResizeOperation,omitempty" json:"resizeOperation,omitemty"` + ResizeOperation ResizeOperation `yaml:"ResizeOperation,omitempty" json:"resizeOperation,omitempty"` ColorChannelOrder ColorChannelOrder `yaml:"ColorChannelOrder,omitempty" json:"inputOrder,omitempty"` OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"` Height int64 `yaml:"Height,omitempty" json:"height,omitempty"` @@ -427,10 +454,11 @@ func (m ModelInfo) IsComplete() bool { return m.Input != nil && m.Output != nil && m.Input.Shape != nil } +// GetModelTagsInfo reads a SavedModel and returns its available meta graph tags. func GetModelTagsInfo(savedModelPath string) ([]ModelInfo, error) { savedModel := filepath.Join(savedModelPath, "saved_model.pb") - data, err := os.ReadFile(savedModel) + data, err := os.ReadFile(savedModel) //nolint:gosec // savedModel path derived from trusted model directory if err != nil { return nil, fmt.Errorf("vision: failed to read %s (%s)", clean.Path(savedModel), clean.Error(err)) diff --git a/internal/ai/tensorflow/info_test.go b/internal/ai/tensorflow/info_test.go index 7ffaa4eff..21fca576a 100644 --- a/internal/ai/tensorflow/info_test.go +++ b/internal/ai/tensorflow/info_test.go @@ -24,11 +24,12 @@ func TestGetModelTagsInfo(t *testing.T) { t.Fatal(err) } - if len(info) != 1 { + switch { + case len(info) != 1: t.Fatalf("Expected 1 info but got %d", len(info)) - } else if len(info[0].Tags) != 1 { + case len(info[0].Tags) != 1: t.Fatalf("Expected 1 tag, but got %d", len(info[0].Tags)) - } else if info[0].Tags[0] != "photoprism" { + case info[0].Tags[0] != "photoprism": t.Fatalf("Expected tag photoprism, but have %s", info[0].Tags[0]) } } diff --git a/internal/ai/tensorflow/labels.go b/internal/ai/tensorflow/labels.go index c5e9f084f..4d5f552eb 100644 --- a/internal/ai/tensorflow/labels.go +++ b/internal/ai/tensorflow/labels.go @@ -12,7 +12,7 @@ import ( func loadLabelsFromPath(path string) (labels []string, err error) { log.Infof("vision: loading TensorFlow model labels from %s", path) - f, err := os.Open(path) + f, err := os.Open(path) //nolint:gosec // path originates from known model directory; reading labels is expected if err != nil { return nil, err } diff --git a/internal/ai/tensorflow/model.go b/internal/ai/tensorflow/model.go index 580b7c619..ebfbe9f32 100644 --- a/internal/ai/tensorflow/model.go +++ b/internal/ai/tensorflow/model.go @@ -23,7 +23,7 @@ func SavedModel(modelPath string, tags []string) (model *tf.SavedModel, err erro } // GuessInputAndOutput tries to inspect a loaded saved model to build the -// ModelInfo struct +// ModelInfo struct. func GuessInputAndOutput(model *tf.SavedModel) (input *PhotoInput, output *ModelOutput, err error) { if model == nil { return nil, nil, fmt.Errorf("tensorflow: GuessInputAndOutput received a nil input") @@ -39,9 +39,10 @@ func GuessInputAndOutput(model *tf.SavedModel) (input *PhotoInput, output *Model shape := modelOps[i].Output(0).Shape() var comps []ShapeComponent - if shape.Size(3) == ExpectedChannels { + switch { + case shape.Size(3) == ExpectedChannels: comps = []ShapeComponent{ShapeBatch, ShapeHeight, ShapeWidth, ShapeColor} - } else if shape.Size(1) == ExpectedChannels { // check the channels are 3 + case shape.Size(1) == ExpectedChannels: // check the channels are 3 comps = []ShapeComponent{ShapeBatch, ShapeColor, ShapeHeight, ShapeWidth, ShapeColor} } @@ -69,6 +70,7 @@ func GuessInputAndOutput(model *tf.SavedModel) (input *PhotoInput, output *Model return nil, nil, fmt.Errorf("could not guess the inputs and outputs") } +// GetInputAndOutputFromSavedModel reads signature definitions to derive input/output info. func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelOutput, error) { if model == nil { return nil, nil, fmt.Errorf("GetInputAndOutputFromSavedModel: nil input") @@ -86,18 +88,20 @@ func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelO for _, inputTensor := range inputs { if inputTensor.Shape.NumDimensions() == 4 { var comps []ShapeComponent - if inputTensor.Shape.Size(3) == ExpectedChannels { + + switch { + case inputTensor.Shape.Size(3) == ExpectedChannels: comps = []ShapeComponent{ShapeBatch, ShapeHeight, ShapeWidth, ShapeColor} - } else if inputTensor.Shape.Size(1) == ExpectedChannels { // check the channels are 3 + case inputTensor.Shape.Size(1) == ExpectedChannels: // check the channels are 3 comps = []ShapeComponent{ShapeBatch, ShapeColor, ShapeHeight, ShapeWidth} - } else { + default: log.Debugf("tensorflow: shape %d", inputTensor.Shape.Size(1)) } if comps == nil { log.Warnf("tensorflow: skipping signature %v because we could not find the color component", k) } else { - var inputIdx = 0 + inputIdx := 0 var err error inputName, inputIndex, found := strings.Cut(inputTensor.Name, ":") diff --git a/internal/ai/tensorflow/model_test.go b/internal/ai/tensorflow/model_test.go index 58c1aaae6..91fa51465 100644 --- a/internal/ai/tensorflow/model_test.go +++ b/internal/ai/tensorflow/model_test.go @@ -20,23 +20,24 @@ func TestTF1ModelLoad(t *testing.T) { t.Fatal(err) } - input, output, err := GetInputAndOutputFromSavedModel(model) + _, _, err = GetInputAndOutputFromSavedModel(model) if err == nil { t.Fatalf("TF1 does not have signatures, but GetInput worked") } - input, output, err = GuessInputAndOutput(model) + input, output, err := GuessInputAndOutput(model) if err != nil { t.Fatal(err) } - if input == nil { + switch { + case input == nil: t.Fatal("Could not get the input") - } else if output == nil { + case output == nil: t.Fatal("Could not get the output") - } else if input.Shape == nil { + case input.Shape == nil: t.Fatal("Could not get the shape") - } else { + default: t.Logf("Shape: %v", input.Shape) } } @@ -55,15 +56,15 @@ func TestTF2ModelLoad(t *testing.T) { t.Fatal(err) } - if input == nil { + switch { + case input == nil: t.Fatal("Could not get the input") - } else if output == nil { + case output == nil: t.Fatal("Could not get the output") - } else if input.Shape == nil { + case input.Shape == nil: t.Fatal("Could not get the shape") - } else if !slices.Equal(input.Shape, DefaultPhotoInputShape()) { - t.Fatalf("Invalid shape calculated. Expected BHWC, got %v", - input.Shape) + case !slices.Equal(input.Shape, DefaultPhotoInputShape()): + t.Fatalf("Invalid shape calculated. Expected BHWC, got %v", input.Shape) } } @@ -81,16 +82,16 @@ func TestTF2ModelBCHWLoad(t *testing.T) { t.Fatal(err) } - if input == nil { + switch { + case input == nil: t.Fatal("Could not get the input") - } else if output == nil { + case output == nil: t.Fatal("Could not get the output") - } else if input.Shape == nil { + case input.Shape == nil: t.Fatal("Could not get the shape") - } else if !slices.Equal(input.Shape, []ShapeComponent{ + case !slices.Equal(input.Shape, []ShapeComponent{ ShapeBatch, ShapeColor, ShapeHeight, ShapeWidth, - }) { - t.Fatalf("Invalid shape calculated. Expected BCHW, got %v", - input.Shape) + }): + t.Fatalf("Invalid shape calculated. Expected BCHW, got %v", input.Shape) } } diff --git a/internal/ai/tensorflow/op.go b/internal/ai/tensorflow/op.go index 9d954ffa5..15bc53d60 100644 --- a/internal/ai/tensorflow/op.go +++ b/internal/ai/tensorflow/op.go @@ -4,6 +4,7 @@ import ( tf "github.com/wamuir/graft/tensorflow" ) +// AddSoftmax appends a Softmax operation to the graph for the configured model output. func AddSoftmax(graph *tf.Graph, info *ModelInfo) (*tf.Operation, error) { randomName := randomString(10) diff --git a/internal/ai/tensorflow/tensor_builder.go b/internal/ai/tensorflow/tensor_builder.go index abe8b141b..2e25a157b 100644 --- a/internal/ai/tensorflow/tensor_builder.go +++ b/internal/ai/tensorflow/tensor_builder.go @@ -7,6 +7,7 @@ import ( tf "github.com/wamuir/graft/tensorflow" ) +// ImageTensorBuilder incrementally constructs an image tensor in BHWC or BCHW order. type ImageTensorBuilder struct { data []float32 shape []ShapeComponent @@ -29,6 +30,7 @@ func shapeLen(c ShapeComponent, res int) int { } } +// NewImageTensorBuilder creates a builder for the given photo input definition. func NewImageTensorBuilder(input *PhotoInput) (*ImageTensorBuilder, error) { if len(input.Shape) != 4 { @@ -62,6 +64,7 @@ func NewImageTensorBuilder(input *PhotoInput) (*ImageTensorBuilder, error) { }, nil } +// Set assigns the normalized RGB values for the pixel at (x,y). func (t *ImageTensorBuilder) Set(x, y int, r, g, b float32) { t.data[t.flatIndex(x, y, t.rIndex)] = r t.data[t.flatIndex(x, y, t.gIndex)] = g @@ -93,6 +96,7 @@ func (t *ImageTensorBuilder) flatIndex(x, y, c int) int { return idx } +// BuildTensor materializes the underlying data into a TensorFlow tensor. func (t *ImageTensorBuilder) BuildTensor() (*tf.Tensor, error) { arr := make([][][][]float32, shapeLen(t.shape[0], t.resolution)) diff --git a/internal/ai/tensorflow/util.go b/internal/ai/tensorflow/util.go index ce51f9dac..46b189bc0 100644 --- a/internal/ai/tensorflow/util.go +++ b/internal/ai/tensorflow/util.go @@ -6,11 +6,12 @@ func randomString(length int) string { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" result := make([]byte, length) for i := range result { - result[i] = charset[rand.IntN(len(charset))] + result[i] = charset[rand.IntN(len(charset))] //nolint:gosec // pseudo-random is sufficient for non-cryptographic identifiers } return string(result) } +// GetOne returns an arbitrary key-value pair from the map or nils when empty. func GetOne[K comparable, V any](input map[K]V) (*K, *V) { for k, v := range input { return &k, &v @@ -19,6 +20,7 @@ func GetOne[K comparable, V any](input map[K]V) (*K, *V) { return nil, nil } +// Deref returns the value of a pointer or a default when the pointer is nil. func Deref[V any](input *V, defval V) V { if input == nil { return defval