Removed parameter Channels

It seems to be standarized, so it is now used as an additional check for
input signatures.
This commit is contained in:
raystlin
2025-04-16 08:19:58 +00:00
parent eca0bc5205
commit d082929dee
7 changed files with 37 additions and 33 deletions

View File

@@ -28,13 +28,15 @@ func GuessInputAndOutput(model *tf.SavedModel) (input *PhotoInput, output *Model
modelOps := model.Graph.Operations()
for i := range modelOps {
if strings.HasPrefix(modelOps[i].Type(), "Placeholder") && modelOps[i].NumOutputs() == 1 && modelOps[i].Output(0).Shape().NumDimensions() == 4 {
if strings.HasPrefix(modelOps[i].Type(), "Placeholder") &&
modelOps[i].NumOutputs() == 1 &&
modelOps[i].Output(0).Shape().NumDimensions() == 4 &&
modelOps[i].Output(0).Shape().Size(3) == ExpectedChannels { // check the channels are 3
shape := modelOps[i].Output(0).Shape()
input = &PhotoInput{
Name: modelOps[i].Name(),
Height: shape.Size(1),
Width: shape.Size(2),
Channels: shape.Size(3),
Name: modelOps[i].Name(),
Height: shape.Size(1),
Width: shape.Size(2),
}
} else if (modelOps[i].Type() == "Softmax" || strings.HasPrefix(modelOps[i].Type(), "StatefulPartitionedCall")) &&
modelOps[i].NumOutputs() == 1 && modelOps[i].Output(0).Shape().NumDimensions() == 2 {
@@ -57,7 +59,7 @@ func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelO
return nil, nil, fmt.Errorf("GetInputAndOutputFromSavedModel: nil input")
}
for _, v := range model.Signatures {
for k, v := range model.Signatures {
inputs := v.Inputs
outputs := v.Outputs
@@ -66,7 +68,13 @@ func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelO
outputVarName, outputTensor := GetOne(outputs)
if inputTensor != nil && outputTensor != nil {
if inputTensor.Shape.Size(3) != ExpectedChannels {
log.Warnf("tensorflow: skipping signature %v because channels are expected to be %d, have %d",
k, ExpectedChannels, inputTensor.Shape.Size(3))
}
if inputTensor.Shape.NumDimensions() == 4 &&
inputTensor.Shape.Size(3) == ExpectedChannels &&
outputTensor.Shape.NumDimensions() == 2 {
var inputIdx, outputIdx = 0, 0
var err error
@@ -92,7 +100,6 @@ func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelO
OutputIndex: inputIdx,
Height: inputTensor.Shape.Size(1),
Width: inputTensor.Shape.Size(2),
Channels: inputTensor.Shape.Size(3),
}, &ModelOutput{
Name: outputName,
OutputIndex: outputIdx,