mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
Added new parameters to model input.
New parameters have been added to define the input of the models: * ResizeOperation: by default center-crop was being performed, now it is configurable. * InputOrder: by default RGB was being used as the order for the array values of the input tensor, now it can be configured. * InputInterval has been changed to InputIntervals (an slice). This means that every channel can have its own interval conversion. * InputInterval can define now stddev and mean, because sometimes instead of adjusting the interval, the stddev and mean of the training data should be use.
This commit is contained in:
@@ -3,6 +3,7 @@ package classify
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"image/color"
|
||||
"math"
|
||||
"os"
|
||||
"path"
|
||||
@@ -53,12 +54,16 @@ func NewNasnet(assetsPath string, disabled bool) *Model {
|
||||
TFVersion: "1.12.0",
|
||||
Tags: []string{"photoprism"},
|
||||
Input: &tensorflow.PhotoInput{
|
||||
Name: "input_1",
|
||||
Height: 224,
|
||||
Width: 224,
|
||||
Interval: &tensorflow.Interval{
|
||||
Start: -1,
|
||||
End: 1,
|
||||
Name: "input_1",
|
||||
Height: 224,
|
||||
Width: 224,
|
||||
ResizeOperation: tensorflow.CenterCrop,
|
||||
InputOrder: tensorflow.RGB,
|
||||
Intervals: []tensorflow.Interval{
|
||||
{
|
||||
Start: -1,
|
||||
End: 1,
|
||||
},
|
||||
},
|
||||
OutputIndex: 0,
|
||||
},
|
||||
@@ -290,7 +295,18 @@ func (m *Model) createTensor(image []byte) (*tf.Tensor, error) {
|
||||
|
||||
// Resize the image only if its resolution does not match the model.
|
||||
if img.Bounds().Dx() != m.meta.Input.Resolution() || img.Bounds().Dy() != m.meta.Input.Resolution() {
|
||||
img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos)
|
||||
switch m.meta.Input.ResizeOperation {
|
||||
case tensorflow.ResizeBreakAspectRatio:
|
||||
imaging.Resize(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Lanczos)
|
||||
case tensorflow.CenterCrop:
|
||||
img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos)
|
||||
case tensorflow.Padding:
|
||||
resized := imaging.Fit(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Lanczos)
|
||||
dst := imaging.New(m.meta.Input.Resolution(), m.meta.Input.Resolution(), color.NRGBA{0, 0, 0, 255})
|
||||
img = imaging.PasteCenter(dst, resized)
|
||||
default:
|
||||
img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos)
|
||||
}
|
||||
}
|
||||
|
||||
return tensorflow.Image(img, m.meta.Input)
|
||||
|
||||
@@ -68,9 +68,11 @@ var modelsInfo = map[string]*tensorflow.ModelInfo{
|
||||
},
|
||||
"vision-transformer-tensorflow2-vit-b16-classification-v1.tar.gz": {
|
||||
Input: &tensorflow.PhotoInput{
|
||||
Interval: &tensorflow.Interval{
|
||||
Start: -1.0,
|
||||
End: 1.0,
|
||||
Intervals: []tensorflow.Interval{
|
||||
{
|
||||
Start: -1.0,
|
||||
End: 1.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
Output: &tensorflow.ModelOutput{
|
||||
|
||||
@@ -140,9 +140,10 @@ func TestModel_Run(t *testing.T) {
|
||||
assert.IsType(t, Labels{}, result)
|
||||
assert.Equal(t, 1, len(result))
|
||||
|
||||
assert.Equal(t, "chameleon", result[0].Name)
|
||||
|
||||
assert.Equal(t, 100-93, result[0].Uncertainty)
|
||||
if len(result) > 0 {
|
||||
assert.Equal(t, "chameleon", result[0].Name)
|
||||
assert.Equal(t, 100-93, result[0].Uncertainty)
|
||||
}
|
||||
}
|
||||
})
|
||||
t.Run("dog_orange.jpg", func(t *testing.T) {
|
||||
@@ -164,9 +165,10 @@ func TestModel_Run(t *testing.T) {
|
||||
assert.IsType(t, Labels{}, result)
|
||||
assert.Equal(t, 1, len(result))
|
||||
|
||||
assert.Equal(t, "dog", result[0].Name)
|
||||
|
||||
assert.Equal(t, 34, result[0].Uncertainty)
|
||||
if len(result) > 0 {
|
||||
assert.Equal(t, "dog", result[0].Name)
|
||||
assert.Equal(t, 34, result[0].Uncertainty)
|
||||
}
|
||||
}
|
||||
})
|
||||
t.Run("Random.docx", func(t *testing.T) {
|
||||
|
||||
@@ -61,6 +61,7 @@ func Image(img image.Image, input *PhotoInput) (tfTensor *tf.Tensor, err error)
|
||||
}
|
||||
|
||||
var tfImage [1][][][3]float32
|
||||
rIndex, gIndex, bIndex := input.InputOrder.Indices()
|
||||
|
||||
for j := 0; j < input.Resolution(); j++ {
|
||||
tfImage[0] = append(tfImage[0], make([][3]float32, input.Resolution()))
|
||||
@@ -69,9 +70,11 @@ func Image(img image.Image, input *PhotoInput) (tfTensor *tf.Tensor, err error)
|
||||
for i := 0; i < input.Resolution(); i++ {
|
||||
for j := 0; j < input.Resolution(); j++ {
|
||||
r, g, b, _ := img.At(i, j).RGBA()
|
||||
tfImage[0][j][i][0] = convertValue(r, input.GetInterval())
|
||||
tfImage[0][j][i][1] = convertValue(g, input.GetInterval())
|
||||
tfImage[0][j][i][2] = convertValue(b, input.GetInterval())
|
||||
//Although RGB can be disordered, we assume the input intervals are
|
||||
//given in RGB order.
|
||||
tfImage[0][j][i][rIndex] = convertValue(r, input.GetInterval(0))
|
||||
tfImage[0][j][i][gIndex] = convertValue(g, input.GetInterval(1))
|
||||
tfImage[0][j][i][bIndex] = convertValue(b, input.GetInterval(2))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,8 +140,14 @@ func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph,
|
||||
}
|
||||
|
||||
func convertValue(value uint32, interval *Interval) float32 {
|
||||
scale := interval.Size() / 255.0
|
||||
offset := interval.Start
|
||||
var scale float32
|
||||
|
||||
if interval.Mean != nil {
|
||||
scale = *interval.Mean
|
||||
} else {
|
||||
scale = interval.Size() / 255.0
|
||||
}
|
||||
offset := interval.Offset()
|
||||
|
||||
return (float32(value>>8))*scale + offset
|
||||
}
|
||||
|
||||
@@ -20,6 +20,14 @@ func TestConvertValue(t *testing.T) {
|
||||
assert.Equal(t, float32(3024.8982), result)
|
||||
}
|
||||
|
||||
func TestConvertStdMean(t *testing.T) {
|
||||
mean := float32(1.0 / 127.5)
|
||||
stdDev := float32(-1.0)
|
||||
|
||||
result := convertValue(uint32(98765432), &Interval{Mean: &mean, StdDev: &stdDev})
|
||||
assert.Equal(t, float32(3024.8982), result)
|
||||
}
|
||||
|
||||
func TestImageFromBytes(t *testing.T) {
|
||||
var assetsPath = fs.Abs("../../../assets")
|
||||
var examplesPath = assetsPath + "/examples"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -18,15 +19,26 @@ const ExpectedChannels = 3
|
||||
|
||||
// Interval of allowed values
|
||||
type Interval struct {
|
||||
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
|
||||
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
|
||||
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
|
||||
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
|
||||
Mean *float32 `yaml:"Mean,omitempty" json:"mean,omitempty"`
|
||||
StdDev *float32 `yaml:"StdDev,omitempty" json:"stdDev,omitempty"`
|
||||
}
|
||||
|
||||
// The size of the interval
|
||||
// The size/mean of the interval
|
||||
func (i Interval) Size() float32 {
|
||||
return i.End - i.Start
|
||||
}
|
||||
|
||||
// The offset of the interval
|
||||
func (i Interval) Offset() float32 {
|
||||
if i.StdDev == nil {
|
||||
return i.Start
|
||||
} else {
|
||||
return *i.StdDev
|
||||
}
|
||||
}
|
||||
|
||||
// The standard interval returned by decodeImage is [0, 1]
|
||||
func StandardInterval() *Interval {
|
||||
return &Interval{
|
||||
@@ -35,13 +47,230 @@ func StandardInterval() *Interval {
|
||||
}
|
||||
}
|
||||
|
||||
// How should we resize the images
|
||||
// JSON and YAML functions are given to make it
|
||||
// user friendly from the configuration files
|
||||
type ResizeOperation int
|
||||
|
||||
const (
|
||||
UndefinedResizeOperation ResizeOperation = iota
|
||||
ResizeBreakAspectRatio
|
||||
CenterCrop
|
||||
Padding
|
||||
)
|
||||
|
||||
func (o ResizeOperation) String() string {
|
||||
switch o {
|
||||
case UndefinedResizeOperation:
|
||||
return "Undefined"
|
||||
case ResizeBreakAspectRatio:
|
||||
return "ResizeBreakAspectRatio"
|
||||
case CenterCrop:
|
||||
return "CenterCrop"
|
||||
case Padding:
|
||||
return "Padding"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func NewResizeOperation(s string) (ResizeOperation, error) {
|
||||
switch s {
|
||||
case "Undefined":
|
||||
return UndefinedResizeOperation, nil
|
||||
case "ResizeBreakAspectRatio":
|
||||
return ResizeBreakAspectRatio, nil
|
||||
case "CenterCrop":
|
||||
return CenterCrop, nil
|
||||
case "Padding":
|
||||
return Padding, nil
|
||||
default:
|
||||
return UndefinedResizeOperation, fmt.Errorf("Invalid operation %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func (o ResizeOperation) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(o.String())
|
||||
}
|
||||
|
||||
func (o *ResizeOperation) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err := NewResizeOperation(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*o = val
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o ResizeOperation) MarshalYAML() (any, error) {
|
||||
return o.String(), nil
|
||||
}
|
||||
|
||||
func (o *ResizeOperation) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var s string
|
||||
if err := unmarshal(&s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err := NewResizeOperation(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*o = val
|
||||
return nil
|
||||
}
|
||||
|
||||
// How should we order the input vectors
|
||||
// JSON and YAML functions are given to make it
|
||||
// user friendly from the configuration files
|
||||
type InputOrder int
|
||||
|
||||
const (
|
||||
UndefinedOrder InputOrder = 0
|
||||
RGB = 123
|
||||
RBG = 132
|
||||
GRB = 213
|
||||
GBR = 231
|
||||
BRG = 312
|
||||
BGR = 321
|
||||
)
|
||||
|
||||
func (o InputOrder) Indices() (r, g, b int) {
|
||||
i := int(o)
|
||||
|
||||
if i == 0 {
|
||||
i = 123
|
||||
}
|
||||
|
||||
for idx := 0; i > 0 && idx < 3; idx += 1 {
|
||||
remainder := i % 10
|
||||
i /= 10
|
||||
|
||||
switch remainder {
|
||||
case 1:
|
||||
r = 2 - idx
|
||||
case 2:
|
||||
g = 2 - idx
|
||||
case 3:
|
||||
b = 2 - idx
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (o InputOrder) String() string {
|
||||
value := int(o)
|
||||
|
||||
if value == 0 {
|
||||
value = 123
|
||||
}
|
||||
|
||||
convert := func(remainder int) string {
|
||||
switch remainder {
|
||||
case 1:
|
||||
return "R"
|
||||
case 2:
|
||||
return "G"
|
||||
case 3:
|
||||
return "B"
|
||||
default:
|
||||
return "?"
|
||||
}
|
||||
}
|
||||
|
||||
result := ""
|
||||
for value > 0 {
|
||||
remainder := value % 10
|
||||
value /= 10
|
||||
|
||||
result = convert(remainder) + result
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func NewInputOrder(val string) (InputOrder, error) {
|
||||
if len(val) != 3 {
|
||||
return UndefinedOrder, fmt.Errorf("Invalid length, expected 3")
|
||||
}
|
||||
|
||||
convert := func(c rune) int {
|
||||
switch c {
|
||||
case 'R':
|
||||
return 1
|
||||
case 'G':
|
||||
return 2
|
||||
case 'B':
|
||||
return 3
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
result := 0
|
||||
for _, c := range val {
|
||||
index := convert(c)
|
||||
if index == 0 {
|
||||
return UndefinedOrder, fmt.Errorf("Invalid val %c", c)
|
||||
}
|
||||
result = result*10 + index
|
||||
}
|
||||
return InputOrder(result), nil
|
||||
}
|
||||
|
||||
func (o InputOrder) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(o.String())
|
||||
}
|
||||
|
||||
func (o *InputOrder) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err := NewInputOrder(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*o = val
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o InputOrder) MarshalYAML() (any, error) {
|
||||
return o.String(), nil
|
||||
}
|
||||
|
||||
func (o *InputOrder) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var s string
|
||||
if err := unmarshal(&s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err := NewInputOrder(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*o = val
|
||||
return nil
|
||||
}
|
||||
|
||||
// Input description for a photo input for a model
|
||||
type PhotoInput struct {
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
Interval *Interval `yaml:"Interval,omitempty" json:"interval,omitempty"`
|
||||
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
||||
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
|
||||
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
Intervals []Interval `yaml:"Intervals,omitempty" json:"intervals,omitempty"`
|
||||
ResizeOperation ResizeOperation `yaml:"ResizeOperation,omitempty" json:"resizeOperation,omitemty"`
|
||||
InputOrder InputOrder `yaml:"InputOrder,omitempty" json:"inputOrder,omitempty"`
|
||||
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
||||
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
|
||||
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
|
||||
}
|
||||
|
||||
// When dimensions are not defined, it means the model accepts any size of
|
||||
@@ -61,12 +290,18 @@ func (p *PhotoInput) SetResolution(resolution int) {
|
||||
p.Width = int64(resolution)
|
||||
}
|
||||
|
||||
// Get the interval or the default one
|
||||
func (p PhotoInput) GetInterval() *Interval {
|
||||
if p.Interval == nil {
|
||||
// Get the interval or the default one.
|
||||
// If just one interval has been fixed, then we assume
|
||||
// it is the same for every channel. If no intervals
|
||||
// have been defined, the default [0, 1] is returned
|
||||
func (p PhotoInput) GetInterval(channel int) *Interval {
|
||||
if len(p.Intervals) <= channel {
|
||||
if len(p.Intervals) == 1 {
|
||||
return &p.Intervals[0]
|
||||
}
|
||||
return StandardInterval()
|
||||
} else {
|
||||
return p.Interval
|
||||
return &p.Intervals[channel]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,8 +311,8 @@ func (p *PhotoInput) Merge(other *PhotoInput) {
|
||||
p.Name = other.Name
|
||||
}
|
||||
|
||||
if p.Interval == nil && other.Interval != nil {
|
||||
p.Interval = other.Interval
|
||||
if p.Intervals == nil && other.Intervals != nil {
|
||||
p.Intervals = other.Intervals
|
||||
}
|
||||
|
||||
if p.OutputIndex == 0 {
|
||||
@@ -91,6 +326,14 @@ func (p *PhotoInput) Merge(other *PhotoInput) {
|
||||
if p.Width == 0 {
|
||||
p.Width = other.Width
|
||||
}
|
||||
|
||||
if p.ResizeOperation == UndefinedResizeOperation {
|
||||
p.ResizeOperation = other.ResizeOperation
|
||||
}
|
||||
|
||||
if p.InputOrder == UndefinedOrder {
|
||||
p.InputOrder = other.InputOrder
|
||||
}
|
||||
}
|
||||
|
||||
// The output expected for a model
|
||||
|
||||
195
internal/ai/tensorflow/info_test.go
Normal file
195
internal/ai/tensorflow/info_test.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// Resize Operation Tests
|
||||
var allOperations = []ResizeOperation{
|
||||
UndefinedResizeOperation,
|
||||
ResizeBreakAspectRatio,
|
||||
CenterCrop,
|
||||
Padding,
|
||||
}
|
||||
|
||||
func TestResizeOperations(t *testing.T) {
|
||||
for i := range allOperations {
|
||||
text := allOperations[i].String()
|
||||
|
||||
op, err := NewResizeOperation(text)
|
||||
if err != nil {
|
||||
t.Fatalf("Invalid operation %s: %v", text, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, op, allOperations[i])
|
||||
}
|
||||
}
|
||||
|
||||
const exampleOperationJSON = `"CenterCrop"`
|
||||
|
||||
func TestResizeOperationJSON(t *testing.T) {
|
||||
var op ResizeOperation
|
||||
|
||||
err := json.Unmarshal(
|
||||
[]byte(exampleOperationJSON), &op)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("Could not unmarshal the example operation")
|
||||
}
|
||||
|
||||
for i := range allOperations {
|
||||
serialized, err := json.Marshal(allOperations[i])
|
||||
if err != nil {
|
||||
t.Fatalf("Could not marshal %v: %v",
|
||||
allOperations[i], err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(serialized, &op)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not unmarshal %s: %v",
|
||||
string(serialized), err)
|
||||
}
|
||||
|
||||
assert.Equal(t, op, allOperations[i])
|
||||
}
|
||||
}
|
||||
|
||||
const exampleOperationYAML = "CenterCrop"
|
||||
|
||||
func TestResizeOperationYAML(t *testing.T) {
|
||||
var op ResizeOperation
|
||||
|
||||
err := yaml.Unmarshal(
|
||||
[]byte(exampleOperationYAML), &op)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("Could not unmarshal the example operation")
|
||||
}
|
||||
|
||||
for i := range allOperations {
|
||||
serialized, err := yaml.Marshal(allOperations[i])
|
||||
if err != nil {
|
||||
t.Fatalf("Could not marshal %v: %v",
|
||||
allOperations[i], err)
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(serialized, &op)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not unmarshal %s: %v",
|
||||
string(serialized), err)
|
||||
}
|
||||
|
||||
assert.Equal(t, op, allOperations[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Resize Operation Tests
|
||||
var allInputOrders = []InputOrder{
|
||||
RGB,
|
||||
RBG,
|
||||
GRB,
|
||||
GBR,
|
||||
BRG,
|
||||
BGR,
|
||||
}
|
||||
|
||||
func TestInputOrders(t *testing.T) {
|
||||
for i := range allInputOrders {
|
||||
text := allInputOrders[i].String()
|
||||
|
||||
order, err := NewInputOrder(text)
|
||||
if err != nil {
|
||||
t.Fatalf("Invalid order %s: %v", text, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, order, allInputOrders[i])
|
||||
}
|
||||
}
|
||||
|
||||
const exampleOrderJSON = `"RGB"`
|
||||
|
||||
func TestInputOrderJSON(t *testing.T) {
|
||||
var order InputOrder
|
||||
|
||||
err := json.Unmarshal(
|
||||
[]byte(exampleOrderJSON), &order)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("Could not unmarshal the example operation")
|
||||
}
|
||||
|
||||
for i := range allInputOrders {
|
||||
serialized, err := json.Marshal(allInputOrders[i])
|
||||
if err != nil {
|
||||
t.Fatalf("Could not marshal %v: %v",
|
||||
allInputOrders[i], err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(serialized, &order)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not unmarshal %s: %v",
|
||||
string(serialized), err)
|
||||
}
|
||||
|
||||
assert.Equal(t, order, allInputOrders[i])
|
||||
}
|
||||
}
|
||||
|
||||
const exampleOrderYAML = "RGB"
|
||||
|
||||
func TestInputOrderYAML(t *testing.T) {
|
||||
var order InputOrder
|
||||
|
||||
err := yaml.Unmarshal(
|
||||
[]byte(exampleOrderYAML), &order)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("Could not unmarshal the example operation")
|
||||
}
|
||||
|
||||
for i := range allInputOrders {
|
||||
serialized, err := yaml.Marshal(allInputOrders[i])
|
||||
if err != nil {
|
||||
t.Fatalf("Could not marshal %v: %v",
|
||||
allInputOrders[i], err)
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(serialized, &order)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not unmarshal %s: %v",
|
||||
string(serialized), err)
|
||||
}
|
||||
|
||||
assert.Equal(t, order, allInputOrders[i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrderIndices(t *testing.T) {
|
||||
r, g, b := UndefinedOrder.Indices()
|
||||
|
||||
assert.Equal(t, r, 0)
|
||||
assert.Equal(t, g, 1)
|
||||
assert.Equal(t, b, 2)
|
||||
|
||||
powerFx := func(i int) int {
|
||||
switch i {
|
||||
case 0:
|
||||
return 100
|
||||
case 1:
|
||||
return 10
|
||||
case 2:
|
||||
return 1
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
for i := range allInputOrders {
|
||||
r, g, b = allInputOrders[i].Indices()
|
||||
assert.Equal(t, powerFx(r)+2*powerFx(g)+3*powerFx(b), int(allInputOrders[i]))
|
||||
}
|
||||
}
|
||||
283
internal/ai/vision/model.go.orig
Normal file
283
internal/ai/vision/model.go.orig
Normal file
@@ -0,0 +1,283 @@
|
||||
package vision
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/classify"
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/ai/tensorflow"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
var modelMutex = sync.Mutex{}
|
||||
|
||||
// Model represents a computer vision model configuration.
|
||||
type Model struct {
|
||||
<<<<<<< HEAD
|
||||
Type ModelType `yaml:"Type,omitempty" json:"type,omitempty"`
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
|
||||
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
|
||||
Meta *tensorflow.ModelInfo `yaml:"Meta,omitempty" json:"meta,omitempty"`
|
||||
Service Service `yaml:"Service,omitempty" json:"Service,omitempty"`
|
||||
Path string `yaml:"Path,omitempty" json:"-"`
|
||||
Disabled bool `yaml:"Disabled,omitempty" json:"disabled,omitempty"`
|
||||
=======
|
||||
Type ModelType `yaml:"Type,omitempty" json:"type,omitempty"`
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
|
||||
Prompt string `yaml:"Prompt,omitempty" json:"prompt,omitempty"`
|
||||
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
|
||||
Service Service `yaml:"Service,omitempty" json:"Service,omitempty"`
|
||||
Path string `yaml:"Path,omitempty" json:"-"`
|
||||
Tags []string `yaml:"Tags,omitempty" json:"-"`
|
||||
Disabled bool `yaml:"Disabled,omitempty" json:"disabled,omitempty"`
|
||||
>>>>>>> upstream/develop
|
||||
classifyModel *classify.Model
|
||||
faceModel *face.Model
|
||||
nsfwModel *nsfw.Model
|
||||
}
|
||||
|
||||
// Models represents a set of computer vision models.
|
||||
type Models []*Model
|
||||
|
||||
// Endpoint returns the remote service request method and endpoint URL, if any.
|
||||
func (m *Model) Endpoint() (uri, method string) {
|
||||
if uri, method = m.Service.Endpoint(); uri != "" && method != "" {
|
||||
return uri, method
|
||||
} else if ServiceUri == "" {
|
||||
return "", ""
|
||||
} else if serviceType := clean.TypeLowerUnderscore(m.Type); serviceType == "" {
|
||||
return "", ""
|
||||
} else {
|
||||
return fmt.Sprintf("%s/%s", ServiceUri, serviceType), ServiceMethod
|
||||
}
|
||||
}
|
||||
|
||||
// EndpointKey returns the access token belonging to the remote service endpoint, if any.
|
||||
func (m *Model) EndpointKey() (key string) {
|
||||
if key = m.Service.EndpointKey(); key != "" {
|
||||
return key
|
||||
} else {
|
||||
return ServiceKey
|
||||
}
|
||||
}
|
||||
|
||||
// EndpointFileScheme returns the endpoint API request file scheme type.
|
||||
func (m *Model) EndpointFileScheme() (fileScheme scheme.Type) {
|
||||
if fileScheme = m.Service.EndpointFileScheme(); fileScheme != "" {
|
||||
return fileScheme
|
||||
}
|
||||
|
||||
return ServiceFileScheme
|
||||
}
|
||||
|
||||
// EndpointRequestFormat returns the endpoint API request format.
|
||||
func (m *Model) EndpointRequestFormat() (format ApiFormat) {
|
||||
if format = m.Service.EndpointRequestFormat(); format != "" {
|
||||
return format
|
||||
}
|
||||
|
||||
return ServiceRequestFormat
|
||||
}
|
||||
|
||||
// EndpointResponseFormat returns the endpoint API response format.
|
||||
func (m *Model) EndpointResponseFormat() (format ApiFormat) {
|
||||
if format = m.Service.EndpointResponseFormat(); format != "" {
|
||||
return format
|
||||
}
|
||||
|
||||
return ServiceResponseFormat
|
||||
}
|
||||
|
||||
// ClassifyModel returns the matching classify model instance, if any.
|
||||
func (m *Model) ClassifyModel() *classify.Model {
|
||||
// Use mutex to prevent models from being loaded and
|
||||
// initialized twice by different indexing workers.
|
||||
modelMutex.Lock()
|
||||
defer modelMutex.Unlock()
|
||||
|
||||
// Return the existing model instance if it has already been created.
|
||||
if m.classifyModel != nil {
|
||||
return m.classifyModel
|
||||
}
|
||||
|
||||
switch m.Name {
|
||||
case "":
|
||||
log.Warnf("vision: missing name, model instance cannot be created")
|
||||
return nil
|
||||
case NasnetModel.Name, "nasnet":
|
||||
// Load and initialize the Nasnet image classification model.
|
||||
if model := classify.NewNasnet(AssetsPath, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init nasnet model)", err)
|
||||
return nil
|
||||
} else {
|
||||
m.classifyModel = model
|
||||
}
|
||||
default:
|
||||
// Set model path from model name if no path is configured.
|
||||
if m.Path == "" {
|
||||
m.Path = clean.TypeLowerUnderscore(m.Name)
|
||||
}
|
||||
|
||||
if m.Meta == nil {
|
||||
m.Meta = &tensorflow.ModelInfo{}
|
||||
}
|
||||
|
||||
// Set default thumbnail resolution if no tags are configured.
|
||||
if m.Resolution <= 0 {
|
||||
m.Resolution = DefaultResolution
|
||||
}
|
||||
|
||||
if m.Meta.Input == nil {
|
||||
m.Meta.Input = new(tensorflow.PhotoInput)
|
||||
}
|
||||
|
||||
m.Meta.Input.SetResolution(m.Resolution)
|
||||
|
||||
// Try to load custom model based on the configuration values.
|
||||
defaultPath := filepath.Join(AssetsPath, "nasnet")
|
||||
if model := classify.NewModel(AssetsPath, m.Path, defaultPath, m.Meta, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
return nil
|
||||
} else {
|
||||
m.classifyModel = model
|
||||
}
|
||||
}
|
||||
|
||||
return m.classifyModel
|
||||
}
|
||||
|
||||
// FaceModel returns the matching face model instance, if any.
|
||||
func (m *Model) FaceModel() *face.Model {
|
||||
// Use mutex to prevent models from being loaded and
|
||||
// initialized twice by different indexing workers.
|
||||
modelMutex.Lock()
|
||||
defer modelMutex.Unlock()
|
||||
|
||||
// Return the existing model instance if it has already been created.
|
||||
if m.faceModel != nil {
|
||||
return m.faceModel
|
||||
}
|
||||
|
||||
switch m.Name {
|
||||
case "":
|
||||
log.Warnf("vision: missing name, model instance cannot be created")
|
||||
return nil
|
||||
case FacenetModel.Name, "facenet":
|
||||
// Load and initialize the Nasnet image classification model.
|
||||
if model := face.NewModel(FaceNetModelPath, CachePath, m.Resolution, m.Meta, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
return nil
|
||||
} else {
|
||||
m.faceModel = model
|
||||
}
|
||||
default:
|
||||
// Set model path from model name if no path is configured.
|
||||
if m.Path == "" {
|
||||
m.Path = clean.TypeLowerUnderscore(m.Name)
|
||||
}
|
||||
|
||||
// Set default thumbnail resolution if no tags are configured.
|
||||
if m.Resolution <= 0 {
|
||||
m.Resolution = DefaultResolution
|
||||
}
|
||||
|
||||
if m.Meta == nil {
|
||||
m.Meta = &tensorflow.ModelInfo{}
|
||||
}
|
||||
|
||||
// Set default tag if no tags are configured.
|
||||
if len(m.Meta.Tags) == 0 {
|
||||
m.Meta.Tags = []string{"serve"}
|
||||
}
|
||||
|
||||
// Try to load custom model based on the configuration values.
|
||||
if model := face.NewModel(filepath.Join(AssetsPath, m.Path), CachePath, m.Resolution, m.Meta, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
return nil
|
||||
} else {
|
||||
m.faceModel = model
|
||||
}
|
||||
}
|
||||
|
||||
return m.faceModel
|
||||
}
|
||||
|
||||
// NsfwModel returns the matching nsfw model instance, if any.
|
||||
func (m *Model) NsfwModel() *nsfw.Model {
|
||||
// Use mutex to prevent models from being loaded and
|
||||
// initialized twice by different indexing workers.
|
||||
modelMutex.Lock()
|
||||
defer modelMutex.Unlock()
|
||||
|
||||
// Return the existing model instance if it has already been created.
|
||||
if m.nsfwModel != nil {
|
||||
return m.nsfwModel
|
||||
}
|
||||
|
||||
switch m.Name {
|
||||
case "":
|
||||
log.Warnf("vision: missing name, model instance cannot be created")
|
||||
return nil
|
||||
case NsfwModel.Name, "nsfw":
|
||||
// Load and initialize the Nasnet image classification model.
|
||||
if model := nsfw.NewModel(NsfwModelPath, NsfwModel.Meta, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
return nil
|
||||
} else {
|
||||
m.nsfwModel = model
|
||||
}
|
||||
default:
|
||||
// Set model path from model name if no path is configured.
|
||||
if m.Path == "" {
|
||||
m.Path = clean.TypeLowerUnderscore(m.Name)
|
||||
}
|
||||
|
||||
// Set default thumbnail resolution if no tags are configured.
|
||||
if m.Resolution <= 0 {
|
||||
m.Resolution = DefaultResolution
|
||||
}
|
||||
|
||||
if m.Meta.Input == nil {
|
||||
m.Meta.Input = new(tensorflow.PhotoInput)
|
||||
}
|
||||
|
||||
m.Meta.Input.SetResolution(m.Resolution)
|
||||
|
||||
if m.Meta == nil {
|
||||
m.Meta = &tensorflow.ModelInfo{}
|
||||
}
|
||||
|
||||
// Set default tag if no tags are configured.
|
||||
if len(m.Meta.Tags) == 0 {
|
||||
m.Meta.Tags = []string{"serve"}
|
||||
}
|
||||
|
||||
// Try to load custom model based on the configuration values.
|
||||
if model := nsfw.NewModel(filepath.Join(AssetsPath, m.Path), m.Meta, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
return nil
|
||||
} else {
|
||||
m.nsfwModel = model
|
||||
}
|
||||
}
|
||||
|
||||
return m.nsfwModel
|
||||
}
|
||||
@@ -19,9 +19,11 @@ var (
|
||||
Name: "input_1",
|
||||
Height: 224,
|
||||
Width: 224,
|
||||
Interval: &tensorflow.Interval{
|
||||
Start: -1.0,
|
||||
End: 1.0,
|
||||
Intervals: []tensorflow.Interval{
|
||||
{
|
||||
Start: -1.0,
|
||||
End: 1.0,
|
||||
},
|
||||
},
|
||||
OutputIndex: 0,
|
||||
},
|
||||
|
||||
4
internal/ai/vision/testdata/vision.yml
vendored
4
internal/ai/vision/testdata/vision.yml
vendored
@@ -8,8 +8,8 @@ Models:
|
||||
- photoprism
|
||||
Input:
|
||||
Name: input_1
|
||||
Interval:
|
||||
Start: -1
|
||||
Intervals:
|
||||
- Start: -1
|
||||
End: 1
|
||||
Height: 224
|
||||
Width: 224
|
||||
|
||||
Reference in New Issue
Block a user