mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
* AI: Added support for non BHWC models Tensorflow models use BHWC by default, however, if we are using converted models, we can find that the expected input is BCHW. Now the input is configurable (although the restriction of being dimesion 4 is still there) via Shape parameter on the input definition. Also, the model instrospection will try to deduce the input shape from the model signature. * AI: Added more tests for enum parsing ShapeComponent was missing from the tests * AI: Modified external tests to the new url The path has been moved from tensorflow/vision to tensorflow/models * AI: Moved the builder to the model to reuse it It should reduce the amount of allocations done * AI: fixed errors after merge Mainly incorrect paths and duplicated variables
This commit is contained in:
@@ -5,8 +5,6 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
pb "github.com/wamuir/graft/tensorflow/core/protobuf/for_core_protos_go_proto"
|
||||
"google.golang.org/protobuf/proto"
|
||||
@@ -263,6 +261,26 @@ func (o *ColorChannelOrder) UnmarshalYAML(unmarshal func(interface{}) error) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// The expected shape for the input layer of a mode. Usually this shape is
|
||||
// (batch, resolution, resolution, channels) but sometimes it is not.
|
||||
type ShapeComponent string
|
||||
|
||||
const (
|
||||
ShapeBatch ShapeComponent = "Batch"
|
||||
ShapeWidth = "Width"
|
||||
ShapeHeight = "Height"
|
||||
ShapeColor = "Color"
|
||||
)
|
||||
|
||||
func DefaultPhotoInputShape() []ShapeComponent {
|
||||
return []ShapeComponent{
|
||||
ShapeBatch,
|
||||
ShapeHeight,
|
||||
ShapeWidth,
|
||||
ShapeColor,
|
||||
}
|
||||
}
|
||||
|
||||
// PhotoInput represents an input description for a photo input for a model.
|
||||
type PhotoInput struct {
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
@@ -272,6 +290,7 @@ type PhotoInput struct {
|
||||
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
||||
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
|
||||
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
|
||||
Shape []ShapeComponent `yaml:"Shape,omitempty" json:"shape,omitempty"`
|
||||
}
|
||||
|
||||
// IsDynamic checks if image dimensions are not defined, so the model accepts any size.
|
||||
@@ -331,6 +350,10 @@ func (p *PhotoInput) Merge(other *PhotoInput) {
|
||||
p.Width = other.Width
|
||||
}
|
||||
|
||||
if p.Shape == nil && other.Shape != nil {
|
||||
p.Shape = other.Shape
|
||||
}
|
||||
|
||||
if p.ResizeOperation == UndefinedResizeOperation {
|
||||
p.ResizeOperation = other.ResizeOperation
|
||||
}
|
||||
@@ -401,83 +424,10 @@ func (m *ModelInfo) Merge(other *ModelInfo) {
|
||||
|
||||
// IsComplete checks if the model input and output are defined.
|
||||
func (m ModelInfo) IsComplete() bool {
|
||||
return m.Input != nil && m.Output != nil
|
||||
return m.Input != nil && m.Output != nil && m.Input.Shape != nil
|
||||
}
|
||||
|
||||
// GetInputAndOutputFromMetaSignature returns the signatures from a MetaGraphDef
|
||||
// and uses them to build PhotoInput and ModelOutput structs, that will complete
|
||||
// ModelInfo struct.
|
||||
func GetInputAndOutputFromMetaSignature(meta *pb.MetaGraphDef) (*PhotoInput, *ModelOutput, error) {
|
||||
if meta == nil {
|
||||
return nil, nil, fmt.Errorf("GetInputAndOutputFromSignature: nil input")
|
||||
}
|
||||
|
||||
sig := meta.GetSignatureDef()
|
||||
for k, v := range sig {
|
||||
inputs := v.GetInputs()
|
||||
outputs := v.GetOutputs()
|
||||
|
||||
if len(inputs) == 1 && len(outputs) == 1 {
|
||||
_, inputTensor := GetOne(inputs)
|
||||
outputVarName, outputTensor := GetOne(outputs)
|
||||
|
||||
if inputTensor != nil && (*inputTensor).GetTensorShape() != nil &&
|
||||
outputTensor != nil && (*outputTensor).GetTensorShape() != nil {
|
||||
inputDims := (*inputTensor).GetTensorShape().Dim
|
||||
outputDims := (*outputTensor).GetTensorShape().Dim
|
||||
|
||||
if inputDims[3].GetSize() != ExpectedChannels {
|
||||
log.Warnf("tensorflow: skipping signature %v because channels are expected to be %d, have %d",
|
||||
k, ExpectedChannels, inputDims[3].GetSize())
|
||||
}
|
||||
|
||||
if len(inputDims) == 4 &&
|
||||
inputDims[3].GetSize() == ExpectedChannels &&
|
||||
len(outputDims) == 2 {
|
||||
var err error
|
||||
var inputIdx, outputIdx = 0, 0
|
||||
|
||||
inputName, inputIndex, found := strings.Cut((*inputTensor).GetName(), ":")
|
||||
if found {
|
||||
|
||||
inputIdx, err = strconv.Atoi(inputIndex)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("could not parse index %s (%s)", inputIndex, clean.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
outputName, outputIndex, found := strings.Cut((*outputTensor).GetName(), ":")
|
||||
if found {
|
||||
|
||||
outputIdx, err = strconv.Atoi(outputIndex)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("could not parse index: %s (%s)", outputIndex, clean.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return &PhotoInput{
|
||||
Name: inputName,
|
||||
OutputIndex: inputIdx,
|
||||
Height: inputDims[1].GetSize(),
|
||||
Width: inputDims[2].GetSize(),
|
||||
}, &ModelOutput{
|
||||
Name: outputName,
|
||||
OutputIndex: outputIdx,
|
||||
NumOutputs: outputDims[1].GetSize(),
|
||||
OutputsLogits: strings.Contains(Deref(outputVarName, ""), "logits"),
|
||||
}, nil
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("GetInputAndOutputFromMetaSignature: Could not find a valid signature")
|
||||
}
|
||||
|
||||
func GetModelInfo(savedModelPath string) ([]ModelInfo, error) {
|
||||
func GetModelTagsInfo(savedModelPath string) ([]ModelInfo, error) {
|
||||
savedModel := filepath.Join(savedModelPath, "saved_model.pb")
|
||||
|
||||
data, err := os.ReadFile(savedModel)
|
||||
@@ -499,20 +449,10 @@ func GetModelInfo(savedModelPath string) ([]ModelInfo, error) {
|
||||
|
||||
for i := range metas {
|
||||
def := metas[i].GetMetaInfoDef()
|
||||
input, output, modelErr := GetInputAndOutputFromMetaSignature(metas[i])
|
||||
|
||||
newModel := ModelInfo{
|
||||
models = append(models, ModelInfo{
|
||||
TFVersion: def.GetTensorflowVersion(),
|
||||
Tags: def.GetTags(),
|
||||
Input: input,
|
||||
Output: output,
|
||||
}
|
||||
|
||||
if modelErr != nil {
|
||||
log.Errorf("vision: could not determine model inputs and outputs from TensorFlow %s signatures (%s)", newModel.TFVersion, clean.Error(modelErr))
|
||||
}
|
||||
|
||||
models = append(models, newModel)
|
||||
})
|
||||
}
|
||||
|
||||
return models, nil
|
||||
|
||||
Reference in New Issue
Block a user