Added Interval parameter to PhotoInput

This parameter allows us to rescale the input of the models because some
of them need values between [0, 1] and other between [-1, 1].
This commit is contained in:
raystlin
2025-04-13 20:32:02 +00:00
parent e55536e581
commit 5521d06bc0
6 changed files with 87 additions and 35 deletions

View File

@@ -53,10 +53,14 @@ func NewNasnet(assetsPath string, disabled bool) *Model {
TFVersion: "1.12.0", TFVersion: "1.12.0",
Tags: []string{"photoprism"}, Tags: []string{"photoprism"},
Input: &tensorflow.PhotoInput{ Input: &tensorflow.PhotoInput{
Name: "input_1", Name: "input_1",
Height: 224, Height: 224,
Width: 224, Width: 224,
Channels: 3, Channels: 3,
Interval: &tensorflow.Interval{
Start: -1,
End: 1,
},
OutputIndex: 0, OutputIndex: 0,
}, },
Output: &tensorflow.ModelOutput{ Output: &tensorflow.ModelOutput{
@@ -290,5 +294,5 @@ func (m *Model) createTensor(image []byte) (*tf.Tensor, error) {
img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos) img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos)
} }
return tensorflow.Image(img, m.meta.Input.Resolution()) return tensorflow.Image(img, m.meta.Input)
} }

View File

@@ -59,6 +59,12 @@ var modelsInfo = map[string]*tensorflow.ModelInfo{
}, },
}, },
"vision-transformer-tensorflow2-vit-b16-classification-v1.tar.gz": &tensorflow.ModelInfo{ "vision-transformer-tensorflow2-vit-b16-classification-v1.tar.gz": &tensorflow.ModelInfo{
Input: &tensorflow.PhotoInput{
Interval: &tensorflow.Interval{
Start: -1.0,
End: 1.0,
},
},
Output: &tensorflow.ModelOutput{ Output: &tensorflow.ModelOutput{
OutputsLogits: true, OutputsLogits: true,
}, },

View File

@@ -20,11 +20,11 @@ const (
Scale = float32(1) Scale = float32(1)
) )
func ImageFromFile(fileName string, resolution int) (*tf.Tensor, error) { func ImageFromFile(fileName string, input *PhotoInput) (*tf.Tensor, error) {
if img, err := OpenImage(fileName); err != nil { if img, err := OpenImage(fileName); err != nil {
return nil, err return nil, err
} else { } else {
return Image(img, resolution) return Image(img, input)
} }
} }
@@ -39,39 +39,39 @@ func OpenImage(fileName string) (image.Image, error) {
return img, err return img, err
} }
func ImageFromBytes(b []byte, resolution int) (*tf.Tensor, error) { func ImageFromBytes(b []byte, input *PhotoInput) (*tf.Tensor, error) {
img, _, imgErr := image.Decode(bytes.NewReader(b)) img, _, imgErr := image.Decode(bytes.NewReader(b))
if imgErr != nil { if imgErr != nil {
return nil, imgErr return nil, imgErr
} }
return Image(img, resolution) return Image(img, input)
} }
func Image(img image.Image, resolution int) (tfTensor *tf.Tensor, err error) { func Image(img image.Image, input *PhotoInput) (tfTensor *tf.Tensor, err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
err = fmt.Errorf("tensorflow: %s (panic)\nstack: %s", r, debug.Stack()) err = fmt.Errorf("tensorflow: %s (panic)\nstack: %s", r, debug.Stack())
} }
}() }()
if resolution <= 0 { if input.Resolution() <= 0 {
return tfTensor, fmt.Errorf("tensorflow: resolution must be larger 0") return tfTensor, fmt.Errorf("tensorflow: resolution must be larger 0")
} }
var tfImage [1][][][3]float32 var tfImage [1][][][3]float32
for j := 0; j < resolution; j++ { for j := 0; j < input.Resolution(); j++ {
tfImage[0] = append(tfImage[0], make([][3]float32, resolution)) tfImage[0] = append(tfImage[0], make([][3]float32, input.Resolution()))
} }
for i := 0; i < resolution; i++ { for i := 0; i < input.Resolution(); i++ {
for j := 0; j < resolution; j++ { for j := 0; j < input.Resolution(); j++ {
r, g, b, _ := img.At(i, j).RGBA() r, g, b, _ := img.At(i, j).RGBA()
tfImage[0][j][i][0] = convertValue(r, 127.5) tfImage[0][j][i][0] = convertValue(r, input.GetInterval())
tfImage[0][j][i][1] = convertValue(g, 127.5) tfImage[0][j][i][1] = convertValue(g, input.GetInterval())
tfImage[0][j][i][2] = convertValue(b, 127.5) tfImage[0][j][i][2] = convertValue(b, input.GetInterval())
} }
} }
@@ -136,10 +136,9 @@ func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph,
return graph, input, output, err return graph, input, output, err
} }
func convertValue(value uint32, mean float32) float32 { func convertValue(value uint32, interval *Interval) float32 {
if mean == 0 { scale := interval.Size() / 255.0
mean = 127.5 offset := interval.Start
}
return (float32(value>>8) - mean) / mean return (float32(value>>8))*scale + offset
} }

View File

@@ -10,9 +10,15 @@ import (
"github.com/photoprism/photoprism/pkg/fs" "github.com/photoprism/photoprism/pkg/fs"
) )
var defaultImageInput = &PhotoInput{
Height: 224,
Width: 224,
Channels: 3,
}
func TestConvertValue(t *testing.T) { func TestConvertValue(t *testing.T) {
result := convertValue(uint32(98765432), 127.5) result := convertValue(uint32(98765432), &Interval{Start: -1, End: 1})
assert.Equal(t, float32(3024.898), result) assert.Equal(t, float32(3024.8982), result)
} }
func TestImageFromBytes(t *testing.T) { func TestImageFromBytes(t *testing.T) {
@@ -26,7 +32,7 @@ func TestImageFromBytes(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
result, err := ImageFromBytes(imageBuffer, 224) result, err := ImageFromBytes(imageBuffer, defaultImageInput)
assert.Equal(t, tensorflow.DataType(0x1), result.DataType()) assert.Equal(t, tensorflow.DataType(0x1), result.DataType())
assert.Equal(t, int64(1), result.Shape()[0]) assert.Equal(t, int64(1), result.Shape()[0])
assert.Equal(t, int64(224), result.Shape()[2]) assert.Equal(t, int64(224), result.Shape()[2])
@@ -34,7 +40,7 @@ func TestImageFromBytes(t *testing.T) {
t.Run("Document", func(t *testing.T) { t.Run("Document", func(t *testing.T) {
imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx") imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx")
assert.Nil(t, err) assert.Nil(t, err)
result, err := ImageFromBytes(imageBuffer, 224) result, err := ImageFromBytes(imageBuffer, defaultImageInput)
assert.Empty(t, result) assert.Empty(t, result)
assert.EqualError(t, err, "image: unknown format") assert.EqualError(t, err, "image: unknown format")

View File

@@ -11,13 +11,33 @@ import (
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )
// Interval of allowed values
type Interval struct {
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
}
// The size of the interval
func (i Interval) Size() float32 {
return i.End - i.Start
}
// The standard interval returned by decodeImage is [0, 1]
func StandardInterval() *Interval {
return &Interval{
Start: 0.0,
End: 1.0,
}
}
// Input description for a photo input for a model // Input description for a photo input for a model
type PhotoInput struct { type PhotoInput struct {
Name string `yaml:"Name,omitempty" json:"name,omitempty"` Name string `yaml:"Name,omitempty" json:"name,omitempty"`
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"` Interval *Interval `yaml:"Interval,omitempty" json:"interval,omitempty"`
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"` OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"` Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
Channels int64 `yaml:"Channels,omitempty" json:"channels,omitempty"` Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
Channels int64 `yaml:"Channels,omitempty" json:"channels,omitempty"`
} }
// When dimensions are not defined, it means the model accepts any size of // When dimensions are not defined, it means the model accepts any size of
@@ -37,12 +57,25 @@ func (p *PhotoInput) SetResolution(resolution int) {
p.Width = int64(resolution) p.Width = int64(resolution)
} }
// Get the interval or the default one
func (p PhotoInput) GetInterval() *Interval {
if p.Interval == nil {
return StandardInterval()
} else {
return p.Interval
}
}
// Merge other input with this. // Merge other input with this.
func (p *PhotoInput) Merge(other *PhotoInput) { func (p *PhotoInput) Merge(other *PhotoInput) {
if p.Name == "" { if p.Name == "" {
p.Name = other.Name p.Name = other.Name
} }
if p.Interval == nil && other.Interval != nil {
p.Interval = other.Interval
}
if p.OutputIndex == 0 { if p.OutputIndex == 0 {
p.OutputIndex = other.OutputIndex p.OutputIndex = other.OutputIndex
} }

View File

@@ -16,9 +16,13 @@ var (
TFVersion: "1.12.0", TFVersion: "1.12.0",
Tags: []string{"photoprism"}, Tags: []string{"photoprism"},
Input: &tensorflow.PhotoInput{ Input: &tensorflow.PhotoInput{
Name: "input_1", Name: "input_1",
Height: 224, Height: 224,
Width: 224, Width: 224,
Interval: &tensorflow.Interval{
Start: -1.0,
End: 1.0,
},
Channels: 3, Channels: 3,
OutputIndex: 0, OutputIndex: 0,
}, },