mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-11 16:24:11 +01:00
168 lines
4.3 KiB
Go
168 lines
4.3 KiB
Go
package tensorflow
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"image"
|
|
_ "image/jpeg" // register JPEG decoder
|
|
_ "image/png" // register PNG decoder
|
|
"math"
|
|
"os"
|
|
"runtime/debug"
|
|
|
|
tf "github.com/wamuir/graft/tensorflow"
|
|
"github.com/wamuir/graft/tensorflow/op"
|
|
|
|
"github.com/photoprism/photoprism/pkg/fs"
|
|
)
|
|
|
|
const (
|
|
// Mean is the default mean pixel value used during normalization.
|
|
Mean = float32(117)
|
|
// Scale is the default scale applied during normalization.
|
|
Scale = float32(1)
|
|
)
|
|
|
|
// ImageFromFile decodes an image from disk and converts it to a tensor for inference.
|
|
func ImageFromFile(fileName string, input *PhotoInput) (*tf.Tensor, error) {
|
|
if img, err := OpenImage(fileName); err != nil {
|
|
return nil, err
|
|
} else {
|
|
return Image(img, input, nil)
|
|
}
|
|
}
|
|
|
|
// OpenImage opens an image file and decodes it using the registered decoders.
|
|
func OpenImage(fileName string) (image.Image, error) {
|
|
f, err := os.Open(fileName) //nolint:gosec // fileName supplied by trusted caller; reading local images is expected
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer f.Close()
|
|
img, _, err := image.Decode(f)
|
|
|
|
return img, err
|
|
}
|
|
|
|
// ImageFromBytes converts raw image bytes into a tensor using the provided input definition.
|
|
func ImageFromBytes(b []byte, input *PhotoInput, builder *ImageTensorBuilder) (*tf.Tensor, error) {
|
|
img, _, imgErr := image.Decode(bytes.NewReader(b))
|
|
|
|
if imgErr != nil {
|
|
return nil, imgErr
|
|
}
|
|
|
|
return Image(img, input, builder)
|
|
}
|
|
|
|
// Image converts a decoded image into a tensor matching the provided input description.
|
|
func Image(img image.Image, input *PhotoInput, builder *ImageTensorBuilder) (tfTensor *tf.Tensor, err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err = fmt.Errorf("tensorflow: %s (panic)\nstack: %s", r, debug.Stack())
|
|
}
|
|
}()
|
|
|
|
if input.Resolution() <= 0 {
|
|
return tfTensor, fmt.Errorf("tensorflow: resolution must be larger than 0")
|
|
}
|
|
|
|
if builder == nil {
|
|
builder, err = NewImageTensorBuilder(input)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
for i := 0; i < input.Resolution(); i++ {
|
|
for j := 0; j < input.Resolution(); j++ {
|
|
r, g, b, _ := img.At(i, j).RGBA()
|
|
// Although RGB can be disordered, we assume the input intervals are
|
|
// given in RGB order.
|
|
builder.Set(i, j,
|
|
convertValue(r, input.GetInterval(0)),
|
|
convertValue(g, input.GetInterval(1)),
|
|
convertValue(b, input.GetInterval(2)))
|
|
}
|
|
}
|
|
|
|
return builder.BuildTensor()
|
|
}
|
|
|
|
// ImageTransform transforms the given image into a *tf.Tensor and returns it.
|
|
func ImageTransform(image []byte, imageFormat fs.Type, resolution int) (*tf.Tensor, error) {
|
|
tensor, err := tf.NewTensor(string(image))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
graph, input, output, err := transformImageGraph(imageFormat, resolution)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
session, err := tf.NewSession(graph, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer session.Close()
|
|
|
|
normalized, err := session.Run(
|
|
map[tf.Output]*tf.Tensor{input: tensor},
|
|
[]tf.Output{output},
|
|
nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return normalized[0], nil
|
|
}
|
|
|
|
func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph, input, output tf.Output, err error) {
|
|
s := op.NewScope()
|
|
input = op.Placeholder(s, tf.String)
|
|
|
|
if resolution <= 0 || resolution > math.MaxInt32 {
|
|
return nil, input, output, fmt.Errorf("tensorflow: resolution %d is out of bounds", resolution)
|
|
}
|
|
|
|
// Assume the image is a JPEG, or a PNG if explicitly specified.
|
|
var decodedImage tf.Output
|
|
switch imageFormat {
|
|
case fs.ImagePng:
|
|
decodedImage = op.DecodePng(s, input, op.DecodePngChannels(3))
|
|
default:
|
|
decodedImage = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3))
|
|
}
|
|
|
|
size := int32(resolution) //nolint:gosec // resolution is validated to be within int32 range above
|
|
|
|
output = op.Div(s,
|
|
op.Sub(s,
|
|
op.ResizeBilinear(s,
|
|
op.ExpandDims(s,
|
|
op.Cast(s, decodedImage, tf.Float),
|
|
op.Const(s.SubScope("make_batch"), int32(0))),
|
|
op.Const(s.SubScope("size"), []int32{size, size})),
|
|
op.Const(s.SubScope("mean"), Mean)),
|
|
op.Const(s.SubScope("scale"), Scale))
|
|
|
|
graph, err = s.Finalize()
|
|
|
|
return graph, input, output, err
|
|
}
|
|
|
|
func convertValue(value uint32, interval *Interval) float32 {
|
|
var scale float32
|
|
|
|
if interval.Mean != nil {
|
|
scale = *interval.Mean
|
|
} else {
|
|
scale = interval.Size() / 255.0
|
|
}
|
|
offset := interval.Offset()
|
|
|
|
return (float32(value>>8))*scale + offset
|
|
}
|