Added utilities to inspect tensorflow models.

These new functions allows us to inspect the saved models, get the tags
and try to guess the inputs and outputs.
This commit is contained in:
raystlin
2025-04-11 14:30:48 +00:00
parent bd634c828b
commit d0f6d903e2
3 changed files with 281 additions and 0 deletions

View File

@@ -1,7 +1,10 @@
package tensorflow
import (
"fmt"
"path/filepath"
"strconv"
"strings"
tf "github.com/wamuir/graft/tensorflow"
@@ -18,3 +21,90 @@ func SavedModel(modelPath string, tags []string) (model *tf.SavedModel, err erro
return tf.LoadSavedModel(modelPath, tags, nil)
}
// GuessInputAndOutput tries to inspect a loaded saved model to build the
// ModelInfo struct
func GuessInputAndOutput(model *tf.SavedModel) (input *PhotoInput, output *ModelOutput, err error) {
modelOps := model.Graph.Operations()
for i := range modelOps {
if strings.HasPrefix(modelOps[i].Type(), "Placeholder") && modelOps[i].NumOutputs() == 1 && modelOps[i].Output(0).Shape().NumDimensions() == 4 {
shape := modelOps[i].Output(0).Shape()
input = &PhotoInput{
Name: modelOps[i].Name(),
Height: shape.Size(1),
Width: shape.Size(2),
Channels: shape.Size(3),
}
} else if (modelOps[i].Type() == "Softmax" || strings.HasPrefix(modelOps[i].Type(), "StatefulPartitionedCall")) &&
modelOps[i].NumOutputs() == 1 && modelOps[i].Output(0).Shape().NumDimensions() == 2 {
output = &ModelOutput{
Name: modelOps[i].Name(),
NumOutputs: modelOps[i].Output(0).Shape().Size(1),
}
}
if input != nil && output != nil {
return
}
}
return nil, nil, fmt.Errorf("Could not guess the inputs and outputs")
}
func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelOutput, error) {
if model == nil {
return nil, nil, fmt.Errorf("GetInputAndOutputFromSavedModel: nil input")
}
for _, v := range model.Signatures {
inputs := v.Inputs
outputs := v.Outputs
if len(inputs) == 1 && len(outputs) == 1 {
_, inputTensor := GetOne(inputs)
outputVarName, outputTensor := GetOne(outputs)
if inputTensor != nil && outputTensor != nil {
if inputTensor.Shape.NumDimensions() == 4 &&
outputTensor.Shape.NumDimensions() == 2 {
var inputIdx, outputIdx = 0, 0
var err error
inputName, inputIndex, found := strings.Cut(inputTensor.Name, ":")
if found {
inputIdx, err = strconv.Atoi(inputIndex)
if err != nil {
return nil, nil, fmt.Errorf("Could not parse index %s: %w", inputIndex, err)
}
}
outputName, outputIndex, found := strings.Cut(outputTensor.Name, ":")
if found {
outputIdx, err = strconv.Atoi(outputIndex)
if err != nil {
return nil, nil, fmt.Errorf("Could not parse index: %s: %w", outputIndex, err)
}
}
return &PhotoInput{
Name: inputName,
OutputIndex: inputIdx,
Height: inputTensor.Shape.Size(1),
Width: inputTensor.Shape.Size(2),
Channels: inputTensor.Shape.Size(3),
}, &ModelOutput{
Name: outputName,
OutputIndex: outputIdx,
NumOutputs: outputTensor.Shape.Size(1),
OutputsLogits: strings.Contains(Deref(outputVarName, ""), "logits"),
}, nil
}
}
}
}
return nil, nil, fmt.Errorf("GetInputAndOutputFromSignature: could not find valid signatures")
}