mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
CI: Apply Go linter recommendations to "ai/tensorflow" package #5330
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
@@ -4,8 +4,9 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
_ "image/jpeg" // register JPEG decoder
|
||||
_ "image/png" // register PNG decoder
|
||||
"math"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
|
||||
@@ -16,10 +17,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Mean = float32(117)
|
||||
// Mean is the default mean pixel value used during normalization.
|
||||
Mean = float32(117)
|
||||
// Scale is the default scale applied during normalization.
|
||||
Scale = float32(1)
|
||||
)
|
||||
|
||||
// ImageFromFile decodes an image from disk and converts it to a tensor for inference.
|
||||
func ImageFromFile(fileName string, input *PhotoInput) (*tf.Tensor, error) {
|
||||
if img, err := OpenImage(fileName); err != nil {
|
||||
return nil, err
|
||||
@@ -28,8 +32,9 @@ func ImageFromFile(fileName string, input *PhotoInput) (*tf.Tensor, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// OpenImage opens an image file and decodes it using the registered decoders.
|
||||
func OpenImage(fileName string) (image.Image, error) {
|
||||
f, err := os.Open(fileName)
|
||||
f, err := os.Open(fileName) //nolint:gosec // fileName supplied by trusted caller; reading local images is expected
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -39,6 +44,7 @@ func OpenImage(fileName string) (image.Image, error) {
|
||||
return img, err
|
||||
}
|
||||
|
||||
// ImageFromBytes converts raw image bytes into a tensor using the provided input definition.
|
||||
func ImageFromBytes(b []byte, input *PhotoInput, builder *ImageTensorBuilder) (*tf.Tensor, error) {
|
||||
img, _, imgErr := image.Decode(bytes.NewReader(b))
|
||||
|
||||
@@ -49,6 +55,7 @@ func ImageFromBytes(b []byte, input *PhotoInput, builder *ImageTensorBuilder) (*
|
||||
return Image(img, input, builder)
|
||||
}
|
||||
|
||||
// Image converts a decoded image into a tensor matching the provided input description.
|
||||
func Image(img image.Image, input *PhotoInput, builder *ImageTensorBuilder) (tfTensor *tf.Tensor, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -70,8 +77,8 @@ func Image(img image.Image, input *PhotoInput, builder *ImageTensorBuilder) (tfT
|
||||
for i := 0; i < input.Resolution(); i++ {
|
||||
for j := 0; j < input.Resolution(); j++ {
|
||||
r, g, b, _ := img.At(i, j).RGBA()
|
||||
//Although RGB can be disordered, we assume the input intervals are
|
||||
//given in RGB order.
|
||||
// Although RGB can be disordered, we assume the input intervals are
|
||||
// given in RGB order.
|
||||
builder.Set(i, j,
|
||||
convertValue(r, input.GetInterval(0)),
|
||||
convertValue(g, input.GetInterval(1)),
|
||||
@@ -116,6 +123,10 @@ func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph,
|
||||
s := op.NewScope()
|
||||
input = op.Placeholder(s, tf.String)
|
||||
|
||||
if resolution <= 0 || resolution > math.MaxInt32 {
|
||||
return nil, input, output, fmt.Errorf("tensorflow: resolution %d is out of bounds", resolution)
|
||||
}
|
||||
|
||||
// Assume the image is a JPEG, or a PNG if explicitly specified.
|
||||
var decodedImage tf.Output
|
||||
switch imageFormat {
|
||||
@@ -125,13 +136,15 @@ func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph,
|
||||
decodedImage = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3))
|
||||
}
|
||||
|
||||
size := int32(resolution) //nolint:gosec // resolution is validated to be within int32 range above
|
||||
|
||||
output = op.Div(s,
|
||||
op.Sub(s,
|
||||
op.ResizeBilinear(s,
|
||||
op.ExpandDims(s,
|
||||
op.Cast(s, decodedImage, tf.Float),
|
||||
op.Const(s.SubScope("make_batch"), int32(0))),
|
||||
op.Const(s.SubScope("size"), []int32{int32(resolution), int32(resolution)})),
|
||||
op.Const(s.SubScope("size"), []int32{size, size})),
|
||||
op.Const(s.SubScope("mean"), Mean)),
|
||||
op.Const(s.SubScope("scale"), Scale))
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestConvertStdMean(t *testing.T) {
|
||||
|
||||
func TestImageFromBytes(t *testing.T) {
|
||||
t.Run("CatJpeg", func(t *testing.T) {
|
||||
imageBuffer, err := os.ReadFile(examplesPath + "/cat_brown.jpg")
|
||||
imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "cat_brown.jpg")) //nolint:gosec // reading bundled test fixture
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -48,7 +48,7 @@ func TestImageFromBytes(t *testing.T) {
|
||||
assert.Equal(t, int64(224), result.Shape()[2])
|
||||
})
|
||||
t.Run("Document", func(t *testing.T) {
|
||||
imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx")
|
||||
imageBuffer, err := os.ReadFile(filepath.Join(examplesPath, "Random.docx")) //nolint:gosec // reading bundled test fixture
|
||||
assert.Nil(t, err)
|
||||
result, err := ImageFromBytes(imageBuffer, defaultImageInput, nil)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
// defined for input images as "what decodeImage returns".
|
||||
const ExpectedChannels = 3
|
||||
|
||||
// Interval of allowed values
|
||||
// Interval of allowed values.
|
||||
type Interval struct {
|
||||
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
|
||||
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
|
||||
@@ -53,9 +53,13 @@ func StandardInterval() *Interval {
|
||||
type ResizeOperation int
|
||||
|
||||
const (
|
||||
// UndefinedResizeOperation indicates that no resize strategy was specified.
|
||||
UndefinedResizeOperation ResizeOperation = iota
|
||||
// ResizeBreakAspectRatio resizes without preserving aspect ratio.
|
||||
ResizeBreakAspectRatio
|
||||
// CenterCrop crops the center region after resizing to fill the target size.
|
||||
CenterCrop
|
||||
// Padding resizes while preserving aspect ratio and pads the rest.
|
||||
Padding
|
||||
)
|
||||
|
||||
@@ -74,6 +78,7 @@ func (o ResizeOperation) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
// NewResizeOperation parses a string into a ResizeOperation.
|
||||
func NewResizeOperation(s string) (ResizeOperation, error) {
|
||||
switch s {
|
||||
case "Undefined":
|
||||
@@ -85,14 +90,16 @@ func NewResizeOperation(s string) (ResizeOperation, error) {
|
||||
case "Padding":
|
||||
return Padding, nil
|
||||
default:
|
||||
return UndefinedResizeOperation, fmt.Errorf("Invalid operation %s", s)
|
||||
return UndefinedResizeOperation, fmt.Errorf("invalid operation %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON encodes the resize operation as its string name.
|
||||
func (o ResizeOperation) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(o.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON decodes a resize operation from its string representation.
|
||||
func (o *ResizeOperation) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
@@ -108,10 +115,12 @@ func (o *ResizeOperation) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalYAML encodes the resize operation for YAML output.
|
||||
func (o ResizeOperation) MarshalYAML() (any, error) {
|
||||
return o.String(), nil
|
||||
}
|
||||
|
||||
// UnmarshalYAML decodes the resize operation from YAML input.
|
||||
func (o *ResizeOperation) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var s string
|
||||
if err := unmarshal(&s); err != nil {
|
||||
@@ -131,15 +140,23 @@ func (o *ResizeOperation) UnmarshalYAML(unmarshal func(interface{}) error) error
|
||||
type ColorChannelOrder int
|
||||
|
||||
const (
|
||||
// UndefinedOrder leaves channel order unspecified, defaulting to RGB.
|
||||
UndefinedOrder ColorChannelOrder = 0
|
||||
RGB = 123
|
||||
RBG = 132
|
||||
GRB = 213
|
||||
GBR = 231
|
||||
BRG = 312
|
||||
BGR = 321
|
||||
// RGB represents Red-Green-Blue channel order.
|
||||
RGB = 123
|
||||
// RBG represents Red-Blue-Green channel order.
|
||||
RBG = 132
|
||||
// GRB represents Green-Red-Blue channel order.
|
||||
GRB = 213
|
||||
// GBR represents Green-Blue-Red channel order.
|
||||
GBR = 231
|
||||
// BRG represents Blue-Red-Green channel order.
|
||||
BRG = 312
|
||||
// BGR represents Blue-Green-Red channel order.
|
||||
BGR = 321
|
||||
)
|
||||
|
||||
// Indices returns the zero-based indices of the R, G, and B channels.
|
||||
func (o ColorChannelOrder) Indices() (r, g, b int) {
|
||||
i := int(o)
|
||||
|
||||
@@ -147,7 +164,7 @@ func (o ColorChannelOrder) Indices() (r, g, b int) {
|
||||
i = 123
|
||||
}
|
||||
|
||||
for idx := 0; i > 0 && idx < 3; idx += 1 {
|
||||
for idx := 0; i > 0 && idx < 3; idx++ {
|
||||
remainder := i % 10
|
||||
i /= 10
|
||||
|
||||
@@ -195,9 +212,10 @@ func (o ColorChannelOrder) String() string {
|
||||
return result
|
||||
}
|
||||
|
||||
// NewColorChannelOrder parses a string (e.g., "RGB") into a ColorChannelOrder.
|
||||
func NewColorChannelOrder(val string) (ColorChannelOrder, error) {
|
||||
if len(val) != 3 {
|
||||
return UndefinedOrder, fmt.Errorf("Invalid length, expected 3")
|
||||
return UndefinedOrder, fmt.Errorf("invalid length, expected 3")
|
||||
}
|
||||
|
||||
convert := func(c rune) int {
|
||||
@@ -217,17 +235,19 @@ func NewColorChannelOrder(val string) (ColorChannelOrder, error) {
|
||||
for _, c := range val {
|
||||
index := convert(c)
|
||||
if index == 0 {
|
||||
return UndefinedOrder, fmt.Errorf("Invalid val %c", c)
|
||||
return UndefinedOrder, fmt.Errorf("invalid val %c", c)
|
||||
}
|
||||
result = result*10 + index
|
||||
}
|
||||
return ColorChannelOrder(result), nil
|
||||
}
|
||||
|
||||
// MarshalJSON encodes the channel order as its string name.
|
||||
func (o ColorChannelOrder) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(o.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON decodes a channel order from its string representation.
|
||||
func (o *ColorChannelOrder) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
@@ -243,10 +263,12 @@ func (o *ColorChannelOrder) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalYAML encodes the channel order for YAML output.
|
||||
func (o ColorChannelOrder) MarshalYAML() (any, error) {
|
||||
return o.String(), nil
|
||||
}
|
||||
|
||||
// UnmarshalYAML decodes the channel order from YAML input.
|
||||
func (o *ColorChannelOrder) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var s string
|
||||
if err := unmarshal(&s); err != nil {
|
||||
@@ -261,17 +283,22 @@ func (o *ColorChannelOrder) UnmarshalYAML(unmarshal func(interface{}) error) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// The expected shape for the input layer of a mode. Usually this shape is
|
||||
// (batch, resolution, resolution, channels) but sometimes it is not.
|
||||
// ShapeComponent describes a single dimension of a model input shape.
|
||||
// Usually this shape is (batch, resolution, resolution, channels) but sometimes it is not.
|
||||
type ShapeComponent string
|
||||
|
||||
const (
|
||||
ShapeBatch ShapeComponent = "Batch"
|
||||
ShapeWidth = "Width"
|
||||
ShapeHeight = "Height"
|
||||
ShapeColor = "Color"
|
||||
// ShapeBatch represents the batch dimension.
|
||||
ShapeBatch ShapeComponent = "Batch"
|
||||
// ShapeWidth represents the width dimension.
|
||||
ShapeWidth = "Width"
|
||||
// ShapeHeight represents the height dimension.
|
||||
ShapeHeight = "Height"
|
||||
// ShapeColor represents the color/channel dimension.
|
||||
ShapeColor = "Color"
|
||||
)
|
||||
|
||||
// DefaultPhotoInputShape returns the standard BHWC input shape.
|
||||
func DefaultPhotoInputShape() []ShapeComponent {
|
||||
return []ShapeComponent{
|
||||
ShapeBatch,
|
||||
@@ -285,7 +312,7 @@ func DefaultPhotoInputShape() []ShapeComponent {
|
||||
type PhotoInput struct {
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
Intervals []Interval `yaml:"Intervals,omitempty" json:"intervals,omitempty"`
|
||||
ResizeOperation ResizeOperation `yaml:"ResizeOperation,omitempty" json:"resizeOperation,omitemty"`
|
||||
ResizeOperation ResizeOperation `yaml:"ResizeOperation,omitempty" json:"resizeOperation,omitempty"`
|
||||
ColorChannelOrder ColorChannelOrder `yaml:"ColorChannelOrder,omitempty" json:"inputOrder,omitempty"`
|
||||
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
||||
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
|
||||
@@ -427,10 +454,11 @@ func (m ModelInfo) IsComplete() bool {
|
||||
return m.Input != nil && m.Output != nil && m.Input.Shape != nil
|
||||
}
|
||||
|
||||
// GetModelTagsInfo reads a SavedModel and returns its available meta graph tags.
|
||||
func GetModelTagsInfo(savedModelPath string) ([]ModelInfo, error) {
|
||||
savedModel := filepath.Join(savedModelPath, "saved_model.pb")
|
||||
|
||||
data, err := os.ReadFile(savedModel)
|
||||
data, err := os.ReadFile(savedModel) //nolint:gosec // savedModel path derived from trusted model directory
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("vision: failed to read %s (%s)", clean.Path(savedModel), clean.Error(err))
|
||||
|
||||
@@ -24,11 +24,12 @@ func TestGetModelTagsInfo(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(info) != 1 {
|
||||
switch {
|
||||
case len(info) != 1:
|
||||
t.Fatalf("Expected 1 info but got %d", len(info))
|
||||
} else if len(info[0].Tags) != 1 {
|
||||
case len(info[0].Tags) != 1:
|
||||
t.Fatalf("Expected 1 tag, but got %d", len(info[0].Tags))
|
||||
} else if info[0].Tags[0] != "photoprism" {
|
||||
case info[0].Tags[0] != "photoprism":
|
||||
t.Fatalf("Expected tag photoprism, but have %s", info[0].Tags[0])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
func loadLabelsFromPath(path string) (labels []string, err error) {
|
||||
log.Infof("vision: loading TensorFlow model labels from %s", path)
|
||||
|
||||
f, err := os.Open(path)
|
||||
f, err := os.Open(path) //nolint:gosec // path originates from known model directory; reading labels is expected
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ func SavedModel(modelPath string, tags []string) (model *tf.SavedModel, err erro
|
||||
}
|
||||
|
||||
// GuessInputAndOutput tries to inspect a loaded saved model to build the
|
||||
// ModelInfo struct
|
||||
// ModelInfo struct.
|
||||
func GuessInputAndOutput(model *tf.SavedModel) (input *PhotoInput, output *ModelOutput, err error) {
|
||||
if model == nil {
|
||||
return nil, nil, fmt.Errorf("tensorflow: GuessInputAndOutput received a nil input")
|
||||
@@ -39,9 +39,10 @@ func GuessInputAndOutput(model *tf.SavedModel) (input *PhotoInput, output *Model
|
||||
shape := modelOps[i].Output(0).Shape()
|
||||
|
||||
var comps []ShapeComponent
|
||||
if shape.Size(3) == ExpectedChannels {
|
||||
switch {
|
||||
case shape.Size(3) == ExpectedChannels:
|
||||
comps = []ShapeComponent{ShapeBatch, ShapeHeight, ShapeWidth, ShapeColor}
|
||||
} else if shape.Size(1) == ExpectedChannels { // check the channels are 3
|
||||
case shape.Size(1) == ExpectedChannels: // check the channels are 3
|
||||
comps = []ShapeComponent{ShapeBatch, ShapeColor, ShapeHeight, ShapeWidth, ShapeColor}
|
||||
}
|
||||
|
||||
@@ -69,6 +70,7 @@ func GuessInputAndOutput(model *tf.SavedModel) (input *PhotoInput, output *Model
|
||||
return nil, nil, fmt.Errorf("could not guess the inputs and outputs")
|
||||
}
|
||||
|
||||
// GetInputAndOutputFromSavedModel reads signature definitions to derive input/output info.
|
||||
func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelOutput, error) {
|
||||
if model == nil {
|
||||
return nil, nil, fmt.Errorf("GetInputAndOutputFromSavedModel: nil input")
|
||||
@@ -86,18 +88,20 @@ func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelO
|
||||
for _, inputTensor := range inputs {
|
||||
if inputTensor.Shape.NumDimensions() == 4 {
|
||||
var comps []ShapeComponent
|
||||
if inputTensor.Shape.Size(3) == ExpectedChannels {
|
||||
|
||||
switch {
|
||||
case inputTensor.Shape.Size(3) == ExpectedChannels:
|
||||
comps = []ShapeComponent{ShapeBatch, ShapeHeight, ShapeWidth, ShapeColor}
|
||||
} else if inputTensor.Shape.Size(1) == ExpectedChannels { // check the channels are 3
|
||||
case inputTensor.Shape.Size(1) == ExpectedChannels: // check the channels are 3
|
||||
comps = []ShapeComponent{ShapeBatch, ShapeColor, ShapeHeight, ShapeWidth}
|
||||
} else {
|
||||
default:
|
||||
log.Debugf("tensorflow: shape %d", inputTensor.Shape.Size(1))
|
||||
}
|
||||
|
||||
if comps == nil {
|
||||
log.Warnf("tensorflow: skipping signature %v because we could not find the color component", k)
|
||||
} else {
|
||||
var inputIdx = 0
|
||||
inputIdx := 0
|
||||
var err error
|
||||
|
||||
inputName, inputIndex, found := strings.Cut(inputTensor.Name, ":")
|
||||
|
||||
@@ -20,23 +20,24 @@ func TestTF1ModelLoad(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
input, output, err := GetInputAndOutputFromSavedModel(model)
|
||||
_, _, err = GetInputAndOutputFromSavedModel(model)
|
||||
if err == nil {
|
||||
t.Fatalf("TF1 does not have signatures, but GetInput worked")
|
||||
}
|
||||
|
||||
input, output, err = GuessInputAndOutput(model)
|
||||
input, output, err := GuessInputAndOutput(model)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if input == nil {
|
||||
switch {
|
||||
case input == nil:
|
||||
t.Fatal("Could not get the input")
|
||||
} else if output == nil {
|
||||
case output == nil:
|
||||
t.Fatal("Could not get the output")
|
||||
} else if input.Shape == nil {
|
||||
case input.Shape == nil:
|
||||
t.Fatal("Could not get the shape")
|
||||
} else {
|
||||
default:
|
||||
t.Logf("Shape: %v", input.Shape)
|
||||
}
|
||||
}
|
||||
@@ -55,15 +56,15 @@ func TestTF2ModelLoad(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if input == nil {
|
||||
switch {
|
||||
case input == nil:
|
||||
t.Fatal("Could not get the input")
|
||||
} else if output == nil {
|
||||
case output == nil:
|
||||
t.Fatal("Could not get the output")
|
||||
} else if input.Shape == nil {
|
||||
case input.Shape == nil:
|
||||
t.Fatal("Could not get the shape")
|
||||
} else if !slices.Equal(input.Shape, DefaultPhotoInputShape()) {
|
||||
t.Fatalf("Invalid shape calculated. Expected BHWC, got %v",
|
||||
input.Shape)
|
||||
case !slices.Equal(input.Shape, DefaultPhotoInputShape()):
|
||||
t.Fatalf("Invalid shape calculated. Expected BHWC, got %v", input.Shape)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,16 +82,16 @@ func TestTF2ModelBCHWLoad(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if input == nil {
|
||||
switch {
|
||||
case input == nil:
|
||||
t.Fatal("Could not get the input")
|
||||
} else if output == nil {
|
||||
case output == nil:
|
||||
t.Fatal("Could not get the output")
|
||||
} else if input.Shape == nil {
|
||||
case input.Shape == nil:
|
||||
t.Fatal("Could not get the shape")
|
||||
} else if !slices.Equal(input.Shape, []ShapeComponent{
|
||||
case !slices.Equal(input.Shape, []ShapeComponent{
|
||||
ShapeBatch, ShapeColor, ShapeHeight, ShapeWidth,
|
||||
}) {
|
||||
t.Fatalf("Invalid shape calculated. Expected BCHW, got %v",
|
||||
input.Shape)
|
||||
}):
|
||||
t.Fatalf("Invalid shape calculated. Expected BCHW, got %v", input.Shape)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
tf "github.com/wamuir/graft/tensorflow"
|
||||
)
|
||||
|
||||
// AddSoftmax appends a Softmax operation to the graph for the configured model output.
|
||||
func AddSoftmax(graph *tf.Graph, info *ModelInfo) (*tf.Operation, error) {
|
||||
|
||||
randomName := randomString(10)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
tf "github.com/wamuir/graft/tensorflow"
|
||||
)
|
||||
|
||||
// ImageTensorBuilder incrementally constructs an image tensor in BHWC or BCHW order.
|
||||
type ImageTensorBuilder struct {
|
||||
data []float32
|
||||
shape []ShapeComponent
|
||||
@@ -29,6 +30,7 @@ func shapeLen(c ShapeComponent, res int) int {
|
||||
}
|
||||
}
|
||||
|
||||
// NewImageTensorBuilder creates a builder for the given photo input definition.
|
||||
func NewImageTensorBuilder(input *PhotoInput) (*ImageTensorBuilder, error) {
|
||||
|
||||
if len(input.Shape) != 4 {
|
||||
@@ -62,6 +64,7 @@ func NewImageTensorBuilder(input *PhotoInput) (*ImageTensorBuilder, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Set assigns the normalized RGB values for the pixel at (x,y).
|
||||
func (t *ImageTensorBuilder) Set(x, y int, r, g, b float32) {
|
||||
t.data[t.flatIndex(x, y, t.rIndex)] = r
|
||||
t.data[t.flatIndex(x, y, t.gIndex)] = g
|
||||
@@ -93,6 +96,7 @@ func (t *ImageTensorBuilder) flatIndex(x, y, c int) int {
|
||||
return idx
|
||||
}
|
||||
|
||||
// BuildTensor materializes the underlying data into a TensorFlow tensor.
|
||||
func (t *ImageTensorBuilder) BuildTensor() (*tf.Tensor, error) {
|
||||
|
||||
arr := make([][][][]float32, shapeLen(t.shape[0], t.resolution))
|
||||
|
||||
@@ -6,11 +6,12 @@ func randomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, length)
|
||||
for i := range result {
|
||||
result[i] = charset[rand.IntN(len(charset))]
|
||||
result[i] = charset[rand.IntN(len(charset))] //nolint:gosec // pseudo-random is sufficient for non-cryptographic identifiers
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// GetOne returns an arbitrary key-value pair from the map or nils when empty.
|
||||
func GetOne[K comparable, V any](input map[K]V) (*K, *V) {
|
||||
for k, v := range input {
|
||||
return &k, &v
|
||||
@@ -19,6 +20,7 @@ func GetOne[K comparable, V any](input map[K]V) (*K, *V) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Deref returns the value of a pointer or a default when the pointer is nil.
|
||||
func Deref[V any](input *V, defval V) V {
|
||||
if input == nil {
|
||||
return defval
|
||||
|
||||
Reference in New Issue
Block a user