Added Interval parameter to PhotoInput

This parameter allows us to rescale the input of the models because some
of them need values between [0, 1] and other between [-1, 1].
This commit is contained in:
raystlin
2025-04-13 20:32:02 +00:00
parent e55536e581
commit 5521d06bc0
6 changed files with 87 additions and 35 deletions

View File

@@ -53,10 +53,14 @@ func NewNasnet(assetsPath string, disabled bool) *Model {
TFVersion: "1.12.0",
Tags: []string{"photoprism"},
Input: &tensorflow.PhotoInput{
Name: "input_1",
Height: 224,
Width: 224,
Channels: 3,
Name: "input_1",
Height: 224,
Width: 224,
Channels: 3,
Interval: &tensorflow.Interval{
Start: -1,
End: 1,
},
OutputIndex: 0,
},
Output: &tensorflow.ModelOutput{
@@ -290,5 +294,5 @@ func (m *Model) createTensor(image []byte) (*tf.Tensor, error) {
img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos)
}
return tensorflow.Image(img, m.meta.Input.Resolution())
return tensorflow.Image(img, m.meta.Input)
}

View File

@@ -59,6 +59,12 @@ var modelsInfo = map[string]*tensorflow.ModelInfo{
},
},
"vision-transformer-tensorflow2-vit-b16-classification-v1.tar.gz": &tensorflow.ModelInfo{
Input: &tensorflow.PhotoInput{
Interval: &tensorflow.Interval{
Start: -1.0,
End: 1.0,
},
},
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},

View File

@@ -20,11 +20,11 @@ const (
Scale = float32(1)
)
func ImageFromFile(fileName string, resolution int) (*tf.Tensor, error) {
func ImageFromFile(fileName string, input *PhotoInput) (*tf.Tensor, error) {
if img, err := OpenImage(fileName); err != nil {
return nil, err
} else {
return Image(img, resolution)
return Image(img, input)
}
}
@@ -39,39 +39,39 @@ func OpenImage(fileName string) (image.Image, error) {
return img, err
}
func ImageFromBytes(b []byte, resolution int) (*tf.Tensor, error) {
func ImageFromBytes(b []byte, input *PhotoInput) (*tf.Tensor, error) {
img, _, imgErr := image.Decode(bytes.NewReader(b))
if imgErr != nil {
return nil, imgErr
}
return Image(img, resolution)
return Image(img, input)
}
func Image(img image.Image, resolution int) (tfTensor *tf.Tensor, err error) {
func Image(img image.Image, input *PhotoInput) (tfTensor *tf.Tensor, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("tensorflow: %s (panic)\nstack: %s", r, debug.Stack())
}
}()
if resolution <= 0 {
if input.Resolution() <= 0 {
return tfTensor, fmt.Errorf("tensorflow: resolution must be larger 0")
}
var tfImage [1][][][3]float32
for j := 0; j < resolution; j++ {
tfImage[0] = append(tfImage[0], make([][3]float32, resolution))
for j := 0; j < input.Resolution(); j++ {
tfImage[0] = append(tfImage[0], make([][3]float32, input.Resolution()))
}
for i := 0; i < resolution; i++ {
for j := 0; j < resolution; j++ {
for i := 0; i < input.Resolution(); i++ {
for j := 0; j < input.Resolution(); j++ {
r, g, b, _ := img.At(i, j).RGBA()
tfImage[0][j][i][0] = convertValue(r, 127.5)
tfImage[0][j][i][1] = convertValue(g, 127.5)
tfImage[0][j][i][2] = convertValue(b, 127.5)
tfImage[0][j][i][0] = convertValue(r, input.GetInterval())
tfImage[0][j][i][1] = convertValue(g, input.GetInterval())
tfImage[0][j][i][2] = convertValue(b, input.GetInterval())
}
}
@@ -136,10 +136,9 @@ func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph,
return graph, input, output, err
}
func convertValue(value uint32, mean float32) float32 {
if mean == 0 {
mean = 127.5
}
func convertValue(value uint32, interval *Interval) float32 {
scale := interval.Size() / 255.0
offset := interval.Start
return (float32(value>>8) - mean) / mean
return (float32(value>>8))*scale + offset
}

View File

@@ -10,9 +10,15 @@ import (
"github.com/photoprism/photoprism/pkg/fs"
)
var defaultImageInput = &PhotoInput{
Height: 224,
Width: 224,
Channels: 3,
}
func TestConvertValue(t *testing.T) {
result := convertValue(uint32(98765432), 127.5)
assert.Equal(t, float32(3024.898), result)
result := convertValue(uint32(98765432), &Interval{Start: -1, End: 1})
assert.Equal(t, float32(3024.8982), result)
}
func TestImageFromBytes(t *testing.T) {
@@ -26,7 +32,7 @@ func TestImageFromBytes(t *testing.T) {
t.Fatal(err)
}
result, err := ImageFromBytes(imageBuffer, 224)
result, err := ImageFromBytes(imageBuffer, defaultImageInput)
assert.Equal(t, tensorflow.DataType(0x1), result.DataType())
assert.Equal(t, int64(1), result.Shape()[0])
assert.Equal(t, int64(224), result.Shape()[2])
@@ -34,7 +40,7 @@ func TestImageFromBytes(t *testing.T) {
t.Run("Document", func(t *testing.T) {
imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx")
assert.Nil(t, err)
result, err := ImageFromBytes(imageBuffer, 224)
result, err := ImageFromBytes(imageBuffer, defaultImageInput)
assert.Empty(t, result)
assert.EqualError(t, err, "image: unknown format")

View File

@@ -11,13 +11,33 @@ import (
"google.golang.org/protobuf/proto"
)
// Interval of allowed values
type Interval struct {
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
}
// The size of the interval
func (i Interval) Size() float32 {
return i.End - i.Start
}
// The standard interval returned by decodeImage is [0, 1]
func StandardInterval() *Interval {
return &Interval{
Start: 0.0,
End: 1.0,
}
}
// Input description for a photo input for a model
type PhotoInput struct {
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
Channels int64 `yaml:"Channels,omitempty" json:"channels,omitempty"`
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Interval *Interval `yaml:"Interval,omitempty" json:"interval,omitempty"`
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
Channels int64 `yaml:"Channels,omitempty" json:"channels,omitempty"`
}
// When dimensions are not defined, it means the model accepts any size of
@@ -37,12 +57,25 @@ func (p *PhotoInput) SetResolution(resolution int) {
p.Width = int64(resolution)
}
// Get the interval or the default one
func (p PhotoInput) GetInterval() *Interval {
if p.Interval == nil {
return StandardInterval()
} else {
return p.Interval
}
}
// Merge other input with this.
func (p *PhotoInput) Merge(other *PhotoInput) {
if p.Name == "" {
p.Name = other.Name
}
if p.Interval == nil && other.Interval != nil {
p.Interval = other.Interval
}
if p.OutputIndex == 0 {
p.OutputIndex = other.OutputIndex
}

View File

@@ -16,9 +16,13 @@ var (
TFVersion: "1.12.0",
Tags: []string{"photoprism"},
Input: &tensorflow.PhotoInput{
Name: "input_1",
Height: 224,
Width: 224,
Name: "input_1",
Height: 224,
Width: 224,
Interval: &tensorflow.Interval{
Start: -1.0,
End: 1.0,
},
Channels: 3,
OutputIndex: 0,
},