Skip to content

Commit 54f0616

Browse files
committed
Deconvolution layer from TensorFlow
1 parent 89172c0 commit 54f0616

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

modules/dnn/src/tensorflow/tf_importer.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,50 @@ void TFImporter::populateNet(Net dstNet)
863863
// one input only
864864
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
865865
}
866+
else if (type == "Conv2DBackpropInput")
867+
{
868+
// op: "Conv2DBackpropInput"
869+
// input: "conv2d_transpose/output_shape"
870+
// input: "weights"
871+
// input: "input"
872+
if (layer.input_size() != 3)
873+
CV_Error(Error::StsNotImplemented,
874+
"Expected output shape, weights and input nodes");
875+
876+
layerParams.set("bias_term", false);
877+
layerParams.blobs.resize(1);
878+
879+
StrIntVector next_layers = getNextLayers(net, name, "BiasAdd");
880+
if (next_layers.size() == 1)
881+
{
882+
layerParams.set("bias_term", true);
883+
layerParams.blobs.resize(2);
884+
885+
int weights_layer_index = next_layers[0].second;
886+
887+
blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs[1]);
888+
ExcludeLayer(net, weights_layer_index, 0, false);
889+
layers_to_ignore[weights_layer_index] = next_layers[0].first;
890+
}
891+
892+
kernelFromTensor(getConstBlob(layer, value_id, 1), layerParams.blobs[0]);
893+
// Swap just numbers of input and output channels.
894+
std::swap(layerParams.blobs[0].size[0], layerParams.blobs[0].size[1]);
895+
896+
const int* kshape = layerParams.blobs[0].size.p;
897+
layerParams.set("kernel_h", kshape[2]);
898+
layerParams.set("kernel_w", kshape[3]);
899+
layerParams.set("num_output", kshape[0]);
900+
901+
setStrides(layerParams, layer);
902+
setPadding(layerParams, layer);
903+
904+
int id = dstNet.addLayer(name, "Deconvolution", layerParams);
905+
layer_id[name] = id;
906+
907+
// one input only
908+
connect(layer_id, dstNet, parsePin(layer.input(2)), id, 0);
909+
}
866910
else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
867911
type == "Relu" || type == "Elu" || type == "Softmax" ||
868912
type == "Identity")

modules/dnn/test/test_tf_importer.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,9 @@ TEST(Test_TensorFlow, pooling)
125125
runTensorFlowNet("max_pool_odd_same");
126126
}
127127

128+
TEST(Test_TensorFlow, deconvolution)
129+
{
130+
runTensorFlowNet("deconvolution");
131+
}
132+
128133
}

0 commit comments

Comments
 (0)