Files
2025-11-22 11:47:17 +01:00

50 lines
1.1 KiB
Go

package tensorflow
import (
tf "github.com/wamuir/graft/tensorflow"
)
// AddSoftmax appends a Softmax operation to the graph for the configured model output.
func AddSoftmax(graph *tf.Graph, info *ModelInfo) (*tf.Operation, error) {
randomName := randomString(10)
logits := graph.Operation(info.Output.Name).Output(info.Output.OutputIndex)
reshapeOpSpec := tf.OpSpec{
Type: "EnsureShape",
Name: randomString(10),
Input: []tf.Input{
logits,
},
Attrs: map[string]any{
"shape": tf.MakeShape(-1, info.Output.NumOutputs),
},
}
// We add this reshape operation becase TF seems unable to infere the input
// shape for softmax operation, eventhough it is perfectly recoverable by
// inspecting the models.
reshapeOp, err := graph.AddOperation(reshapeOpSpec)
if err != nil {
return nil, err
}
opspec := tf.OpSpec{
Type: "Softmax",
Name: randomName,
Input: []tf.Input{
reshapeOp.Output(0),
},
}
op, err := graph.AddOperation(opspec)
if err != nil {
return nil, err
}
info.Output.Name = randomName
info.Output.OutputIndex = 0
return op, nil
}