1 module grain.cudnn;
2 
3 version (grain_cuda):
4 
5 public import grain.cuda : cudnnHandle, checkCUDNN;
6 import grain.autograd; //  : Variable, DeviceStorage;
7 public import derelict.cuda;
8 public import derelict.cudnn7;
9 
10 auto cudnnDataType(T)() {
11     // TODO support half
12     static if(is(T == float)) return CUDNN_DATA_FLOAT;
13     else static if(is(T == double)) return CUDNN_DATA_DOUBLE;
14     else static assert(false, "unsupported type");
15 }
16 
17 
18 private struct TensorDesc {
19     cudnnTensorDescriptor_t desc;
20     CUdeviceptr ptr;
21     alias desc this;
22 
23     @disable this(this); // no copy
24 
25     ~this() {
26         checkCUDNN( cudnnDestroyTensorDescriptor(desc) );
27     }
28 }
29 
30 
31 auto makeCudnnTensor(T, size_t dim)(Variable!(T, dim, DeviceStorage) x) {
32     static assert(dim < CUDNN_DIM_MAX);
33     static if (dim < 4) {
34         enum int ddim = 4;
35         int[ddim] shape, strides;
36         shape[] = 1;
37         strides[] = 1;
38         shape[0..dim] = x.shape;
39         strides[0..dim] = x.strides;
40     } else {
41         enum int ddim = cast(int) dim;
42         auto shape = x.shape;
43         auto strides = x.strides;
44     }
45 
46     TensorDesc tdesc;
47     tdesc.ptr = x.data.ptr;
48     checkCUDNN(cudnnCreateTensorDescriptor(&tdesc.desc));
49     checkCUDNN(cudnnSetTensorNdDescriptor(tdesc.desc,
50                                           cudnnDataType!T,
51                                           ddim,
52                                           shape.ptr,
53                                           strides.ptr));
54     return tdesc;
55 }
56 
57 /// y = alpha * f(x) + beta * y
58 void activationForward(cudnnActivationMode_t A, T, size_t dim)(
59     Variable!(T, dim, DeviceStorage) x, Variable!(T, dim, DeviceStorage) y,
60     T alpha=1.0, T beta=0.0, double coeff=0.0) {
61     static assert(dim <= 5, "cuDNN only supports <= 5 dim tensors. and pack dim is not supported yet.");
62     // init descriptors
63     cudnnActivationDescriptor_t  activDesc;
64     checkCUDNN( cudnnCreateActivationDescriptor(&activDesc) );
65     scope(exit) checkCUDNN( cudnnCreateActivationDescriptor(&activDesc) );
66     checkCUDNN( cudnnSetActivationDescriptor(activDesc,
67                                              A, // CUDNN_ACTIVATION_RELU,
68                                              CUDNN_PROPAGATE_NAN,
69                                              coeff) );
70     auto tx = x.makeCudnnTensor;
71     auto ty = y.makeCudnnTensor;
72     checkCUDNN( cudnnActivationForward(cudnnHandle,
73                                        activDesc,
74                                        &alpha,
75                                        tx,
76                                        cast(void*) tx.ptr,
77                                        &beta,
78                                        ty,
79                                        cast(void*) ty.ptr) );
80 }
81 
82 ///
83 void activationBackward(cudnnActivationMode_t A, T, size_t dim)(
84     Variable!(T, dim, DeviceStorage) gx, Variable!(T, dim, DeviceStorage) gy,
85     Variable!(T, dim, DeviceStorage) x, Variable!(T, dim, DeviceStorage) y,
86     T alpha=1.0, T beta=0.0, double coeff=0.0) {
87     static assert(dim <= 5, "cuDNN only supports <= 5 dim tensors. and pack dim is not supported yet.");
88     // init descriptors
89     cudnnActivationDescriptor_t  activDesc;
90     checkCUDNN( cudnnCreateActivationDescriptor(&activDesc) );
91     scope(exit) checkCUDNN( cudnnCreateActivationDescriptor(&activDesc) );
92     checkCUDNN( cudnnSetActivationDescriptor(activDesc,
93                                              A, // CUDNN_ACTIVATION_RELU,
94                                              CUDNN_PROPAGATE_NAN,
95                                              coeff) );
96     auto tgx = gx.makeCudnnTensor;
97     auto tgy = gy.makeCudnnTensor;
98     auto tx = x.makeCudnnTensor;
99     auto ty = y.makeCudnnTensor;
100     checkCUDNN( cudnnActivationBackward(cudnnHandle,
101                                         activDesc,
102                                         &alpha,
103                                         ty,
104                                         cast(void*) ty.ptr,
105                                         tgy,
106                                         cast(void*) tgy.ptr,
107                                         tx,
108                                         cast(void*) tx.ptr,
109                                         &beta,
110                                         tgx,
111                                         cast(void*) tgx.ptr,
112                     ) );
113 }
114 
115 /// compute the softmax over all C for each H, W, N
116 void softmaxForward(cudnnSoftmaxAlgorithm_t A, T, size_t dim)(
117     Variable!(T, dim, DeviceStorage) x, Variable!(T, dim, DeviceStorage) y, T alpha=1.0, T beta=0.0) {
118     static assert(dim <= 4, "cuDNN only supports <= 4 dim tensors. and pack dim is not supported yet.");
119     checkCUDNN( cudnnSoftmaxForward(cudnnHandle,
120                                     A,
121                                     CUDNN_SOFTMAX_MODE_CHANNEL,
122                                     &alpha,
123                                     x.makeCudnnTensor,
124                                     cast(void*) x.data.ptr,
125                                     &beta,
126                                     y.makeCudnnTensor,
127                                     cast(void*) y.data.ptr));
128 }
129 
130 
131 void softmaxBackward(cudnnSoftmaxAlgorithm_t A, T, size_t dim)(
132     Variable!(T, dim, DeviceStorage) gx, Variable!(T, dim, DeviceStorage) gy,
133     Variable!(T, dim, DeviceStorage) y, T alpha=1.0, T beta=0.0) {
134     static assert(dim <= 4, "cuDNN only supports <= 4 dim tensors. and pack dim is not supported yet.");
135     checkCUDNN( cudnnSoftmaxBackward(cudnnHandle,
136                                      A,
137                                      CUDNN_SOFTMAX_MODE_CHANNEL,
138                                      &alpha,
139                                      y.makeCudnnTensor,
140                                      cast(const void*) y.data.ptr,
141                                      gy.makeCudnnTensor,
142                                      cast(const void*) gy.data.ptr,
143                                      &beta,
144                                      gx.makeCudnnTensor,
145                                      cast(void*) gx.data.ptr
146                     ));
147 }