Files
photoprism/internal/ai/tensorflow/image.go
raystlin adc4dc0f74 Added new parameters to model input.
New parameters have been added to define the input of the models:
* ResizeOperation: by default center-crop was being performed, now it is
  configurable.
* InputOrder: by default RGB was being used as the order for the array
  values of the input tensor, now it can be configured.
* InputInterval has been changed to InputIntervals (an slice). This
  means that every channel can have its own interval conversion.
* InputInterval can define now stddev and mean, because sometimes
  instead of adjusting the interval, the stddev and mean of the training
data should be use.
2025-07-15 13:31:31 +00:00

154 lines
3.5 KiB
Go

package tensorflow
import (
"bytes"
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"os"
"runtime/debug"
tf "github.com/wamuir/graft/tensorflow"
"github.com/wamuir/graft/tensorflow/op"
"github.com/photoprism/photoprism/pkg/fs"
)
const (
Mean = float32(117)
Scale = float32(1)
)
func ImageFromFile(fileName string, input *PhotoInput) (*tf.Tensor, error) {
if img, err := OpenImage(fileName); err != nil {
return nil, err
} else {
return Image(img, input)
}
}
func OpenImage(fileName string) (image.Image, error) {
f, err := os.Open(fileName)
if err != nil {
return nil, err
}
defer f.Close()
img, _, err := image.Decode(f)
return img, err
}
func ImageFromBytes(b []byte, input *PhotoInput) (*tf.Tensor, error) {
img, _, imgErr := image.Decode(bytes.NewReader(b))
if imgErr != nil {
return nil, imgErr
}
return Image(img, input)
}
func Image(img image.Image, input *PhotoInput) (tfTensor *tf.Tensor, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("tensorflow: %s (panic)\nstack: %s", r, debug.Stack())
}
}()
if input.Resolution() <= 0 {
return tfTensor, fmt.Errorf("tensorflow: resolution must be larger 0")
}
var tfImage [1][][][3]float32
rIndex, gIndex, bIndex := input.InputOrder.Indices()
for j := 0; j < input.Resolution(); j++ {
tfImage[0] = append(tfImage[0], make([][3]float32, input.Resolution()))
}
for i := 0; i < input.Resolution(); i++ {
for j := 0; j < input.Resolution(); j++ {
r, g, b, _ := img.At(i, j).RGBA()
//Although RGB can be disordered, we assume the input intervals are
//given in RGB order.
tfImage[0][j][i][rIndex] = convertValue(r, input.GetInterval(0))
tfImage[0][j][i][gIndex] = convertValue(g, input.GetInterval(1))
tfImage[0][j][i][bIndex] = convertValue(b, input.GetInterval(2))
}
}
return tf.NewTensor(tfImage)
}
// ImageTransform transforms the given image into a *tf.Tensor and returns it.
func ImageTransform(image []byte, imageFormat fs.Type, resolution int) (*tf.Tensor, error) {
tensor, err := tf.NewTensor(string(image))
if err != nil {
return nil, err
}
graph, input, output, err := transformImageGraph(imageFormat, resolution)
if err != nil {
return nil, err
}
session, err := tf.NewSession(graph, nil)
if err != nil {
return nil, err
}
defer session.Close()
normalized, err := session.Run(
map[tf.Output]*tf.Tensor{input: tensor},
[]tf.Output{output},
nil)
if err != nil {
return nil, err
}
return normalized[0], nil
}
func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph, input, output tf.Output, err error) {
s := op.NewScope()
input = op.Placeholder(s, tf.String)
// Assume the image is a JPEG, or a PNG if explicitly specified.
var decodedImage tf.Output
switch imageFormat {
case fs.ImagePng:
decodedImage = op.DecodePng(s, input, op.DecodePngChannels(3))
default:
decodedImage = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3))
}
output = op.Div(s,
op.Sub(s,
op.ResizeBilinear(s,
op.ExpandDims(s,
op.Cast(s, decodedImage, tf.Float),
op.Const(s.SubScope("make_batch"), int32(0))),
op.Const(s.SubScope("size"), []int32{int32(resolution), int32(resolution)})),
op.Const(s.SubScope("mean"), Mean)),
op.Const(s.SubScope("scale"), Scale))
graph, err = s.Finalize()
return graph, input, output, err
}
func convertValue(value uint32, interval *Interval) float32 {
var scale float32
if interval.Mean != nil {
scale = *interval.Mean
} else {
scale = interval.Size() / 255.0
}
offset := interval.Offset()
return (float32(value>>8))*scale + offset
}