Skip to content

Commit c7049ca

Browse files
committed
Merge pull request opencv#8293 from alalek:update_rng_in_parallel_for
2 parents 5f99056 + 649bb7a commit c7049ca

File tree

4 files changed

+65
-3
lines changed

4 files changed

+65
-3
lines changed

modules/core/include/opencv2/core.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2834,6 +2834,8 @@ class CV_EXPORTS RNG
28342834
double gaussian(double sigma);
28352835

28362836
uint64 state;
2837+
2838+
bool operator ==(const RNG& other) const;
28372839
};
28382840

28392841
/** @brief Mersenne Twister random number generator

modules/core/include/opencv2/core/operations.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ inline int RNG::uniform(int a, int b) { return a == b ? a : (int)(next(
349349
inline float RNG::uniform(float a, float b) { return ((float)*this)*(b - a) + a; }
350350
inline double RNG::uniform(double a, double b) { return ((double)*this)*(b - a) + a; }
351351

352+
inline bool RNG::operator ==(const RNG& other) const { return state == other.state; }
353+
352354
inline unsigned RNG::next()
353355
{
354356
state = (uint64)(unsigned)state* /*CV_RNG_COEFF*/ 4164903690U + (unsigned)(state >> 32);

modules/core/src/parallel.cpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,25 +166,39 @@ namespace
166166
class ParallelLoopBodyWrapper : public cv::ParallelLoopBody
167167
{
168168
public:
169-
ParallelLoopBodyWrapper(const cv::ParallelLoopBody& _body, const cv::Range& _r, double _nstripes)
169+
ParallelLoopBodyWrapper(const cv::ParallelLoopBody& _body, const cv::Range& _r, double _nstripes) :
170+
is_rng_used(false)
170171
{
171172

172173
body = &_body;
173174
wholeRange = _r;
174175
double len = wholeRange.end - wholeRange.start;
175176
nstripes = cvRound(_nstripes <= 0 ? len : MIN(MAX(_nstripes, 1.), len));
176177

178+
// propagate main thread state
179+
rng = cv::theRNG();
180+
177181
#ifdef ENABLE_INSTRUMENTATION
178182
pThreadRoot = cv::instr::getInstrumentTLSStruct().pCurrentNode;
179183
#endif
180184
}
181-
#ifdef ENABLE_INSTRUMENTATION
182185
~ParallelLoopBodyWrapper()
183186
{
187+
#ifdef ENABLE_INSTRUMENTATION
184188
for(size_t i = 0; i < pThreadRoot->m_childs.size(); i++)
185189
SyncNodes(pThreadRoot->m_childs[i]);
186-
}
187190
#endif
191+
if (is_rng_used)
192+
{
193+
// Some parallel backends execute nested jobs in the main thread,
194+
// so we need to restore initial RNG state here.
195+
cv::theRNG() = rng;
196+
// We can't properly update RNG state based on RNG usage in worker threads,
197+
// so lets just change main thread RNG state to the next value.
198+
// Note: this behaviour is not equal to single-threaded mode.
199+
cv::theRNG().next();
200+
}
201+
}
188202
void operator()(const cv::Range& sr) const
189203
{
190204
#ifdef ENABLE_INSTRUMENTATION
@@ -195,19 +209,27 @@ namespace
195209
#endif
196210
CV_INSTRUMENT_REGION()
197211

212+
// propagate main thread state
213+
cv::theRNG() = rng;
214+
198215
cv::Range r;
199216
r.start = (int)(wholeRange.start +
200217
((uint64)sr.start*(wholeRange.end - wholeRange.start) + nstripes/2)/nstripes);
201218
r.end = sr.end >= nstripes ? wholeRange.end : (int)(wholeRange.start +
202219
((uint64)sr.end*(wholeRange.end - wholeRange.start) + nstripes/2)/nstripes);
203220
(*body)(r);
221+
222+
if (!is_rng_used && !(cv::theRNG() == rng))
223+
is_rng_used = true;
204224
}
205225
cv::Range stripeRange() const { return cv::Range(0, nstripes); }
206226

207227
protected:
208228
const cv::ParallelLoopBody* body;
209229
cv::Range wholeRange;
210230
int nstripes;
231+
cv::RNG rng;
232+
mutable bool is_rng_used;
211233
#ifdef ENABLE_INSTRUMENTATION
212234
cv::instr::InstrNode *pThreadRoot;
213235
#endif

modules/core/test/test_rand.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,3 +382,39 @@ TEST(Core_Rand, Regression_Stack_Corruption)
382382
ASSERT_EQ(param1, -9);
383383
ASSERT_EQ(param2, 2);
384384
}
385+
386+
namespace {
387+
388+
class RandRowFillParallelLoopBody : public cv::ParallelLoopBody
389+
{
390+
public:
391+
RandRowFillParallelLoopBody(Mat& dst) : dst_(dst) {}
392+
~RandRowFillParallelLoopBody() {}
393+
void operator()(const cv::Range& r) const
394+
{
395+
cv::RNG rng = cv::theRNG(); // copy state
396+
for (int y = r.start; y < r.end; y++)
397+
{
398+
cv::theRNG() = cv::RNG(rng.state + y); // seed is based on processed row
399+
cv::randu(dst_.row(y), Scalar(-100), Scalar(100));
400+
}
401+
// theRNG() state is changed here (but state collision has low probability, so we don't check this)
402+
}
403+
protected:
404+
Mat& dst_;
405+
};
406+
407+
TEST(Core_Rand, parallel_for_stable_results)
408+
{
409+
cv::RNG rng = cv::theRNG(); // save rng state
410+
Mat dst1(1000, 100, CV_8SC1);
411+
parallel_for_(cv::Range(0, dst1.rows), RandRowFillParallelLoopBody(dst1));
412+
413+
cv::theRNG() = rng; // restore rng state
414+
Mat dst2(1000, 100, CV_8SC1);
415+
parallel_for_(cv::Range(0, dst2.rows), RandRowFillParallelLoopBody(dst2));
416+
417+
ASSERT_EQ(0, countNonZero(dst1 != dst2));
418+
}
419+
420+
} // namespace

0 commit comments

Comments
 (0)