@@ -42,6 +42,14 @@ namespace
42
42
43
43
static int toNCHW[] = {0 , 2 , 3 , 1 };
44
44
45
+ // This values are used to indicate layer output's data layout where it's possible.
46
+ enum DataLayout
47
+ {
48
+ DATA_LAYOUT_NHWC,
49
+ DATA_LAYOUT_NCHW,
50
+ DATA_LAYOUT_UNKNOWN
51
+ };
52
+
45
53
typedef std::vector<std::pair<String, int > > StrIntVector;
46
54
47
55
struct Pin
@@ -608,6 +616,31 @@ static void addConstNodes(const tensorflow::GraphDef& net, std::map<String, int>
608
616
}
609
617
}
610
618
619
+ // If all inputs of specific layer have the same data layout we can say that
620
+ // this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
621
+ static int predictOutputDataLayout (const tensorflow::NodeDef& layer, const std::map<String, int >& data_layouts)
622
+ {
623
+ int layout = DATA_LAYOUT_UNKNOWN;
624
+ std::map<String, int >::const_iterator it;
625
+ for (int i = 0 , n = layer.input_size (); i < n; ++i)
626
+ {
627
+ it = data_layouts.find (layer.input (i));
628
+ if (it != data_layouts.end ())
629
+ {
630
+ if (it->second == DATA_LAYOUT_UNKNOWN)
631
+ return DATA_LAYOUT_UNKNOWN;
632
+ else if (it->second != layout)
633
+ {
634
+ if (layout == DATA_LAYOUT_UNKNOWN)
635
+ layout = it->second ;
636
+ else
637
+ return DATA_LAYOUT_UNKNOWN;
638
+ }
639
+ }
640
+ }
641
+ return layout;
642
+ }
643
+
611
644
void TFImporter::populateNet (Net dstNet)
612
645
{
613
646
RemoveIdentityOps (netBin);
@@ -619,6 +652,8 @@ void TFImporter::populateNet(Net dstNet)
619
652
620
653
int layersSize = net.node_size ();
621
654
655
+ std::map<String, int > data_layouts;
656
+
622
657
// find all Const layers for params
623
658
std::map<String, int > value_id;
624
659
addConstNodes (netBin, value_id, layers_to_ignore);
@@ -636,6 +671,8 @@ void TFImporter::populateNet(Net dstNet)
636
671
if (layers_to_ignore.find (name) != layers_to_ignore.end ())
637
672
continue ;
638
673
674
+ data_layouts[name] = predictOutputDataLayout (layer, data_layouts);
675
+
639
676
if (type == " Conv2D" || type == " SpaceToBatchND" || type == " DepthwiseConv2dNative" )
640
677
{
641
678
// The first node of dilated convolution subgraph.
@@ -731,6 +768,19 @@ void TFImporter::populateNet(Net dstNet)
731
768
732
769
// one input only
733
770
connect (layer_id, dstNet, parsePin (input), id, 0 );
771
+
772
+ if (hasLayerAttr (layer, " data_format" ))
773
+ {
774
+ std::string format = getLayerAttr (layer, " data_format" ).s ();
775
+ if (format == " NHWC" )
776
+ data_layouts[name] = DATA_LAYOUT_NHWC;
777
+ else if (format == " NCHW" )
778
+ data_layouts[name] = DATA_LAYOUT_NCHW;
779
+ else
780
+ CV_Error (Error::StsParseError, " Unknown data_format value: " + format);
781
+ }
782
+ else
783
+ data_layouts[name] = DATA_LAYOUT_NHWC;
734
784
}
735
785
else if (type == " BiasAdd" || type == " Add" )
736
786
{
@@ -806,22 +856,55 @@ void TFImporter::populateNet(Net dstNet)
806
856
// one input only
807
857
int input_blob_index = kernel_blob_index == 0 ? 1 : 0 ;
808
858
connect (layer_id, dstNet, parsePin (layer.input (input_blob_index)), id, 0 );
859
+ data_layouts[name] = DATA_LAYOUT_UNKNOWN;
809
860
}
810
861
else if (type == " Reshape" )
811
862
{
812
- layerParams.set (" dim" , parseDims (getConstBlob (layer, value_id, 1 )));
863
+ Pin inpId = parsePin (layer.input (0 ));
864
+ DictValue newShape = parseDims (getConstBlob (layer, value_id, 1 ));
865
+
866
+ if (newShape.size () != 4 && data_layouts[layer.input (0 )] == DATA_LAYOUT_NHWC)
867
+ {
868
+ LayerParams permLP;
869
+ int order[] = {0 , 2 , 3 , 1 }; // From OpenCV's NCHW to NHWC.
870
+ permLP.set (" order" , DictValue::arrayInt<int *>(order, 4 ));
871
+
872
+ std::string permName = name + " /nchw" ;
873
+ CV_Assert (layer_id.find (permName) == layer_id.end ());
874
+ int permId = dstNet.addLayer (permName, " Permute" , permLP);
875
+ layer_id[permName] = permId;
876
+ connect (layer_id, dstNet, inpId, permId, 0 );
877
+ inpId = Pin (permName);
878
+ }
879
+ layerParams.set (" dim" , newShape);
813
880
814
881
int id = dstNet.addLayer (name, " Reshape" , layerParams);
815
882
layer_id[name] = id;
816
883
817
884
// one input only
818
- connect (layer_id, dstNet, parsePin (layer.input (0 )), id, 0 );
885
+ connect (layer_id, dstNet, inpId, id, 0 );
886
+ data_layouts[name] = DATA_LAYOUT_UNKNOWN;
819
887
}
820
888
else if (type == " Flatten" )
821
889
{
890
+ Pin inpId = parsePin (layer.input (0 ));
891
+ if (data_layouts[layer.input (0 )] == DATA_LAYOUT_NHWC)
892
+ {
893
+ LayerParams permLP;
894
+ int order[] = {0 , 2 , 3 , 1 }; // From OpenCV's NCHW to NHWC.
895
+ permLP.set (" order" , DictValue::arrayInt<int *>(order, 4 ));
896
+
897
+ std::string permName = name + " /nchw" ;
898
+ CV_Assert (layer_id.find (permName) == layer_id.end ());
899
+ int permId = dstNet.addLayer (permName, " Permute" , permLP);
900
+ layer_id[permName] = permId;
901
+ connect (layer_id, dstNet, inpId, permId, 0 );
902
+ inpId = Pin (permName);
903
+ }
822
904
int id = dstNet.addLayer (name, " Flatten" , layerParams);
823
905
layer_id[name] = id;
824
- connect (layer_id, dstNet, parsePin (layer.input (0 )), id, 0 );
906
+ connect (layer_id, dstNet, inpId, id, 0 );
907
+ data_layouts[name] = DATA_LAYOUT_UNKNOWN;
825
908
}
826
909
else if (type == " Transpose" )
827
910
{
@@ -830,16 +913,57 @@ void TFImporter::populateNet(Net dstNet)
830
913
int * permData = (int *)perm.data ;
831
914
if (perm.total () == 4 )
832
915
{
833
- for (int i = 0 ; i < 4 ; ++i)
834
- permData[i] = toNCHW[permData[i]];
916
+ // Only NHWC <-> NCHW permutations are allowed. OpenCV is always
917
+ // keep NCHW layout this way.
918
+ if (data_layouts[layer.input (0 )] == DATA_LAYOUT_NHWC)
919
+ {
920
+ if (permData[0 ] == 0 && permData[1 ] == 3 && permData[2 ] == 1 && permData[3 ] == 2 )
921
+ {
922
+ // in TensorFlow: NHWC->NCHW
923
+ // in OpenCV: NCHW->NCHW
924
+ data_layouts[name] = DATA_LAYOUT_NCHW;
925
+ }
926
+ else if (permData[0 ] == 0 && permData[1 ] == 1 && permData[2 ] == 2 && permData[3 ] == 3 )
927
+ {
928
+ // in TensorFlow: NHWC->NHWC
929
+ // in OpenCV: NCHW->NCHW
930
+ data_layouts[name] = DATA_LAYOUT_NHWC;
931
+ }
932
+ else
933
+ CV_Assert (Error::StsParseError, " Only NHWC <-> NCHW permutations are allowed." );
934
+ }
935
+ else if (data_layouts[layer.input (0 )] == DATA_LAYOUT_NCHW)
936
+ {
937
+ if (permData[0 ] == 0 && permData[1 ] == 2 && permData[2 ] == 3 && permData[3 ] == 1 )
938
+ {
939
+ // in TensorFlow: NCHW->NHWC
940
+ // in OpenCV: NCHW->NCHW
941
+ data_layouts[name] = DATA_LAYOUT_NHWC;
942
+ }
943
+ else if (permData[0 ] == 0 && permData[1 ] == 1 && permData[2 ] == 2 && permData[3 ] == 3 )
944
+ {
945
+ // in TensorFlow: NCHW->NCHW
946
+ // in OpenCV: NCHW->NCHW
947
+ data_layouts[name] = DATA_LAYOUT_NCHW;
948
+ }
949
+ else
950
+ CV_Assert (Error::StsParseError, " Only NHWC <-> NCHW permutations are allowed." );
951
+ }
952
+ int id = dstNet.addLayer (name, " Identity" , layerParams);
953
+ layer_id[name] = id;
954
+ connect (layer_id, dstNet, parsePin (layer.input (0 )), id, 0 );
835
955
}
836
- layerParams.set (" order" , DictValue::arrayInt<int *>(permData, perm.total ()));
956
+ else
957
+ {
958
+ layerParams.set (" order" , DictValue::arrayInt<int *>(permData, perm.total ()));
837
959
838
- int id = dstNet.addLayer (name, " Permute" , layerParams);
839
- layer_id[name] = id;
960
+ int id = dstNet.addLayer (name, " Permute" , layerParams);
961
+ layer_id[name] = id;
840
962
841
- // one input only
842
- connect (layer_id, dstNet, parsePin (layer.input (0 )), id, 0 );
963
+ // one input only
964
+ connect (layer_id, dstNet, parsePin (layer.input (0 )), id, 0 );
965
+ data_layouts[name] = DATA_LAYOUT_UNKNOWN;
966
+ }
843
967
}
844
968
else if (type == " Const" )
845
969
{
@@ -1207,6 +1331,7 @@ void TFImporter::populateNet(Net dstNet)
1207
1331
1208
1332
// one input only
1209
1333
connect (layer_id, dstNet, parsePin (layer.input (1 )), id, 0 );
1334
+ data_layouts[name] = DATA_LAYOUT_UNKNOWN;
1210
1335
}
1211
1336
else if (type == " ResizeNearestNeighbor" )
1212
1337
{
@@ -1258,6 +1383,7 @@ void TFImporter::populateNet(Net dstNet)
1258
1383
layer_id[name] = id;
1259
1384
connect (layer_id, dstNet, parsePin (layer.input (0 )), id, 0 );
1260
1385
connect (layer_id, dstNet, parsePin (layer.input (1 )), id, 1 );
1386
+ data_layouts[name] = DATA_LAYOUT_UNKNOWN;
1261
1387
}
1262
1388
else if (type == " DetectionOutput" )
1263
1389
{
@@ -1288,6 +1414,7 @@ void TFImporter::populateNet(Net dstNet)
1288
1414
layer_id[name] = id;
1289
1415
for (int i = 0 ; i < 3 ; ++i)
1290
1416
connect (layer_id, dstNet, parsePin (layer.input (i)), id, i);
1417
+ data_layouts[name] = DATA_LAYOUT_UNKNOWN;
1291
1418
}
1292
1419
else if (type == " Abs" || type == " Tanh" || type == " Sigmoid" ||
1293
1420
type == " Relu" || type == " Elu" || type == " Softmax" ||
0 commit comments