1 /**
2    cuDNN high level wrapper for grain.autograd.Variable
3 
4    TODO: support global workspace instead of frequent allocation
5  */
6 module grain.cudnn;
7 
8 version (grain_cuda):
9 
10 public import grain.cuda : cudnnHandle, checkCUDNN, CuPtr, CuArray, isDeviceMemory;
11 import grain.autograd; //  : Variable, DeviceStorage;
12 import grain.utility : castArray;
13 public import derelict.cuda;
14 public import derelict.cudnn7;
15 
16 // TODO make shared
17 __gshared bool deterministic = false;
18 __gshared bool nanProp = true;
19 
20 /// return global cudnn option
21 auto isDeterministic() {
22     return deterministic ? CUDNN_DETERMINISTIC : CUDNN_NON_DETERMINISTIC;
23 }
24 
25 /// ditto
26 auto isNanProp() {
27     return nanProp ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN;
28 }
29 
30 
31 /// convert floating point types (float, double) into cudnn enum
32 auto cudnnDataType(T)() {
33     // TODO support half
34     static if(is(T == float)) return CUDNN_DATA_FLOAT;
35     else static if(is(T == double)) return CUDNN_DATA_DOUBLE;
36     else static assert(false, "unsupported type");
37 }
38 
39 /// cudnn data type of variable like struct
40 struct TensorDesc {
41     cudnnTensorDescriptor_t desc;
42     CUdeviceptr ptr;
43     alias desc this;
44 
45     /// no copy
46     @disable this(this);
47     /// no allocation on heap
48     @disable new(size_t);
49 
50     ~this() {
51         checkCUDNN( cudnnDestroyTensorDescriptor(desc) );
52     }
53 }
54 
55 /// convert variable to cudnn tensor discriptor object
56 auto makeCudnnTensor(T, size_t dim)(Variable!(T, dim, DeviceStorage) x) {
57     static assert(dim < CUDNN_DIM_MAX);
58     static if (dim < 4) {
59         enum int ddim = 4;
60         int[ddim] shape;
61         int[ddim] strides;
62         shape[] = 1;
63         strides[] = 1;
64         foreach (d; 0 .. dim) {
65             assert(x.shape[d] < int.max);
66             shape[d] = cast(int) x.shape[d];
67         }
68         // shape[0..dim] = x.shape;
69         strides[0..dim] = x.strides;
70     } else {
71         enum int ddim = cast(int) dim;
72         int[ddim] shape;
73         foreach (d; 0 .. dim) {
74             assert(x.shape[d] < int.max);
75             shape[d] = cast(int) x.shape[d];
76         }
77         auto strides = x.strides;
78     }
79 
80     TensorDesc tdesc;
81     tdesc.ptr = x.data.ptr;
82     checkCUDNN(cudnnCreateTensorDescriptor(&tdesc.desc));
83     checkCUDNN(cudnnSetTensorNdDescriptor(tdesc.desc,
84                                           cudnnDataType!T,
85                                           ddim,
86                                           shape.ptr,
87                                           strides.ptr));
88     return tdesc;
89 }
90 
91 /// convert contiguous cuda storage to 1-D tensor disc
92 auto makeCudnnTensor(T)(ref T storage) if (isDeviceMemory!T) {
93     import grain.cuda : CudaElementType;
94     assert(storage.length <= int.max);
95     int[1] shape = [cast(int) storage.length];
96     int[1] strides = [1];
97     int ddim = 1;
98     TensorDesc tdesc;
99     tdesc.ptr = storage.ptr;
100     checkCUDNN(cudnnCreateTensorDescriptor(&tdesc.desc));
101     checkCUDNN(cudnnSetTensorNdDescriptor(tdesc.desc,
102                                           cudnnDataType!(CudaElementType!T),
103                                           ddim,
104                                           shape.ptr,
105                                           strides.ptr));
106     return tdesc;
107 }
108 
109 
110 /// y = alpha * f(x) + beta * y
111 void activationForward(cudnnActivationMode_t A, T, size_t dim)(
112     Variable!(T, dim, DeviceStorage) x, Variable!(T, dim, DeviceStorage) y,
113     T alpha=1.0, T beta=0.0, double coeff=0.0) {
114     static assert(dim <= 5, "cuDNN only supports <= 5 dim tensors. and pack dim is not supported yet.");
115     // init descriptors
116     cudnnActivationDescriptor_t  activDesc;
117     checkCUDNN( cudnnCreateActivationDescriptor(&activDesc) );
118     scope(exit) checkCUDNN( cudnnCreateActivationDescriptor(&activDesc) );
119     checkCUDNN( cudnnSetActivationDescriptor(activDesc,
120                                              A, // CUDNN_ACTIVATION_RELU,
121                                              CUDNN_PROPAGATE_NAN,
122                                              coeff) );
123     auto tx = x.makeCudnnTensor;
124     auto ty = y.makeCudnnTensor;
125     checkCUDNN( cudnnActivationForward(cudnnHandle,
126                                        activDesc,
127                                        &alpha,
128                                        tx,
129                                        cast(void*) tx.ptr,
130                                        &beta,
131                                        ty,
132                                        cast(void*) ty.ptr) );
133 }
134 
135 /// grad function of sigmoid/tanh ... etc wrapper
136 void activationBackward(cudnnActivationMode_t A, T, size_t dim)(
137     Variable!(T, dim, DeviceStorage) gx, Variable!(T, dim, DeviceStorage) gy,
138     Variable!(T, dim, DeviceStorage) x, Variable!(T, dim, DeviceStorage) y,
139     T alpha=1.0, T beta=0.0, double coeff=0.0) {
140     static assert(dim <= 5, "cuDNN only supports <= 5 dim tensors. and pack dim is not supported yet.");
141     // init descriptors
142     cudnnActivationDescriptor_t  activDesc;
143     checkCUDNN( cudnnCreateActivationDescriptor(&activDesc) );
144     scope(exit) checkCUDNN( cudnnCreateActivationDescriptor(&activDesc) );
145     checkCUDNN( cudnnSetActivationDescriptor(activDesc,
146                                              A, // CUDNN_ACTIVATION_RELU,
147                                              isNanProp(), // CUDNN_PROPAGATE_NAN,
148                                              coeff) );
149     auto tgx = gx.makeCudnnTensor;
150     auto tgy = gy.makeCudnnTensor;
151     auto tx = x.makeCudnnTensor;
152     auto ty = y.makeCudnnTensor;
153     checkCUDNN( cudnnActivationBackward(cudnnHandle,
154                                         activDesc,
155                                         &alpha,
156                                         ty,
157                                         cast(void*) ty.ptr,
158                                         tgy,
159                                         cast(void*) tgy.ptr,
160                                         tx,
161                                         cast(void*) tx.ptr,
162                                         &beta,
163                                         tgx,
164                                         cast(void*) tgx.ptr,
165                     ) );
166 }
167 
168 /// compute the softmax over all C for each H, W, N
169 void softmaxForward(cudnnSoftmaxAlgorithm_t A, T, size_t dim)(
170     Variable!(T, dim, DeviceStorage) x, Variable!(T, dim, DeviceStorage) y, T alpha=1.0, T beta=0.0) {
171     static assert(dim <= 4, "cuDNN only supports <= 4 dim tensors. and pack dim is not supported yet.");
172     checkCUDNN( cudnnSoftmaxForward(cudnnHandle,
173                                     A,
174                                     CUDNN_SOFTMAX_MODE_CHANNEL,
175                                     &alpha,
176                                     x.makeCudnnTensor,
177                                     cast(void*) x.data.ptr,
178                                     &beta,
179                                     y.makeCudnnTensor,
180                                     cast(void*) y.data.ptr));
181 }
182 
183 /// grad of softmax
184 void softmaxBackward(cudnnSoftmaxAlgorithm_t A, T, size_t dim)(
185     Variable!(T, dim, DeviceStorage) gx, Variable!(T, dim, DeviceStorage) gy,
186     Variable!(T, dim, DeviceStorage) y, T alpha=1.0, T beta=0.0) {
187     static assert(dim <= 4, "cuDNN only supports <= 4 dim tensors. and pack dim is not supported yet.");
188     checkCUDNN( cudnnSoftmaxBackward(cudnnHandle,
189                                      A,
190                                      CUDNN_SOFTMAX_MODE_CHANNEL,
191                                      &alpha,
192                                      y.makeCudnnTensor,
193                                      cast(const void*) y.data.ptr,
194                                      gy.makeCudnnTensor,
195                                      cast(const void*) gy.data.ptr,
196                                      &beta,
197                                      gx.makeCudnnTensor,
198                                      cast(void*) gx.data.ptr
199                     ));
200 }
201 
202 /**
203    Tensor operation : C = op( alpha1 * A, alpha2 * B ) + beta * C
204 
205    - list of ops
206     CUDNN_OP_TENSOR_ADD  = 0,
207     CUDNN_OP_TENSOR_MUL  = 1,
208     CUDNN_OP_TENSOR_MIN  = 2,
209     CUDNN_OP_TENSOR_MAX  = 3,
210     CUDNN_OP_TENSOR_SQRT = 4,
211     CUDNN_OP_TENSOR_NOT  = 5,
212 
213    B tensor is ignored for CUDNN_OP_TENSOR_SQRT, CUDNN_OP_TENSOR_NOT.
214 */
215 void tensorOp(cudnnOpTensorOp_t op, T, size_t dim)(
216     Variable!(T, dim, DeviceStorage) c, Variable!(T, dim, DeviceStorage) a, Variable!(T, dim, DeviceStorage) b,
217     T alpha1 = 1, T alpha2 = 1, T beta = 0
218 ) {
219     import grain.functions.common : broadcastable;
220     assert(broadcastable(a, b).ok);
221     cudnnOpTensorDescriptor_t opDisc;
222     checkCUDNN( cudnnCreateOpTensorDescriptor(&opDisc) );
223     scope(exit) cudnnDestroyOpTensorDescriptor(opDisc);
224     checkCUDNN( cudnnSetOpTensorDescriptor(opDisc, op, cudnnDataType!T, isNanProp()) );
225     checkCUDNN( cudnnOpTensor(cudnnHandle, opDisc,
226                               &alpha1, a.makeCudnnTensor, cast(const void*) a.data.ptr,
227                               &alpha2, b.makeCudnnTensor, cast(const void*) b.data.ptr,
228                               &beta, c.makeCudnnTensor, cast(void*) c.data.ptr) );
229 }
230 
231 /// x = alpha x
232 void scale(T, size_t dim)(Variable!(T, dim, DeviceStorage) x, T alpha) {
233     checkCUDNN( cudnnScaleTensor(cudnnHandle, x.makeCudnnTensor, cast(void*) x.data.ptr, &alpha) );
234 }
235 
236 /**
237    Tensor operation : C = reduce op( alpha * A ) + beta * C
238 
239    - list of op
240     CUDNN_REDUCE_TENSOR_ADD          = 0,
241     CUDNN_REDUCE_TENSOR_MUL          = 1,
242     CUDNN_REDUCE_TENSOR_MIN          = 2,
243     CUDNN_REDUCE_TENSOR_MAX          = 3,
244     CUDNN_REDUCE_TENSOR_AMAX         = 4,
245     CUDNN_REDUCE_TENSOR_AVG          = 5,
246     CUDNN_REDUCE_TENSOR_NORM1        = 6,
247     CUDNN_REDUCE_TENSOR_NORM2        = 7,
248     CUDNN_REDUCE_TENSOR_MUL_NO_ZEROS = 8,
249 
250    The NaN propagation enum applies to only the min and max reduce ops;
251    the other reduce ops propagate NaN as usual.
252    The indices space is ignored for reduce ops other than min or max.
253 */
254 void reduce(cudnnReduceTensorOp_t op, T, size_t dim)(
255     Variable!(T, dim, DeviceStorage) src, Variable!(T, dim, DeviceStorage) dst, T alpha=1, T beta=0)
256 {
257     // create tensor
258     auto srcDesc = src.makeCudnnTensor;
259     auto dstDesc = dst.makeCudnnTensor;
260 
261     // create descriptor
262     cudnnReduceTensorDescriptor_t opDesc;
263     checkCUDNN( cudnnCreateReduceTensorDescriptor(&opDesc) );
264     scope(exit) cudnnDestroyReduceTensorDescriptor(opDesc);
265     checkCUDNN( cudnnSetReduceTensorDescriptor(
266                     opDesc, op, cudnnDataType!T, isNanProp(),
267                     CUDNN_REDUCE_TENSOR_NO_INDICES, // CUDNN_REDUCE_TENSOR_FLATTENED_INDICES for backprop?
268                     CUDNN_32BIT_INDICES // only uint is supported in cudnn7
269                     ) );
270 
271     // create indices (for backprop???)
272     size_t indicesBytes;
273     checkCUDNN( cudnnGetReductionIndicesSize(cudnnHandle, opDesc, srcDesc, dstDesc, &indicesBytes) );
274     auto indices = CuPtr!uint(indicesBytes / uint.sizeof);
275 
276     // create workspace
277     size_t workspaceBytes;
278     checkCUDNN( cudnnGetReductionWorkspaceSize(cudnnHandle, opDesc, srcDesc, dstDesc, &workspaceBytes) );
279     auto workspace = CuPtr!byte(workspaceBytes);
280 
281     checkCUDNN( cudnnReduceTensor(
282                     cudnnHandle, opDesc,
283                     cast(void*) indices.ptr, indicesBytes,
284                     cast(void*) workspace.ptr, workspaceBytes,
285                     cast(const void*) &alpha, srcDesc, cast(const void*) srcDesc.ptr,
286                     cast(const void*) &beta, dstDesc, cast(void*) dstDesc.ptr
287                     ) );
288 }
289 
290 /// x[] = value (WARNING: not tested)
291 void fill(T, size_t dim)(Variable!(T, dim, DeviceStorage) x, T value) {
292     checkCUDNN( cudnnSetTensor(cudnnHandle, x.makeCudnnTensor, cast(void*) x.data.ptr, cast(const void*) &value) );
293 }
294 
295 /// WIP
296 bool isContiguous(T, size_t dim, alias Storage)(Variable!(T, dim, Storage) x) {
297     // FIXME reconsider this when I support reshape, reversed and transposed
298     bool ret = x.strides[$-1] == 1;
299     int s = 1;
300     foreach_reverse(i; 0..dim-1) {
301         ret &= x.strides[i] == x.strides[i + 1] * x.shape[i+1];
302     }
303     return ret;
304 }
305 
306 ///
307 unittest {
308     {
309         auto x = [[0.1f, 0.2f], [0.3f, 0.4f]].variable;
310         assert(x.isContiguous);
311         x.strides = [2, 2];
312         assert(!x.isContiguous);
313     }
314     version (grain_cuda) {
315         auto x = [[0.1f, 0.2f], [0.3f, 0.4f]].variable.to!DeviceStorage;
316         assert(x.isContiguous);
317         x.strides = [2, 2];
318         assert(!x.isContiguous);
319     }
320 }
321 
322 /// copy src to dst with broadcasting
323 void transform(T, size_t dim)(Variable!(T, dim, DeviceStorage) src, ref Variable!(T, dim, DeviceStorage) dst, T alpha=1, T beta=0) {
324     assert(src.shape == dst.shape);
325 
326     if (src.isContiguous && dst.isContiguous && beta == 1) {
327         import grain.cuda : axpy;
328         axpy(src.data, dst.data, alpha);
329         return;
330     }
331 
332     checkCUDNN(
333         cudnnTransformTensor(
334             cudnnHandle,
335             cast(const void*) &alpha, src.makeCudnnTensor, cast(const void*) src.data.ptr,
336             cast(const void*) &beta, dst.makeCudnnTensor, cast(void*) dst.data.ptr
337             ) );
338 }
339 
340 auto contiguous(T, size_t dim)(Variable!(T, dim, DeviceStorage) x) {
341     auto y = x.uninit;
342     y.bprop = x.bprop;
343     transform(x, y);
344     return y;
345 }
346 
347 /// test cudnnTransformTensor with array ptr manipulations
348 unittest {
349     import std.stdio;
350     // skipping stride 2
351     {
352         auto x = [1f, 0f, 2f, 0f, 3f].variable;
353         x.strides = [2];
354         x.shape = [3];
355         auto y = x.to!DeviceStorage.contiguous.to!HostStorage;
356         assert(y.data == [1f, 2f, 3f]);
357         assert(y.strides == [1]);
358         assert(y.shape == [3]);
359     }
360     // reverse skipping stride -2
361     {
362         auto x = [1f, 0f, 2f, 0f, 3f].variable;
363         x.strides = [-2];
364         x.shape = [3];
365         auto dx = x.to!DeviceStorage;
366         dx.data.ptr += 4 * float.sizeof;
367         scope(exit) dx.data.ptr -= 4 * float.sizeof;
368         auto y = dx.contiguous.to!HostStorage;
369         assert(y.data == [3f, 2f, 1f]);
370         assert(y.strides == [1]);
371         assert(y.shape == [3]);
372     }
373     // multi-dim transposed stride [3, 1]
374     {
375         auto x = [[1f, 0f, 2f],
376                   [0f, 3f, 0f]].variable;
377         x.strides = [1, 3];
378         x.shape = [3, 2];
379         auto dx = x.to!DeviceStorage;
380         auto y = dx.contiguous.to!HostStorage;
381         assert(y.sliced == [[1f, 0f], [0f, 3f], [2f, 0f]]);
382         assert(y.strides == [2, 1]);
383         assert(y.shape == [3, 2]);
384     }
385     // multi-dim skipping stride [3, 2]
386     {
387         auto x = [[1f, 0f, 2f],
388                   [0f, 3f, 0f]].variable;
389         x.strides = [3, 2];
390         x.shape = [2, 2];
391         auto dx = x.to!DeviceStorage;
392         auto y = dx.contiguous.to!HostStorage;
393         assert(y.sliced == [[1f, 2f],  [0f, 0f]]);
394         assert(y.strides == [2, 1]);
395         assert(y.shape == [2, 2]);
396     }
397     // multi-dim transposed skipping stride [2, 3]
398     {
399         auto x = [[1f, 0f, 2f],
400                   [0f, 3f, 0f]].variable;
401         x.strides = [2, 3];
402         x.shape = [2, 2];
403         auto dx = x.to!DeviceStorage;
404         // dx.data.ptr += (2 * 3 - 1) * float.sizeof;
405         // scope(exit) dx.data.ptr -= (2 * 3 - 1) * float.sizeof;
406         auto y = dx.contiguous.to!HostStorage;
407         assert(y.sliced == [[1f, 0f],  [2f, 0f]]);
408         assert(y.strides == [2, 1]);
409         assert(y.shape == [2, 2]);
410     }
411     // multi-dim transposed reverse skipping stride [-2, -3]
412     {
413         auto x = [[1f, 0f, 2f],
414                   [0f, 3f, 0f]].variable;
415         x.strides = [-2, -3];
416         x.shape = [2, 2];
417         auto dx = x.to!DeviceStorage;
418         dx.data.ptr += (2 * 3 - 1) * float.sizeof;
419         scope(exit) dx.data.ptr -= (2 * 3 - 1) * float.sizeof;
420         auto y = dx.contiguous.to!HostStorage;
421         assert(y.sliced == [[0f, 2f],  [0f, 1f]]);
422         assert(y.strides == [2, 1]);
423         assert(y.shape == [2, 2]);
424     }
425 
426 }
427 
428 /// wrapper of cudnnConvolutionForward for Variable
429 void convForward(bool isConv, bool isNchw, T, size_t dim, size_t imDims)
430     (Variable!(T, dim, DeviceStorage) input,      // [N, CI, HI, WI]
431      Variable!(T, dim, DeviceStorage) filter,     // [CO, CI/G, KH, KW]
432      ref Variable!(T, dim, DeviceStorage) output, // [N, CO, HO, WO]
433      const int[imDims]   stride,
434      const int[imDims]   pad,
435      const int[imDims]   dilation,
436      int ngroup = 1,
437      cudnnConvolutionFwdAlgo_t algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
438      float alpha = 1,
439      float beta = 0
440         ) {
441     static assert(dim < CUDNN_DIM_MAX);
442     static assert(dim == imDims + 2, "dim should be like N(batch), C(channel) ~ dim(stride)");
443     enum cudnnConvolutionMode_t mode = isConv ? CUDNN_CONVOLUTION : CUDNN_CROSS_CORRELATION;
444     enum cudnnTensorFormat_t format = isNchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NCHW;
445 
446     static if (imDims == 1) {
447         enum nbDim_ = 4;
448         enum imDim_ = 2;
449         const stride_ = stride ~ [1];
450         const pad_ = pad ~ [0];
451         const dilation_ = dilation ~ [1];
452         const fshape_ = filter.shape.castArray!int ~ [1];
453     } else {
454         enum nbDim_ = dim;
455         enum imDim_ = imDims;
456         const pad_ = pad;
457         const stride_ = stride;
458         const dilation_ = dilation;
459         const fshape_ = filter.shape.castArray!int;
460     }
461 
462     // import std.stdio;
463     // writeln("stride:", stride_);
464     // writeln("pad:", pad_);
465     // writeln("dilation:", dilation_);
466 
467     // TODO cache these?
468     cudnnFilterDescriptor_t cudnnFdesc;
469     checkCUDNN( cudnnCreateFilterDescriptor(&cudnnFdesc) );
470     scope(exit) cudnnDestroyFilterDescriptor(cudnnFdesc);
471     checkCUDNN( cudnnSetFilterNdDescriptor(cudnnFdesc, cudnnDataType!T, format,
472                                            cast(int) nbDim_, fshape_.ptr
473                                            ) );
474 
475     cudnnConvolutionDescriptor_t cudnnConvDesc;
476     checkCUDNN( cudnnCreateConvolutionDescriptor(&cudnnConvDesc) );
477     scope(exit) cudnnDestroyConvolutionDescriptor(cudnnConvDesc);
478     checkCUDNN( cudnnSetConvolutionGroupCount(cudnnConvDesc, ngroup) );
479     checkCUDNN( cudnnSetConvolutionNdDescriptor(cudnnConvDesc, cast(int) imDim_,
480                                                 pad_.ptr, stride_.ptr, dilation_.ptr,
481                                                 mode, cudnnDataType!T
482                                                 ) );
483 
484     auto cudnnIdesc = input.makeCudnnTensor;
485     auto cudnnOdesc = output.makeCudnnTensor;
486     size_t workSpaceSize;
487     checkCUDNN ( cudnnGetConvolutionForwardWorkspaceSize
488                  (cudnnHandle, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
489                   cudnnOdesc, algo, &workSpaceSize) );
490     auto workSpace = CuPtr!byte(workSpaceSize);
491 
492     checkCUDNN ( cudnnConvolutionForward (cudnnHandle,
493                                              cast(const void*) &alpha,
494                                              cudnnIdesc, cast(const void*) input.data.ptr,
495                                              cudnnFdesc, cast(const void*) filter.data.ptr,
496                                              cudnnConvDesc,
497                                              algo,
498                                              cast(void*) workSpace.ptr, workSpaceSize,
499                                              cast(const void*) &beta,
500                                              cudnnOdesc, cast(void*) output.data.ptr) );
501 }
502 
503 /// wrapper of cudnnConvolutionBackwardData and Weight for Variable
504 void convBackward(bool isConv, bool isNchw, T, size_t dim, size_t imDims
505     )
506     (
507      ref Variable!(T, dim, DeviceStorage) gradInput,      // [N, CI, HI, WI]
508      Variable!(T, dim, DeviceStorage) input,      // [N, CI, HI, WI]
509      ref Variable!(T, dim, DeviceStorage) gradFilter,     // [CO, CI/G, KH, KW]
510      Variable!(T, dim, DeviceStorage) filter,     // [CO, CI/G, KH, KW]
511      Variable!(T, dim, DeviceStorage) gradOutput, // [N, CO, HO, WO]
512      const int[imDims]   stride,
513      const int[imDims]   pad,
514      const int[imDims]   dilation,
515      int ngroup = 1,
516      cudnnConvolutionBwdDataAlgo_t algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
517      float alpha = 1,
518      float beta = 0
519 )  {
520     static assert(dim < CUDNN_DIM_MAX);
521     static assert(dim == imDims + 2, "dim should be like N(batch), C(channel) ~ dim(stride)");
522     enum cudnnConvolutionMode_t mode = isConv ? CUDNN_CONVOLUTION : CUDNN_CROSS_CORRELATION;
523     enum cudnnTensorFormat_t format = isNchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NCHW;
524 
525     static if (imDims == 1) {
526         enum nbDim_ = 4;
527         enum imDim_ = 2;
528         const stride_ = stride ~ [1];
529         const pad_ = pad ~ [0];
530         const dilation_ = dilation ~ [1];
531         const fshape_ = filter.shape.castArray!int ~ [1];
532     } else {
533         enum nbDim_ = dim;
534         enum imDim_ = imDims;
535         const pad_ = pad;
536         const stride_ = stride;
537         const dilation_ = dilation;
538         const fshape_ = filter.shape.castArray!int;
539     }
540 
541     // TODO cache these?
542     cudnnFilterDescriptor_t cudnnFdesc;
543     checkCUDNN( cudnnCreateFilterDescriptor(&cudnnFdesc) );
544     scope(exit) cudnnDestroyFilterDescriptor(cudnnFdesc);
545     checkCUDNN( cudnnSetFilterNdDescriptor(cudnnFdesc, cudnnDataType!T, format,
546                                            cast(int) nbDim_, fshape_.ptr
547                                            ) );
548 
549     cudnnConvolutionDescriptor_t cudnnConvDesc;
550     checkCUDNN( cudnnCreateConvolutionDescriptor(&cudnnConvDesc) );
551     scope(exit) cudnnDestroyConvolutionDescriptor(cudnnConvDesc);
552     checkCUDNN( cudnnSetConvolutionGroupCount(cudnnConvDesc, ngroup) );
553     checkCUDNN( cudnnSetConvolutionNdDescriptor(cudnnConvDesc, cast(int) imDim_,
554                                                 pad_.ptr, stride_.ptr, dilation_.ptr,
555                                                 mode, cudnnDataType!T
556                                                 ) );
557 
558     auto cudnnIdesc = input.makeCudnnTensor;
559     auto cudnnGIdesc = gradInput.makeCudnnTensor;
560     auto cudnnGOdesc = gradOutput.makeCudnnTensor;
561 
562     size_t dworkSpaceSize;
563     checkCUDNN ( cudnnGetConvolutionBackwardDataWorkspaceSize
564                     (cudnnHandle, cudnnFdesc, cudnnGOdesc, cudnnConvDesc,
565                      cudnnGIdesc, algo, &dworkSpaceSize) );
566     auto dworkSpace = CuPtr!byte(dworkSpaceSize);
567     checkCUDNN ( cudnnConvolutionBackwardData (cudnnHandle,
568                                                   cast(const void*)(&alpha),
569                                                   cudnnFdesc, cast(const void*) filter.data.ptr,
570                                                   cudnnGOdesc, cast(const void*) gradOutput.data.ptr,
571                                                   cudnnConvDesc,
572                                                   algo,
573                                                   cast(void*) dworkSpace.ptr, dworkSpaceSize,
574                                                   cast(const void*)(&beta),
575                                                   cudnnGIdesc, cast(void*) gradInput.data.ptr) );
576 
577     size_t fworkSpaceSize;
578     checkCUDNN ( cudnnGetConvolutionBackwardFilterWorkspaceSize
579                     (cudnnHandle, cudnnIdesc, cudnnGOdesc, cudnnConvDesc,
580                      cudnnFdesc, algo, &fworkSpaceSize) );
581     auto fworkSpace = CuPtr!byte(fworkSpaceSize);
582     checkCUDNN ( cudnnConvolutionBackwardFilter (cudnnHandle,
583                                                     cast(const void*)(&alpha),
584                                                     cudnnIdesc,  cast(const void*) input.data.ptr,
585                                                     cudnnGOdesc, cast(const void*) gradOutput.data.ptr,
586                                                     cudnnConvDesc,
587                                                     algo,
588                                                     cast(void*) fworkSpace.ptr, fworkSpaceSize,
589                                                     cast(const void*)(&beta),
590                                                     cudnnFdesc, cast(void*) gradFilter.data.ptr) );
591 }
592 
593 
594 /// wrapper of cudnnPoolingForward for Variable
595 auto poolForward(bool isMax = true, bool isAveragePad = false,
596                  T, size_t _tensorDims, size_t _poolDims)
597     (Variable!(T, _tensorDims, DeviceStorage) input,      // [N, C, HI, WI]
598      int[_poolDims] windowDim,
599      int[_poolDims] padding,
600      int[_poolDims] stride,
601      T alpha = 1,
602      T beta = 0
603         ) {
604     static assert(_tensorDims < CUDNN_DIM_MAX);
605     static assert(_tensorDims == _poolDims + 2);
606 
607     static if (_poolDims == 1) {
608         enum tensorDims = 4;
609         enum poolDims = 2;
610         const strideA = stride ~ [1];
611         const paddingA = padding ~ [0];
612         const windowDimA = windowDim ~ [1];
613     } else {
614         enum tensorDims = _tensorDims;
615         enum poolDims = _poolDims;
616         const strideA = stride;
617         const paddingA = padding;
618         const windowDimA = windowDim;
619     }
620 
621     cudnnPoolingDescriptor_t     poolingDesc;
622     checkCUDNN( cudnnCreatePoolingDescriptor(&poolingDesc) );
623     scope(exit) checkCUDNN( cudnnDestroyPoolingDescriptor(poolingDesc) );
624 
625     static if (isMax) {
626         immutable mode = isDeterministic() == CUDNN_DETERMINISTIC
627             ? CUDNN_POOLING_MAX_DETERMINISTIC
628             : CUDNN_POOLING_MAX;
629     } else {
630         enum mode = isAveragePad
631             ? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
632             : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
633     }
634     checkCUDNN( cudnnSetPoolingNdDescriptor(poolingDesc,
635                                             mode,
636                                             isNanProp(),
637                                             cast(int) poolDims,
638                                             windowDimA.ptr,
639                                             paddingA.ptr,
640                                             strideA.ptr ) );
641 
642     const inputDesc = input.makeCudnnTensor;
643     int[tensorDims] tensorOutputDimA;
644     checkCUDNN( cudnnGetPoolingNdForwardOutputDim(poolingDesc,
645                                                   inputDesc,
646                                                   cast(int) tensorDims,
647                                                   tensorOutputDimA.ptr) );
648     // resize output if shape is not met
649     // if (tensorOutputDimA != output.shape.castArray!int) {
650     auto output = uninitVariable!(T, DeviceStorage, tensorDims)(tensorOutputDimA.castArray!uint, input.requiresGrad);
651 
652     checkCUDNN( cudnnPoolingForward(cudnnHandle,
653                                     poolingDesc,
654                                     cast(const void*) &alpha,
655                                     inputDesc,
656                                     cast(const void*) input.data.ptr,
657                                     cast(const void*) &beta,
658                                     output.makeCudnnTensor,
659                                     cast(void*) output.data.ptr) );
660     return output;
661 }
662 
663 
664 /// wrapper of cudnnPoolingBackward for Variable
665 void poolBackward(bool isMax = true, bool isAveragePad = false,
666                   T, size_t _tensorDims, size_t _poolDims)
667     (ref Variable!(T, _tensorDims, DeviceStorage) gradInput,
668      Variable!(T, _tensorDims, DeviceStorage) input,
669      Variable!(T, _tensorDims, DeviceStorage) gradOutput,
670      Variable!(T, _tensorDims, DeviceStorage) output,
671      int[_poolDims] windowDim,
672      int[_poolDims] padding,
673      int[_poolDims] stride,
674      T alpha = 1,
675      T beta = 0
676         ) {
677     static assert(_tensorDims < CUDNN_DIM_MAX);
678     static assert(_tensorDims == _poolDims + 2);
679 
680     static if (_poolDims == 1) {
681         enum tensorDims = 4;
682         enum poolDims = 2;
683         const strideA = stride ~ [1];
684         const paddingA = padding ~ [0];
685         const windowDimA = windowDim ~ [1];
686     } else {
687         enum tensorDims = _tensorDims;
688         enum poolDims = _poolDims;
689         const strideA = stride;
690         const paddingA = padding;
691         const windowDimA = windowDim;
692     }
693 
694     cudnnPoolingDescriptor_t     poolingDesc;
695     checkCUDNN( cudnnCreatePoolingDescriptor(&poolingDesc) );
696     scope(exit) checkCUDNN( cudnnDestroyPoolingDescriptor(poolingDesc) );
697 
698     static if (isMax) {
699         immutable mode = isDeterministic() == CUDNN_DETERMINISTIC
700             ? CUDNN_POOLING_MAX_DETERMINISTIC
701             : CUDNN_POOLING_MAX;
702     } else {
703         enum mode = isAveragePad
704             ? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
705             : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
706     }
707     checkCUDNN( cudnnSetPoolingNdDescriptor(poolingDesc,
708                                             mode,
709                                             isNanProp(),
710                                             cast(int) poolDims,
711                                             windowDimA.ptr,
712                                             paddingA.ptr,
713                                             strideA.ptr ) );
714 
715     checkCUDNN( cudnnPoolingBackward(cudnnHandle,
716                                      poolingDesc,
717 
718                                      cast(const void*) &alpha,
719                                      output.makeCudnnTensor,
720                                      cast(const void*) output.data.ptr,
721                                      gradOutput.makeCudnnTensor,
722                                      cast(const void*) gradOutput.data.ptr,
723                                      input.makeCudnnTensor,
724                                      cast(const void*) input.data.ptr,
725 
726                                      cast(const void*) &beta,
727                                      gradInput.makeCudnnTensor,
728                                      cast(void*) gradInput.data.ptr) );
729 }
730 
731 
732 /// Global (thread local) dropout state with descriptor and state array
733 struct ThreadLocalDropout {
734     static shared size_t count = 0;
735     static cudnnDropoutDescriptor_t _dropoutDesc = null;
736     static CuPtr!byte _stateArray;
737 
738     @disable this(this);
739     @disable new(size_t);
740 
741     /// FIXME set global seed
742     static void init(size_t seed=0) {
743         if (_dropoutDesc != null) return;
744 
745         checkCUDNN(cudnnCreateDropoutDescriptor(&_dropoutDesc));
746         // How much memory does dropout need for states?
747         // These states are used to generate random numbers internally
748         // and should not be freed until the RNN descriptor is no longer used
749         size_t stateSize;
750         checkCUDNN(cudnnDropoutGetStatesSize(cudnnHandle, &stateSize));
751         _stateArray = CuPtr!byte(stateSize);
752         checkCUDNN(cudnnSetDropoutDescriptor(_dropoutDesc,
753                                              cudnnHandle,
754                                              0f,
755                                              cast(void*) _stateArray.ptr,
756                                              stateSize,
757                                              seed));
758 
759         import core.atomic : atomicOp;
760         count.atomicOp!"+="(1);
761     }
762 
763     static descriptor(float ratio=0.0) {
764         init();
765         if (ratio != 0.0) {
766             checkCUDNN(cudnnSetDropoutDescriptor(
767                            _dropoutDesc,
768                            cudnnHandle,
769                            ratio,
770                            null, // if state is null, state won't be updated
771                            0,
772                            0));
773         }
774         return this._dropoutDesc;
775     }
776 
777     static ref state() {
778         init();
779         return _stateArray;
780     }
781 }
782 
783 struct CudnnDropout {
784     CuPtr!byte reserved;
785 
786     auto forward(size_t dim)(Variable!(float, dim, DeviceStorage) x, float ratio) {
787         import std.algorithm : move;
788         auto y = x.uninit;
789         auto xt = x.makeCudnnTensor;
790 
791         size_t reservedSize;
792         cudnnDropoutGetReserveSpaceSize(xt, &reservedSize);
793         this.reserved = CuPtr!byte(reservedSize);
794 
795         checkCUDNN(cudnnDropoutForward(
796                        cudnnHandle,
797                        ThreadLocalDropout.descriptor(ratio),
798                        xt,
799                        cast(void*) x.ptr,
800                        y.makeCudnnTensor,
801                        cast(void*) y.ptr,
802                        cast(void*) this.reserved.ptr,
803                        this.reserved.length
804                        ));
805         return y;
806     }
807 
808     auto backward(size_t dim)(Variable!(float, dim, DeviceStorage) gy) {
809         auto gx = gy.uninit;
810         checkCUDNN(cudnnDropoutBackward(
811                        cudnnHandle,
812                        ThreadLocalDropout.descriptor,
813                        gy.makeCudnnTensor,
814                        cast(void*) gy.ptr,
815                        gx.makeCudnnTensor,
816                        cast(void*) gx.ptr,
817                        cast(void*) this.reserved.ptr,
818                        this.reserved.length
819                        ));
820         return gx;
821     }
822 }