mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
Added utilities to inspect tensorflow models.
These new functions allows us to inspect the saved models, get the tags and try to guess the inputs and outputs.
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
tf "github.com/wamuir/graft/tensorflow"
|
||||
|
||||
@@ -18,3 +21,90 @@ func SavedModel(modelPath string, tags []string) (model *tf.SavedModel, err erro
|
||||
|
||||
return tf.LoadSavedModel(modelPath, tags, nil)
|
||||
}
|
||||
|
||||
// GuessInputAndOutput tries to inspect a loaded saved model to build the
|
||||
// ModelInfo struct
|
||||
func GuessInputAndOutput(model *tf.SavedModel) (input *PhotoInput, output *ModelOutput, err error) {
|
||||
modelOps := model.Graph.Operations()
|
||||
|
||||
for i := range modelOps {
|
||||
if strings.HasPrefix(modelOps[i].Type(), "Placeholder") && modelOps[i].NumOutputs() == 1 && modelOps[i].Output(0).Shape().NumDimensions() == 4 {
|
||||
shape := modelOps[i].Output(0).Shape()
|
||||
input = &PhotoInput{
|
||||
Name: modelOps[i].Name(),
|
||||
Height: shape.Size(1),
|
||||
Width: shape.Size(2),
|
||||
Channels: shape.Size(3),
|
||||
}
|
||||
} else if (modelOps[i].Type() == "Softmax" || strings.HasPrefix(modelOps[i].Type(), "StatefulPartitionedCall")) &&
|
||||
modelOps[i].NumOutputs() == 1 && modelOps[i].Output(0).Shape().NumDimensions() == 2 {
|
||||
output = &ModelOutput{
|
||||
Name: modelOps[i].Name(),
|
||||
NumOutputs: modelOps[i].Output(0).Shape().Size(1),
|
||||
}
|
||||
}
|
||||
|
||||
if input != nil && output != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("Could not guess the inputs and outputs")
|
||||
}
|
||||
|
||||
func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelOutput, error) {
|
||||
if model == nil {
|
||||
return nil, nil, fmt.Errorf("GetInputAndOutputFromSavedModel: nil input")
|
||||
}
|
||||
|
||||
for _, v := range model.Signatures {
|
||||
inputs := v.Inputs
|
||||
outputs := v.Outputs
|
||||
|
||||
if len(inputs) == 1 && len(outputs) == 1 {
|
||||
_, inputTensor := GetOne(inputs)
|
||||
outputVarName, outputTensor := GetOne(outputs)
|
||||
|
||||
if inputTensor != nil && outputTensor != nil {
|
||||
if inputTensor.Shape.NumDimensions() == 4 &&
|
||||
outputTensor.Shape.NumDimensions() == 2 {
|
||||
var inputIdx, outputIdx = 0, 0
|
||||
var err error
|
||||
|
||||
inputName, inputIndex, found := strings.Cut(inputTensor.Name, ":")
|
||||
if found {
|
||||
inputIdx, err = strconv.Atoi(inputIndex)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("Could not parse index %s: %w", inputIndex, err)
|
||||
}
|
||||
}
|
||||
|
||||
outputName, outputIndex, found := strings.Cut(outputTensor.Name, ":")
|
||||
if found {
|
||||
outputIdx, err = strconv.Atoi(outputIndex)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("Could not parse index: %s: %w", outputIndex, err)
|
||||
}
|
||||
}
|
||||
|
||||
return &PhotoInput{
|
||||
Name: inputName,
|
||||
OutputIndex: inputIdx,
|
||||
Height: inputTensor.Shape.Size(1),
|
||||
Width: inputTensor.Shape.Size(2),
|
||||
Channels: inputTensor.Shape.Size(3),
|
||||
}, &ModelOutput{
|
||||
Name: outputName,
|
||||
OutputIndex: outputIdx,
|
||||
NumOutputs: outputTensor.Shape.Size(1),
|
||||
OutputsLogits: strings.Contains(Deref(outputVarName, ""), "logits"),
|
||||
}, nil
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("GetInputAndOutputFromSignature: could not find valid signatures")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user