Added utilities to inspect tensorflow models.

These new functions allows us to inspect the saved models, get the tags
and try to guess the inputs and outputs.
This commit is contained in:
raystlin
2025-04-11 14:30:48 +00:00
parent bd634c828b
commit d0f6d903e2
3 changed files with 281 additions and 0 deletions

View File

@@ -0,0 +1,164 @@
package tensorflow
import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
pb "github.com/wamuir/graft/tensorflow/core/protobuf/for_core_protos_go_proto"
"google.golang.org/protobuf/proto"
)
// Input description for a photo input for a model
type PhotoInput struct {
Name string
OutputIndex int
Height int64
Width int64
Channels int64
}
// When dimensions are not defined, it means the model accepts any size of
// photo
func (f PhotoInput) IsDynamic() bool {
return f.Height == -1 && f.Width == -1
}
// Get the resolution
func (f PhotoInput) Resolution() int {
return int(f.Height)
}
// Set the resolution: same height and width
func (f *PhotoInput) SetResolution(resolution int) {
f.Height = int64(resolution)
f.Width = int64(resolution)
}
// The output expected for a model
type ModelOutput struct {
Name string
OutputIndex int
NumOutputs int64
OutputsLogits bool
}
// The meta information for the model
type ModelInfo struct {
TFVersion string
Tags []string
Input *PhotoInput
Output *ModelOutput
}
// We consider a model complete if we know its inputs and outputs
func (m ModelInfo) IsComplete() bool {
return m.Input != nil && m.Output != nil
}
// GetInputAndOutputFromSignature gets the signatures from a MetaGraphDef and
// uses them to build PhotoInput and ModelOutput structs, that will complete
// ModelInfo struct
func GetInputAndOutputFromMetaSignature(meta *pb.MetaGraphDef) (*PhotoInput, *ModelOutput, error) {
if meta == nil {
return nil, nil, fmt.Errorf("GetInputAndOutputFromSignature: nil input")
}
sig := meta.GetSignatureDef()
for _, v := range sig {
inputs := v.GetInputs()
outputs := v.GetOutputs()
if len(inputs) == 1 && len(outputs) == 1 {
_, inputTensor := GetOne(inputs)
outputVarName, outputTensor := GetOne(outputs)
if inputTensor != nil && (*inputTensor).GetTensorShape() != nil &&
outputTensor != nil && (*outputTensor).GetTensorShape() != nil {
inputDims := (*inputTensor).GetTensorShape().Dim
outputDims := (*outputTensor).GetTensorShape().Dim
if len(inputDims) == 4 && len(outputDims) == 2 {
var err error
var inputIdx, outputIdx = 0, 0
inputName, inputIndex, found := strings.Cut((*inputTensor).GetName(), ":")
if found {
inputIdx, err = strconv.Atoi(inputIndex)
if err != nil {
return nil, nil, fmt.Errorf("Could not parse index %s: %w", inputIndex, err)
}
}
outputName, outputIndex, found := strings.Cut((*outputTensor).GetName(), ":")
if found {
outputIdx, err = strconv.Atoi(outputIndex)
if err != nil {
return nil, nil, fmt.Errorf("Could not parse index: %s: %w", outputIndex, err)
}
}
return &PhotoInput{
Name: inputName,
OutputIndex: inputIdx,
Height: inputDims[1].GetSize(),
Width: inputDims[2].GetSize(),
Channels: inputDims[3].GetSize(),
}, &ModelOutput{
Name: outputName,
OutputIndex: outputIdx,
NumOutputs: outputDims[1].GetSize(),
OutputsLogits: strings.Contains(Deref(outputVarName, ""), "logits"),
}, nil
}
}
}
}
return nil, nil, fmt.Errorf("GetInputAndOutputFromMetaSignature: Could not find a valid signature")
}
func GetModelInfo(path string) ([]ModelInfo, error) {
path = filepath.Join(path, "saved_model.pb")
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("Could not read the file %s: %w", path, err)
}
model := new(pb.SavedModel)
err = proto.Unmarshal(data, model)
if err != nil {
return nil, fmt.Errorf("Could not unmarshal the file %s: %w", path, err)
}
models := make([]ModelInfo, 0)
metas := model.GetMetaGraphs()
for i := range metas {
def := metas[i].GetMetaInfoDef()
input, output, err := GetInputAndOutputFromMetaSignature(metas[i])
newModel := ModelInfo{
TFVersion: def.GetTensorflowVersion(),
Tags: def.GetTags(),
Input: input,
Output: output,
}
if err != nil {
log.Printf("Could not get the inputs and outputs from signatures. (TF Version %s): %w", newModel.TFVersion, err)
}
models = append(models, newModel)
}
return models, nil
}

View File

@@ -1,7 +1,10 @@
package tensorflow
import (
"fmt"
"path/filepath"
"strconv"
"strings"
tf "github.com/wamuir/graft/tensorflow"
@@ -18,3 +21,90 @@ func SavedModel(modelPath string, tags []string) (model *tf.SavedModel, err erro
return tf.LoadSavedModel(modelPath, tags, nil)
}
// GuessInputAndOutput tries to inspect a loaded saved model to build the
// ModelInfo struct
func GuessInputAndOutput(model *tf.SavedModel) (input *PhotoInput, output *ModelOutput, err error) {
modelOps := model.Graph.Operations()
for i := range modelOps {
if strings.HasPrefix(modelOps[i].Type(), "Placeholder") && modelOps[i].NumOutputs() == 1 && modelOps[i].Output(0).Shape().NumDimensions() == 4 {
shape := modelOps[i].Output(0).Shape()
input = &PhotoInput{
Name: modelOps[i].Name(),
Height: shape.Size(1),
Width: shape.Size(2),
Channels: shape.Size(3),
}
} else if (modelOps[i].Type() == "Softmax" || strings.HasPrefix(modelOps[i].Type(), "StatefulPartitionedCall")) &&
modelOps[i].NumOutputs() == 1 && modelOps[i].Output(0).Shape().NumDimensions() == 2 {
output = &ModelOutput{
Name: modelOps[i].Name(),
NumOutputs: modelOps[i].Output(0).Shape().Size(1),
}
}
if input != nil && output != nil {
return
}
}
return nil, nil, fmt.Errorf("Could not guess the inputs and outputs")
}
func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelOutput, error) {
if model == nil {
return nil, nil, fmt.Errorf("GetInputAndOutputFromSavedModel: nil input")
}
for _, v := range model.Signatures {
inputs := v.Inputs
outputs := v.Outputs
if len(inputs) == 1 && len(outputs) == 1 {
_, inputTensor := GetOne(inputs)
outputVarName, outputTensor := GetOne(outputs)
if inputTensor != nil && outputTensor != nil {
if inputTensor.Shape.NumDimensions() == 4 &&
outputTensor.Shape.NumDimensions() == 2 {
var inputIdx, outputIdx = 0, 0
var err error
inputName, inputIndex, found := strings.Cut(inputTensor.Name, ":")
if found {
inputIdx, err = strconv.Atoi(inputIndex)
if err != nil {
return nil, nil, fmt.Errorf("Could not parse index %s: %w", inputIndex, err)
}
}
outputName, outputIndex, found := strings.Cut(outputTensor.Name, ":")
if found {
outputIdx, err = strconv.Atoi(outputIndex)
if err != nil {
return nil, nil, fmt.Errorf("Could not parse index: %s: %w", outputIndex, err)
}
}
return &PhotoInput{
Name: inputName,
OutputIndex: inputIdx,
Height: inputTensor.Shape.Size(1),
Width: inputTensor.Shape.Size(2),
Channels: inputTensor.Shape.Size(3),
}, &ModelOutput{
Name: outputName,
OutputIndex: outputIdx,
NumOutputs: outputTensor.Shape.Size(1),
OutputsLogits: strings.Contains(Deref(outputVarName, ""), "logits"),
}, nil
}
}
}
}
return nil, nil, fmt.Errorf("GetInputAndOutputFromSignature: could not find valid signatures")
}

View File

@@ -0,0 +1,27 @@
package tensorflow
import "math/rand"
func randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
result[i] = charset[rand.Intn(len(charset))]
}
return string(result)
}
func GetOne[K comparable, V any](input map[K]V) (*K, *V) {
for k, v := range input {
return &k, &v
}
return nil, nil
}
func Deref[V any](input *V, defval V) V {
if input == nil {
return defval
}
return *input
}