Skip to content

Commit 6bf8fe8

Browse files
committed
Merge pull request opencv#9384 from dkurt:torch_split
2 parents a391871 + 0ce7c33 commit 6bf8fe8

File tree

2 files changed

+9
-19
lines changed

2 files changed

+9
-19
lines changed

modules/dnn/src/layers/split_layer.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class SplitLayerImpl : public SplitLayer
7575

7676
Layer::getMemoryShapes(inputs, max(1, outputsCount >= 0 ? outputsCount : requiredOutputs),
7777
outputs, internals);
78-
return true;
78+
return false;
7979
}
8080

8181
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
@@ -86,8 +86,7 @@ class SplitLayerImpl : public SplitLayer
8686
for (size_t i = 0; i < outputs.size(); i++)
8787
{
8888
CV_Assert(inputs[0]->total() == outputs[i].total());
89-
if (outputs[i].data != inputs[0]->data)
90-
inputs[0]->copyTo(outputs[i]);
89+
inputs[0]->copyTo(outputs[i]);
9190
}
9291
}
9392
};

modules/dnn/src/torch/torch_importer.cpp

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -934,20 +934,18 @@ struct TorchImporter : public ::cv::dnn::Importer
934934
}
935935
else if (module->thName == "Concat")
936936
{
937-
int newId, splitId, mergeId;
938-
LayerParams mergeParams, splitParams;
937+
int newId, mergeId;
938+
LayerParams mergeParams;
939939
mergeParams.set("axis", module->params.get<int>("dimension") - 1);
940940

941-
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams);
942-
net.connect(prevLayerId, prevOutNum, splitId, 0);
943-
944941
std::vector<int> branchIds;
945942
for (int i = 0; i < (int)module->modules.size(); i++)
946943
{
947-
newId = fill(module->modules[i], addedModules, splitId, i);
944+
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
948945
branchIds.push_back(newId);
949946
}
950947

948+
moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
951949
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams);
952950

953951
for (int i = 0; i < branchIds.size(); i++)
@@ -1015,19 +1013,12 @@ struct TorchImporter : public ::cv::dnn::Importer
10151013
return mergeId;
10161014
}
10171015
else if (module->thName == "ConcatTable") {
1018-
int newId = -1, splitId;
1019-
LayerParams splitParams;
1020-
1021-
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams);
1022-
net.connect(prevLayerId, prevOutNum, splitId, 0);
1023-
1024-
addedModules.push_back(std::make_pair(splitId, module));
1025-
1016+
int newId = -1;
1017+
moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
10261018
for (int i = 0; i < (int)module->modules.size(); i++)
10271019
{
1028-
newId = fill(module->modules[i], addedModules, splitId, i);
1020+
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
10291021
}
1030-
10311022
return newId;
10321023
}
10331024
else if (module->thName == "JoinTable") {

0 commit comments

Comments
 (0)