@@ -171,16 +171,16 @@ cudnnStatus_t cudnnGetConvolutionNdForwardOutputDim(
171
171
convDesc, inputTensorDesc, filterDesc, nbDims, tensorOuputDimA);
172
172
}
173
173
174
- cudnnStatus_t cudnnGetConvolutionForwardAlgorithm (
175
- cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc,
176
- const cudnnFilterDescriptor_t wDesc ,
177
- const cudnnConvolutionDescriptor_t convDesc,
178
- const cudnnTensorDescriptor_t yDesc,
179
- cudnnConvolutionFwdPreference_t preference, size_t memoryLimitInBytes,
180
- cudnnConvolutionFwdAlgo_t *algo) {
181
- return getCudnnPlugin (). cudnnGetConvolutionForwardAlgorithm (
182
- handle, xDesc, wDesc, convDesc, yDesc, preference, memoryLimitInBytes,
183
- algo );
174
+ cudnnStatus_t cudnnGetConvolutionForwardAlgorithmMaxCount (cudnnHandle_t handle,
175
+ int *count) {
176
+ return getCudnnPlugin (). cudnnGetConvolutionForwardAlgorithmMaxCount (handle ,
177
+ count);
178
+ }
179
+
180
+ cudnnStatus_t cudnnGetConvolutionBackwardFilterAlgorithmMaxCount (
181
+ cudnnHandle_t handle, int *count) {
182
+ return getCudnnPlugin (). cudnnGetConvolutionBackwardFilterAlgorithmMaxCount (
183
+ handle, count );
184
184
}
185
185
186
186
cudnnStatus_t cudnnGetConvolutionForwardWorkspaceSize (
@@ -193,16 +193,57 @@ cudnnStatus_t cudnnGetConvolutionForwardWorkspaceSize(
193
193
handle, xDesc, wDesc, convDesc, yDesc, algo, sizeInBytes);
194
194
}
195
195
196
- cudnnStatus_t cudnnConvolutionForward (
197
- cudnnHandle_t handle, const void *alpha,
198
- const cudnnTensorDescriptor_t xDesc, const void *x,
199
- const cudnnFilterDescriptor_t wDesc, const void *w,
200
- const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo,
201
- void *workSpace, size_t workSpaceSizeInBytes, const void *beta,
202
- const cudnnTensorDescriptor_t yDesc, void *y) {
203
- return getCudnnPlugin ().cudnnConvolutionForward (
204
- handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace,
205
- workSpaceSizeInBytes, beta, yDesc, y);
196
+ cudnnStatus_t cudnnGetConvolutionBackwardFilterWorkspaceSize (
197
+ cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc,
198
+ const cudnnTensorDescriptor_t dyDesc,
199
+ const cudnnConvolutionDescriptor_t convDesc,
200
+ const cudnnFilterDescriptor_t gradDesc,
201
+ cudnnConvolutionBwdFilterAlgo_t algo, size_t *sizeInBytes) {
202
+ return getCudnnPlugin ().cudnnGetConvolutionBackwardFilterWorkspaceSize (
203
+ handle, xDesc, dyDesc, convDesc, gradDesc, algo, sizeInBytes);
204
+ }
205
+
206
+ cudnnStatus_t cudnnFindConvolutionForwardAlgorithm (
207
+ cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc,
208
+ const cudnnFilterDescriptor_t wDesc,
209
+ const cudnnConvolutionDescriptor_t convDesc,
210
+ const cudnnTensorDescriptor_t yDesc, const int requestedAlgoCount,
211
+ int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults) {
212
+ return getCudnnPlugin ().cudnnFindConvolutionForwardAlgorithm (
213
+ handle, xDesc, wDesc, convDesc, yDesc, requestedAlgoCount,
214
+ returnedAlgoCount, perfResults);
215
+ }
216
+
217
+ cudnnStatus_t cudnnFindConvolutionBackwardFilterAlgorithm (
218
+ cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc,
219
+ const cudnnTensorDescriptor_t dyDesc,
220
+ const cudnnConvolutionDescriptor_t convDesc,
221
+ const cudnnFilterDescriptor_t dwDesc, const int requestedAlgoCount,
222
+ int *returnedAlgoCount, cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) {
223
+ return getCudnnPlugin ().cudnnFindConvolutionBackwardFilterAlgorithm (
224
+ handle, xDesc, dyDesc, convDesc, dwDesc, requestedAlgoCount,
225
+ returnedAlgoCount, perfResults);
226
+ }
227
+
228
+ cudnnStatus_t cudnnGetConvolutionForwardAlgorithm (
229
+ cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc,
230
+ const cudnnFilterDescriptor_t wDesc,
231
+ const cudnnConvolutionDescriptor_t convDesc,
232
+ const cudnnTensorDescriptor_t yDesc,
233
+ cudnnConvolutionFwdPreference_t preference, size_t memoryLimitInBytes,
234
+ cudnnConvolutionFwdAlgo_t *algo) {
235
+ auto version = getCudnnPlugin ().getVersion ();
236
+ if (std::get<0 >(version) < 8 ) {
237
+ return getCudnnPlugin ().cudnnGetConvolutionForwardAlgorithm (
238
+ handle, xDesc, wDesc, convDesc, yDesc, preference,
239
+ memoryLimitInBytes, algo);
240
+ } else {
241
+ AF_ERROR (
242
+ " cudnnGetConvolutionForwardAlgorithm has been removed since cuDNN "
243
+ " 8" ,
244
+ AF_ERR_NOT_SUPPORTED);
245
+ return CUDNN_STATUS_SUCCESS;
246
+ }
206
247
}
207
248
208
249
cudnnStatus_t cudnnGetConvolutionBackwardFilterAlgorithm (
@@ -212,19 +253,30 @@ cudnnStatus_t cudnnGetConvolutionBackwardFilterAlgorithm(
212
253
const cudnnFilterDescriptor_t dwDesc,
213
254
cudnnConvolutionBwdFilterPreference_t preference, size_t memoryLimitInBytes,
214
255
cudnnConvolutionBwdFilterAlgo_t *algo) {
215
- return getCudnnPlugin ().cudnnGetConvolutionBackwardFilterAlgorithm (
216
- handle, xDesc, dyDesc, convDesc, dwDesc, preference, memoryLimitInBytes,
217
- algo);
256
+ auto version = getCudnnPlugin ().getVersion ();
257
+ if (std::get<0 >(version) < 8 ) {
258
+ return getCudnnPlugin ().cudnnGetConvolutionBackwardFilterAlgorithm (
259
+ handle, xDesc, dyDesc, convDesc, dwDesc, preference,
260
+ memoryLimitInBytes, algo);
261
+ } else {
262
+ AF_ERROR (
263
+ " cudnnGetConvolutionBackwardFilterAlgorithm has been removed since "
264
+ " cuDNN 8" ,
265
+ AF_ERR_NOT_SUPPORTED);
266
+ return CUDNN_STATUS_SUCCESS;
267
+ }
218
268
}
219
269
220
- cudnnStatus_t cudnnGetConvolutionBackwardFilterWorkspaceSize (
221
- cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc,
222
- const cudnnTensorDescriptor_t dyDesc,
223
- const cudnnConvolutionDescriptor_t convDesc,
224
- const cudnnFilterDescriptor_t gradDesc,
225
- cudnnConvolutionBwdFilterAlgo_t algo, size_t *sizeInBytes) {
226
- return getCudnnPlugin ().cudnnGetConvolutionBackwardFilterWorkspaceSize (
227
- handle, xDesc, dyDesc, convDesc, gradDesc, algo, sizeInBytes);
270
+ cudnnStatus_t cudnnConvolutionForward (
271
+ cudnnHandle_t handle, const void *alpha,
272
+ const cudnnTensorDescriptor_t xDesc, const void *x,
273
+ const cudnnFilterDescriptor_t wDesc, const void *w,
274
+ const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo,
275
+ void *workSpace, size_t workSpaceSizeInBytes, const void *beta,
276
+ const cudnnTensorDescriptor_t yDesc, void *y) {
277
+ return getCudnnPlugin ().cudnnConvolutionForward (
278
+ handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace,
279
+ workSpaceSizeInBytes, beta, yDesc, y);
228
280
}
229
281
230
282
cudnnStatus_t cudnnConvolutionBackwardFilter (
0 commit comments