mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-11 16:24:11 +01:00
488 lines
12 KiB
Go
488 lines
12 KiB
Go
package tensorflow
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
pb "github.com/wamuir/graft/tensorflow/core/protobuf/for_core_protos_go_proto"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/photoprism/photoprism/pkg/clean"
|
|
)
|
|
|
|
// ExpectedChannels defines the expected number of channels.
|
|
// This is a fixed value because a standard seems to have been
|
|
// defined for input images as "what decodeImage returns".
|
|
const ExpectedChannels = 3
|
|
|
|
// Interval of allowed values.
|
|
type Interval struct {
|
|
Start float32 `yaml:"Start,omitempty" json:"start,omitempty"`
|
|
End float32 `yaml:"End,omitempty" json:"end,omitempty"`
|
|
Mean *float32 `yaml:"Mean,omitempty" json:"mean,omitempty"`
|
|
StdDev *float32 `yaml:"StdDev,omitempty" json:"stdDev,omitempty"`
|
|
}
|
|
|
|
// Size returns the size/mean of the interval.
|
|
func (i Interval) Size() float32 {
|
|
return i.End - i.Start
|
|
}
|
|
|
|
// Offset returns the offset of the interval.
|
|
func (i Interval) Offset() float32 {
|
|
if i.StdDev == nil {
|
|
return i.Start
|
|
} else {
|
|
return *i.StdDev
|
|
}
|
|
}
|
|
|
|
// StandardInterval returns the standard interval, i.e.
|
|
// the range of values returned by decodeImage in [0, 1].
|
|
func StandardInterval() *Interval {
|
|
return &Interval{
|
|
Start: 0.0,
|
|
End: 1.0,
|
|
}
|
|
}
|
|
|
|
// ResizeOperation represents resizing operations for images.
|
|
// JSON and YAML functions are provided to make configuration files user-friendly.
|
|
type ResizeOperation int
|
|
|
|
const (
|
|
// UndefinedResizeOperation indicates that no resize strategy was specified.
|
|
UndefinedResizeOperation ResizeOperation = iota
|
|
// ResizeBreakAspectRatio resizes without preserving aspect ratio.
|
|
ResizeBreakAspectRatio
|
|
// CenterCrop crops the center region after resizing to fill the target size.
|
|
CenterCrop
|
|
// Padding resizes while preserving aspect ratio and pads the rest.
|
|
Padding
|
|
)
|
|
|
|
func (o ResizeOperation) String() string {
|
|
switch o {
|
|
case UndefinedResizeOperation:
|
|
return "Undefined"
|
|
case ResizeBreakAspectRatio:
|
|
return "ResizeBreakAspectRatio"
|
|
case CenterCrop:
|
|
return "CenterCrop"
|
|
case Padding:
|
|
return "Padding"
|
|
default:
|
|
return "Unknown"
|
|
}
|
|
}
|
|
|
|
// NewResizeOperation parses a string into a ResizeOperation.
|
|
func NewResizeOperation(s string) (ResizeOperation, error) {
|
|
switch s {
|
|
case "Undefined":
|
|
return UndefinedResizeOperation, nil
|
|
case "ResizeBreakAspectRatio":
|
|
return ResizeBreakAspectRatio, nil
|
|
case "CenterCrop":
|
|
return CenterCrop, nil
|
|
case "Padding":
|
|
return Padding, nil
|
|
default:
|
|
return UndefinedResizeOperation, fmt.Errorf("invalid operation %s", s)
|
|
}
|
|
}
|
|
|
|
// MarshalJSON encodes the resize operation as its string name.
|
|
func (o ResizeOperation) MarshalJSON() ([]byte, error) {
|
|
return json.Marshal(o.String())
|
|
}
|
|
|
|
// UnmarshalJSON decodes a resize operation from its string representation.
|
|
func (o *ResizeOperation) UnmarshalJSON(data []byte) error {
|
|
var s string
|
|
if err := json.Unmarshal(data, &s); err != nil {
|
|
return err
|
|
}
|
|
|
|
val, err := NewResizeOperation(s)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*o = val
|
|
|
|
return nil
|
|
}
|
|
|
|
// MarshalYAML encodes the resize operation for YAML output.
|
|
func (o ResizeOperation) MarshalYAML() (any, error) {
|
|
return o.String(), nil
|
|
}
|
|
|
|
// UnmarshalYAML decodes the resize operation from YAML input.
|
|
func (o *ResizeOperation) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|
var s string
|
|
if err := unmarshal(&s); err != nil {
|
|
return err
|
|
}
|
|
|
|
val, err := NewResizeOperation(s)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*o = val
|
|
return nil
|
|
}
|
|
|
|
// ColorChannelOrder represents the order of the model's input vectors.
|
|
// JSON and YAML functions are provided to make the configuration files user-friendly.
|
|
type ColorChannelOrder int
|
|
|
|
const (
|
|
// UndefinedOrder leaves channel order unspecified, defaulting to RGB.
|
|
UndefinedOrder ColorChannelOrder = 0
|
|
// RGB represents Red-Green-Blue channel order.
|
|
RGB = 123
|
|
// RBG represents Red-Blue-Green channel order.
|
|
RBG = 132
|
|
// GRB represents Green-Red-Blue channel order.
|
|
GRB = 213
|
|
// GBR represents Green-Blue-Red channel order.
|
|
GBR = 231
|
|
// BRG represents Blue-Red-Green channel order.
|
|
BRG = 312
|
|
// BGR represents Blue-Green-Red channel order.
|
|
BGR = 321
|
|
)
|
|
|
|
// Indices returns the zero-based indices of the R, G, and B channels.
|
|
func (o ColorChannelOrder) Indices() (r, g, b int) {
|
|
i := int(o)
|
|
|
|
if i == 0 {
|
|
i = 123
|
|
}
|
|
|
|
for idx := 0; i > 0 && idx < 3; idx++ {
|
|
remainder := i % 10
|
|
i /= 10
|
|
|
|
switch remainder {
|
|
case 1:
|
|
r = 2 - idx
|
|
case 2:
|
|
g = 2 - idx
|
|
case 3:
|
|
b = 2 - idx
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (o ColorChannelOrder) String() string {
|
|
value := int(o)
|
|
|
|
if value == 0 {
|
|
value = 123
|
|
}
|
|
|
|
convert := func(remainder int) string {
|
|
switch remainder {
|
|
case 1:
|
|
return "R"
|
|
case 2:
|
|
return "G"
|
|
case 3:
|
|
return "B"
|
|
default:
|
|
return "?"
|
|
}
|
|
}
|
|
|
|
result := ""
|
|
for value > 0 {
|
|
remainder := value % 10
|
|
value /= 10
|
|
|
|
result = convert(remainder) + result
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// NewColorChannelOrder parses a string (e.g., "RGB") into a ColorChannelOrder.
|
|
func NewColorChannelOrder(val string) (ColorChannelOrder, error) {
|
|
if len(val) != 3 {
|
|
return UndefinedOrder, fmt.Errorf("invalid length, expected 3")
|
|
}
|
|
|
|
convert := func(c rune) int {
|
|
switch c {
|
|
case 'R':
|
|
return 1
|
|
case 'G':
|
|
return 2
|
|
case 'B':
|
|
return 3
|
|
default:
|
|
return 0
|
|
}
|
|
}
|
|
|
|
result := 0
|
|
for _, c := range val {
|
|
index := convert(c)
|
|
if index == 0 {
|
|
return UndefinedOrder, fmt.Errorf("invalid val %c", c)
|
|
}
|
|
result = result*10 + index
|
|
}
|
|
return ColorChannelOrder(result), nil
|
|
}
|
|
|
|
// MarshalJSON encodes the channel order as its string name.
|
|
func (o ColorChannelOrder) MarshalJSON() ([]byte, error) {
|
|
return json.Marshal(o.String())
|
|
}
|
|
|
|
// UnmarshalJSON decodes a channel order from its string representation.
|
|
func (o *ColorChannelOrder) UnmarshalJSON(data []byte) error {
|
|
var s string
|
|
if err := json.Unmarshal(data, &s); err != nil {
|
|
return err
|
|
}
|
|
|
|
val, err := NewColorChannelOrder(s)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*o = val
|
|
|
|
return nil
|
|
}
|
|
|
|
// MarshalYAML encodes the channel order for YAML output.
|
|
func (o ColorChannelOrder) MarshalYAML() (any, error) {
|
|
return o.String(), nil
|
|
}
|
|
|
|
// UnmarshalYAML decodes the channel order from YAML input.
|
|
func (o *ColorChannelOrder) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|
var s string
|
|
if err := unmarshal(&s); err != nil {
|
|
return err
|
|
}
|
|
|
|
val, err := NewColorChannelOrder(s)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*o = val
|
|
return nil
|
|
}
|
|
|
|
// ShapeComponent describes a single dimension of a model input shape.
|
|
// Usually this shape is (batch, resolution, resolution, channels) but sometimes it is not.
|
|
type ShapeComponent string
|
|
|
|
const (
|
|
// ShapeBatch represents the batch dimension.
|
|
ShapeBatch ShapeComponent = "Batch"
|
|
// ShapeWidth represents the width dimension.
|
|
ShapeWidth = "Width"
|
|
// ShapeHeight represents the height dimension.
|
|
ShapeHeight = "Height"
|
|
// ShapeColor represents the color/channel dimension.
|
|
ShapeColor = "Color"
|
|
)
|
|
|
|
// DefaultPhotoInputShape returns the standard BHWC input shape.
|
|
func DefaultPhotoInputShape() []ShapeComponent {
|
|
return []ShapeComponent{
|
|
ShapeBatch,
|
|
ShapeHeight,
|
|
ShapeWidth,
|
|
ShapeColor,
|
|
}
|
|
}
|
|
|
|
// PhotoInput represents an input description for a photo input for a model.
|
|
type PhotoInput struct {
|
|
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
|
Intervals []Interval `yaml:"Intervals,omitempty" json:"intervals,omitempty"`
|
|
ResizeOperation ResizeOperation `yaml:"ResizeOperation,omitempty" json:"resizeOperation,omitempty"`
|
|
ColorChannelOrder ColorChannelOrder `yaml:"ColorChannelOrder,omitempty" json:"inputOrder,omitempty"`
|
|
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
|
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
|
|
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
|
|
Shape []ShapeComponent `yaml:"Shape,omitempty" json:"shape,omitempty"`
|
|
}
|
|
|
|
// IsDynamic checks if image dimensions are not defined, so the model accepts any size.
|
|
func (p PhotoInput) IsDynamic() bool {
|
|
return p.Height == -1 && p.Width == -1
|
|
}
|
|
|
|
// Resolution returns the input image resolution based on the image width or height if the width is undefined.
|
|
func (p PhotoInput) Resolution() int {
|
|
if p.Width > 0 {
|
|
return int(p.Width)
|
|
}
|
|
|
|
return int(p.Height)
|
|
}
|
|
|
|
// SetResolution sets the input image width and height based on the resolution in pixels (max width and height).
|
|
func (p *PhotoInput) SetResolution(resolution int) {
|
|
p.Height = int64(resolution)
|
|
p.Width = int64(resolution)
|
|
}
|
|
|
|
// GetInterval returns the interval or the default one.
|
|
// If just one interval has been fixed, then we assume
|
|
// it is the same for every channel. If no intervals
|
|
// have been defined, the default [0, 1] is returned
|
|
func (p PhotoInput) GetInterval(channel int) *Interval {
|
|
if len(p.Intervals) <= channel {
|
|
if len(p.Intervals) == 1 {
|
|
return &p.Intervals[0]
|
|
}
|
|
return StandardInterval()
|
|
} else {
|
|
return &p.Intervals[channel]
|
|
}
|
|
}
|
|
|
|
// Merge other input with this.
|
|
func (p *PhotoInput) Merge(other *PhotoInput) {
|
|
if p.Name == "" {
|
|
p.Name = other.Name
|
|
}
|
|
|
|
if p.Intervals == nil && other.Intervals != nil {
|
|
p.Intervals = other.Intervals
|
|
}
|
|
|
|
if p.OutputIndex == 0 {
|
|
p.OutputIndex = other.OutputIndex
|
|
}
|
|
|
|
if p.Height == 0 {
|
|
p.Height = other.Height
|
|
}
|
|
|
|
if p.Width == 0 {
|
|
p.Width = other.Width
|
|
}
|
|
|
|
if p.Shape == nil && other.Shape != nil {
|
|
p.Shape = other.Shape
|
|
}
|
|
|
|
if p.ResizeOperation == UndefinedResizeOperation {
|
|
p.ResizeOperation = other.ResizeOperation
|
|
}
|
|
|
|
if p.ColorChannelOrder == UndefinedOrder {
|
|
p.ColorChannelOrder = other.ColorChannelOrder
|
|
}
|
|
}
|
|
|
|
// ModelOutput represents the expected model output.
|
|
type ModelOutput struct {
|
|
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
|
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
|
NumOutputs int64 `yaml:"Outputs,omitempty" json:"outputs,omitempty"`
|
|
OutputsLogits bool `yaml:"Logits,omitempty" json:"logits,omitempty"`
|
|
}
|
|
|
|
// Merge merges other outputs with this output.
|
|
func (m *ModelOutput) Merge(other *ModelOutput) {
|
|
if m.Name == "" {
|
|
m.Name = other.Name
|
|
}
|
|
|
|
if m.OutputIndex == 0 {
|
|
m.OutputIndex = other.OutputIndex
|
|
}
|
|
|
|
if m.NumOutputs == 0 {
|
|
m.NumOutputs = other.NumOutputs
|
|
}
|
|
|
|
if !m.OutputsLogits {
|
|
m.OutputsLogits = other.OutputsLogits
|
|
}
|
|
}
|
|
|
|
// ModelInfo represents meta information for the model.
|
|
type ModelInfo struct {
|
|
TFVersion string `yaml:"-" json:"-"`
|
|
Tags []string `yaml:"Tags" json:"tags"`
|
|
Input *PhotoInput `yaml:"Input" json:"input"`
|
|
Output *ModelOutput `yaml:"Output" json:"output"`
|
|
}
|
|
|
|
// Merge other model info. In case of having information
|
|
// for a field, the current model will keep its current value
|
|
func (m *ModelInfo) Merge(other *ModelInfo) {
|
|
if m.TFVersion == "" {
|
|
m.TFVersion = other.TFVersion
|
|
}
|
|
|
|
if len(m.Tags) == 0 {
|
|
m.Tags = other.Tags
|
|
}
|
|
|
|
if m.Input == nil {
|
|
m.Input = other.Input
|
|
} else if other.Input != nil {
|
|
m.Input.Merge(other.Input)
|
|
}
|
|
|
|
if m.Output == nil {
|
|
m.Output = other.Output
|
|
} else if other.Output != nil {
|
|
m.Output.Merge(other.Output)
|
|
}
|
|
}
|
|
|
|
// IsComplete checks if the model input and output are defined.
|
|
func (m ModelInfo) IsComplete() bool {
|
|
return m.Input != nil && m.Output != nil && m.Input.Shape != nil
|
|
}
|
|
|
|
// GetModelTagsInfo reads a SavedModel and returns its available meta graph tags.
|
|
func GetModelTagsInfo(savedModelPath string) ([]ModelInfo, error) {
|
|
savedModel := filepath.Join(savedModelPath, "saved_model.pb")
|
|
|
|
data, err := os.ReadFile(savedModel) //nolint:gosec // savedModel path derived from trusted model directory
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("vision: failed to read %s (%s)", clean.Path(savedModel), clean.Error(err))
|
|
}
|
|
|
|
model := new(pb.SavedModel)
|
|
|
|
err = proto.Unmarshal(data, model)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("vision: failed to unmarshal %s (%s)", clean.Path(savedModel), clean.Error(err))
|
|
}
|
|
|
|
models := make([]ModelInfo, 0)
|
|
metas := model.GetMetaGraphs()
|
|
|
|
for i := range metas {
|
|
def := metas[i].GetMetaInfoDef()
|
|
models = append(models, ModelInfo{
|
|
TFVersion: def.GetTensorflowVersion(),
|
|
Tags: def.GetTags(),
|
|
})
|
|
}
|
|
|
|
return models, nil
|
|
}
|