1 /**
2    A module for binary autograd functions.
3 
4    TODO:
5    - support add tensor with cudnnAddTensor and mir.ndslice
6    - support opBinary(add, mul, min, max) cudnnOpTensor
7    - convolution
8    - batchnorm
9 */
10 module grain.functions.binary;
11 
12 import grain.functions.common; // : FunctionCommon, TypeChecker;
13 import grain.autograd; //  : variable, Variable, UntypedVariable, HostStorage, DeviceStorage;
14 import grain.cuda;
15 import grain.utility;
16 import std.traits : isFloatingPoint;
17 import std.typecons : tuple;
18 import mir.ndslice : isSlice;
19 import std.format : format;
20 
21 
22 /// c = op(alpha1 * a + alpha2 * b) + beta * c;
23 struct OpBinary(T, size_t dim, string ops) if (isFloatingPoint!T) {
24     import mir.ndslice;
25 
26     T alpha1 = 1, alpha2 = 1;
27 
28     uint[dim] shape1, shape2;
29     Variable!(T, dim, HostStorage) ha, hb;
30 
31     auto forward(Variable!(T, dim, HostStorage) a, Variable!(T, dim, HostStorage) b) {
32         auto info = broadcastable(a, b);
33         assert(info.ok);
34         auto abx = broadcast(a.sliced, b.sliced);
35         auto ax = abx[0];
36         auto bx = abx[1];
37         auto c = slice(this.alpha1 * ax);
38 
39         // TODO if train
40         this.shape1 = a.shape;
41         this.shape2 = b.shape;
42         // ops
43         static if (ops == "+") {
44             c[] += this.alpha2 * bx;
45         } else static if (ops == "*") {
46             this.ha = ax.variable;
47             this.hb = bx.variable;
48             c[] *= this.alpha2 * bx;
49         } else {
50             static assert("unknown operator: " ~ ops);
51         }
52         return c.variable(a.requiresGrad || b.requiresGrad);
53     }
54 
55     auto backward(Variable!(T, dim, HostStorage) gc) {
56         import numir;
57         import mir.math : sum;
58         import mir.ndslice;
59         static if (ops == "+") {
60             auto ga = this.alpha1 == 1 ? gc.sliced.slice.universal : slice(this.alpha1 * gc.sliced).universal;
61             if (ga.shape != this.shape1) {
62                 ga = reduceShape!(sum!"fast")(ga, this.shape1.castArray!size_t).universal;
63             }
64             auto gb = this.alpha2 == 1 ? gc.sliced.slice.universal : slice(this.alpha2 * gc.sliced).universal;
65             if (gb.shape != this.shape2) {
66                 gb = reduceShape!(sum!"fast")(gb, this.shape2.castArray!size_t).universal;
67             }
68             return tuple(ga.variable, gb.variable);
69         } else static if (ops == "*") {
70             assert(this.ha.defined);
71             assert(this.hb.defined);
72             auto ga = gc.sliced.slice.universal;
73             ga[] *= this.alpha1 * this.alpha2 * this.hb.sliced;
74             if (ga.shape != this.shape1) {
75                 ga = reduceShape!(sum!"fast")(ga, this.shape1.castArray!size_t).universal;
76             }
77             auto gb = gc.sliced.slice.universal;
78             gb[] *= this.alpha1 * this.alpha2 * this.ha.sliced;
79             if (gb.shape != this.shape2) {
80                 gb = reduceShape!(sum!"fast")(gb, this.shape2.castArray!size_t).universal;
81             }
82             return tuple(ga.variable, gb.variable);
83         } else {
84             static assert("unknown operator: " ~ ops);
85         }
86     }
87 
88     version (grain_cuda) {
89         import grain.cudnn;
90         import derelict.cudnn7;
91         import std.algorithm : find;
92 
93         enum opBinaryDict = [
94                              "+": CUDNN_OP_TENSOR_ADD,
95                              "*": CUDNN_OP_TENSOR_MUL,
96                              "min": CUDNN_OP_TENSOR_MIN,
97                              "max": CUDNN_OP_TENSOR_MAX
98                              ];
99 
100         static if (opBinaryDict.keys.find(ops)) {
101             static if (ops == "*") {
102                 Variable!(T, dim, DeviceStorage) da, db;
103             }
104 
105             auto forward(Variable!(T, dim, DeviceStorage) a, Variable!(T, dim, DeviceStorage) b) {
106                 // TODO implement non-cudnn case
107                 foreach (d; 0 .. dim) {
108                     assert(a.shape[d] == b.shape[d] || b.shape[d] == 1,
109                            "cuDNN does not support complete broadcasting");
110                 }
111                 // TODO if train
112                 this.shape1 = a.shape;
113                 this.shape2 = b.shape;
114                 static if (ops == "*") {
115                     this.da = a;
116                     this.db = b;
117                 }
118 
119                 auto c = a.uninit;
120                 c.requiresGrad = a.requiresGrad || b.requiresGrad;
121                 import grain.cudnn;
122                 tensorOp!(opBinaryDict[ops], T, dim)(c, a, b, this.alpha1, this.alpha2);
123                 return c;
124             }
125         } else {
126             static assert("unknown operator: " ~ ops);
127         }
128 
129         static if (ops == "+") {
130             auto backward(Variable!(T, dim, DeviceStorage) gc) {
131                 Variable!(T, dim, DeviceStorage) ga, gb;
132                 if (this.shape1 == gc.shape) {
133                     ga = gc.dup;
134                     if (this.alpha1 != 1.0) grain.cudnn.scale(ga, this.alpha1);
135                 } else {
136                     ga = uninitVariable!(T, DeviceStorage)(this.shape1);
137                     grain.cudnn.reduce!CUDNN_REDUCE_TENSOR_ADD(gc, ga, this.alpha1);
138                 }
139 
140                 if (this.shape2 == gc.shape) {
141                     gb = gc.dup;
142                     if (this.alpha2 != 1.0) grain.cudnn.scale(gb, this.alpha2);
143                 } else {
144                     gb = uninitVariable!(T, DeviceStorage)(this.shape2);
145                     grain.cudnn.reduce!CUDNN_REDUCE_TENSOR_ADD(gc, gb, this.alpha2);
146                 }
147                 return tuple(ga, gb);
148             }
149         } else static if (ops == "*") {
150             auto backward(Variable!(T, dim, DeviceStorage) gc) {
151                 auto gax = uninitVariable!(T, DeviceStorage)(gc.shape);
152                 auto gbx = uninitVariable!(T, DeviceStorage)(gc.shape);
153                 auto alpha = this.alpha1 * this.alpha2;
154                 grain.cudnn.tensorOp!CUDNN_OP_TENSOR_MUL(gax, gc, this.db, alpha);
155                 grain.cudnn.tensorOp!CUDNN_OP_TENSOR_MUL(gbx, gc, this.da, alpha);
156 
157                 Variable!(T, dim, DeviceStorage) ga, gb;
158                 if (this.shape1 == gc.shape) {
159                     ga = gax;
160                 } else {
161                     ga = uninitVariable!(T, DeviceStorage)(this.shape1);
162                     grain.cudnn.reduce!CUDNN_REDUCE_TENSOR_ADD(gax, ga);
163                 }
164 
165                 if (this.shape2 == gc.shape) {
166                     gb = gbx;
167                 } else {
168                     gb = uninitVariable!(T, DeviceStorage)(this.shape2);
169                     grain.cudnn.reduce!CUDNN_REDUCE_TENSOR_ADD(gbx, gb);
170                 }
171                 return tuple(ga, gb);
172             }
173         }
174     }
175 
176     mixin FunctionCommon;
177 }
178 
179 ///
180 unittest {
181     static foreach (op; ["+", "*"]) {
182         foreach (j; [1, 2]) {
183             import std.typecons : tuple;
184             import numir : uniform, approxEqual;
185             import mir.ndslice : slice;
186             import grain.testing;
187 
188             auto a = uniform!float(3, 2).slice.variable;
189             auto b = uniform!float(3, j).slice.variable;
190             auto gc = uniform!float(3, 2).slice.variable;
191             auto func = OpBinary!(float, 2, op)(1, 2);
192             gradCheck(func, tuple(a, b), gc);
193 
194             auto c = func.forward(a, b);
195             auto gab = func.backward(gc);
196             version (grain_cuda) {
197                 auto dfunc = OpBinary!(float, 2, op)(1, 2);
198                 auto dc = dfunc.forward(a.to!DeviceStorage, b.to!DeviceStorage);
199                 assert(approxEqual(dc.to!HostStorage.sliced, c.sliced));
200                 auto dgab = dfunc.backward(gc.to!DeviceStorage);
201                 assert(approxEqual(dgab[0].to!HostStorage.sliced, gab[0].sliced));
202                 assert(approxEqual(dgab[1].to!HostStorage.sliced, gab[1].sliced));
203             }
204         }
205     }
206 }
207 
208 ///
209 unittest {
210     foreach (i; [1, 2]) {
211         foreach (j; [1, 2]) {
212             import std.typecons : tuple;
213             import numir : uniform;
214             import mir.ndslice : slice;
215             import grain.testing;
216 
217             auto a = uniform!float(i, 2).slice.variable;
218             auto b = uniform!float(2, j).slice.variable;
219             auto gc = uniform!float(2, 2).slice.variable;
220             auto func = OpBinary!(float, 2, "*")(1, 2);
221             gradCheck(func, tuple(a, b), gc);
222         }
223     }
224 }
225 
226 
227 /// a and b have the same shape
228 unittest {
229     import mir.ndslice;
230 
231     auto plus = OpBinary!(float, 2, "+")(1.0f, 2.0f);
232     auto a = [[1.0f, 2.0f, 3.0f], [4.0f, 5.0f, 3.0f]].variable;
233     auto b = [[-1.0f, 4.0f, 0.0f], [1.0f, 2.0f, 3.0f]].variable;
234     auto hc = plus.forward(a, b);
235     assert(hc.sliced == [[-1.0f, 10.0f, 3.0f], [6.0f, 9.0f, 9.0f]]);
236 
237     version (grain_cuda) {
238         auto dplus = OpBinary!(float, 2, "+")(1.0f, 2.0f);
239         auto dc = dplus.forward(a.to!DeviceStorage, b.to!DeviceStorage);
240         assert(dc.to!HostStorage.sliced == [[-1.0f, 10.0f, 3.0f], [6.0f, 9.0f, 9.0f]]);
241     }
242 }
243 
244 
245 /// a and b have different shapes
246 unittest {
247     import mir.ndslice;
248 
249     auto plus = OpBinary!(float, 2, "+")(1.0f, 2.0f);
250     auto a = [[1.0f, 2.0f, 3.0f], [4.0f, 5.0f, 3.0f]].variable;
251     auto b = [[-1.0f, 4.0f, 0.0f]].variable;
252     auto hc = plus.forward(a, b);
253     assert(hc.sliced == [[-1.0f, 10.0f, 3.0f], [2.0f, 13.0f, 3.0f]]);
254 
255     version (grain_cuda) {
256         auto dc = plus.forward(a.to!DeviceStorage, b.to!DeviceStorage);
257         assert(dc.to!HostStorage.sliced == [[-1.0f, 10.0f, 3.0f], [2.0f, 13.0f, 3.0f]]);
258     }
259 }
260 
261 /// a and b have different shapes
262 unittest {
263     import mir.ndslice;
264 
265     auto plus = OpBinary!(float, 2, "*")(1.0f, 2.0f);
266     auto a = [[1.0f, 2.0f, 3.0f], [4.0f, 5.0f, 3.0f]].variable;
267     auto b = [[-1.0f, 4.0f, 0.0f]].variable;
268     auto hc = plus.forward(a, b);
269     assert(hc.sliced == [[1*2*-1, 2*2*4, 0], [4*2*-1, 5*2*4, 0]]);
270 
271     version (grain_cuda) {
272         auto dc = plus.forward(a.to!DeviceStorage, b.to!DeviceStorage);
273         assert(dc.to!HostStorage.sliced ==[[1*2*-1, 2*2*4, 0], [4*2*-1, 5*2*4, 0]]);
274     }
275 }
276 
277 
278 /++
279  Matrix-Matrix multiplication (using cuBLAS)
280 
281  See_Also: https://github.com/chainer/chainer/blob/v1/chainer/functions/connection/linear.py#L11
282  +/
283 struct MatMul(T) {
284     import mir.ndslice : transposed, universal;
285     import std.typecons : tuple;
286     import lubeck : mtimes;
287     T alpha = 1;
288     T beta = 0;
289     Variable!(T, 2, HostStorage) ha, hb;
290 
291     // TODO uncomment this line
292     mixin FunctionCommon;
293 
294     auto forward(Variable!(T, 2, HostStorage) a, Variable!(T, 2, HostStorage) b) {
295         // TODO if training
296         this.ha = a;
297         this.hb = b;
298         return mtimes(a.sliced, b.sliced).variable(a.requiresGrad || b.requiresGrad);
299     }
300 
301     auto backward(Variable!(T, 2, HostStorage) gy) {
302         auto ga = mtimes(gy.sliced, this.hb.sliced.transposed).variable;
303         auto gb = mtimes(this.ha.sliced.transposed, gy.sliced).variable;
304         return tuple(ga, gb);
305     }
306 
307     version(grain_cuda) {
308         Variable!(T, 2, DeviceStorage) da, db;
309 
310         auto forward(Variable!(T, 2, DeviceStorage) a, Variable!(T, 2, DeviceStorage) b) {
311             import grain.cublas;
312             static if (is(T == float)) {
313                 alias gemm = cublasSgemm_v2;
314             } else static if (is(T == double)) {
315                 alias gemm = cublasDgemm_v2;
316             } else {
317                 static assert(false, "unsupported type");
318             }
319 
320             import std.typecons : RefCounted;
321             assert(a.shape[1] == b.shape[0]);
322             auto cdata = CuArray!T(a.shape[0] * b.shape[1]);
323             auto c = Variable!(T, 2, DeviceStorage)(
324                 a.requiresGrad || b.requiresGrad,
325                 [a.shape[0], b.shape[1]], [b.shape[1], 1], cdata);
326             // C = A x B = (BT x AT)T
327             // TODO support transposed (CUBLAS_OP_T)
328             // see https://github.com/libmir/mir-blas/blob/master/source/mir/blas.d#L299
329             // TODO if train
330             this.da = a;
331             this.db = b;
332             checkCublasErrors(gemm(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N,
333                                    cast(int) b.shape[1],
334                                    cast(int) a.shape[0], cast(int) a.shape[1],
335                                    &alpha,
336                                    cast(const T*) b.data.ptr, cast(int) b.strides[0],
337                                    cast(const T*) a.data.ptr, cast(int) a.strides[0],
338                                    &beta,
339                                    cast(T*) c.data.ptr, cast(int) c.strides[0]));
340             return c;
341         }
342 
343         auto backward(Variable!(T, 2, DeviceStorage) gc) {
344             import grain.cublas;
345             static if (is(T == float)) {
346                 alias gemm = cublasSgemm_v2;
347             } else static if (is(T == double)) {
348                 alias gemm = cublasDgemm_v2;
349             } else {
350                 static assert(false, "unsupported type");
351             }
352             auto ga = this.da.dup;
353             auto gb = this.db.dup;
354             // auto ga = mtimes(gc.sliced, this.hb.sliced.transposed).variable;
355             checkCublasErrors(gemm(cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N,
356                                    cast(int) ga.shape[1],
357                                    cast(int) ga.shape[0], cast(int) gc.shape[1],
358                                    &alpha,
359                                    cast(const T*) db.data.ptr, cast(int) db.strides[0],
360                                    cast(const T*) gc.data.ptr, cast(int) gc.strides[0],
361                                    &beta,
362                                    cast(T*) ga.data.ptr, cast(int) ga.strides[0]));
363             // auto gb = mtimes(this.ha.sliced.transposed, gc.sliced).variable;
364             checkCublasErrors(gemm(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T,
365                                    cast(int) gb.shape[1],
366                                    cast(int) gb.shape[0], cast(int) da.shape[0],
367                                    &alpha,
368                                    cast(const T*) gc.data.ptr, cast(int) gc.strides[0],
369                                    cast(const T*) da.data.ptr, cast(int) da.strides[0],
370                                    &beta,
371                                    cast(T*) gb.data.ptr, cast(int) gb.strides[0]));
372             return tuple(ga, gb);
373         }
374     }
375 }
376 
377 /// test matmul gradcheck and cpu/cuda equality
378 unittest {
379     foreach (i; [2, 3, 4]) {
380         foreach (j; [2, 3, 4]) {
381             import std.typecons : tuple;
382             import numir : uniform;
383             import mir.ndslice : slice;
384             import grain.testing;
385 
386             auto k = 3;
387             auto a = uniform!float(i, k).slice.variable;
388             auto b = uniform!float(k, j).slice.variable;
389             auto gc = uniform!float(i, j).slice.variable;
390             MatMul!float func;
391             gradCheck(func, tuple(a, b), gc, 1e-3, 1e-3, 1e-3);
392 
393             version (grain_cuda) {
394                 import numir.testing;
395                 MatMul!float func2;
396                 auto hc = func.forward(a, b);
397                 auto dc = func2.forward(a.to!DeviceStorage, b.to!DeviceStorage);
398                 assert(approxEqual(dc.to!HostStorage.sliced, hc.sliced));
399                 auto hgab = func.backward(gc);
400                 auto dgab = func2.backward(gc.to!DeviceStorage);
401                 // writefln!"%s vs %s"(dgab[0].to!HostStorage.sliced, hgab[0].sliced);
402                 assert(approxEqual(dgab[0].to!HostStorage.sliced, hgab[0].sliced));
403                 assert(approxEqual(dgab[1].to!HostStorage.sliced, hgab[1].sliced));
404             }
405         }
406     }
407 }
408 
409 /// matmul with variable.backward
410 unittest {
411     import std.typecons;
412     import numir;
413     import mir.ndslice;
414     static import grain.config;
415 
416     grain.config.backprop = true;
417     auto func = new MatMul!float;
418     auto a = uniform!float(3, 4).slice.variable(true);
419     auto b = uniform!float(4, 2).slice.variable(true);
420     auto c = func.applyForward(a, b);
421     auto gc = uniform!float(3, 2).slice.variable;
422     auto ugc = UntypedVariable(gc);
423     c.backward(&ugc);
424     auto gab = func.backward(gc);
425     assert(a.gradSlice == gab[0].sliced);
426     assert(b.gradSlice == gab[1].sliced);
427 }
428 
429 /**
430    Add bias vector to matrix used inside grain.chain.Linear
431    TODO: generalize to broadcastable addition (use cudnnAddTensor)
432 */
433 struct AddBias(T) {
434     mixin FunctionCommon;
435 
436     import mir.ndslice : map, slice;
437     import std.typecons : tuple;
438     auto forward(Variable!(T, 2, HostStorage) a, Variable!(T, 1, HostStorage) b) {
439         assert(a.shape[1] == b.shape[0]);
440         auto ret = a.dup;
441         foreach (i; 0 .. a.shape[0]) {
442             ret.sliced[i][] += b.sliced;
443         }
444         return ret;
445     }
446 
447     auto backward(Variable!(T, 2, HostStorage) gy) {
448         import numir : alongDim;
449         import mir.math : sum;
450         auto gb = gy.sliced.alongDim!0.map!sum.slice.variable;
451         return tuple(gy, gb);
452     }
453 
454     version (grain_cuda) {
455         import grain.kernel : addBias, addBiasGrad;
456 
457         auto forward(Variable!(T, 2, DeviceStorage) a, Variable!(T, 1, DeviceStorage) b) {
458             assert(a.shape[1] == b.shape[0]);
459             auto y = a.dup;
460             auto n = cast(uint) y.data.length;
461             auto blen = cast(uint) b.data.length;
462             Global.kernel!addBias
463                 .call(y.data.ptr, b.data.ptr, blen, n).launch(n);
464             return y;
465         }
466 
467         auto backward(Variable!(T, 2, DeviceStorage) gy) {
468             auto gb = CuArray!T(gy.shape[1]);
469             gb.zero_();
470             auto n = cast(uint) gy.data.length;
471             auto blen = cast(uint) gb.length;
472             Global.kernel!addBiasGrad
473                 .call(gy.data.ptr, gb.ptr, blen, n).launch(n);
474             return tuple(gy, Variable!(T, 1, DeviceStorage)(false, [cast(int) blen], [1], gb));
475         }
476     }
477 }
478 
479 ///
480 unittest {
481     import std.typecons;
482     import grain.testing;
483     import numir;
484     import mir.ndslice;
485 
486     AddBias!float func;
487     auto hx = [[0f, 1f], [2f, 3f], [4f, 5f]].variable; // 3x2
488     auto hb = [-1f, 1f].variable; // 2
489     auto hy = func.forward(hx, hb);
490     assert(hy.sliced == [[-1f, 2f], [1f, 4f], [3f, 6f]]);
491 
492     auto hgy = uniform!float(hy.shape.castArray!size_t).slice.variable;
493     auto hgxb = func.backward(hgy);
494     assert(hgxb[0].sliced == hgy.sliced);
495     assert(hgxb[1].sliced == [hgy.sliced[0, 0] + hgy.sliced[1, 0] + hgy.sliced[2, 0],
496                               hgy.sliced[0, 1] + hgy.sliced[1, 1] + hgy.sliced[2, 1]]);
497     gradCheck(func, tuple(hx, hb), hgy);
498 
499     version (grain_cuda) {
500         auto dx = hx.to!DeviceStorage;
501         auto db = hb.to!DeviceStorage;
502         auto dy = func.forward(dx, db);
503         assert(dy.to!HostStorage.sliced == [[-1f, 2f], [1f, 4f], [3f, 6f]]);
504         auto dgy = hgy.to!DeviceStorage;
505         auto dgxb = func.backward(dgy);
506         assert(dgxb[0].to!HostStorage.sliced == hgxb[0].sliced);
507         assert(dgxb[1].to!HostStorage.sliced == hgxb[1].sliced);
508     }
509 }
510 
511 /// Emebedding ID into vector. TODO: support N-dim input. support sparse weight matrix
512 struct Embedding(T) {
513     import std.range : enumerate;
514     import numir : view, empty, zeros;
515 
516     Variable!(int, 1, HostStorage) hx;
517     uint[2] wshape;
518 
519     auto forward(Variable!(T, 2, HostStorage) weight, Variable!(int, 1, HostStorage) ids) {
520         this.hx = ids; // TODO if train
521         this.wshape = weight.shape; // TODO if train
522         auto ys = empty!T(ids.shape[0], weight.shape[1]);
523         foreach (i, id; ids.sliced.enumerate) {
524             ys[i, 0..$] = weight.sliced[id, 0..$];
525         }
526         return ys.variable(weight.requiresGrad);
527     }
528 
529     auto backward(Variable!(T, 2, HostStorage) gy) {
530         auto gw = zeros!T(this.wshape.castArray!size_t);
531         foreach (i, id; this.hx.sliced.enumerate) {
532             gw[id, 0..$] += gy.sliced[i];
533         }
534         return tuple(gw.variable(gy.requiresGrad), typeof(this.hx)());
535     }
536 
537     version (grain_cuda) {
538         Variable!(int, 1, DeviceStorage) dx;
539 
540         auto forward(Variable!(T, 2, DeviceStorage) weight, Variable!(int, 1, DeviceStorage) ids) {
541             import grain.kernel : embedding;
542             this.dx = ids; // TODO if train
543             this.wshape = weight.shape; // TODO if train
544             auto ys = uninitVariable!(T, DeviceStorage, 2)([ids.shape[0], weight.shape[1]], weight.requiresGrad);
545             Global.kernel!embedding
546                 .call(weight.data.ptr, ids.data.ptr, ys.data.ptr, weight.shape[0], weight.shape[1], ids.shape[0])
547                 .launch(weight.shape[1] * ids.shape[0]);
548             return ys;
549         }
550 
551         auto backward(Variable!(T, 2, DeviceStorage) gy) {
552             import grain.kernel : embeddingGrad;
553             auto gw = uninitVariable!(T, DeviceStorage, 2)(this.wshape, gy.requiresGrad);
554             Global.kernel!embeddingGrad
555                 .call(gw.data.ptr, this.dx.data.ptr, gy.data.ptr, this.wshape[0], this.wshape[1], this.dx.shape[0])
556                 .launch(this.wshape[1] * this.dx.shape[0]);
557             return tuple(gw, typeof(this.dx)());
558         }
559     }
560 
561     mixin FunctionCommon;
562 }
563 
564 ///
565 unittest {
566     import numir;
567 
568     Embedding!float embed;
569     auto w = [[1.0f, 2.0f], [3.0f, 4.0f]].nparray.variable;
570     auto x = [0, 1, 0].variable;
571     auto y = embed.forward(w, x);
572     assert(y.sliced == [[1,2],[3,4],[1,2]]);
573 
574     auto gy = [[1f, 2f], [-1f, -2f], [1f, 0f]].nparray.variable;
575     auto gw = embed.backward(gy)[0];
576     assert(gw.sliced == [[2f, 2f], [-1f, -2f]]);
577 
578     version (grain_cuda) {
579         Embedding!float dembed;
580         auto dy = dembed.forward(w.to!DeviceStorage, x.to!DeviceStorage);
581         assert(dy.to!HostStorage.sliced == y.sliced);
582         auto dgw = dembed.backward(gy.to!DeviceStorage)[0];
583         assert(dgw.to!HostStorage.sliced == gw.sliced);
584     }
585 }
586 
587 
588 void generateStrides(const int* dimA, int* strideA, int nbDims, bool isNchw) {
589     if (isNchw) {
590         strideA[nbDims-1] = 1 ;
591         for(int d = nbDims-2 ; d >= 0 ; d--) {
592             strideA[d] = strideA[d+1] * dimA[d+1] ;
593         }
594     } else {
595         strideA[1] = 1;
596         strideA[nbDims-1] = strideA[1]*dimA[1];
597         for(int d = nbDims-2 ; d >= 2 ; d--) {
598             strideA[d] = strideA[d+1] * dimA[d+1] ;
599         }
600         strideA[0] = strideA[2]*dimA[2];
601     }
602 }
603 
604 /** Convert a linear index
605 i = d_1 s_1 ... s_n + d_2 s_2 ... s_n + d_n-1 s_n + d_n
606 into a multidimensional index
607 (d_1, d_2, ..., d_n)
608 */
609 void lin2dim(size_t length)(int id, scope ref int[length] ids, const ref int[length] dims) {
610     int idrem = id ;
611     int prod  = 1 ; // accumulates the product of the dimensions
612     foreach_reverse(i; 0 .. length) {
613         ids[i] = (idrem / prod) % dims[i] ;
614         idrem = id - ids[i] * prod ;
615         prod *= dims[i] ;
616     }
617 }
618 
619 void doEpilog(float[] o, int idx, float alphaAcc, float beta) {
620     if( beta == 0f ) {
621         o[idx] = alphaAcc;
622     } else {
623         o[idx] = alphaAcc + o[idx]*beta;
624     }
625 }
626 
627 @nogc pure @safe int dim2lin(size_t length)(const ref int[length] ids, const int[] strides) {
628     assert(length == strides.length);
629     import mir.ndslice;
630     import mir.math;
631     return sum(ids.sliced * strides.sliced);
632 }
633 
634 /// Reference CPU implementation of Convolution function
635 static struct ConvolutionRefImpl(T, size_t imDims, bool isConv=false, bool isNchw = true) {
636     enum int nbDims = imDims + 2;
637 
638     static void forward(const T[] inputData,
639                         const T[] filterData,
640                         T[] outputData,
641                         T alpha,
642                         T beta,
643                         const int[nbDims] inDims,
644                         const int[nbDims] filDims,
645                         const int[nbDims] outDims,
646                         const int[nbDims] inStride,
647                         const int[nbDims] filStride,
648                         const int[nbDims] outStride,
649                         const int[imDims] stride,
650                         const int[imDims] pad,
651                         const int[imDims] dilation,
652                         )
653     in {
654         // Sanity checks
655         // in     is n x c x h x w
656         // out    is n x k x p x q
657         // filter is k x c x r x s
658         assert(inDims[0] == outDims[0]); // n
659         assert(inDims[1] == filDims[1]); // k
660         assert(outDims[1] == filDims[0]); // c
661     } do {
662         import std.algorithm : reduce;
663 
664         immutable nPixelsOut = outDims[2..$].reduce!"a * b";
665         immutable nPixelsFil = filDims[2..$].reduce!"a * b";
666 
667         // Used to store coordinates
668         int[imDims] filIds, outIds, inIds, tmpIds;
669         // For each image in the output
670         // TODO these three loops can be parallelized without atomic ops
671         foreach (ni; 0 .. outDims[0]) {
672             // For each feature layer of the output
673             foreach (ki; 0 .. outDims[1]) {
674                 immutable outputOffset = ni * outStride[0] + ki * outStride[1] ;
675                 // Loop over all entries of the result
676                 foreach (outId; 0 .. nPixelsOut) {
677                     // Get output pixel ids
678                     lin2dim(outId, outIds, outDims[2..$]) ; // Skip n and k dimensions
679                     // Now we get the coordinates in input space of
680                     // the "top left" corner of the filter: multiply by stride and remove pad
681                     inIds[] = outIds[] * stride[] - pad[];
682                     // We then accumulate
683                     T tmp = 0;
684                     foreach (ci; 0 .. inDims[1]) {
685                         immutable inputOffset = ni * inStride[0] + ci * inStride[1] ;
686                         immutable filterOffset = ki * filStride[0] + ci * filStride[1] ;
687                         foreach (filId; 0 .. nPixelsFil) {
688                             // Get the position of the pixel
689                             lin2dim(filId, filIds, filDims[2..$]) ;
690                             // Compute the corresponding output pixel
691                             // and check wether we are in the padding area on the fly too
692                             // (not that for convolution, we flip the image patch
693                             // (equivalent to flipping the filter patch))
694                             bool inside = true ;
695                             for (int d = 0; d < imDims && inside; d++) {
696                                 if (isConv) {
697                                     tmpIds[d] = inIds[d] + dilation[d] * (filDims[2+d]-1 - filIds[d]);
698                                 } else {
699                                     tmpIds[d] = inIds[d] + dilation[d] * filIds[d];
700                                 }
701                                 // If we are in the padding area: stop and skip computations
702                                 inside &= (tmpIds[d] >= 0 && tmpIds[d] < inDims[2+d]) ;
703                             }
704                             if (inside) {
705                                 immutable actualTmpId = inputOffset + dim2lin(tmpIds, inStride[2..$]);
706                                 immutable actualFilId = filterOffset + dim2lin(filIds, filStride[2..$]);
707                                 tmp += filterData[actualFilId] * inputData [actualTmpId];
708                             }
709                         }
710                     }
711                     // We put the result in the output
712                     immutable actualOutId = outputOffset + dim2lin(outIds, outStride[2..$]);
713                     doEpilog(outputData, actualOutId, alpha*tmp, beta);
714                 }
715             }
716         }
717     }
718 
719     static void backwardData(const T[] weight,
720                              const T[] top_diff,
721                              scope T[] output,
722                              float alpha,
723                              float beta,
724 
725                              const int[nbDims] inDims,
726                              const int[nbDims] filDims,
727                              const int[nbDims] outDims,
728 
729                              const int[nbDims] inStride,
730                              const int[nbDims] filterStride,
731                              const int[nbDims] outStride,
732 
733                              const int[imDims] stride,
734                              const int[imDims] pad,
735                              const int[imDims] dilation)
736     in {
737         // Sanity checks
738         // output is n x c x h x w
739         // diff   is n x k x p x q
740         // filter is k x c x r x s
741         assert(inDims[0] == outDims[0]); // n
742         assert(inDims[1] == filDims[0]); // k
743         assert(outDims[1] == filDims[1]); // c
744     } do {
745         import std.algorithm : map, any, all, reduce;
746         import std.range : iota;
747 
748         // Number of pixels in output
749         immutable nPixelsOut = outDims[2..$].reduce!"a * b";
750         // Number of pixels in filter
751         immutable nPixelsFil = filDims[2..$].reduce!"a * b";
752 
753         int[imDims] outIds, filIds, accIds;
754         // TODO these three loops can be parallelized without atomic ops
755         foreach (ni; 0 .. outDims[0]) {
756             foreach (ci; 0 .. outDims[1]) {
757                 foreach (outIdx; 0 .. nPixelsOut) {
758                     lin2dim(outIdx, outIds, outDims[2..$]);
759                     T val = 0;
760                     // For every diff channel (k)
761                     foreach (ki; 0 .. inDims[1]) {
762                         immutable offsetFilter = ki * filterStride[0] + ci * filterStride[1];
763                         immutable offsetDiff = ni * inStride[0] + ki * inStride[1];
764 
765                         foreach (filIdx; 0 .. nPixelsFil) {
766                             lin2dim(filIdx, filIds, filDims[2..$]);
767 
768                             // Fetch the value in filter and diff, product and accumulate
769                             // So basically, for the convolution,
770                             // we replace r by dim-1-r and s by dim-1-s to "flip" the filter
771                             // We can then just reason in term of correlation
772                             accIds[] = outIds[] + pad[];
773                             if (isConv){
774                                 accIds[] -= (filDims[2..$] - 1 - filIds[]) * dilation[];
775                             } else {
776                                 accIds[] -= filIds[] * dilation[];
777                             }
778                             immutable outtaStride = iota(imDims).map!(i => accIds[i] % stride[i]).any;
779                             if (outtaStride) {
780                                 continue;
781                             }
782                             accIds[] /= stride[];
783 
784                             immutable inBounds = iota(imDims).map!(i => 0 <= accIds[i] && accIds[i] < inDims[i+2]).all;
785                             if (inBounds) {
786                                 immutable filterIdx = offsetFilter + dim2lin(filIds, filterStride[2..$]);
787                                 immutable diffIdx = offsetDiff + dim2lin(accIds, inStride[2..$]);
788                                 val += top_diff[diffIdx] * weight[filterIdx];
789                             }
790                         }
791                         immutable offsetOut = ni * outStride[0] + ci * outStride[1];
792                         doEpilog(output, offsetOut + outIdx, alpha*val, beta);
793                     }
794                 }
795             }
796         }
797     }
798 
799     static void backwardWeight(/*const TensorNdTestDesc_t *tensorInputDesc,*/
800                                const T[] image,
801                                /*const TensorNdTestDesc_t *tensorDiffDesc,*/
802                                const T[] diffData,
803                                /*const ConvNdTestDesc_t *convDesc,*/
804                                /*const TensorNdTestDesc_t *filterOutputDesc,*/
805                                float alpha,
806                                float beta,
807                                scope T[] output,
808 
809                                const int[nbDims] inDims,
810                                const int[nbDims] filDims,
811                                const int[nbDims] diffDims,
812                                const int[nbDims] inStride,
813                                const int[nbDims] filterStride,
814                                const int[nbDims] diffStride,
815                                const int[imDims] stride,
816                                const int[imDims] pad,
817                                const int[imDims] dilation)
818     in {
819         // Some sanity checks
820         // image   is n x c x h x w
821         // diff    is n x k x p x q
822         // filter  is k x c x r x s
823         assert(inDims[0] == diffDims[0]) ;
824         assert(inDims[1] == filDims[1]) ;
825         assert(diffDims[1]  == filDims[0]) ;
826 
827     } do {
828         import std.algorithm : all, sum, map, reduce;
829         import std.range : iota;
830 
831         // Number of pixels in output
832         immutable nPixelsDiff = diffDims[2..$].reduce!"a * b";
833         // Number of pixels in filter
834         immutable nPixelsFil = filDims[2..$].reduce!"a * b";
835 
836         // For every filter pixel (k x c x r x s)
837         int[imDims] filIds, diffIds, accIds;
838         // TODO these three loops can be parallelized without atomic ops
839         foreach (ki; 0 .. filDims[0]){
840             foreach (ci; 0 .. filDims[1]) {
841                 foreach (filIdx; 0 .. nPixelsFil) {
842                     lin2dim(filIdx, filIds, filDims[2..$]);
843                     T val = 0;
844                     // For every image (n)
845                     foreach (ni; 0 .. inDims[0]) { // Sum over the batch
846                         immutable offsetIn = ni * inStride[0] + ci * inStride[1] ;
847                         immutable offsetDiff = ni * diffStride[0] + ki * diffStride[1] ;
848                         // For every pixel in diff
849                         foreach (diffIdx; 0 .. nPixelsDiff) {
850                             lin2dim(diffIdx, diffIds, diffDims[2..$]);
851                             // Fetch the value in image and diff, product and accumulate
852                             accIds[] = diffIds[] * stride[] - pad[];
853 
854                             // Convolution = Correlation with a flipped filter
855                             // So basically, for the convolution, we replace r by dim-1-r and s
856                             // by dim-1-s to "flip" the filter
857                             // We can then just reason in term of correlation
858                             if (isConv){
859                                 accIds[] += (filDims[2..$] - 1 - filIds[]) * dilation[];
860                             } else {
861                                 // The effect of dilation on the gradient is to start the "zone of influence"
862                                 // of a given pixel further into the image, so dilation
863                                 // only produces a shift in x and y
864                                 accIds[] += filIds[] * dilation[];
865                             }
866                             // Image value
867                             immutable inBounds = iota(imDims).map!(i => 0 <= accIds[i] && accIds[i] < inDims[i+2]).all;
868                             if (inBounds) {
869                                 immutable imId = offsetIn + dim2lin(accIds, inStride[2..$]);
870                                 // Diff value
871                                 immutable diffId  = offsetDiff + dim2lin(diffIds, diffStride[2..$]);
872                                 // Prod and accumulate
873                                 val += image[imId] * diffData[diffId];
874                             }
875                         }
876                     }
877                     immutable offsetFilter = ki * filterStride[0] + ci * filterStride[1];
878                     doEpilog(output, offsetFilter + filIdx, alpha*val, beta);
879                 }
880             }
881         }
882     }
883 }
884 
885 /++
886  Convolution/Cross-correration function
887 
888  TODO add cudnn wrapped functions
889  +/
890 struct Convolution(T, size_t imDims, bool isConv=false, bool isNchw = true) {
891     int[imDims] stride;
892     int[imDims] pad;
893     int[imDims] dilation;
894     enum int nbDims = imDims + 2;
895     enum int ngroup=1; // TODO support ngroup > 1?
896     alias RefImpl = ConvolutionRefImpl!(T, imDims, isConv, isNchw);
897 
898     Variable!(T, nbDims, HostStorage) hx, hw;
899 
900     /// https://pytorch.org/docs/master/nn.html#convolution-layers
901     auto outShape(uint[nbDims] inShape, uint[nbDims] weightShape) {
902         uint[nbDims] ret;
903         ret[0] = inShape[0]; // batchsize
904         ret[1] = weightShape[0]; // output ch size
905         assert(inShape[1] == weightShape[1]);
906         auto kernel = weightShape[2..$];
907         foreach (d; 0 .. imDims) {
908             ret[d+2] = cast(uint)
909                 ((inShape[d+2] + 2 * pad[d] - dilation[d] * (kernel[d] - 1) - 1)
910                  / stride[d] + 1);
911         }
912         return ret;
913     }
914 
915     void setDefault() {
916         foreach (d; 0..imDims) {
917             if (this.stride[d] == 0) {
918                 this.stride[d] = 1;
919             }
920             if (this.dilation[d] == 0) {
921                 this.dilation[d] = 1;
922             }
923         }
924     }
925 
926     auto forward(Variable!(T, nbDims, HostStorage) x, Variable!(T, nbDims, HostStorage) w) {
927         this.setDefault();
928         // TODO if train
929         this.hx = x;
930         this.hw = w;
931         auto y = uninitVariable!(T, HostStorage)(outShape(x.shape, w.shape));
932         RefImpl.forward(x.data, w.data, y.data,
933                         1f, 0f,
934                         x.shape.castArray!int, w.shape.castArray!int, y.shape.castArray!int,
935                         x.strides, w.strides, y.strides,
936                         this.stride, this.pad, this.dilation);
937         return y;
938     }
939 
940 
941     auto backward(Variable!(T, nbDims, HostStorage) gy) {
942         // TODO use requires_grad for skipping grad calc
943         static assert(3 <= nbDims && nbDims < 6, "cudnn7 only supports 3, 4, 5 dimensionl inputs");
944         auto gx = this.hx.uninit;
945         gx.data.zero_();
946         RefImpl.backwardData(this.hw.data, gy.data, gx.data,
947                              1f, 0f,
948                              gy.shape.castArray!int, this.hw.shape.castArray!int, gx.shape.castArray!int,
949                              gy.strides, this.hw.strides, gx.strides,
950                              this.stride, this.pad, this.dilation);
951 
952         auto gw = this.hw.uninit;
953         gw.data.zero_();
954         RefImpl.backwardWeight(this.hx.data, gy.data,
955                                1f, 0f,
956                                gw.data,
957                                this.hx.shape.castArray!int, gw.shape.castArray!int, gy.shape.castArray!int,
958                                this.hx.strides, gw.strides, gy.strides,
959                                this.stride, this.pad, this.dilation);
960         return tuple(gx, gw);
961     }
962 
963     version (grain_cuda) {
964         import derelict.cudnn7;
965         import grain.cudnn;
966         // TODO implement benchmark mode to search the best algo
967         cudnnConvolutionFwdAlgo_t forwardAlgo = // CUDNN_CONVOLUTION_FWD_ALGO_DIRECT;
968             CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
969         cudnnConvolutionBwdDataAlgo_t backwardAlgo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;;
970         // CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
971 
972         Variable!(T, nbDims, DeviceStorage) dx, dw;
973 
974         auto forward(Variable!(T, nbDims, DeviceStorage) x, Variable!(T, nbDims, DeviceStorage) w) {
975             this.setDefault();
976             // TODO if train
977             this.dx = x;
978             this.dw = w;
979             auto y = uninitVariable!(T, DeviceStorage)(outShape(x.shape, w.shape));
980             grain.cudnn.convForward!(isConv, isNchw)(x, w, y, this.stride, this.pad, this.dilation,
981                                                      this.ngroup, this.forwardAlgo);
982             return y;
983         }
984 
985 
986         auto backward(Variable!(T, nbDims, DeviceStorage) gy) {
987             // TODO use requires_grad for skipping grad calc
988             auto gx = this.dx.uninit;
989             gx.data.zero_();
990             auto gw = this.dw.uninit;
991             gw.data.zero_();
992             // TODO separate data/weight backward
993             grain.cudnn.convBackward!(isConv, isNchw)
994                 (gx, this.dx, gw, this.dw, gy, this.stride, this.pad, this.dilation,
995                  this.ngroup, this.backwardAlgo);
996             return tuple(gx, gw);
997         }
998 
999     }
1000 
1001     mixin FunctionCommon;
1002 }
1003 
1004 /** Conv1d pytorch equality test
1005    ``` python
1006    >>> iota = lambda s: torch.arange(torch.prod(torch.tensor(s))).view(s)
1007    >>> torch.nn.functional.conv1d(iota([2, 3, 4]), iota([5, 3, 3]))
1008    tensor([[[  258.,   294.],
1009             [  663.,   780.],
1010             [ 1068.,  1266.],
1011             [ 1473.,  1752.],
1012             [ 1878.,  2238.]],
1013 
1014            [[  690.,   726.],
1015             [ 2067.,  2184.],
1016             [ 3444.,  3642.],
1017             [ 4821.,  5100.],
1018             [ 6198.,  6558.]]])
1019    >>> y.shape
1020    [2, 5, 2]
1021 
1022    >>> x = iota([2, 3, 4])
1023    >>> x.requires_grad = True
1024    >>> w = iota([5, 3, 3])
1025    >>> w.requires_grad = True
1026    >>> y = torch.nn.functional.conv1d(x, w)
1027    >>> y.backward(torch.ones_like(y))
1028    >>> x.grad
1029    tensor(
1030        [[[  90.,  185.,  195.,  100.],
1031          [ 105.,  215.,  225.,  115.],
1032          [ 120.,  245.,  255.,  130.]],
1033 
1034         [[  90.,  185.,  195.,  100.],
1035          [ 105.,  215.,  225.,  115.],
1036          [ 120.,  245.,  255.,  130.]]])
1037    >>> w.grad
1038    tensor([[[ 26.,  30.,  34.],
1039             [ 42.,  46.,  50.],
1040             [ 58.,  62.,  66.]],
1041 
1042            [[ 26.,  30.,  34.],
1043             [ 42.,  46.,  50.],
1044             [ 58.,  62.,  66.]],
1045 
1046            [[ 26.,  30.,  34.],
1047             [ 42.,  46.,  50.],
1048             [ 58.,  62.,  66.]],
1049 
1050            [[ 26.,  30.,  34.],
1051             [ 42.,  46.,  50.],
1052             [ 58.,  62.,  66.]],
1053 
1054            [[ 26.,  30.,  34.],
1055             [ 42.,  46.,  50.],
1056             [ 58.,  62.,  66.]]])
1057 
1058    ```
1059 */
1060 unittest {
1061     import std.stdio;
1062     import mir.ndslice;
1063     import numir;
1064     auto x = iota(2, 3, 4).as!float.slice.variable;
1065     auto w = iota(5, 3, 3).as!float.slice.variable;
1066     Convolution!(float, 1) conv;
1067     auto y = conv.forward(x, w);
1068     auto yx = [[[  258.,   294.],
1069                 [  663.,   780.],
1070                 [ 1068.,  1266.],
1071                 [ 1473.,  1752.],
1072                 [ 1878.,  2238.]],
1073 
1074                [[  690.,   726.],
1075                 [ 2067.,  2184.],
1076                 [ 3444.,  3642.],
1077                 [ 4821.,  5100.],
1078                 [ 6198.,  6558.]]];
1079     assert(y.sliced == yx);
1080 
1081     // test backward
1082     auto gy = y.uninit;
1083     gy.data[] = 1;
1084     auto gs = conv.backward(gy);
1085     auto gx = gs[0];
1086     auto gw = gs[1];
1087 
1088     auto gxx = [[[  90.,  185.,  195.,  100.],
1089                  [ 105.,  215.,  225.,  115.],
1090                  [ 120.,  245.,  255.,  130.]],
1091 
1092                 [[  90.,  185.,  195.,  100.],
1093                  [ 105.,  215.,  225.,  115.],
1094                  [ 120.,  245.,  255.,  130.]]];
1095     assert(gx.sliced == gxx);
1096 
1097     auto gwx = [[[ 26.,  30.,  34.],
1098                  [ 42.,  46.,  50.],
1099                  [ 58.,  62.,  66.]],
1100 
1101                 [[ 26.,  30.,  34.],
1102                  [ 42.,  46.,  50.],
1103                  [ 58.,  62.,  66.]],
1104 
1105                 [[ 26.,  30.,  34.],
1106                  [ 42.,  46.,  50.],
1107                  [ 58.,  62.,  66.]],
1108 
1109                 [[ 26.,  30.,  34.],
1110                  [ 42.,  46.,  50.],
1111                  [ 58.,  62.,  66.]],
1112 
1113                 [[ 26.,  30.,  34.],
1114                  [ 42.,  46.,  50.],
1115                  [ 58.,  62.,  66.]]];
1116     assert(gw.sliced == gwx);
1117 
1118     import grain.testing : gradCheck;
1119     auto hx = uniform!float(x.shape.castArray!size_t).slice.variable;
1120     auto hw = uniform!float(w.shape.castArray!size_t).slice.variable;
1121     auto hgy = uniform!float(y.shape.castArray!size_t).slice.variable;
1122     auto hy = conv.forward(hx, hw);
1123     auto hgx = conv.backward(hgy);
1124     gradCheck(conv, tuple(hx, hw), hgy, 1e-3, 1e-3, 1e-2);
1125 
1126     version (grain_cuda) {
1127         auto dy = conv.forward(hx.to!DeviceStorage, hw.to!DeviceStorage);
1128         auto dgx = conv.backward(hgy.to!DeviceStorage);
1129         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
1130         assert(approxEqual(dgx[0].to!HostStorage.sliced, hgx[0].sliced));
1131         assert(approxEqual(dgx[1].to!HostStorage.sliced, hgx[1].sliced));
1132     }
1133 }
1134 
1135 /** Conv2d pytorch equality test
1136    ``` python
1137    >>> import torch
1138    >>> iota = lambda s: torch.arange(torch.prod(torch.tensor(s))).view(s)
1139    >>> x = iota([2, 3, 4, 4])
1140    >>> px.requires_grad = True
1141    >>> w = iota([2, 3, 3, 3])
1142    >>> w.requires_grad = True
1143    >>> y = torch.nn.functional.conv2d(x, w)
1144    >>> y
1145    tensor([[[[ 10197.,  10548.],
1146              [ 11601.,  11952.]],
1147 
1148             [[ 25506.,  26586.],
1149              [ 29826.,  30906.]]],
1150 
1151 
1152             [[[ 27045.,  27396.],
1153               [ 28449.,  28800.]],
1154 
1155              [[ 77346.,  78426.],
1156               [ 81666.,  82746.]]]])
1157 
1158    >>> y = torch.nn.functional.conv1d(iota([2, 3, 4]), w)
1159    >>> y.backward(torch.ones_like(y))
1160    >>> x.grad
1161    tensor(
1162        [[[[  27.,   56.,   60.,   31.],
1163           [  60.,  124.,  132.,   68.],
1164           [  72.,  148.,  156.,   80.],
1165           [  39.,   80.,   84.,   43.]],
1166 
1167          [[  45.,   92.,   96.,   49.],
1168           [  96.,  196.,  204.,  104.],
1169           [ 108.,  220.,  228.,  116.],
1170           [  57.,  116.,  120.,   61.]],
1171 
1172          [[  63.,  128.,  132.,   67.],
1173           [ 132.,  268.,  276.,  140.],
1174           [ 144.,  292.,  300.,  152.],
1175           [  75.,  152.,  156.,   79.]]],
1176 
1177 
1178         [[[  27.,   56.,   60.,   31.],
1179           [  60.,  124.,  132.,   68.],
1180           [  72.,  148.,  156.,   80.],
1181           [  39.,   80.,   84.,   43.]],
1182 
1183          [[  45.,   92.,   96.,   49.],
1184           [  96.,  196.,  204.,  104.],
1185           [ 108.,  220.,  228.,  116.],
1186           [  57.,  116.,  120.,   61.]],
1187 
1188          [[  63.,  128.,  132.,   67.],
1189           [ 132.,  268.,  276.,  140.],
1190           [ 144.,  292.,  300.,  152.],
1191           [  75.,  152.,  156.,   79.]]]])
1192    >>> w.grad
1193    tensor(
1194        [[[[ 212.,  220.,  228.],
1195           [ 244.,  252.,  260.],
1196           [ 276.,  284.,  292.]],
1197 
1198          [[ 340.,  348.,  356.],
1199           [ 372.,  380.,  388.],
1200           [ 404.,  412.,  420.]],
1201 
1202          [[ 468.,  476.,  484.],
1203           [ 500.,  508.,  516.],
1204           [ 532.,  540.,  548.]]],
1205 
1206 
1207         [[[ 212.,  220.,  228.],
1208           [ 244.,  252.,  260.],
1209           [ 276.,  284.,  292.]],
1210 
1211          [[ 340.,  348.,  356.],
1212           [ 372.,  380.,  388.],
1213           [ 404.,  412.,  420.]],
1214 
1215          [[ 468.,  476.,  484.],
1216           [ 500.,  508.,  516.],
1217           [ 532.,  540.,  548.]]]])
1218    ```
1219 */
1220 unittest {
1221     import std.stdio;
1222     import mir.ndslice;
1223     import numir;
1224     auto x = iota(2, 3, 4, 4).as!float.slice.variable;
1225     auto w = iota(2, 3, 3, 3).as!float.slice.variable;
1226     Convolution!(float, 2) conv;
1227     auto y = conv.forward(x, w);
1228     auto yx = [[[[ 10197.,  10548.],
1229                  [ 11601.,  11952.]],
1230                 [[ 25506.,  26586.],
1231                  [ 29826.,  30906.]]],
1232                [[[ 27045.,  27396.],
1233                  [ 28449.,  28800.]],
1234                 [[ 77346.,  78426.],
1235                  [ 81666.,  82746.]]]];
1236     assert(y.sliced == yx);
1237 
1238     // test backward
1239     auto gy = y.uninit;
1240     gy.data[] = 1;
1241     auto gs = conv.backward(gy);
1242     auto gx = gs[0];
1243     auto gw = gs[1];
1244 
1245     auto gxx = [[[[  27.,   56.,   60.,   31.],
1246                   [  60.,  124.,  132.,   68.],
1247                   [  72.,  148.,  156.,   80.],
1248                   [  39.,   80.,   84.,   43.]],
1249 
1250                  [[  45.,   92.,   96.,   49.],
1251                   [  96.,  196.,  204.,  104.],
1252                   [ 108.,  220.,  228.,  116.],
1253                   [  57.,  116.,  120.,   61.]],
1254 
1255                  [[  63.,  128.,  132.,   67.],
1256                   [ 132.,  268.,  276.,  140.],
1257                   [ 144.,  292.,  300.,  152.],
1258                   [  75.,  152.,  156.,   79.]]],
1259 
1260 
1261                 [[[  27.,   56.,   60.,   31.],
1262                   [  60.,  124.,  132.,   68.],
1263                   [  72.,  148.,  156.,   80.],
1264                   [  39.,   80.,   84.,   43.]],
1265 
1266                  [[  45.,   92.,   96.,   49.],
1267                   [  96.,  196.,  204.,  104.],
1268                   [ 108.,  220.,  228.,  116.],
1269                   [  57.,  116.,  120.,   61.]],
1270 
1271                  [[  63.,  128.,  132.,   67.],
1272                   [ 132.,  268.,  276.,  140.],
1273                   [ 144.,  292.,  300.,  152.],
1274                   [  75.,  152.,  156.,   79.]]]];
1275     assert(gx.sliced == gxx);
1276 
1277     auto gwx = [[[[ 212.,  220.,  228.],
1278                   [ 244.,  252.,  260.],
1279                   [ 276.,  284.,  292.]],
1280                  [[ 340.,  348.,  356.],
1281                   [ 372.,  380.,  388.],
1282                   [ 404.,  412.,  420.]],
1283                  [[ 468.,  476.,  484.],
1284                   [ 500.,  508.,  516.],
1285                   [ 532.,  540.,  548.]]],
1286                 [[[ 212.,  220.,  228.],
1287                   [ 244.,  252.,  260.],
1288                   [ 276.,  284.,  292.]],
1289                  [[ 340.,  348.,  356.],
1290                   [ 372.,  380.,  388.],
1291                   [ 404.,  412.,  420.]],
1292                  [[ 468.,  476.,  484.],
1293                   [ 500.,  508.,  516.],
1294                   [ 532.,  540.,  548.]]]];
1295     assert(gw.sliced == gwx);
1296 
1297     import grain.testing : gradCheck;
1298     auto hx = uniform!float(x.shape.castArray!size_t).slice.variable;
1299     auto hw = uniform!float(w.shape.castArray!size_t).slice.variable;
1300     auto hgy = uniform!float(y.shape.castArray!size_t).slice.variable;
1301     auto hy = conv.forward(hx, hw);
1302     auto hgx = conv.backward(hgy);
1303     gradCheck(conv, tuple(hx, hw), hgy, 1e-3, 1e-3, 1e-2);
1304 
1305     version (grain_cuda) {
1306         auto dy = conv.forward(hx.to!DeviceStorage, hw.to!DeviceStorage);
1307         auto dgx = conv.backward(hgy.to!DeviceStorage);
1308         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
1309         assert(approxEqual(dgx[0].to!HostStorage.sliced, hgx[0].sliced));
1310         assert(approxEqual(dgx[1].to!HostStorage.sliced, hgx[1].sliced));
1311     }
1312 }
1313 
1314 
1315 unittest {
1316     import numir;
1317     import mir.ndslice;
1318     import grain.testing : gradCheck;
1319     import derelict.cudnn7;
1320     static foreach (i; 1..4) {{
1321         size_t[i+2] xshape, wshape;
1322         xshape[] = 3;
1323         wshape[] = 2;
1324         xshape[1] = 2;
1325         Convolution!(float, i) conv;
1326         // conv.forwardAlgo = CUDNN_CONVOLUTION_FWD_ALGO_GEMM;
1327 
1328         auto hx = uniform!float(xshape).slice.variable;
1329         auto hw = uniform!float(wshape).slice.variable;
1330         auto hy = conv.forward(hx, hw);
1331         auto hgy = uniform!float(hy.shape.castArray!size_t).slice.variable;
1332         auto hgx = conv.backward(hgy);
1333         gradCheck(conv, tuple(hx, hw), hgy, 1e-3, 1e-3, 1e-2);
1334 
1335         version (grain_cuda) {
1336             auto dy = conv.forward(hx.to!DeviceStorage, hw.to!DeviceStorage);
1337             auto dgx = conv.backward(hgy.to!DeviceStorage);
1338             assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
1339             assert(approxEqual(dgx[0].to!HostStorage.sliced, hgx[0].sliced));
1340             assert(approxEqual(dgx[1].to!HostStorage.sliced, hgx[1].sliced));
1341         }
1342     }}
1343 }