mirror of
https://github.com/photoprism/photoprism.git
synced 2025-12-12 00:34:13 +01:00
Removed parameter Channels
It seems to be standarized, so it is now used as an additional check for input signatures.
This commit is contained in:
@@ -28,13 +28,15 @@ func GuessInputAndOutput(model *tf.SavedModel) (input *PhotoInput, output *Model
|
||||
modelOps := model.Graph.Operations()
|
||||
|
||||
for i := range modelOps {
|
||||
if strings.HasPrefix(modelOps[i].Type(), "Placeholder") && modelOps[i].NumOutputs() == 1 && modelOps[i].Output(0).Shape().NumDimensions() == 4 {
|
||||
if strings.HasPrefix(modelOps[i].Type(), "Placeholder") &&
|
||||
modelOps[i].NumOutputs() == 1 &&
|
||||
modelOps[i].Output(0).Shape().NumDimensions() == 4 &&
|
||||
modelOps[i].Output(0).Shape().Size(3) == ExpectedChannels { // check the channels are 3
|
||||
shape := modelOps[i].Output(0).Shape()
|
||||
input = &PhotoInput{
|
||||
Name: modelOps[i].Name(),
|
||||
Height: shape.Size(1),
|
||||
Width: shape.Size(2),
|
||||
Channels: shape.Size(3),
|
||||
Name: modelOps[i].Name(),
|
||||
Height: shape.Size(1),
|
||||
Width: shape.Size(2),
|
||||
}
|
||||
} else if (modelOps[i].Type() == "Softmax" || strings.HasPrefix(modelOps[i].Type(), "StatefulPartitionedCall")) &&
|
||||
modelOps[i].NumOutputs() == 1 && modelOps[i].Output(0).Shape().NumDimensions() == 2 {
|
||||
@@ -57,7 +59,7 @@ func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelO
|
||||
return nil, nil, fmt.Errorf("GetInputAndOutputFromSavedModel: nil input")
|
||||
}
|
||||
|
||||
for _, v := range model.Signatures {
|
||||
for k, v := range model.Signatures {
|
||||
inputs := v.Inputs
|
||||
outputs := v.Outputs
|
||||
|
||||
@@ -66,7 +68,13 @@ func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelO
|
||||
outputVarName, outputTensor := GetOne(outputs)
|
||||
|
||||
if inputTensor != nil && outputTensor != nil {
|
||||
if inputTensor.Shape.Size(3) != ExpectedChannels {
|
||||
log.Warnf("tensorflow: skipping signature %v because channels are expected to be %d, have %d",
|
||||
k, ExpectedChannels, inputTensor.Shape.Size(3))
|
||||
}
|
||||
|
||||
if inputTensor.Shape.NumDimensions() == 4 &&
|
||||
inputTensor.Shape.Size(3) == ExpectedChannels &&
|
||||
outputTensor.Shape.NumDimensions() == 2 {
|
||||
var inputIdx, outputIdx = 0, 0
|
||||
var err error
|
||||
@@ -92,7 +100,6 @@ func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelO
|
||||
OutputIndex: inputIdx,
|
||||
Height: inputTensor.Shape.Size(1),
|
||||
Width: inputTensor.Shape.Size(2),
|
||||
Channels: inputTensor.Shape.Size(3),
|
||||
}, &ModelOutput{
|
||||
Name: outputName,
|
||||
OutputIndex: outputIdx,
|
||||
|
||||
Reference in New Issue
Block a user