mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
Added Interval parameter to PhotoInput
This parameter allows us to rescale the input of the models because some of them need values between [0, 1] and other between [-1, 1].
This commit is contained in:
@@ -57,6 +57,10 @@ func NewNasnet(assetsPath string, disabled bool) *Model {
|
|||||||
Height: 224,
|
Height: 224,
|
||||||
Width: 224,
|
Width: 224,
|
||||||
Channels: 3,
|
Channels: 3,
|
||||||
|
Interval: &tensorflow.Interval{
|
||||||
|
Start: -1,
|
||||||
|
End: 1,
|
||||||
|
},
|
||||||
OutputIndex: 0,
|
OutputIndex: 0,
|
||||||
},
|
},
|
||||||
Output: &tensorflow.ModelOutput{
|
Output: &tensorflow.ModelOutput{
|
||||||
@@ -290,5 +294,5 @@ func (m *Model) createTensor(image []byte) (*tf.Tensor, error) {
|
|||||||
img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos)
|
img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tensorflow.Image(img, m.meta.Input.Resolution())
|
return tensorflow.Image(img, m.meta.Input)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,6 +59,12 @@ var modelsInfo = map[string]*tensorflow.ModelInfo{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"vision-transformer-tensorflow2-vit-b16-classification-v1.tar.gz": &tensorflow.ModelInfo{
|
"vision-transformer-tensorflow2-vit-b16-classification-v1.tar.gz": &tensorflow.ModelInfo{
|
||||||
|
Input: &tensorflow.PhotoInput{
|
||||||
|
Interval: &tensorflow.Interval{
|
||||||
|
Start: -1.0,
|
||||||
|
End: 1.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
Output: &tensorflow.ModelOutput{
|
Output: &tensorflow.ModelOutput{
|
||||||
OutputsLogits: true,
|
OutputsLogits: true,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -20,11 +20,11 @@ const (
|
|||||||
Scale = float32(1)
|
Scale = float32(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
func ImageFromFile(fileName string, resolution int) (*tf.Tensor, error) {
|
func ImageFromFile(fileName string, input *PhotoInput) (*tf.Tensor, error) {
|
||||||
if img, err := OpenImage(fileName); err != nil {
|
if img, err := OpenImage(fileName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
return Image(img, resolution)
|
return Image(img, input)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,39 +39,39 @@ func OpenImage(fileName string) (image.Image, error) {
|
|||||||
return img, err
|
return img, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func ImageFromBytes(b []byte, resolution int) (*tf.Tensor, error) {
|
func ImageFromBytes(b []byte, input *PhotoInput) (*tf.Tensor, error) {
|
||||||
img, _, imgErr := image.Decode(bytes.NewReader(b))
|
img, _, imgErr := image.Decode(bytes.NewReader(b))
|
||||||
|
|
||||||
if imgErr != nil {
|
if imgErr != nil {
|
||||||
return nil, imgErr
|
return nil, imgErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return Image(img, resolution)
|
return Image(img, input)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Image(img image.Image, resolution int) (tfTensor *tf.Tensor, err error) {
|
func Image(img image.Image, input *PhotoInput) (tfTensor *tf.Tensor, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
err = fmt.Errorf("tensorflow: %s (panic)\nstack: %s", r, debug.Stack())
|
err = fmt.Errorf("tensorflow: %s (panic)\nstack: %s", r, debug.Stack())
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if resolution <= 0 {
|
if input.Resolution() <= 0 {
|
||||||
return tfTensor, fmt.Errorf("tensorflow: resolution must be larger 0")
|
return tfTensor, fmt.Errorf("tensorflow: resolution must be larger 0")
|
||||||
}
|
}
|
||||||
|
|
||||||
var tfImage [1][][][3]float32
|
var tfImage [1][][][3]float32
|
||||||
|
|
||||||
for j := 0; j < resolution; j++ {
|
for j := 0; j < input.Resolution(); j++ {
|
||||||
tfImage[0] = append(tfImage[0], make([][3]float32, resolution))
|
tfImage[0] = append(tfImage[0], make([][3]float32, input.Resolution()))
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < resolution; i++ {
|
for i := 0; i < input.Resolution(); i++ {
|
||||||
for j := 0; j < resolution; j++ {
|
for j := 0; j < input.Resolution(); j++ {
|
||||||
r, g, b, _ := img.At(i, j).RGBA()
|
r, g, b, _ := img.At(i, j).RGBA()
|
||||||
tfImage[0][j][i][0] = convertValue(r, 127.5)
|
tfImage[0][j][i][0] = convertValue(r, input.GetInterval())
|
||||||
tfImage[0][j][i][1] = convertValue(g, 127.5)
|
tfImage[0][j][i][1] = convertValue(g, input.GetInterval())
|
||||||
tfImage[0][j][i][2] = convertValue(b, 127.5)
|
tfImage[0][j][i][2] = convertValue(b, input.GetInterval())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,10 +136,9 @@ func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph,
|
|||||||
return graph, input, output, err
|
return graph, input, output, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertValue(value uint32, mean float32) float32 {
|
func convertValue(value uint32, interval *Interval) float32 {
|
||||||
if mean == 0 {
|
scale := interval.Size() / 255.0
|
||||||
mean = 127.5
|
offset := interval.Start
|
||||||
}
|
|
||||||
|
|
||||||
return (float32(value>>8) - mean) / mean
|
return (float32(value>>8))*scale + offset
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,9 +10,15 @@ import (
|
|||||||
"github.com/photoprism/photoprism/pkg/fs"
|
"github.com/photoprism/photoprism/pkg/fs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var defaultImageInput = &PhotoInput{
|
||||||
|
Height: 224,
|
||||||
|
Width: 224,
|
||||||
|
Channels: 3,
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertValue(t *testing.T) {
|
func TestConvertValue(t *testing.T) {
|
||||||
result := convertValue(uint32(98765432), 127.5)
|
result := convertValue(uint32(98765432), &Interval{Start: -1, End: 1})
|
||||||
assert.Equal(t, float32(3024.898), result)
|
assert.Equal(t, float32(3024.8982), result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestImageFromBytes(t *testing.T) {
|
func TestImageFromBytes(t *testing.T) {
|
||||||
@@ -26,7 +32,7 @@ func TestImageFromBytes(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := ImageFromBytes(imageBuffer, 224)
|
result, err := ImageFromBytes(imageBuffer, defaultImageInput)
|
||||||
assert.Equal(t, tensorflow.DataType(0x1), result.DataType())
|
assert.Equal(t, tensorflow.DataType(0x1), result.DataType())
|
||||||
assert.Equal(t, int64(1), result.Shape()[0])
|
assert.Equal(t, int64(1), result.Shape()[0])
|
||||||
assert.Equal(t, int64(224), result.Shape()[2])
|
assert.Equal(t, int64(224), result.Shape()[2])
|
||||||
@@ -34,7 +40,7 @@ func TestImageFromBytes(t *testing.T) {
|
|||||||
t.Run("Document", func(t *testing.T) {
|
t.Run("Document", func(t *testing.T) {
|
||||||
imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx")
|
imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
result, err := ImageFromBytes(imageBuffer, 224)
|
result, err := ImageFromBytes(imageBuffer, defaultImageInput)
|
||||||
|
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.EqualError(t, err, "image: unknown format")
|
assert.EqualError(t, err, "image: unknown format")
|
||||||
|
|||||||
@@ -11,9 +11,29 @@ import (
|
|||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Interval of allowed values
|
||||||
|
type Interval struct {
|
||||||
|
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
|
||||||
|
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// The size of the interval
|
||||||
|
func (i Interval) Size() float32 {
|
||||||
|
return i.End - i.Start
|
||||||
|
}
|
||||||
|
|
||||||
|
// The standard interval returned by decodeImage is [0, 1]
|
||||||
|
func StandardInterval() *Interval {
|
||||||
|
return &Interval{
|
||||||
|
Start: 0.0,
|
||||||
|
End: 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Input description for a photo input for a model
|
// Input description for a photo input for a model
|
||||||
type PhotoInput struct {
|
type PhotoInput struct {
|
||||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||||
|
Interval *Interval `yaml:"Interval,omitempty" json:"interval,omitempty"`
|
||||||
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
||||||
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
|
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
|
||||||
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
|
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
|
||||||
@@ -37,12 +57,25 @@ func (p *PhotoInput) SetResolution(resolution int) {
|
|||||||
p.Width = int64(resolution)
|
p.Width = int64(resolution)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the interval or the default one
|
||||||
|
func (p PhotoInput) GetInterval() *Interval {
|
||||||
|
if p.Interval == nil {
|
||||||
|
return StandardInterval()
|
||||||
|
} else {
|
||||||
|
return p.Interval
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Merge other input with this.
|
// Merge other input with this.
|
||||||
func (p *PhotoInput) Merge(other *PhotoInput) {
|
func (p *PhotoInput) Merge(other *PhotoInput) {
|
||||||
if p.Name == "" {
|
if p.Name == "" {
|
||||||
p.Name = other.Name
|
p.Name = other.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.Interval == nil && other.Interval != nil {
|
||||||
|
p.Interval = other.Interval
|
||||||
|
}
|
||||||
|
|
||||||
if p.OutputIndex == 0 {
|
if p.OutputIndex == 0 {
|
||||||
p.OutputIndex = other.OutputIndex
|
p.OutputIndex = other.OutputIndex
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,6 +19,10 @@ var (
|
|||||||
Name: "input_1",
|
Name: "input_1",
|
||||||
Height: 224,
|
Height: 224,
|
||||||
Width: 224,
|
Width: 224,
|
||||||
|
Interval: &tensorflow.Interval{
|
||||||
|
Start: -1.0,
|
||||||
|
End: 1.0,
|
||||||
|
},
|
||||||
Channels: 3,
|
Channels: 3,
|
||||||
OutputIndex: 0,
|
OutputIndex: 0,
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user