Skip to content

Commit bbbec30

Browse files
committed
nn.BatchNormalization and nn.Dropout layers from Torch
1 parent fc9e031 commit bbbec30

File tree

3 files changed

+22
-8
lines changed

3 files changed

+22
-8
lines changed

modules/dnn/src/layers/batch_norm_layer.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,9 @@ class BatchNormLayerImpl : public BatchNormLayer
119119
CV_Assert(inputs.size() == 1);
120120

121121
Mat &inpBlob = *inputs[0];
122-
int rows = inpBlob.size[2];
123-
int cols = inpBlob.size[3];
122+
CV_Assert(inpBlob.dims == 2 || inpBlob.dims == 4);
123+
int rows = inpBlob.dims > 2 ? inpBlob.size[2] : 1;
124+
int cols = inpBlob.dims > 2 ? inpBlob.size[3] : 1;
124125

125126
for (size_t ii = 0; ii < outputs.size(); ii++)
126127
{

modules/dnn/src/torch/torch_importer.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,8 @@ struct TorchImporter : public ::cv::dnn::Importer
617617
curModule->modules.push_back(cv::Ptr<Module>(new Module(nnName, "Sigmoid")));
618618
readObject();
619619
}
620-
else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization")
620+
else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization" ||
621+
nnName == "BatchNormalization")
621622
{
622623
newModule->apiType = "BatchNorm";
623624
readTorchTable(scalarParams, tensorParams);
@@ -700,17 +701,24 @@ struct TorchImporter : public ::cv::dnn::Importer
700701

701702
curModule->modules.push_back(newModule);
702703
}
703-
else if (nnName == "SpatialDropout")
704+
else if (nnName == "SpatialDropout" || nnName == "Dropout")
704705
{
705706
readTorchTable(scalarParams, tensorParams);
706707
CV_Assert(scalarParams.has("p"));
707708

708-
float scale = 1 - scalarParams.get<double>("p");
709+
if (scalarParams.has("v2") && scalarParams.get<bool>("v2"))
710+
{
711+
newModule->apiType = "Identity";
712+
}
713+
else
714+
{
715+
float scale = 1 - scalarParams.get<double>("p");
709716

710-
CV_Assert(scale > 0);
717+
CV_Assert(scale > 0);
711718

712-
newModule->apiType = "Power";
713-
layerParams.set("scale", scale);
719+
newModule->apiType = "Power";
720+
layerParams.set("scale", scale);
721+
}
714722
curModule->modules.push_back(newModule);
715723
}
716724
// TotalVariation layer is from fast-neural-style project: https://github.com/jcjohnson/fast-neural-style

modules/dnn/test/test_torch_importer.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ TEST(Torch_Importer, net_padding)
234234
runTorchNet("net_spatial_reflection_padding", DNN_TARGET_CPU, "", false, true);
235235
}
236236

237+
TEST(Torch_Importer, net_non_spatial)
238+
{
239+
runTorchNet("net_non_spatial", DNN_TARGET_CPU, "", false, true);
240+
}
241+
237242
TEST(Torch_Importer, ENet_accuracy)
238243
{
239244
Net net;

0 commit comments

Comments
 (0)