Skip to content

Commit ec49eb8

Browse files
committed
Merge pull request opencv#8314 from chacha21:fix_cuda_absdiff
2 parents f6b6fbf + bfd8003 commit ec49eb8

File tree

1 file changed

+44
-20
lines changed

1 file changed

+44
-20
lines changed

modules/cudaarithm/src/cuda/absdiff_scalar.cu

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ void absDiffScalar(const GpuMat& src, cv::Scalar val, bool, GpuMat& dst, const G
5656

5757
namespace
5858
{
59-
template <typename T, typename S> struct AbsDiffScalarOp : unary_function<T, T>
59+
template <typename SrcType, typename ScalarType, typename DstType> struct AbsDiffScalarOp : unary_function<SrcType, DstType>
6060
{
61-
S val;
61+
ScalarType val;
6262

63-
__device__ __forceinline__ T operator ()(T a) const
63+
__device__ __forceinline__ DstType operator ()(SrcType a) const
6464
{
65-
abs_func<S> f;
66-
return saturate_cast<T>(f(a - val));
65+
abs_func<ScalarType> f;
66+
return saturate_cast<DstType>(f(saturate_cast<ScalarType>(a) - val));
6767
}
6868
};
6969

@@ -78,33 +78,57 @@ namespace
7878
};
7979

8080
template <typename SrcType, typename ScalarDepth>
81-
void absDiffScalarImpl(const GpuMat& src, double value, GpuMat& dst, Stream& stream)
81+
void absDiffScalarImpl(const GpuMat& src, cv::Scalar value, GpuMat& dst, Stream& stream)
8282
{
83-
AbsDiffScalarOp<SrcType, ScalarDepth> op;
84-
op.val = static_cast<ScalarDepth>(value);
83+
typedef typename MakeVec<ScalarDepth, VecTraits<SrcType>::cn>::type ScalarType;
84+
85+
cv::Scalar_<ScalarDepth> value_ = value;
86+
87+
AbsDiffScalarOp<SrcType, ScalarType, SrcType> op;
88+
op.val = VecTraits<ScalarType>::make(value_.val);
8589
gridTransformUnary_< TransformPolicy<ScalarDepth> >(globPtr<SrcType>(src), globPtr<SrcType>(dst), op, stream);
8690
}
8791
}
8892

8993
void absDiffScalar(const GpuMat& src, cv::Scalar val, bool, GpuMat& dst, const GpuMat&, double, Stream& stream, int)
9094
{
91-
typedef void (*func_t)(const GpuMat& src, double val, GpuMat& dst, Stream& stream);
92-
static const func_t funcs[] =
95+
typedef void (*func_t)(const GpuMat& src, cv::Scalar val, GpuMat& dst, Stream& stream);
96+
static const func_t funcs[7][4] =
9397
{
94-
absDiffScalarImpl<uchar, float>,
95-
absDiffScalarImpl<schar, float>,
96-
absDiffScalarImpl<ushort, float>,
97-
absDiffScalarImpl<short, float>,
98-
absDiffScalarImpl<int, float>,
99-
absDiffScalarImpl<float, float>,
100-
absDiffScalarImpl<double, double>
98+
{
99+
absDiffScalarImpl<uchar, float>, absDiffScalarImpl<uchar2, float>, absDiffScalarImpl<uchar3, float>, absDiffScalarImpl<uchar4, float>
100+
},
101+
{
102+
absDiffScalarImpl<schar, float>, absDiffScalarImpl<char2, float>, absDiffScalarImpl<char3, float>, absDiffScalarImpl<char4, float>
103+
},
104+
{
105+
absDiffScalarImpl<ushort, float>, absDiffScalarImpl<ushort2, float>, absDiffScalarImpl<ushort3, float>, absDiffScalarImpl<ushort4, float>
106+
},
107+
{
108+
absDiffScalarImpl<short, float>, absDiffScalarImpl<short2, float>, absDiffScalarImpl<short3, float>, absDiffScalarImpl<short4, float>
109+
},
110+
{
111+
absDiffScalarImpl<int, float>, absDiffScalarImpl<int2, float>, absDiffScalarImpl<int3, float>, absDiffScalarImpl<int4, float>
112+
},
113+
{
114+
absDiffScalarImpl<float, float>, absDiffScalarImpl<float2, float>, absDiffScalarImpl<float3, float>, absDiffScalarImpl<float4, float>
115+
},
116+
{
117+
absDiffScalarImpl<double, double>, absDiffScalarImpl<double2, double>, absDiffScalarImpl<double3, double>, absDiffScalarImpl<double4, double>
118+
}
101119
};
102120

103-
const int depth = src.depth();
121+
const int sdepth = src.depth();
122+
const int ddepth = dst.depth();
123+
const int cn = src.channels();
124+
125+
CV_DbgAssert( sdepth <= CV_64F && ddepth <= CV_64F && cn <= 4 && src.type() == dst.type());
104126

105-
CV_DbgAssert( depth <= CV_64F );
127+
const func_t func = funcs[sdepth][cn - 1];
128+
if (!func)
129+
CV_Error(cv::Error::StsUnsupportedFormat, "Unsupported combination of source and destination types");
106130

107-
funcs[depth](src, val[0], dst, stream);
131+
func(src, val, dst, stream);
108132
}
109133

110134
#endif

0 commit comments

Comments
 (0)