Files
photoprism/internal/ai/tensorflow/info.go
2025-11-22 11:47:17 +01:00

488 lines
12 KiB
Go

package tensorflow
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
pb "github.com/wamuir/graft/tensorflow/core/protobuf/for_core_protos_go_proto"
"google.golang.org/protobuf/proto"
"github.com/photoprism/photoprism/pkg/clean"
)
// ExpectedChannels defines the expected number of channels.
// This is a fixed value because a standard seems to have been
// defined for input images as "what decodeImage returns".
const ExpectedChannels = 3
// Interval of allowed values.
type Interval struct {
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
Mean *float32 `yaml:"Mean,omitempty" json:"mean,omitempty"`
StdDev *float32 `yaml:"StdDev,omitempty" json:"stdDev,omitempty"`
}
// Size returns the size/mean of the interval.
func (i Interval) Size() float32 {
return i.End - i.Start
}
// Offset returns the offset of the interval.
func (i Interval) Offset() float32 {
if i.StdDev == nil {
return i.Start
} else {
return *i.StdDev
}
}
// StandardInterval returns the standard interval, i.e.
// the range of values returned by decodeImage in [0, 1].
func StandardInterval() *Interval {
return &Interval{
Start: 0.0,
End: 1.0,
}
}
// ResizeOperation represents resizing operations for images.
// JSON and YAML functions are provided to make configuration files user-friendly.
type ResizeOperation int
const (
// UndefinedResizeOperation indicates that no resize strategy was specified.
UndefinedResizeOperation ResizeOperation = iota
// ResizeBreakAspectRatio resizes without preserving aspect ratio.
ResizeBreakAspectRatio
// CenterCrop crops the center region after resizing to fill the target size.
CenterCrop
// Padding resizes while preserving aspect ratio and pads the rest.
Padding
)
func (o ResizeOperation) String() string {
switch o {
case UndefinedResizeOperation:
return "Undefined"
case ResizeBreakAspectRatio:
return "ResizeBreakAspectRatio"
case CenterCrop:
return "CenterCrop"
case Padding:
return "Padding"
default:
return "Unknown"
}
}
// NewResizeOperation parses a string into a ResizeOperation.
func NewResizeOperation(s string) (ResizeOperation, error) {
switch s {
case "Undefined":
return UndefinedResizeOperation, nil
case "ResizeBreakAspectRatio":
return ResizeBreakAspectRatio, nil
case "CenterCrop":
return CenterCrop, nil
case "Padding":
return Padding, nil
default:
return UndefinedResizeOperation, fmt.Errorf("invalid operation %s", s)
}
}
// MarshalJSON encodes the resize operation as its string name.
func (o ResizeOperation) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
// UnmarshalJSON decodes a resize operation from its string representation.
func (o *ResizeOperation) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
val, err := NewResizeOperation(s)
if err != nil {
return err
}
*o = val
return nil
}
// MarshalYAML encodes the resize operation for YAML output.
func (o ResizeOperation) MarshalYAML() (any, error) {
return o.String(), nil
}
// UnmarshalYAML decodes the resize operation from YAML input.
func (o *ResizeOperation) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
if err := unmarshal(&s); err != nil {
return err
}
val, err := NewResizeOperation(s)
if err != nil {
return err
}
*o = val
return nil
}
// ColorChannelOrder represents the order of the model's input vectors.
// JSON and YAML functions are provided to make the configuration files user-friendly.
type ColorChannelOrder int
const (
// UndefinedOrder leaves channel order unspecified, defaulting to RGB.
UndefinedOrder ColorChannelOrder = 0
// RGB represents Red-Green-Blue channel order.
RGB = 123
// RBG represents Red-Blue-Green channel order.
RBG = 132
// GRB represents Green-Red-Blue channel order.
GRB = 213
// GBR represents Green-Blue-Red channel order.
GBR = 231
// BRG represents Blue-Red-Green channel order.
BRG = 312
// BGR represents Blue-Green-Red channel order.
BGR = 321
)
// Indices returns the zero-based indices of the R, G, and B channels.
func (o ColorChannelOrder) Indices() (r, g, b int) {
i := int(o)
if i == 0 {
i = 123
}
for idx := 0; i > 0 && idx < 3; idx++ {
remainder := i % 10
i /= 10
switch remainder {
case 1:
r = 2 - idx
case 2:
g = 2 - idx
case 3:
b = 2 - idx
}
}
return
}
func (o ColorChannelOrder) String() string {
value := int(o)
if value == 0 {
value = 123
}
convert := func(remainder int) string {
switch remainder {
case 1:
return "R"
case 2:
return "G"
case 3:
return "B"
default:
return "?"
}
}
result := ""
for value > 0 {
remainder := value % 10
value /= 10
result = convert(remainder) + result
}
return result
}
// NewColorChannelOrder parses a string (e.g., "RGB") into a ColorChannelOrder.
func NewColorChannelOrder(val string) (ColorChannelOrder, error) {
if len(val) != 3 {
return UndefinedOrder, fmt.Errorf("invalid length, expected 3")
}
convert := func(c rune) int {
switch c {
case 'R':
return 1
case 'G':
return 2
case 'B':
return 3
default:
return 0
}
}
result := 0
for _, c := range val {
index := convert(c)
if index == 0 {
return UndefinedOrder, fmt.Errorf("invalid val %c", c)
}
result = result*10 + index
}
return ColorChannelOrder(result), nil
}
// MarshalJSON encodes the channel order as its string name.
func (o ColorChannelOrder) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
// UnmarshalJSON decodes a channel order from its string representation.
func (o *ColorChannelOrder) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
val, err := NewColorChannelOrder(s)
if err != nil {
return err
}
*o = val
return nil
}
// MarshalYAML encodes the channel order for YAML output.
func (o ColorChannelOrder) MarshalYAML() (any, error) {
return o.String(), nil
}
// UnmarshalYAML decodes the channel order from YAML input.
func (o *ColorChannelOrder) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
if err := unmarshal(&s); err != nil {
return err
}
val, err := NewColorChannelOrder(s)
if err != nil {
return err
}
*o = val
return nil
}
// ShapeComponent describes a single dimension of a model input shape.
// Usually this shape is (batch, resolution, resolution, channels) but sometimes it is not.
type ShapeComponent string
const (
// ShapeBatch represents the batch dimension.
ShapeBatch ShapeComponent = "Batch"
// ShapeWidth represents the width dimension.
ShapeWidth = "Width"
// ShapeHeight represents the height dimension.
ShapeHeight = "Height"
// ShapeColor represents the color/channel dimension.
ShapeColor = "Color"
)
// DefaultPhotoInputShape returns the standard BHWC input shape.
func DefaultPhotoInputShape() []ShapeComponent {
return []ShapeComponent{
ShapeBatch,
ShapeHeight,
ShapeWidth,
ShapeColor,
}
}
// PhotoInput represents an input description for a photo input for a model.
type PhotoInput struct {
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Intervals []Interval `yaml:"Intervals,omitempty" json:"intervals,omitempty"`
ResizeOperation ResizeOperation `yaml:"ResizeOperation,omitempty" json:"resizeOperation,omitempty"`
ColorChannelOrder ColorChannelOrder `yaml:"ColorChannelOrder,omitempty" json:"inputOrder,omitempty"`
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
Shape []ShapeComponent `yaml:"Shape,omitempty" json:"shape,omitempty"`
}
// IsDynamic checks if image dimensions are not defined, so the model accepts any size.
func (p PhotoInput) IsDynamic() bool {
return p.Height == -1 && p.Width == -1
}
// Resolution returns the input image resolution based on the image width or height if the width is undefined.
func (p PhotoInput) Resolution() int {
if p.Width > 0 {
return int(p.Width)
}
return int(p.Height)
}
// SetResolution sets the input image width and height based on the resolution in pixels (max width and height).
func (p *PhotoInput) SetResolution(resolution int) {
p.Height = int64(resolution)
p.Width = int64(resolution)
}
// GetInterval returns the interval or the default one.
// If just one interval has been fixed, then we assume
// it is the same for every channel. If no intervals
// have been defined, the default [0, 1] is returned
func (p PhotoInput) GetInterval(channel int) *Interval {
if len(p.Intervals) <= channel {
if len(p.Intervals) == 1 {
return &p.Intervals[0]
}
return StandardInterval()
} else {
return &p.Intervals[channel]
}
}
// Merge other input with this.
func (p *PhotoInput) Merge(other *PhotoInput) {
if p.Name == "" {
p.Name = other.Name
}
if p.Intervals == nil && other.Intervals != nil {
p.Intervals = other.Intervals
}
if p.OutputIndex == 0 {
p.OutputIndex = other.OutputIndex
}
if p.Height == 0 {
p.Height = other.Height
}
if p.Width == 0 {
p.Width = other.Width
}
if p.Shape == nil && other.Shape != nil {
p.Shape = other.Shape
}
if p.ResizeOperation == UndefinedResizeOperation {
p.ResizeOperation = other.ResizeOperation
}
if p.ColorChannelOrder == UndefinedOrder {
p.ColorChannelOrder = other.ColorChannelOrder
}
}
// ModelOutput represents the expected model output.
type ModelOutput struct {
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
NumOutputs int64 `yaml:"Outputs,omitempty" json:"outputs,omitempty"`
OutputsLogits bool `yaml:"Logits,omitempty" json:"logits,omitempty"`
}
// Merge merges other outputs with this output.
func (m *ModelOutput) Merge(other *ModelOutput) {
if m.Name == "" {
m.Name = other.Name
}
if m.OutputIndex == 0 {
m.OutputIndex = other.OutputIndex
}
if m.NumOutputs == 0 {
m.NumOutputs = other.NumOutputs
}
if !m.OutputsLogits {
m.OutputsLogits = other.OutputsLogits
}
}
// ModelInfo represents meta information for the model.
type ModelInfo struct {
TFVersion string `yaml:"-" json:"-"`
Tags []string `yaml:"Tags" json:"tags"`
Input *PhotoInput `yaml:"Input" json:"input"`
Output *ModelOutput `yaml:"Output" json:"output"`
}
// Merge other model info. In case of having information
// for a field, the current model will keep its current value
func (m *ModelInfo) Merge(other *ModelInfo) {
if m.TFVersion == "" {
m.TFVersion = other.TFVersion
}
if len(m.Tags) == 0 {
m.Tags = other.Tags
}
if m.Input == nil {
m.Input = other.Input
} else if other.Input != nil {
m.Input.Merge(other.Input)
}
if m.Output == nil {
m.Output = other.Output
} else if other.Output != nil {
m.Output.Merge(other.Output)
}
}
// IsComplete checks if the model input and output are defined.
func (m ModelInfo) IsComplete() bool {
return m.Input != nil && m.Output != nil && m.Input.Shape != nil
}
// GetModelTagsInfo reads a SavedModel and returns its available meta graph tags.
func GetModelTagsInfo(savedModelPath string) ([]ModelInfo, error) {
savedModel := filepath.Join(savedModelPath, "saved_model.pb")
data, err := os.ReadFile(savedModel) //nolint:gosec // savedModel path derived from trusted model directory
if err != nil {
return nil, fmt.Errorf("vision: failed to read %s (%s)", clean.Path(savedModel), clean.Error(err))
}
model := new(pb.SavedModel)
err = proto.Unmarshal(data, model)
if err != nil {
return nil, fmt.Errorf("vision: failed to unmarshal %s (%s)", clean.Path(savedModel), clean.Error(err))
}
models := make([]ModelInfo, 0)
metas := model.GetMetaGraphs()
for i := range metas {
def := metas[i].GetMetaInfoDef()
models = append(models, ModelInfo{
TFVersion: def.GetTensorflowVersion(),
Tags: def.GetTags(),
})
}
return models, nil
}