Files
photoprism/internal/ai/tensorflow/op.go
raystlin 8bc7121394 Added a comment to clarify
EnsureShape operation may seem unnecesary, but PredictCosts fails it it
is not added to the graph.
2025-04-11 16:57:27 +00:00

49 lines
1005 B
Go

package tensorflow
import (
tf "github.com/wamuir/graft/tensorflow"
)
func AddSoftmax(graph *tf.Graph, info *ModelInfo) (*tf.Operation, error) {
randomName := randomString(10)
logits := graph.Operation(info.Output.Name).Output(info.Output.OutputIndex)
reshapeOpSpec := tf.OpSpec{
Type: "EnsureShape",
Name: randomString(10),
Input: []tf.Input{
logits,
},
Attrs: map[string]any{
"shape": tf.MakeShape(-1, info.Output.NumOutputs),
},
}
// We add this reshape operation becase TF seems unable to infere the input
// shape for softmax operation, eventhough it is perfectly recoverable by
// inspecting the models.
reshapeOp, err := graph.AddOperation(reshapeOpSpec)
if err != nil {
return nil, err
}
opspec := tf.OpSpec{
Type: "Softmax",
Name: randomName,
Input: []tf.Input{
reshapeOp.Output(0),
},
}
op, err := graph.AddOperation(opspec)
if err != nil {
return nil, err
}
info.Output.Name = randomName
info.Output.OutputIndex = 0
return op, nil
}