Files
photoprism/internal/ai/tensorflow/info.go
raystlin 519a6ab34a AI: Add TensorFlow model shape detection #127 #5164
* AI: Added support for non BHWC models

Tensorflow models use BHWC by default, however, if we are using
converted models, we can find that the expected input is BCHW. Now the
input is configurable (although the restriction of being dimesion 4 is
still there) via Shape parameter on the input definition. Also, the
model instrospection will try to deduce the input shape from the model
signature.

* AI: Added more tests for enum parsing

ShapeComponent was missing from the tests

* AI: Modified external tests to the new url

The path has been moved from tensorflow/vision to tensorflow/models

* AI: Moved the builder to the model to reuse it

It should reduce the amount of allocations done

* AI: fixed errors after merge

Mainly incorrect paths and duplicated variables
2025-08-16 15:55:59 +02:00

460 lines
10 KiB
Go

package tensorflow
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
pb "github.com/wamuir/graft/tensorflow/core/protobuf/for_core_protos_go_proto"
"google.golang.org/protobuf/proto"
"github.com/photoprism/photoprism/pkg/clean"
)
// ExpectedChannels defines the expected number of channels.
// This is a fixed value because a standard seems to have been
// defined for input images as "what decodeImage returns".
const ExpectedChannels = 3
// Interval of allowed values
type Interval struct {
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
Mean *float32 `yaml:"Mean,omitempty" json:"mean,omitempty"`
StdDev *float32 `yaml:"StdDev,omitempty" json:"stdDev,omitempty"`
}
// Size returns the size/mean of the interval.
func (i Interval) Size() float32 {
return i.End - i.Start
}
// Offset returns the offset of the interval.
func (i Interval) Offset() float32 {
if i.StdDev == nil {
return i.Start
} else {
return *i.StdDev
}
}
// StandardInterval returns the standard interval, i.e.
// the range of values returned by decodeImage in [0, 1].
func StandardInterval() *Interval {
return &Interval{
Start: 0.0,
End: 1.0,
}
}
// ResizeOperation represents resizing operations for images.
// JSON and YAML functions are provided to make configuration files user-friendly.
type ResizeOperation int
const (
UndefinedResizeOperation ResizeOperation = iota
ResizeBreakAspectRatio
CenterCrop
Padding
)
func (o ResizeOperation) String() string {
switch o {
case UndefinedResizeOperation:
return "Undefined"
case ResizeBreakAspectRatio:
return "ResizeBreakAspectRatio"
case CenterCrop:
return "CenterCrop"
case Padding:
return "Padding"
default:
return "Unknown"
}
}
func NewResizeOperation(s string) (ResizeOperation, error) {
switch s {
case "Undefined":
return UndefinedResizeOperation, nil
case "ResizeBreakAspectRatio":
return ResizeBreakAspectRatio, nil
case "CenterCrop":
return CenterCrop, nil
case "Padding":
return Padding, nil
default:
return UndefinedResizeOperation, fmt.Errorf("Invalid operation %s", s)
}
}
func (o ResizeOperation) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
func (o *ResizeOperation) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
val, err := NewResizeOperation(s)
if err != nil {
return err
}
*o = val
return nil
}
func (o ResizeOperation) MarshalYAML() (any, error) {
return o.String(), nil
}
func (o *ResizeOperation) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
if err := unmarshal(&s); err != nil {
return err
}
val, err := NewResizeOperation(s)
if err != nil {
return err
}
*o = val
return nil
}
// ColorChannelOrder represents the order of the model's input vectors.
// JSON and YAML functions are provided to make the configuration files user-friendly.
type ColorChannelOrder int
const (
UndefinedOrder ColorChannelOrder = 0
RGB = 123
RBG = 132
GRB = 213
GBR = 231
BRG = 312
BGR = 321
)
func (o ColorChannelOrder) Indices() (r, g, b int) {
i := int(o)
if i == 0 {
i = 123
}
for idx := 0; i > 0 && idx < 3; idx += 1 {
remainder := i % 10
i /= 10
switch remainder {
case 1:
r = 2 - idx
case 2:
g = 2 - idx
case 3:
b = 2 - idx
}
}
return
}
func (o ColorChannelOrder) String() string {
value := int(o)
if value == 0 {
value = 123
}
convert := func(remainder int) string {
switch remainder {
case 1:
return "R"
case 2:
return "G"
case 3:
return "B"
default:
return "?"
}
}
result := ""
for value > 0 {
remainder := value % 10
value /= 10
result = convert(remainder) + result
}
return result
}
func NewColorChannelOrder(val string) (ColorChannelOrder, error) {
if len(val) != 3 {
return UndefinedOrder, fmt.Errorf("Invalid length, expected 3")
}
convert := func(c rune) int {
switch c {
case 'R':
return 1
case 'G':
return 2
case 'B':
return 3
default:
return 0
}
}
result := 0
for _, c := range val {
index := convert(c)
if index == 0 {
return UndefinedOrder, fmt.Errorf("Invalid val %c", c)
}
result = result*10 + index
}
return ColorChannelOrder(result), nil
}
func (o ColorChannelOrder) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
func (o *ColorChannelOrder) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
val, err := NewColorChannelOrder(s)
if err != nil {
return err
}
*o = val
return nil
}
func (o ColorChannelOrder) MarshalYAML() (any, error) {
return o.String(), nil
}
func (o *ColorChannelOrder) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
if err := unmarshal(&s); err != nil {
return err
}
val, err := NewColorChannelOrder(s)
if err != nil {
return err
}
*o = val
return nil
}
// The expected shape for the input layer of a mode. Usually this shape is
// (batch, resolution, resolution, channels) but sometimes it is not.
type ShapeComponent string
const (
ShapeBatch ShapeComponent = "Batch"
ShapeWidth = "Width"
ShapeHeight = "Height"
ShapeColor = "Color"
)
func DefaultPhotoInputShape() []ShapeComponent {
return []ShapeComponent{
ShapeBatch,
ShapeHeight,
ShapeWidth,
ShapeColor,
}
}
// PhotoInput represents an input description for a photo input for a model.
type PhotoInput struct {
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Intervals []Interval `yaml:"Intervals,omitempty" json:"intervals,omitempty"`
ResizeOperation ResizeOperation `yaml:"ResizeOperation,omitempty" json:"resizeOperation,omitemty"`
ColorChannelOrder ColorChannelOrder `yaml:"ColorChannelOrder,omitempty" json:"inputOrder,omitempty"`
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
Shape []ShapeComponent `yaml:"Shape,omitempty" json:"shape,omitempty"`
}
// IsDynamic checks if image dimensions are not defined, so the model accepts any size.
func (p PhotoInput) IsDynamic() bool {
return p.Height == -1 && p.Width == -1
}
// Resolution returns the input image resolution based on the image width or height if the width is undefined.
func (p PhotoInput) Resolution() int {
if p.Width > 0 {
return int(p.Width)
}
return int(p.Height)
}
// SetResolution sets the input image width and height based on the resolution in pixels (max width and height).
func (p *PhotoInput) SetResolution(resolution int) {
p.Height = int64(resolution)
p.Width = int64(resolution)
}
// GetInterval returns the interval or the default one.
// If just one interval has been fixed, then we assume
// it is the same for every channel. If no intervals
// have been defined, the default [0, 1] is returned
func (p PhotoInput) GetInterval(channel int) *Interval {
if len(p.Intervals) <= channel {
if len(p.Intervals) == 1 {
return &p.Intervals[0]
}
return StandardInterval()
} else {
return &p.Intervals[channel]
}
}
// Merge other input with this.
func (p *PhotoInput) Merge(other *PhotoInput) {
if p.Name == "" {
p.Name = other.Name
}
if p.Intervals == nil && other.Intervals != nil {
p.Intervals = other.Intervals
}
if p.OutputIndex == 0 {
p.OutputIndex = other.OutputIndex
}
if p.Height == 0 {
p.Height = other.Height
}
if p.Width == 0 {
p.Width = other.Width
}
if p.Shape == nil && other.Shape != nil {
p.Shape = other.Shape
}
if p.ResizeOperation == UndefinedResizeOperation {
p.ResizeOperation = other.ResizeOperation
}
if p.ColorChannelOrder == UndefinedOrder {
p.ColorChannelOrder = other.ColorChannelOrder
}
}
// ModelOutput represents the expected model output.
type ModelOutput struct {
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
NumOutputs int64 `yaml:"Outputs,omitempty" json:"outputs,omitempty"`
OutputsLogits bool `yaml:"Logits,omitempty" json:"logits,omitempty"`
}
// Merge merges other outputs with this output.
func (m *ModelOutput) Merge(other *ModelOutput) {
if m.Name == "" {
m.Name = other.Name
}
if m.OutputIndex == 0 {
m.OutputIndex = other.OutputIndex
}
if m.NumOutputs == 0 {
m.NumOutputs = other.NumOutputs
}
if !m.OutputsLogits {
m.OutputsLogits = other.OutputsLogits
}
}
// ModelInfo represents meta information for the model.
type ModelInfo struct {
TFVersion string `yaml:"-" json:"-"`
Tags []string `yaml:"Tags" json:"tags"`
Input *PhotoInput `yaml:"Input" json:"input"`
Output *ModelOutput `yaml:"Output" json:"output"`
}
// Merge other model info. In case of having information
// for a field, the current model will keep its current value
func (m *ModelInfo) Merge(other *ModelInfo) {
if m.TFVersion == "" {
m.TFVersion = other.TFVersion
}
if len(m.Tags) == 0 {
m.Tags = other.Tags
}
if m.Input == nil {
m.Input = other.Input
} else if other.Input != nil {
m.Input.Merge(other.Input)
}
if m.Output == nil {
m.Output = other.Output
} else if other.Output != nil {
m.Output.Merge(other.Output)
}
}
// IsComplete checks if the model input and output are defined.
func (m ModelInfo) IsComplete() bool {
return m.Input != nil && m.Output != nil && m.Input.Shape != nil
}
func GetModelTagsInfo(savedModelPath string) ([]ModelInfo, error) {
savedModel := filepath.Join(savedModelPath, "saved_model.pb")
data, err := os.ReadFile(savedModel)
if err != nil {
return nil, fmt.Errorf("vision: failed to read %s (%s)", clean.Path(savedModel), clean.Error(err))
}
model := new(pb.SavedModel)
err = proto.Unmarshal(data, model)
if err != nil {
return nil, fmt.Errorf("vision: failed to unmarshal %s (%s)", clean.Path(savedModel), clean.Error(err))
}
models := make([]ModelInfo, 0)
metas := model.GetMetaGraphs()
for i := range metas {
def := metas[i].GetMetaInfoDef()
models = append(models, ModelInfo{
TFVersion: def.GetTensorflowVersion(),
Tags: def.GetTags(),
})
}
return models, nil
}