Skip to content

Commit 5e93c82

Browse files
committed
Merge pull request opencv#9491 from dkurt:tf_lstm
2 parents a6d634b + 84cec17 commit 5e93c82

File tree

7 files changed

+320
-68
lines changed

7 files changed

+320
-68
lines changed

modules/dnn/include/opencv2/dnn/all_layers.hpp

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
8484
/** Creates instance of LSTM layer */
8585
static Ptr<LSTMLayer> create(const LayerParams& params);
8686

87-
/** Set trained weights for LSTM layer.
87+
/** @deprecated Use LayerParams::blobs instead.
88+
@brief Set trained weights for LSTM layer.
89+
8890
LSTM behavior on each step is defined by current input, previous output, previous cell state and learned weights.
8991
9092
Let @f$x_t@f$ be current input, @f$h_t@f$ be current output, @f$c_t@f$ be current state.
@@ -114,28 +116,30 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
114116
@param Wx is matrix defining how current input is transformed to internal gates (i.e. according to abovemtioned notation is @f$ W_x @f$)
115117
@param b is bias vector (i.e. according to abovemtioned notation is @f$ b @f$)
116118
*/
117-
virtual void setWeights(const Mat &Wh, const Mat &Wx, const Mat &b) = 0;
119+
CV_DEPRECATED virtual void setWeights(const Mat &Wh, const Mat &Wx, const Mat &b) = 0;
118120

119121
/** @brief Specifies shape of output blob which will be [[`T`], `N`] + @p outTailShape.
120122
* @details If this parameter is empty or unset then @p outTailShape = [`Wh`.size(0)] will be used,
121123
* where `Wh` is parameter from setWeights().
122124
*/
123125
virtual void setOutShape(const MatShape &outTailShape = MatShape()) = 0;
124126

125-
/** @brief Specifies either interpet first dimension of input blob as timestamp dimenion either as sample.
127+
/** @deprecated Use flag `produce_cell_output` in LayerParams.
128+
* @brief Specifies either interpet first dimension of input blob as timestamp dimenion either as sample.
126129
*
127130
* If flag is set to true then shape of input blob will be interpeted as [`T`, `N`, `[data dims]`] where `T` specifies number of timpestamps, `N` is number of independent streams.
128131
* In this case each forward() call will iterate through `T` timestamps and update layer's state `T` times.
129132
*
130133
* If flag is set to false then shape of input blob will be interpeted as [`N`, `[data dims]`].
131134
* In this case each forward() call will make one iteration and produce one timestamp with shape [`N`, `[out dims]`].
132135
*/
133-
virtual void setUseTimstampsDim(bool use = true) = 0;
136+
CV_DEPRECATED virtual void setUseTimstampsDim(bool use = true) = 0;
134137

135-
/** @brief If this flag is set to true then layer will produce @f$ c_t @f$ as second output.
138+
/** @deprecated Use flag `use_timestamp_dim` in LayerParams.
139+
* @brief If this flag is set to true then layer will produce @f$ c_t @f$ as second output.
136140
* @details Shape of the second output is the same as first output.
137141
*/
138-
virtual void setProduceCellOutput(bool produce = false) = 0;
142+
CV_DEPRECATED virtual void setProduceCellOutput(bool produce = false) = 0;
139143

140144
/* In common case it use single input with @f$x_t@f$ values to compute output(s) @f$h_t@f$ (and @f$c_t@f$).
141145
* @param input should contain packed values @f$x_t@f$
@@ -323,11 +327,41 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
323327
static Ptr<SplitLayer> create(const LayerParams &params);
324328
};
325329

330+
/**
331+
* Slice layer has several modes:
332+
* 1. Caffe mode
333+
* @param[in] axis Axis of split operation
334+
* @param[in] slice_point Array of split points
335+
*
336+
* Number of output blobs equals to number of split points plus one. The
337+
* first blob is a slice on input from 0 to @p slice_point[0] - 1 by @p axis,
338+
* the second output blob is a slice of input from @p slice_point[0] to
339+
* @p slice_point[1] - 1 by @p axis and the last output blob is a slice of
340+
* input from @p slice_point[-1] up to the end of @p axis size.
341+
*
342+
* 2. TensorFlow mode
343+
* @param begin Vector of start indices
344+
* @param size Vector of sizes
345+
*
346+
* More convinient numpy-like slice. One and only output blob
347+
* is a slice `input[begin[0]:begin[0]+size[0], begin[1]:begin[1]+size[1], ...]`
348+
*
349+
* 3. Torch mode
350+
* @param axis Axis of split operation
351+
*
352+
* Split input blob on the equal parts by @p axis.
353+
*/
326354
class CV_EXPORTS SliceLayer : public Layer
327355
{
328356
public:
357+
/**
358+
* @brief Vector of slice ranges.
359+
*
360+
* The first dimension equals number of output blobs.
361+
* Inner vector has slice ranges for the first number of input dimensions.
362+
*/
363+
std::vector<std::vector<Range> > sliceRanges;
329364
int axis;
330-
std::vector<int> sliceIndices;
331365

332366
static Ptr<SliceLayer> create(const LayerParams &params);
333367
};

modules/dnn/src/init.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ void initializeLayerFactory()
117117
CV_DNN_REGISTER_LAYER_CLASS(Shift, ShiftLayer);
118118
CV_DNN_REGISTER_LAYER_CLASS(Padding, PaddingLayer);
119119
CV_DNN_REGISTER_LAYER_CLASS(Scale, ScaleLayer);
120+
121+
CV_DNN_REGISTER_LAYER_CLASS(LSTM, LSTMLayer);
120122
}
121123

122124
CV__DNN_EXPERIMENTAL_NS_END

modules/dnn/src/layers/recurrent_layers.cpp

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,49 @@ class LSTMLayerImpl : public LSTMLayer
9090

9191
bool useTimestampDim;
9292
bool produceCellOutput;
93+
float forgetBias, cellClip;
94+
bool useCellClip, usePeephole;
9395

9496
public:
9597

9698
LSTMLayerImpl(const LayerParams& params)
9799
: numTimeStamps(0), numSamples(0)
98100
{
99101
setParamsFrom(params);
100-
type = "LSTM";
101-
useTimestampDim = true;
102-
produceCellOutput = false;
102+
103+
if (!blobs.empty())
104+
{
105+
CV_Assert(blobs.size() >= 3);
106+
107+
blobs[2] = blobs[2].reshape(1, 1);
108+
109+
const Mat& Wh = blobs[0];
110+
const Mat& Wx = blobs[1];
111+
const Mat& bias = blobs[2];
112+
CV_Assert(Wh.dims == 2 && Wx.dims == 2);
113+
CV_Assert(Wh.rows == Wx.rows);
114+
CV_Assert(Wh.rows == 4*Wh.cols);
115+
CV_Assert(Wh.rows == (int)bias.total());
116+
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
117+
118+
// Peephole weights.
119+
if (blobs.size() > 3)
120+
{
121+
CV_Assert(blobs.size() == 6);
122+
for (int i = 3; i < 6; ++i)
123+
{
124+
CV_Assert(blobs[i].rows == Wh.cols && blobs[i].cols == Wh.cols);
125+
CV_Assert(blobs[i].type() == bias.type());
126+
}
127+
}
128+
}
129+
useTimestampDim = params.get<bool>("use_timestamp_dim", true);
130+
produceCellOutput = params.get<bool>("produce_cell_output", false);
131+
forgetBias = params.get<float>("forget_bias", 0.0f);
132+
cellClip = params.get<float>("cell_clip", 0.0f);
133+
useCellClip = params.get<bool>("use_cell_clip", false);
134+
usePeephole = params.get<bool>("use_peephole", false);
135+
103136
allocated = false;
104137
outTailShape.clear();
105138
}
@@ -141,7 +174,7 @@ class LSTMLayerImpl : public LSTMLayer
141174
std::vector<MatShape> &outputs,
142175
std::vector<MatShape> &internals) const
143176
{
144-
CV_Assert(blobs.size() == 3);
177+
CV_Assert(!usePeephole && blobs.size() == 3 || usePeephole && blobs.size() == 6);
145178
CV_Assert(inputs.size() == 1);
146179
const MatShape& inp0 = inputs[0];
147180

@@ -186,7 +219,7 @@ class LSTMLayerImpl : public LSTMLayer
186219

187220
void finalize(const std::vector<Mat*> &input, std::vector<Mat> &output)
188221
{
189-
CV_Assert(blobs.size() == 3);
222+
CV_Assert(!usePeephole && blobs.size() == 3 || usePeephole && blobs.size() == 6);
190223
CV_Assert(input.size() == 1);
191224
const Mat& inp0 = *input[0];
192225

@@ -251,20 +284,45 @@ class LSTMLayerImpl : public LSTMLayer
251284
gemm(hInternal, Wh, 1, gates, 1, gates, GEMM_2_T); //+Wh * h_{t-1}
252285
gemm(dummyOnes, bias, 1, gates, 1, gates); //+b
253286

254-
Mat getesIFO = gates.colRange(0, 3*numOut);
255287
Mat gateI = gates.colRange(0*numOut, 1*numOut);
256288
Mat gateF = gates.colRange(1*numOut, 2*numOut);
257289
Mat gateO = gates.colRange(2*numOut, 3*numOut);
258290
Mat gateG = gates.colRange(3*numOut, 4*numOut);
259291

260-
sigmoid(getesIFO, getesIFO);
292+
if (forgetBias)
293+
add(gateF, forgetBias, gateF);
294+
295+
if (usePeephole)
296+
{
297+
Mat gatesIF = gates.colRange(0, 2*numOut);
298+
gemm(cInternal, blobs[3], 1, gateI, 1, gateI);
299+
gemm(cInternal, blobs[4], 1, gateF, 1, gateF);
300+
sigmoid(gatesIF, gatesIF);
301+
}
302+
else
303+
{
304+
Mat gatesIFO = gates.colRange(0, 3*numOut);
305+
sigmoid(gatesIFO, gatesIFO);
306+
}
307+
261308
tanh(gateG, gateG);
262309

263310
//compute c_t
264311
multiply(gateF, cInternal, gateF); // f_t (*) c_{t-1}
265312
multiply(gateI, gateG, gateI); // i_t (*) g_t
266313
add(gateF, gateI, cInternal); // c_t = f_t (*) c_{t-1} + i_t (*) g_t
267314

315+
if (useCellClip)
316+
{
317+
min(cInternal, cellClip, cInternal);
318+
max(cInternal, -cellClip, cInternal);
319+
}
320+
if (usePeephole)
321+
{
322+
gemm(cInternal, blobs[5], 1, gateO, 1, gateO);
323+
sigmoid(gateO, gateO);
324+
}
325+
268326
//compute h_t
269327
tanh(cInternal, hInternal);
270328
multiply(gateO, hInternal, hInternal);

modules/dnn/src/layers/slice_layer.cpp

Lines changed: 83 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,40 @@ class SliceLayerImpl : public SliceLayer
5656
{
5757
setParamsFrom(params);
5858
axis = params.get<int>("axis", 1);
59-
6059
if (params.has("slice_point"))
6160
{
61+
CV_Assert(!params.has("begin") && !params.has("size"));
6262
const DictValue &indicesValue = params.get("slice_point");
63-
int i, n = indicesValue.size();
64-
sliceIndices.resize(n);
65-
for (i = 0; i < n; i++)
66-
sliceIndices[i] = indicesValue.get<int>(i);
63+
sliceRanges.resize(indicesValue.size() + 1,
64+
std::vector<Range>(axis + 1, Range::all()));
65+
int prevSlice = 0;
66+
for (int i = 0; i < indicesValue.size(); ++i)
67+
{
68+
sliceRanges[i][axis].start = prevSlice;
69+
sliceRanges[i][axis].end = indicesValue.get<int>(i);
70+
prevSlice = sliceRanges[i][axis].end;
71+
}
72+
sliceRanges.back()[axis].start = prevSlice;
73+
}
74+
else if (params.has("begin") && params.has("size"))
75+
{
76+
const DictValue &begins = params.get("begin");
77+
const DictValue &sizes = params.get("size");
78+
CV_Assert(begins.size() == sizes.size());
79+
80+
sliceRanges.resize(1);
81+
sliceRanges[0].resize(begins.size(), Range::all());
82+
for (int i = 0; i < begins.size(); ++i)
83+
{
84+
int start = begins.get<int>(i);
85+
int size = sizes.get<int>(i);
86+
CV_Assert(start >= 0);
87+
CV_Assert(size == -1 || size > 0); // -1 value means range [start, axis_size).
88+
89+
sliceRanges[0][i].start = start;
90+
if (size > 0)
91+
sliceRanges[0][i].end = start + size;
92+
}
6793
}
6894
}
6995

@@ -73,47 +99,68 @@ class SliceLayerImpl : public SliceLayer
7399
std::vector<MatShape> &internals) const
74100
{
75101
CV_Assert(inputs.size() == 1);
76-
77-
outputs.clear();
78-
79102
MatShape inpShape = inputs[0];
80-
int cAxis = clamp(axis, inpShape.size());
81-
int axisSize = inpShape[cAxis];
82103

83-
if (sliceIndices.size()) //divide blob with respect to passed parameters
104+
if (!sliceRanges.empty())
84105
{
85-
std::vector<int> outAxisSize;
86-
int prevSlice = 0;
87-
88-
for (size_t i = 0; i < sliceIndices.size(); i++)
89-
{
90-
if (!(prevSlice < sliceIndices[i] && sliceIndices[i] < axisSize))
91-
CV_Error(Error::StsBadArg, "Slice indices should be positive, increased and don't exceed size of sliced dimension");
92-
93-
outAxisSize.push_back(sliceIndices[i] - prevSlice);
94-
prevSlice = sliceIndices[i];
95-
}
96-
outAxisSize.push_back(axisSize - prevSlice);
97-
98-
for (size_t i = 0; i < outAxisSize.size(); i++)
106+
outputs.resize(sliceRanges.size(), inpShape);
107+
for (int i = 0; i < outputs.size(); ++i)
99108
{
100-
inpShape[cAxis] = outAxisSize[i];
101-
outputs.push_back(inpShape);
109+
CV_Assert(sliceRanges[i].size() <= inpShape.size());
110+
for (int j = 0; j < sliceRanges[i].size(); ++j)
111+
{
112+
outputs[i][j] = std::min(sliceRanges[i][j].end, inpShape[j]) -
113+
std::max(sliceRanges[i][j].start, 0);
114+
}
102115
}
103116
}
104-
else //divide blob with respect to count of output blobs
117+
else // Divide input blob on equal parts by axis.
105118
{
106-
CV_Assert(requiredOutputs > 0 && axisSize % requiredOutputs == 0);
107-
int outAxisSize = axisSize / (int)requiredOutputs;
119+
CV_Assert(0 < axis && axis < inpShape.size());
120+
CV_Assert(requiredOutputs > 0 && inpShape[axis] % requiredOutputs == 0);
121+
inpShape[axis] /= requiredOutputs;
122+
outputs.resize(requiredOutputs, inpShape);
123+
}
124+
return false;
125+
}
108126

109-
for (size_t i = 0; i < requiredOutputs; i++)
127+
void finalize(const std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
128+
{
129+
CV_Assert(inputs.size() == 1);
130+
const MatSize& inpShape = inputs[0]->size;
131+
132+
if (sliceRanges.empty())
133+
{
134+
// Divide input blob on equal parts by axis.
135+
int outAxisSize = inpShape[axis] / outputs.size();
136+
sliceRanges.resize(outputs.size(),
137+
std::vector<Range>(axis + 1, Range::all()));
138+
int prevSlice = 0;
139+
for (int i = 0; i < outputs.size(); ++i)
110140
{
111-
inpShape[cAxis] = outAxisSize;
112-
outputs.push_back(inpShape);
141+
sliceRanges[i][axis].start = prevSlice;
142+
sliceRanges[i][axis].end = sliceRanges[i][axis].start + outAxisSize;
143+
prevSlice = sliceRanges[i][axis].end;
113144
}
114145
}
146+
else
147+
CV_Assert(outputs.size() == sliceRanges.size());
115148

116-
return false;
149+
for (int i = 0; i < outputs.size(); ++i)
150+
{
151+
CV_Assert(sliceRanges[i].size() <= inpShape[-1]);
152+
// Clamp.
153+
for (int j = 0; j < sliceRanges[i].size(); ++j)
154+
{
155+
sliceRanges[i][j].start = std::max(0, sliceRanges[i][j].start);
156+
sliceRanges[i][j].end = std::min(sliceRanges[i][j].end, inpShape[j]);
157+
}
158+
// Fill the rest of ranges.
159+
for (int j = sliceRanges[i].size(); j < inpShape[-1]; ++j)
160+
{
161+
sliceRanges[i].push_back(Range::all());
162+
}
163+
}
117164
}
118165

119166
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
@@ -122,15 +169,10 @@ class SliceLayerImpl : public SliceLayer
122169
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
123170

124171
const Mat& inpMat = *inputs[0];
125-
std::vector<Range> ranges(inpMat.dims, Range::all());
126-
int cAxis = clamp(axis, inpMat.dims);
127-
128-
ranges[cAxis].start = 0;
172+
CV_Assert(outputs.size() == sliceRanges.size());
129173
for (size_t i = 0; i < outputs.size(); i++)
130174
{
131-
ranges[cAxis].end = ranges[cAxis].start + outputs[i].size[cAxis];
132-
inpMat(&ranges[0]).copyTo(outputs[i]);
133-
ranges[cAxis].start = ranges[cAxis].end;
175+
inpMat(sliceRanges[i]).copyTo(outputs[i]);
134176
}
135177
}
136178
};

0 commit comments

Comments
 (0)