mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
Added Interval parameter to PhotoInput
This parameter allows us to rescale the input of the models because some of them need values between [0, 1] and other between [-1, 1].
This commit is contained in:
@@ -53,10 +53,14 @@ func NewNasnet(assetsPath string, disabled bool) *Model {
|
||||
TFVersion: "1.12.0",
|
||||
Tags: []string{"photoprism"},
|
||||
Input: &tensorflow.PhotoInput{
|
||||
Name: "input_1",
|
||||
Height: 224,
|
||||
Width: 224,
|
||||
Channels: 3,
|
||||
Name: "input_1",
|
||||
Height: 224,
|
||||
Width: 224,
|
||||
Channels: 3,
|
||||
Interval: &tensorflow.Interval{
|
||||
Start: -1,
|
||||
End: 1,
|
||||
},
|
||||
OutputIndex: 0,
|
||||
},
|
||||
Output: &tensorflow.ModelOutput{
|
||||
@@ -290,5 +294,5 @@ func (m *Model) createTensor(image []byte) (*tf.Tensor, error) {
|
||||
img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos)
|
||||
}
|
||||
|
||||
return tensorflow.Image(img, m.meta.Input.Resolution())
|
||||
return tensorflow.Image(img, m.meta.Input)
|
||||
}
|
||||
|
||||
@@ -59,6 +59,12 @@ var modelsInfo = map[string]*tensorflow.ModelInfo{
|
||||
},
|
||||
},
|
||||
"vision-transformer-tensorflow2-vit-b16-classification-v1.tar.gz": &tensorflow.ModelInfo{
|
||||
Input: &tensorflow.PhotoInput{
|
||||
Interval: &tensorflow.Interval{
|
||||
Start: -1.0,
|
||||
End: 1.0,
|
||||
},
|
||||
},
|
||||
Output: &tensorflow.ModelOutput{
|
||||
OutputsLogits: true,
|
||||
},
|
||||
|
||||
@@ -20,11 +20,11 @@ const (
|
||||
Scale = float32(1)
|
||||
)
|
||||
|
||||
func ImageFromFile(fileName string, resolution int) (*tf.Tensor, error) {
|
||||
func ImageFromFile(fileName string, input *PhotoInput) (*tf.Tensor, error) {
|
||||
if img, err := OpenImage(fileName); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return Image(img, resolution)
|
||||
return Image(img, input)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,39 +39,39 @@ func OpenImage(fileName string) (image.Image, error) {
|
||||
return img, err
|
||||
}
|
||||
|
||||
func ImageFromBytes(b []byte, resolution int) (*tf.Tensor, error) {
|
||||
func ImageFromBytes(b []byte, input *PhotoInput) (*tf.Tensor, error) {
|
||||
img, _, imgErr := image.Decode(bytes.NewReader(b))
|
||||
|
||||
if imgErr != nil {
|
||||
return nil, imgErr
|
||||
}
|
||||
|
||||
return Image(img, resolution)
|
||||
return Image(img, input)
|
||||
}
|
||||
|
||||
func Image(img image.Image, resolution int) (tfTensor *tf.Tensor, err error) {
|
||||
func Image(img image.Image, input *PhotoInput) (tfTensor *tf.Tensor, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("tensorflow: %s (panic)\nstack: %s", r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
if resolution <= 0 {
|
||||
if input.Resolution() <= 0 {
|
||||
return tfTensor, fmt.Errorf("tensorflow: resolution must be larger 0")
|
||||
}
|
||||
|
||||
var tfImage [1][][][3]float32
|
||||
|
||||
for j := 0; j < resolution; j++ {
|
||||
tfImage[0] = append(tfImage[0], make([][3]float32, resolution))
|
||||
for j := 0; j < input.Resolution(); j++ {
|
||||
tfImage[0] = append(tfImage[0], make([][3]float32, input.Resolution()))
|
||||
}
|
||||
|
||||
for i := 0; i < resolution; i++ {
|
||||
for j := 0; j < resolution; j++ {
|
||||
for i := 0; i < input.Resolution(); i++ {
|
||||
for j := 0; j < input.Resolution(); j++ {
|
||||
r, g, b, _ := img.At(i, j).RGBA()
|
||||
tfImage[0][j][i][0] = convertValue(r, 127.5)
|
||||
tfImage[0][j][i][1] = convertValue(g, 127.5)
|
||||
tfImage[0][j][i][2] = convertValue(b, 127.5)
|
||||
tfImage[0][j][i][0] = convertValue(r, input.GetInterval())
|
||||
tfImage[0][j][i][1] = convertValue(g, input.GetInterval())
|
||||
tfImage[0][j][i][2] = convertValue(b, input.GetInterval())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,10 +136,9 @@ func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph,
|
||||
return graph, input, output, err
|
||||
}
|
||||
|
||||
func convertValue(value uint32, mean float32) float32 {
|
||||
if mean == 0 {
|
||||
mean = 127.5
|
||||
}
|
||||
func convertValue(value uint32, interval *Interval) float32 {
|
||||
scale := interval.Size() / 255.0
|
||||
offset := interval.Start
|
||||
|
||||
return (float32(value>>8) - mean) / mean
|
||||
return (float32(value>>8))*scale + offset
|
||||
}
|
||||
|
||||
@@ -10,9 +10,15 @@ import (
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
||||
var defaultImageInput = &PhotoInput{
|
||||
Height: 224,
|
||||
Width: 224,
|
||||
Channels: 3,
|
||||
}
|
||||
|
||||
func TestConvertValue(t *testing.T) {
|
||||
result := convertValue(uint32(98765432), 127.5)
|
||||
assert.Equal(t, float32(3024.898), result)
|
||||
result := convertValue(uint32(98765432), &Interval{Start: -1, End: 1})
|
||||
assert.Equal(t, float32(3024.8982), result)
|
||||
}
|
||||
|
||||
func TestImageFromBytes(t *testing.T) {
|
||||
@@ -26,7 +32,7 @@ func TestImageFromBytes(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, err := ImageFromBytes(imageBuffer, 224)
|
||||
result, err := ImageFromBytes(imageBuffer, defaultImageInput)
|
||||
assert.Equal(t, tensorflow.DataType(0x1), result.DataType())
|
||||
assert.Equal(t, int64(1), result.Shape()[0])
|
||||
assert.Equal(t, int64(224), result.Shape()[2])
|
||||
@@ -34,7 +40,7 @@ func TestImageFromBytes(t *testing.T) {
|
||||
t.Run("Document", func(t *testing.T) {
|
||||
imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx")
|
||||
assert.Nil(t, err)
|
||||
result, err := ImageFromBytes(imageBuffer, 224)
|
||||
result, err := ImageFromBytes(imageBuffer, defaultImageInput)
|
||||
|
||||
assert.Empty(t, result)
|
||||
assert.EqualError(t, err, "image: unknown format")
|
||||
|
||||
@@ -11,13 +11,33 @@ import (
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// Interval of allowed values
|
||||
type Interval struct {
|
||||
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
|
||||
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
|
||||
}
|
||||
|
||||
// The size of the interval
|
||||
func (i Interval) Size() float32 {
|
||||
return i.End - i.Start
|
||||
}
|
||||
|
||||
// The standard interval returned by decodeImage is [0, 1]
|
||||
func StandardInterval() *Interval {
|
||||
return &Interval{
|
||||
Start: 0.0,
|
||||
End: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
// Input description for a photo input for a model
|
||||
type PhotoInput struct {
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
||||
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
|
||||
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
|
||||
Channels int64 `yaml:"Channels,omitempty" json:"channels,omitempty"`
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
Interval *Interval `yaml:"Interval,omitempty" json:"interval,omitempty"`
|
||||
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
||||
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
|
||||
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
|
||||
Channels int64 `yaml:"Channels,omitempty" json:"channels,omitempty"`
|
||||
}
|
||||
|
||||
// When dimensions are not defined, it means the model accepts any size of
|
||||
@@ -37,12 +57,25 @@ func (p *PhotoInput) SetResolution(resolution int) {
|
||||
p.Width = int64(resolution)
|
||||
}
|
||||
|
||||
// Get the interval or the default one
|
||||
func (p PhotoInput) GetInterval() *Interval {
|
||||
if p.Interval == nil {
|
||||
return StandardInterval()
|
||||
} else {
|
||||
return p.Interval
|
||||
}
|
||||
}
|
||||
|
||||
// Merge other input with this.
|
||||
func (p *PhotoInput) Merge(other *PhotoInput) {
|
||||
if p.Name == "" {
|
||||
p.Name = other.Name
|
||||
}
|
||||
|
||||
if p.Interval == nil && other.Interval != nil {
|
||||
p.Interval = other.Interval
|
||||
}
|
||||
|
||||
if p.OutputIndex == 0 {
|
||||
p.OutputIndex = other.OutputIndex
|
||||
}
|
||||
|
||||
@@ -16,9 +16,13 @@ var (
|
||||
TFVersion: "1.12.0",
|
||||
Tags: []string{"photoprism"},
|
||||
Input: &tensorflow.PhotoInput{
|
||||
Name: "input_1",
|
||||
Height: 224,
|
||||
Width: 224,
|
||||
Name: "input_1",
|
||||
Height: 224,
|
||||
Width: 224,
|
||||
Interval: &tensorflow.Interval{
|
||||
Start: -1.0,
|
||||
End: 1.0,
|
||||
},
|
||||
Channels: 3,
|
||||
OutputIndex: 0,
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user