Skip to content

Commit 325cbd7

Browse files
committed
Merge pull request opencv#10364 from dkurt:dnn_smooth_tf_data_layout
2 parents 1bc1f3d + 7e48fa5 commit 325cbd7

File tree

2 files changed

+139
-10
lines changed

2 files changed

+139
-10
lines changed

modules/dnn/src/tensorflow/tf_importer.cpp

Lines changed: 137 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ namespace
4242

4343
static int toNCHW[] = {0, 2, 3, 1};
4444

45+
// This values are used to indicate layer output's data layout where it's possible.
46+
enum DataLayout
47+
{
48+
DATA_LAYOUT_NHWC,
49+
DATA_LAYOUT_NCHW,
50+
DATA_LAYOUT_UNKNOWN
51+
};
52+
4553
typedef std::vector<std::pair<String, int> > StrIntVector;
4654

4755
struct Pin
@@ -608,6 +616,31 @@ static void addConstNodes(const tensorflow::GraphDef& net, std::map<String, int>
608616
}
609617
}
610618

619+
// If all inputs of specific layer have the same data layout we can say that
620+
// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
621+
static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std::map<String, int>& data_layouts)
622+
{
623+
int layout = DATA_LAYOUT_UNKNOWN;
624+
std::map<String, int>::const_iterator it;
625+
for (int i = 0, n = layer.input_size(); i < n; ++i)
626+
{
627+
it = data_layouts.find(layer.input(i));
628+
if (it != data_layouts.end())
629+
{
630+
if (it->second == DATA_LAYOUT_UNKNOWN)
631+
return DATA_LAYOUT_UNKNOWN;
632+
else if (it->second != layout)
633+
{
634+
if (layout == DATA_LAYOUT_UNKNOWN)
635+
layout = it->second;
636+
else
637+
return DATA_LAYOUT_UNKNOWN;
638+
}
639+
}
640+
}
641+
return layout;
642+
}
643+
611644
void TFImporter::populateNet(Net dstNet)
612645
{
613646
RemoveIdentityOps(netBin);
@@ -619,6 +652,8 @@ void TFImporter::populateNet(Net dstNet)
619652

620653
int layersSize = net.node_size();
621654

655+
std::map<String, int> data_layouts;
656+
622657
// find all Const layers for params
623658
std::map<String, int> value_id;
624659
addConstNodes(netBin, value_id, layers_to_ignore);
@@ -636,6 +671,8 @@ void TFImporter::populateNet(Net dstNet)
636671
if(layers_to_ignore.find(name) != layers_to_ignore.end())
637672
continue;
638673

674+
data_layouts[name] = predictOutputDataLayout(layer, data_layouts);
675+
639676
if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative")
640677
{
641678
// The first node of dilated convolution subgraph.
@@ -731,6 +768,19 @@ void TFImporter::populateNet(Net dstNet)
731768

732769
// one input only
733770
connect(layer_id, dstNet, parsePin(input), id, 0);
771+
772+
if (hasLayerAttr(layer, "data_format"))
773+
{
774+
std::string format = getLayerAttr(layer, "data_format").s();
775+
if (format == "NHWC")
776+
data_layouts[name] = DATA_LAYOUT_NHWC;
777+
else if (format == "NCHW")
778+
data_layouts[name] = DATA_LAYOUT_NCHW;
779+
else
780+
CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
781+
}
782+
else
783+
data_layouts[name] = DATA_LAYOUT_NHWC;
734784
}
735785
else if (type == "BiasAdd" || type == "Add")
736786
{
@@ -806,22 +856,55 @@ void TFImporter::populateNet(Net dstNet)
806856
// one input only
807857
int input_blob_index = kernel_blob_index == 0 ? 1 : 0;
808858
connect(layer_id, dstNet, parsePin(layer.input(input_blob_index)), id, 0);
859+
data_layouts[name] = DATA_LAYOUT_UNKNOWN;
809860
}
810861
else if (type == "Reshape")
811862
{
812-
layerParams.set("dim", parseDims(getConstBlob(layer, value_id, 1)));
863+
Pin inpId = parsePin(layer.input(0));
864+
DictValue newShape = parseDims(getConstBlob(layer, value_id, 1));
865+
866+
if (newShape.size() != 4 && data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
867+
{
868+
LayerParams permLP;
869+
int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC.
870+
permLP.set("order", DictValue::arrayInt<int*>(order, 4));
871+
872+
std::string permName = name + "/nchw";
873+
CV_Assert(layer_id.find(permName) == layer_id.end());
874+
int permId = dstNet.addLayer(permName, "Permute", permLP);
875+
layer_id[permName] = permId;
876+
connect(layer_id, dstNet, inpId, permId, 0);
877+
inpId = Pin(permName);
878+
}
879+
layerParams.set("dim", newShape);
813880

814881
int id = dstNet.addLayer(name, "Reshape", layerParams);
815882
layer_id[name] = id;
816883

817884
// one input only
818-
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
885+
connect(layer_id, dstNet, inpId, id, 0);
886+
data_layouts[name] = DATA_LAYOUT_UNKNOWN;
819887
}
820888
else if (type == "Flatten")
821889
{
890+
Pin inpId = parsePin(layer.input(0));
891+
if (data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
892+
{
893+
LayerParams permLP;
894+
int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC.
895+
permLP.set("order", DictValue::arrayInt<int*>(order, 4));
896+
897+
std::string permName = name + "/nchw";
898+
CV_Assert(layer_id.find(permName) == layer_id.end());
899+
int permId = dstNet.addLayer(permName, "Permute", permLP);
900+
layer_id[permName] = permId;
901+
connect(layer_id, dstNet, inpId, permId, 0);
902+
inpId = Pin(permName);
903+
}
822904
int id = dstNet.addLayer(name, "Flatten", layerParams);
823905
layer_id[name] = id;
824-
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
906+
connect(layer_id, dstNet, inpId, id, 0);
907+
data_layouts[name] = DATA_LAYOUT_UNKNOWN;
825908
}
826909
else if (type == "Transpose")
827910
{
@@ -830,16 +913,57 @@ void TFImporter::populateNet(Net dstNet)
830913
int* permData = (int*)perm.data;
831914
if (perm.total() == 4)
832915
{
833-
for (int i = 0; i < 4; ++i)
834-
permData[i] = toNCHW[permData[i]];
916+
// Only NHWC <-> NCHW permutations are allowed. OpenCV is always
917+
// keep NCHW layout this way.
918+
if (data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
919+
{
920+
if (permData[0] == 0 && permData[1] == 3 && permData[2] == 1 && permData[3] == 2)
921+
{
922+
// in TensorFlow: NHWC->NCHW
923+
// in OpenCV: NCHW->NCHW
924+
data_layouts[name] = DATA_LAYOUT_NCHW;
925+
}
926+
else if (permData[0] == 0 && permData[1] == 1 && permData[2] == 2 && permData[3] == 3)
927+
{
928+
// in TensorFlow: NHWC->NHWC
929+
// in OpenCV: NCHW->NCHW
930+
data_layouts[name] = DATA_LAYOUT_NHWC;
931+
}
932+
else
933+
CV_Assert(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
934+
}
935+
else if (data_layouts[layer.input(0)] == DATA_LAYOUT_NCHW)
936+
{
937+
if (permData[0] == 0 && permData[1] == 2 && permData[2] == 3 && permData[3] == 1)
938+
{
939+
// in TensorFlow: NCHW->NHWC
940+
// in OpenCV: NCHW->NCHW
941+
data_layouts[name] = DATA_LAYOUT_NHWC;
942+
}
943+
else if (permData[0] == 0 && permData[1] == 1 && permData[2] == 2 && permData[3] == 3)
944+
{
945+
// in TensorFlow: NCHW->NCHW
946+
// in OpenCV: NCHW->NCHW
947+
data_layouts[name] = DATA_LAYOUT_NCHW;
948+
}
949+
else
950+
CV_Assert(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
951+
}
952+
int id = dstNet.addLayer(name, "Identity", layerParams);
953+
layer_id[name] = id;
954+
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
835955
}
836-
layerParams.set("order", DictValue::arrayInt<int*>(permData, perm.total()));
956+
else
957+
{
958+
layerParams.set("order", DictValue::arrayInt<int*>(permData, perm.total()));
837959

838-
int id = dstNet.addLayer(name, "Permute", layerParams);
839-
layer_id[name] = id;
960+
int id = dstNet.addLayer(name, "Permute", layerParams);
961+
layer_id[name] = id;
840962

841-
// one input only
842-
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
963+
// one input only
964+
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
965+
data_layouts[name] = DATA_LAYOUT_UNKNOWN;
966+
}
843967
}
844968
else if (type == "Const")
845969
{
@@ -1207,6 +1331,7 @@ void TFImporter::populateNet(Net dstNet)
12071331

12081332
// one input only
12091333
connect(layer_id, dstNet, parsePin(layer.input(1)), id, 0);
1334+
data_layouts[name] = DATA_LAYOUT_UNKNOWN;
12101335
}
12111336
else if (type == "ResizeNearestNeighbor")
12121337
{
@@ -1258,6 +1383,7 @@ void TFImporter::populateNet(Net dstNet)
12581383
layer_id[name] = id;
12591384
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
12601385
connect(layer_id, dstNet, parsePin(layer.input(1)), id, 1);
1386+
data_layouts[name] = DATA_LAYOUT_UNKNOWN;
12611387
}
12621388
else if (type == "DetectionOutput")
12631389
{
@@ -1288,6 +1414,7 @@ void TFImporter::populateNet(Net dstNet)
12881414
layer_id[name] = id;
12891415
for (int i = 0; i < 3; ++i)
12901416
connect(layer_id, dstNet, parsePin(layer.input(i)), id, i);
1417+
data_layouts[name] = DATA_LAYOUT_UNKNOWN;
12911418
}
12921419
else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
12931420
type == "Relu" || type == "Elu" || type == "Softmax" ||

modules/dnn/test/test_tf_importer.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ TEST(Test_TensorFlow, deconvolution)
159159
TEST(Test_TensorFlow, matmul)
160160
{
161161
runTensorFlowNet("matmul");
162+
runTensorFlowNet("nhwc_reshape_matmul");
163+
runTensorFlowNet("nhwc_transpose_reshape_matmul");
162164
}
163165

164166
TEST(Test_TensorFlow, defun)

0 commit comments

Comments
 (0)