Added new parameters to model input.

New parameters have been added to define the input of the models:
* ResizeOperation: by default center-crop was being performed, now it is
  configurable.
* InputOrder: by default RGB was being used as the order for the array
  values of the input tensor, now it can be configured.
* InputInterval has been changed to InputIntervals (an slice). This
  means that every channel can have its own interval conversion.
* InputInterval can define now stddev and mean, because sometimes
  instead of adjusting the interval, the stddev and mean of the training
data should be use.
This commit is contained in:
raystlin
2025-07-15 13:31:31 +00:00
parent c682a94a07
commit adc4dc0f74
10 changed files with 800 additions and 40 deletions

View File

@@ -3,6 +3,7 @@ package classify
import (
"bytes"
"fmt"
"image/color"
"math"
"os"
"path"
@@ -53,12 +54,16 @@ func NewNasnet(assetsPath string, disabled bool) *Model {
TFVersion: "1.12.0",
Tags: []string{"photoprism"},
Input: &tensorflow.PhotoInput{
Name: "input_1",
Height: 224,
Width: 224,
Interval: &tensorflow.Interval{
Start: -1,
End: 1,
Name: "input_1",
Height: 224,
Width: 224,
ResizeOperation: tensorflow.CenterCrop,
InputOrder: tensorflow.RGB,
Intervals: []tensorflow.Interval{
{
Start: -1,
End: 1,
},
},
OutputIndex: 0,
},
@@ -290,7 +295,18 @@ func (m *Model) createTensor(image []byte) (*tf.Tensor, error) {
// Resize the image only if its resolution does not match the model.
if img.Bounds().Dx() != m.meta.Input.Resolution() || img.Bounds().Dy() != m.meta.Input.Resolution() {
img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos)
switch m.meta.Input.ResizeOperation {
case tensorflow.ResizeBreakAspectRatio:
imaging.Resize(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Lanczos)
case tensorflow.CenterCrop:
img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos)
case tensorflow.Padding:
resized := imaging.Fit(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Lanczos)
dst := imaging.New(m.meta.Input.Resolution(), m.meta.Input.Resolution(), color.NRGBA{0, 0, 0, 255})
img = imaging.PasteCenter(dst, resized)
default:
img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos)
}
}
return tensorflow.Image(img, m.meta.Input)

View File

@@ -68,9 +68,11 @@ var modelsInfo = map[string]*tensorflow.ModelInfo{
},
"vision-transformer-tensorflow2-vit-b16-classification-v1.tar.gz": {
Input: &tensorflow.PhotoInput{
Interval: &tensorflow.Interval{
Start: -1.0,
End: 1.0,
Intervals: []tensorflow.Interval{
{
Start: -1.0,
End: 1.0,
},
},
},
Output: &tensorflow.ModelOutput{

View File

@@ -140,9 +140,10 @@ func TestModel_Run(t *testing.T) {
assert.IsType(t, Labels{}, result)
assert.Equal(t, 1, len(result))
assert.Equal(t, "chameleon", result[0].Name)
assert.Equal(t, 100-93, result[0].Uncertainty)
if len(result) > 0 {
assert.Equal(t, "chameleon", result[0].Name)
assert.Equal(t, 100-93, result[0].Uncertainty)
}
}
})
t.Run("dog_orange.jpg", func(t *testing.T) {
@@ -164,9 +165,10 @@ func TestModel_Run(t *testing.T) {
assert.IsType(t, Labels{}, result)
assert.Equal(t, 1, len(result))
assert.Equal(t, "dog", result[0].Name)
assert.Equal(t, 34, result[0].Uncertainty)
if len(result) > 0 {
assert.Equal(t, "dog", result[0].Name)
assert.Equal(t, 34, result[0].Uncertainty)
}
}
})
t.Run("Random.docx", func(t *testing.T) {

View File

@@ -61,6 +61,7 @@ func Image(img image.Image, input *PhotoInput) (tfTensor *tf.Tensor, err error)
}
var tfImage [1][][][3]float32
rIndex, gIndex, bIndex := input.InputOrder.Indices()
for j := 0; j < input.Resolution(); j++ {
tfImage[0] = append(tfImage[0], make([][3]float32, input.Resolution()))
@@ -69,9 +70,11 @@ func Image(img image.Image, input *PhotoInput) (tfTensor *tf.Tensor, err error)
for i := 0; i < input.Resolution(); i++ {
for j := 0; j < input.Resolution(); j++ {
r, g, b, _ := img.At(i, j).RGBA()
tfImage[0][j][i][0] = convertValue(r, input.GetInterval())
tfImage[0][j][i][1] = convertValue(g, input.GetInterval())
tfImage[0][j][i][2] = convertValue(b, input.GetInterval())
//Although RGB can be disordered, we assume the input intervals are
//given in RGB order.
tfImage[0][j][i][rIndex] = convertValue(r, input.GetInterval(0))
tfImage[0][j][i][gIndex] = convertValue(g, input.GetInterval(1))
tfImage[0][j][i][bIndex] = convertValue(b, input.GetInterval(2))
}
}
@@ -137,8 +140,14 @@ func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph,
}
func convertValue(value uint32, interval *Interval) float32 {
scale := interval.Size() / 255.0
offset := interval.Start
var scale float32
if interval.Mean != nil {
scale = *interval.Mean
} else {
scale = interval.Size() / 255.0
}
offset := interval.Offset()
return (float32(value>>8))*scale + offset
}

View File

@@ -20,6 +20,14 @@ func TestConvertValue(t *testing.T) {
assert.Equal(t, float32(3024.8982), result)
}
func TestConvertStdMean(t *testing.T) {
mean := float32(1.0 / 127.5)
stdDev := float32(-1.0)
result := convertValue(uint32(98765432), &Interval{Mean: &mean, StdDev: &stdDev})
assert.Equal(t, float32(3024.8982), result)
}
func TestImageFromBytes(t *testing.T) {
var assetsPath = fs.Abs("../../../assets")
var examplesPath = assetsPath + "/examples"

View File

@@ -1,6 +1,7 @@
package tensorflow
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
@@ -18,15 +19,26 @@ const ExpectedChannels = 3
// Interval of allowed values
type Interval struct {
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
Mean *float32 `yaml:"Mean,omitempty" json:"mean,omitempty"`
StdDev *float32 `yaml:"StdDev,omitempty" json:"stdDev,omitempty"`
}
// The size of the interval
// The size/mean of the interval
func (i Interval) Size() float32 {
return i.End - i.Start
}
// The offset of the interval
func (i Interval) Offset() float32 {
if i.StdDev == nil {
return i.Start
} else {
return *i.StdDev
}
}
// The standard interval returned by decodeImage is [0, 1]
func StandardInterval() *Interval {
return &Interval{
@@ -35,13 +47,230 @@ func StandardInterval() *Interval {
}
}
// How should we resize the images
// JSON and YAML functions are given to make it
// user friendly from the configuration files
type ResizeOperation int
const (
UndefinedResizeOperation ResizeOperation = iota
ResizeBreakAspectRatio
CenterCrop
Padding
)
func (o ResizeOperation) String() string {
switch o {
case UndefinedResizeOperation:
return "Undefined"
case ResizeBreakAspectRatio:
return "ResizeBreakAspectRatio"
case CenterCrop:
return "CenterCrop"
case Padding:
return "Padding"
default:
return "Unknown"
}
}
func NewResizeOperation(s string) (ResizeOperation, error) {
switch s {
case "Undefined":
return UndefinedResizeOperation, nil
case "ResizeBreakAspectRatio":
return ResizeBreakAspectRatio, nil
case "CenterCrop":
return CenterCrop, nil
case "Padding":
return Padding, nil
default:
return UndefinedResizeOperation, fmt.Errorf("Invalid operation %s", s)
}
}
func (o ResizeOperation) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
func (o *ResizeOperation) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
val, err := NewResizeOperation(s)
if err != nil {
return err
}
*o = val
return nil
}
func (o ResizeOperation) MarshalYAML() (any, error) {
return o.String(), nil
}
func (o *ResizeOperation) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
if err := unmarshal(&s); err != nil {
return err
}
val, err := NewResizeOperation(s)
if err != nil {
return err
}
*o = val
return nil
}
// How should we order the input vectors
// JSON and YAML functions are given to make it
// user friendly from the configuration files
type InputOrder int
const (
UndefinedOrder InputOrder = 0
RGB = 123
RBG = 132
GRB = 213
GBR = 231
BRG = 312
BGR = 321
)
func (o InputOrder) Indices() (r, g, b int) {
i := int(o)
if i == 0 {
i = 123
}
for idx := 0; i > 0 && idx < 3; idx += 1 {
remainder := i % 10
i /= 10
switch remainder {
case 1:
r = 2 - idx
case 2:
g = 2 - idx
case 3:
b = 2 - idx
}
}
return
}
func (o InputOrder) String() string {
value := int(o)
if value == 0 {
value = 123
}
convert := func(remainder int) string {
switch remainder {
case 1:
return "R"
case 2:
return "G"
case 3:
return "B"
default:
return "?"
}
}
result := ""
for value > 0 {
remainder := value % 10
value /= 10
result = convert(remainder) + result
}
return result
}
func NewInputOrder(val string) (InputOrder, error) {
if len(val) != 3 {
return UndefinedOrder, fmt.Errorf("Invalid length, expected 3")
}
convert := func(c rune) int {
switch c {
case 'R':
return 1
case 'G':
return 2
case 'B':
return 3
default:
return 0
}
}
result := 0
for _, c := range val {
index := convert(c)
if index == 0 {
return UndefinedOrder, fmt.Errorf("Invalid val %c", c)
}
result = result*10 + index
}
return InputOrder(result), nil
}
func (o InputOrder) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
func (o *InputOrder) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
val, err := NewInputOrder(s)
if err != nil {
return err
}
*o = val
return nil
}
func (o InputOrder) MarshalYAML() (any, error) {
return o.String(), nil
}
func (o *InputOrder) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
if err := unmarshal(&s); err != nil {
return err
}
val, err := NewInputOrder(s)
if err != nil {
return err
}
*o = val
return nil
}
// Input description for a photo input for a model
type PhotoInput struct {
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Interval *Interval `yaml:"Interval,omitempty" json:"interval,omitempty"`
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Intervals []Interval `yaml:"Intervals,omitempty" json:"intervals,omitempty"`
ResizeOperation ResizeOperation `yaml:"ResizeOperation,omitempty" json:"resizeOperation,omitemty"`
InputOrder InputOrder `yaml:"InputOrder,omitempty" json:"inputOrder,omitempty"`
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
}
// When dimensions are not defined, it means the model accepts any size of
@@ -61,12 +290,18 @@ func (p *PhotoInput) SetResolution(resolution int) {
p.Width = int64(resolution)
}
// Get the interval or the default one
func (p PhotoInput) GetInterval() *Interval {
if p.Interval == nil {
// Get the interval or the default one.
// If just one interval has been fixed, then we assume
// it is the same for every channel. If no intervals
// have been defined, the default [0, 1] is returned
func (p PhotoInput) GetInterval(channel int) *Interval {
if len(p.Intervals) <= channel {
if len(p.Intervals) == 1 {
return &p.Intervals[0]
}
return StandardInterval()
} else {
return p.Interval
return &p.Intervals[channel]
}
}
@@ -76,8 +311,8 @@ func (p *PhotoInput) Merge(other *PhotoInput) {
p.Name = other.Name
}
if p.Interval == nil && other.Interval != nil {
p.Interval = other.Interval
if p.Intervals == nil && other.Intervals != nil {
p.Intervals = other.Intervals
}
if p.OutputIndex == 0 {
@@ -91,6 +326,14 @@ func (p *PhotoInput) Merge(other *PhotoInput) {
if p.Width == 0 {
p.Width = other.Width
}
if p.ResizeOperation == UndefinedResizeOperation {
p.ResizeOperation = other.ResizeOperation
}
if p.InputOrder == UndefinedOrder {
p.InputOrder = other.InputOrder
}
}
// The output expected for a model

View File

@@ -0,0 +1,195 @@
package tensorflow
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)
// Resize Operation Tests
var allOperations = []ResizeOperation{
UndefinedResizeOperation,
ResizeBreakAspectRatio,
CenterCrop,
Padding,
}
func TestResizeOperations(t *testing.T) {
for i := range allOperations {
text := allOperations[i].String()
op, err := NewResizeOperation(text)
if err != nil {
t.Fatalf("Invalid operation %s: %v", text, err)
}
assert.Equal(t, op, allOperations[i])
}
}
const exampleOperationJSON = `"CenterCrop"`
func TestResizeOperationJSON(t *testing.T) {
var op ResizeOperation
err := json.Unmarshal(
[]byte(exampleOperationJSON), &op)
if err != nil {
t.Fatal("Could not unmarshal the example operation")
}
for i := range allOperations {
serialized, err := json.Marshal(allOperations[i])
if err != nil {
t.Fatalf("Could not marshal %v: %v",
allOperations[i], err)
}
err = json.Unmarshal(serialized, &op)
if err != nil {
t.Fatalf("Could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, op, allOperations[i])
}
}
const exampleOperationYAML = "CenterCrop"
func TestResizeOperationYAML(t *testing.T) {
var op ResizeOperation
err := yaml.Unmarshal(
[]byte(exampleOperationYAML), &op)
if err != nil {
t.Fatal("Could not unmarshal the example operation")
}
for i := range allOperations {
serialized, err := yaml.Marshal(allOperations[i])
if err != nil {
t.Fatalf("Could not marshal %v: %v",
allOperations[i], err)
}
err = yaml.Unmarshal(serialized, &op)
if err != nil {
t.Fatalf("Could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, op, allOperations[i])
}
}
// Resize Operation Tests
var allInputOrders = []InputOrder{
RGB,
RBG,
GRB,
GBR,
BRG,
BGR,
}
func TestInputOrders(t *testing.T) {
for i := range allInputOrders {
text := allInputOrders[i].String()
order, err := NewInputOrder(text)
if err != nil {
t.Fatalf("Invalid order %s: %v", text, err)
}
assert.Equal(t, order, allInputOrders[i])
}
}
const exampleOrderJSON = `"RGB"`
func TestInputOrderJSON(t *testing.T) {
var order InputOrder
err := json.Unmarshal(
[]byte(exampleOrderJSON), &order)
if err != nil {
t.Fatal("Could not unmarshal the example operation")
}
for i := range allInputOrders {
serialized, err := json.Marshal(allInputOrders[i])
if err != nil {
t.Fatalf("Could not marshal %v: %v",
allInputOrders[i], err)
}
err = json.Unmarshal(serialized, &order)
if err != nil {
t.Fatalf("Could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, order, allInputOrders[i])
}
}
const exampleOrderYAML = "RGB"
func TestInputOrderYAML(t *testing.T) {
var order InputOrder
err := yaml.Unmarshal(
[]byte(exampleOrderYAML), &order)
if err != nil {
t.Fatal("Could not unmarshal the example operation")
}
for i := range allInputOrders {
serialized, err := yaml.Marshal(allInputOrders[i])
if err != nil {
t.Fatalf("Could not marshal %v: %v",
allInputOrders[i], err)
}
err = yaml.Unmarshal(serialized, &order)
if err != nil {
t.Fatalf("Could not unmarshal %s: %v",
string(serialized), err)
}
assert.Equal(t, order, allInputOrders[i])
}
}
func TestOrderIndices(t *testing.T) {
r, g, b := UndefinedOrder.Indices()
assert.Equal(t, r, 0)
assert.Equal(t, g, 1)
assert.Equal(t, b, 2)
powerFx := func(i int) int {
switch i {
case 0:
return 100
case 1:
return 10
case 2:
return 1
default:
return -1
}
}
for i := range allInputOrders {
r, g, b = allInputOrders[i].Indices()
assert.Equal(t, powerFx(r)+2*powerFx(g)+3*powerFx(b), int(allInputOrders[i]))
}
}

View File

@@ -0,0 +1,283 @@
package vision
import (
"fmt"
"path/filepath"
"sync"
"github.com/photoprism/photoprism/internal/ai/classify"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/internal/ai/tensorflow"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
var modelMutex = sync.Mutex{}
// Model represents a computer vision model configuration.
type Model struct {
<<<<<<< HEAD
Type ModelType `yaml:"Type,omitempty" json:"type,omitempty"`
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
Meta *tensorflow.ModelInfo `yaml:"Meta,omitempty" json:"meta,omitempty"`
Service Service `yaml:"Service,omitempty" json:"Service,omitempty"`
Path string `yaml:"Path,omitempty" json:"-"`
Disabled bool `yaml:"Disabled,omitempty" json:"disabled,omitempty"`
=======
Type ModelType `yaml:"Type,omitempty" json:"type,omitempty"`
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
Prompt string `yaml:"Prompt,omitempty" json:"prompt,omitempty"`
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
Service Service `yaml:"Service,omitempty" json:"Service,omitempty"`
Path string `yaml:"Path,omitempty" json:"-"`
Tags []string `yaml:"Tags,omitempty" json:"-"`
Disabled bool `yaml:"Disabled,omitempty" json:"disabled,omitempty"`
>>>>>>> upstream/develop
classifyModel *classify.Model
faceModel *face.Model
nsfwModel *nsfw.Model
}
// Models represents a set of computer vision models.
type Models []*Model
// Endpoint returns the remote service request method and endpoint URL, if any.
func (m *Model) Endpoint() (uri, method string) {
if uri, method = m.Service.Endpoint(); uri != "" && method != "" {
return uri, method
} else if ServiceUri == "" {
return "", ""
} else if serviceType := clean.TypeLowerUnderscore(m.Type); serviceType == "" {
return "", ""
} else {
return fmt.Sprintf("%s/%s", ServiceUri, serviceType), ServiceMethod
}
}
// EndpointKey returns the access token belonging to the remote service endpoint, if any.
func (m *Model) EndpointKey() (key string) {
if key = m.Service.EndpointKey(); key != "" {
return key
} else {
return ServiceKey
}
}
// EndpointFileScheme returns the endpoint API request file scheme type.
func (m *Model) EndpointFileScheme() (fileScheme scheme.Type) {
if fileScheme = m.Service.EndpointFileScheme(); fileScheme != "" {
return fileScheme
}
return ServiceFileScheme
}
// EndpointRequestFormat returns the endpoint API request format.
func (m *Model) EndpointRequestFormat() (format ApiFormat) {
if format = m.Service.EndpointRequestFormat(); format != "" {
return format
}
return ServiceRequestFormat
}
// EndpointResponseFormat returns the endpoint API response format.
func (m *Model) EndpointResponseFormat() (format ApiFormat) {
if format = m.Service.EndpointResponseFormat(); format != "" {
return format
}
return ServiceResponseFormat
}
// ClassifyModel returns the matching classify model instance, if any.
func (m *Model) ClassifyModel() *classify.Model {
// Use mutex to prevent models from being loaded and
// initialized twice by different indexing workers.
modelMutex.Lock()
defer modelMutex.Unlock()
// Return the existing model instance if it has already been created.
if m.classifyModel != nil {
return m.classifyModel
}
switch m.Name {
case "":
log.Warnf("vision: missing name, model instance cannot be created")
return nil
case NasnetModel.Name, "nasnet":
// Load and initialize the Nasnet image classification model.
if model := classify.NewNasnet(AssetsPath, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init nasnet model)", err)
return nil
} else {
m.classifyModel = model
}
default:
// Set model path from model name if no path is configured.
if m.Path == "" {
m.Path = clean.TypeLowerUnderscore(m.Name)
}
if m.Meta == nil {
m.Meta = &tensorflow.ModelInfo{}
}
// Set default thumbnail resolution if no tags are configured.
if m.Resolution <= 0 {
m.Resolution = DefaultResolution
}
if m.Meta.Input == nil {
m.Meta.Input = new(tensorflow.PhotoInput)
}
m.Meta.Input.SetResolution(m.Resolution)
// Try to load custom model based on the configuration values.
defaultPath := filepath.Join(AssetsPath, "nasnet")
if model := classify.NewModel(AssetsPath, m.Path, defaultPath, m.Meta, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path)
return nil
} else {
m.classifyModel = model
}
}
return m.classifyModel
}
// FaceModel returns the matching face model instance, if any.
func (m *Model) FaceModel() *face.Model {
// Use mutex to prevent models from being loaded and
// initialized twice by different indexing workers.
modelMutex.Lock()
defer modelMutex.Unlock()
// Return the existing model instance if it has already been created.
if m.faceModel != nil {
return m.faceModel
}
switch m.Name {
case "":
log.Warnf("vision: missing name, model instance cannot be created")
return nil
case FacenetModel.Name, "facenet":
// Load and initialize the Nasnet image classification model.
if model := face.NewModel(FaceNetModelPath, CachePath, m.Resolution, m.Meta, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path)
return nil
} else {
m.faceModel = model
}
default:
// Set model path from model name if no path is configured.
if m.Path == "" {
m.Path = clean.TypeLowerUnderscore(m.Name)
}
// Set default thumbnail resolution if no tags are configured.
if m.Resolution <= 0 {
m.Resolution = DefaultResolution
}
if m.Meta == nil {
m.Meta = &tensorflow.ModelInfo{}
}
// Set default tag if no tags are configured.
if len(m.Meta.Tags) == 0 {
m.Meta.Tags = []string{"serve"}
}
// Try to load custom model based on the configuration values.
if model := face.NewModel(filepath.Join(AssetsPath, m.Path), CachePath, m.Resolution, m.Meta, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path)
return nil
} else {
m.faceModel = model
}
}
return m.faceModel
}
// NsfwModel returns the matching nsfw model instance, if any.
func (m *Model) NsfwModel() *nsfw.Model {
// Use mutex to prevent models from being loaded and
// initialized twice by different indexing workers.
modelMutex.Lock()
defer modelMutex.Unlock()
// Return the existing model instance if it has already been created.
if m.nsfwModel != nil {
return m.nsfwModel
}
switch m.Name {
case "":
log.Warnf("vision: missing name, model instance cannot be created")
return nil
case NsfwModel.Name, "nsfw":
// Load and initialize the Nasnet image classification model.
if model := nsfw.NewModel(NsfwModelPath, NsfwModel.Meta, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path)
return nil
} else {
m.nsfwModel = model
}
default:
// Set model path from model name if no path is configured.
if m.Path == "" {
m.Path = clean.TypeLowerUnderscore(m.Name)
}
// Set default thumbnail resolution if no tags are configured.
if m.Resolution <= 0 {
m.Resolution = DefaultResolution
}
if m.Meta.Input == nil {
m.Meta.Input = new(tensorflow.PhotoInput)
}
m.Meta.Input.SetResolution(m.Resolution)
if m.Meta == nil {
m.Meta = &tensorflow.ModelInfo{}
}
// Set default tag if no tags are configured.
if len(m.Meta.Tags) == 0 {
m.Meta.Tags = []string{"serve"}
}
// Try to load custom model based on the configuration values.
if model := nsfw.NewModel(filepath.Join(AssetsPath, m.Path), m.Meta, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path)
return nil
} else {
m.nsfwModel = model
}
}
return m.nsfwModel
}

View File

@@ -19,9 +19,11 @@ var (
Name: "input_1",
Height: 224,
Width: 224,
Interval: &tensorflow.Interval{
Start: -1.0,
End: 1.0,
Intervals: []tensorflow.Interval{
{
Start: -1.0,
End: 1.0,
},
},
OutputIndex: 0,
},

View File

@@ -8,8 +8,8 @@ Models:
- photoprism
Input:
Name: input_1
Interval:
Start: -1
Intervals:
- Start: -1
End: 1
Height: 224
Width: 224