1 /**
2    A module for unary functions
3 
4    TODO: support cudnn functions (see PDF manual in .deb for detail https://developer.nvidia.com/cudnn)
5    - activation (e.g., clipped-relu, elu), cudnnActivationForward/Backward
6    - (non-log) softmax, cudnnSoftmaxForward/Backward
7    - sqrt not, cudnnOpTensor
8    - transform (e.g., contiguous or permute strides), cudnnTransformTensor
9    - reshape (i.e., view), ...???
10    - reduce (sum, prod, min, max, amax, avg), cudnnReduceTensor
11    - pool (max, average), cudnnPoolingForward/Backward
12    - dropout, cudnnDropoutForward/Backward
13  */
14 module grain.functions.unary;
15 
16 import grain.autograd;
17 import grain.cuda;
18 import grain.utility;
19 import grain.functions.common;
20 
21 version (grain_cuda) {
22     import grain.cudnn;
23 
24     // FIXME do not know why this mixin won't work
25     // mixin template CudnnActivation(T, size_t dim, cudnnActivationMode_t mode) {
26 
27     ///
28     enum CUDNN_ACTIVATION_IMPL_MIXIN = q{
29         // TODO support inplace
30         Variable!(T, dim, DeviceStorage) dx, dy;
31         ///
32         auto forward(Variable!(T, dim, DeviceStorage) x) {
33             // FIXME if train
34             this.dx = x.dup;
35             auto y = x.uninit;
36             activationForward!mode(x, y);
37             this.dy = y;
38             return y;
39         }
40         ///
41         auto backward(Variable!(T, dim, DeviceStorage) gy) {
42             auto gx = gy.uninit;
43             activationBackward!mode(gx, gy, this.dx, this.dy);
44             return gx;
45         }
46     };
47 }
48 
49 /// sigmoid function
50 struct Sigmoid(T, size_t dim) {
51     import mir.math : exp;
52     import std.math : tanh;
53     import mir.ndslice : sliced, slice, map;
54 
55     Variable!(T, dim, HostStorage) hy;
56 
57     ///
58     auto forward(Variable!(T, dim, HostStorage) x) {
59         enum z = T(0.5);
60         auto ys = x.sliced.map!(a => tanh(a * z) * z + z);
61         auto y = ys.slice.variable(x.requiresGrad);
62         this.hy = y;
63         return y;
64     }
65 
66     ///
67     auto backward(Variable!(T, dim, HostStorage) gy) {
68         auto gx = this.hy.dup;
69         gx.sliced[] *= T(1) - this.hy.sliced;
70         gx.sliced[] *= gy.sliced;
71         return gx;
72     }
73 
74     version (grain_cuda) {
75         // mixin CudnnActivation!(T, dim, CUDNN_ACTIVATION_TANH);
76         enum mode = CUDNN_ACTIVATION_SIGMOID;
77         mixin(CUDNN_ACTIVATION_IMPL_MIXIN);
78     }
79 
80     mixin FunctionCommon;
81 }
82 
83 ///
84 unittest {
85     // test CPU
86     import grain.testing;
87     import std.math : tanh;
88     import numir;
89 
90     auto func = new Sigmoid!(float, 1);
91     auto hx = [-1.0f, 1.0f, 0.0f].variable;
92     gradCheck(func, hx, [0.1f, 0.1f, 0.1f].variable);
93 
94     auto hy = func.forward(hx);
95     // assert(hy.data == [tanh(-1.0f), tanh(1.0f), tanh(0.0f)]);
96     auto hgy = [1.0f, 2.0f, 3.0f].variable;
97     auto hgx = func.backward(hgy);
98 
99     // test CUDA
100     version (grain_cuda) {
101         auto dfunc = new Sigmoid!(float, 1);
102         auto dx = hx.to!DeviceStorage;
103         auto dy = dfunc.forward(dx);
104         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
105         auto dgx = dfunc.backward(hgy.to!DeviceStorage);
106         assert(approxEqual(dgx.to!HostStorage.sliced, hgx.sliced));
107     }
108 }
109 
110 /// hyperbolic tangent
111 struct Tanh(T, size_t dim) {
112     import std.math : tanh;
113     import mir.ndslice : sliced, slice, map;
114 
115     Variable!(T, dim, HostStorage) hy;
116 
117     ///
118     auto forward(Variable!(T, dim, HostStorage) x) {
119         auto ys = x.sliced.map!tanh;
120         auto y = ys.slice.variable(x.requiresGrad);
121         this.hy = y;
122         return y;
123     }
124 
125     ///
126     auto backward(Variable!(T, dim, HostStorage) gy) {
127         auto gx = this.hy.dup;
128         gx.sliced[] *= gx.sliced; // hy ^^ 2
129         gx.sliced[] = T(1.0) - gx.sliced;
130         gx.sliced[] *= gy.sliced;
131         return gx; // .slice.variable;
132     }
133 
134     version (grain_cuda) {
135         // mixin CudnnActivation!(T, dim, CUDNN_ACTIVATION_TANH);
136         enum mode = CUDNN_ACTIVATION_TANH;
137         mixin(CUDNN_ACTIVATION_IMPL_MIXIN);
138     }
139 
140     mixin FunctionCommon;
141 }
142 
143 ///
144 unittest {
145     // test CPU
146     import grain.testing;
147     import std.math : tanh;
148     import numir;
149 
150     auto func = new Tanh!(float, 1);
151     auto hx = [-1.0f, 1.0f, 0.0f].variable;
152     gradCheck(func, hx, [0.1f, 0.1f, 0.1f].variable);
153 
154     auto hy = func.forward(hx);
155     assert(hy.data == [tanh(-1.0f), tanh(1.0f), tanh(0.0f)]);
156     auto hgy = [1.0f, 2.0f, 3.0f].variable;
157     auto hgx = func.backward(hgy);
158 
159     // test CUDA
160     version (grain_cuda) {
161         auto dfunc = new Tanh!(float, 1);
162         auto dx = hx.to!DeviceStorage;
163         auto dy = dfunc.forward(dx);
164         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
165         auto dgx = dfunc.backward(hgy.to!DeviceStorage);
166         assert(approxEqual(dgx.to!HostStorage.sliced, hgx.sliced));
167     }
168 }
169 
170 /// TODO implement scale with cudnnScaleTensor
171 
172 /// rectified linear unit nonlinearity (using cuDNN)
173 struct ReLU(T, size_t dim) {
174     mixin FunctionCommon;
175     bool inplace = false;
176     bool useCuDNN = true;
177     Variable!(T, dim, HostStorage) hx;
178 
179     auto forward(Variable!(T, dim, HostStorage) x) {
180         import mir.ndslice : each;
181 
182         // FIXME if train
183         this.hx = x.dup;
184         auto y = this.inplace ? x : x.dup;
185         y.sliced.each!((ref a) {
186             if (a < 0)
187                 a = 0;
188         });
189         return y;
190     }
191 
192     auto backward(Variable!(T, dim, HostStorage) gy) {
193         auto gx = gy.dup;
194         foreach (i; 0 .. gx.data.length) {
195             if (this.hx.data[i] < 0.0)
196                 gx.data[i] = 0.0;
197         }
198         return gx;
199     }
200 
201     // TODO use cudnn
202     version (grain_cuda) {
203         import grain.cudnn;
204 
205         Variable!(T, dim, DeviceStorage) dx, dy;
206 
207         auto forward(Variable!(T, dim, DeviceStorage) x) {
208             // FIXME if train
209             this.dx = x.dup;
210             auto y = this.inplace ? x : x.dup;
211 
212             if (this.useCuDNN) {
213                 activationForward!CUDNN_ACTIVATION_RELU(x, y);
214                 this.dy = y;
215             }
216             else {
217                 import grain.kernel : relu;
218 
219                 auto n = cast(uint) y.data.length; // FIXME use y.nElement
220                 Global.kernel!relu.call(y.data.ptr, n).launch(n);
221             }
222             return y;
223         }
224 
225         auto backward(Variable!(T, dim, DeviceStorage) gy) {
226             auto gx = gy.uninit;
227             if (this.useCuDNN) {
228                 activationBackward!CUDNN_ACTIVATION_RELU(gx, gy, dx, dy);
229             }
230             else {
231                 import grain.kernel : reluGrad;
232 
233                 auto n = cast(uint) gy.data.length;
234                 Global.kernel!reluGrad.call(gx.data.ptr, gy.data.ptr,
235                         this.dx.data.ptr, n).launch(n);
236             }
237             return gx;
238         }
239     }
240 }
241 
242 /// test relu
243 unittest {
244     import grain.testing : gradCheck;
245 
246     foreach (inplace; [true, false]) {
247         foreach (useCuDNN; [true, false]) {
248             auto func = new ReLU!(float, 1);
249             func.inplace = inplace;
250             func.useCuDNN = useCuDNN;
251 
252             // test CPU
253             {
254                 auto x = [-1.0f, 1.0f, 0.0f].variable;
255                 // fail because of non-smooth function?
256                 // gradCheck(func, x, [0.1f, 0.1f, 0.1f].variable);
257 
258                 auto y = func.forward(x);
259                 assert(x.data == (inplace ? y.data : [-1.0f, 1.0f, 0.0f]));
260                 assert(y.data == [0.0f, 1.0f, 0.0f]);
261 
262                 auto gy = [1.0f, 2.0f, 3.0f].variable;
263                 auto gx = func.backward(gy);
264                 assert(gx.data == [0.0f, 2.0f, 3.0f]);
265             }
266 
267             // test CUDA
268             version (grain_cuda) {
269                 auto x = [-1.0f, 1.0f, 0.0f].variable;
270                 auto xd = x.to!DeviceStorage;
271                 auto yd = func.forward(xd);
272                 x = xd.to!HostStorage;
273                 auto y = yd.to!HostStorage;
274                 assert(x.data == (inplace ? y.data : [-1.0f, 1.0f, 0.0f]));
275                 assert(y.data == [0.0f, 1.0f, 0.0f]);
276 
277                 x = [-1.0f, 1.0f, 0.0f].variable;
278                 auto gy = [1.0f, 2.0f, 3.0f].variable;
279                 auto gxd = func.backward(gy.to!DeviceStorage);
280                 auto gx = gxd.to!HostStorage;
281                 assert(gx.data == [0.0, 2.0, 0.0]);
282             }
283         }
284     }
285 }
286 
287 // forward two functions parallel
288 unittest {
289     import std.typecons;
290     static import grain.config;
291     grain.config.backprop = true;
292     {
293         auto x = [-1.0f, 2.0f, 3.0f].variable(true);
294         x.requiresGrad = true;
295         Variable!(float, 1) y, h;
296         y.requiresGrad = true;
297         h.requiresGrad = true;
298         // bprop will survive even if deeper scope
299         {
300             // FIXME cannot use RefCounted instead of new here
301             // RefCounted!(ReLU!(float, 1)) func0 = ReLU!(float, 1)();
302             auto func0 = new ReLU!(float, 1);
303             h = func0.applyForward(x);
304             // assert(h.bprop.inputs[0].data == x.data);
305             auto func1 = new ReLU!(float, 1);
306             y = func1.applyForward(h);
307             // test the chain to backprop
308             // assert(y.bprop.inputs[0].data == h.data);
309             // assert(y.bprop.inputs[0].bprop.inputs[0].data == x.data);
310         }
311         auto gy = [1.0f, 2.0f, 3.0f].variable;
312         auto ugy = UntypedVariable(gy);
313         y.backward(&ugy);
314         assert(x.grad == [0, 2, 3]);
315 
316         auto func2 = new ReLU!(float, 1);
317         auto y2 = func2.applyForward(x);
318         y2.backward(&ugy);
319         assert(x.grad == [0, 4, 6]); // summation
320     }
321     version (grain_cuda) {
322         auto func = new ReLU!(float, 1);
323         auto x = [-1.0f, 2.0f, 3.0f].variable(true).to!DeviceStorage;
324         auto y = func.applyForward(x);
325         auto gy = [1.0f, 2.0f, 3.0f].variable.to!DeviceStorage;
326         auto ugy = UntypedVariable(gy);
327         y.backward(&ugy);
328         assert(x.grad.toHost() == [0, 2, 3]);
329 
330         auto func2 = new ReLU!(float, 1);
331         auto y2 = func.applyForward(x);
332         y2.backward(&ugy);
333         assert(x.grad.toHost() == [0, 4, 6]); // summation
334     }
335 }
336 
337 // TODO add to numir
338 import mir.ndslice : isSlice;
339 import mir.primitives : DimensionCount;
340 
341 pure nothrow @nogc logsumexp(S)(S x) if (isSlice!S && DimensionCount!S == 1) {
342     import mir.ndslice : map, maxIndex;
343     import mir.math : log, sum, exp;
344 
345     auto m = x[x.maxIndex];
346     auto s = map!exp(x - m).sum!"fast".log;
347     return m + s;
348 }
349 
350 ///
351 pure nothrow // @nogc due to bug dmd 2.082.0
352 unittest {
353     import numir;
354     import mir.ndslice;
355 
356     // import mir.math;
357     import std.math;
358 
359     // FIXME: add static after dmd 2.082.0 fixed
360     immutable x = [-1.0, 2.0, 3.0];
361     immutable e = log(exp(-1.0) + exp(2.0) + exp(3.0));
362     assert(approxEqual(x.sliced.logsumexp, e));
363     immutable xs = [-1.0, 2.0, 3.0, -1.0, 2.0, 3.0, -1.0, 2.0, 3.0];
364     immutable es = [e, e, e];
365     assert(approxEqual(xs.sliced(3, 3).alongDim!1
366             .map!logsumexp, es));
367 }
368 
369 /++
370 See_also: https://github.com/chainer/chainer/blob/v1/chainer/functions/activation/log_softmax.py
371  +/
372 struct LogSoftmax(T, size_t dim = 2) {
373     // TODO support custom dim to compute softmax over (now only dim=1)
374     mixin FunctionCommon;
375 
376     Variable!(T, dim, HostStorage) hy;
377 
378     auto forward(Variable!(T, dim, HostStorage) x) {
379         import mir.ndslice;
380         import numir;
381 
382         // return slice(x.sliced.alongDim!0.map!(e => e - e.logsumexp)).variable;
383         auto y = x.dup;
384         foreach (i; 0 .. y.shape[0]) {
385             y.sliced[i][] -= x.sliced[i].logsumexp;
386         }
387         // TODO if train
388         this.hy = y;
389         return y;
390     }
391 
392     auto backward(Variable!(T, dim, HostStorage) gy) {
393         import mir.math;
394         import numir;
395         import mir.ndslice;
396 
397         auto gx = gy.dup;
398         auto m = gy.sliced
399             .alongDim!1
400             .map!(sum!"fast");
401         foreach (i; 0 .. gx.shape[0]) {
402             gx.sliced[i][] -= this.hy.sliced[i].map!exp * m[i];
403         }
404         return gx;
405     }
406 
407     version (grain_cuda) {
408         import grain.cudnn;
409 
410         Variable!(T, dim, DeviceStorage) dy;
411 
412         auto forward(Variable!(T, dim, DeviceStorage) x) {
413             auto y = x.dup;
414             softmaxForward!CUDNN_SOFTMAX_LOG(x, y);
415             // TODO if train
416             this.dy = y;
417             return y;
418         }
419 
420         auto backward(Variable!(T, dim, DeviceStorage) gy) {
421             auto gx = gy.dup;
422             softmaxBackward!CUDNN_SOFTMAX_LOG(gx, gy, this.dy);
423             return gx;
424         }
425     }
426 }
427 
428 /// test logsoftmax simple case, gradcheck and cpu/cuda equality
429 unittest {
430     import grain.testing;
431     import std.typecons;
432     import numir;
433     import mir.ndslice;
434     import mir.math;
435 
436     auto e = log(exp(-1.0) + exp(2.0) + exp(3.0));
437     auto xs = [[-1.0f, 2.0f, 3.0f], [-1.0f, 2.0f, 3.0f], [-1.0f, 2.0f, 3.0f]]
438         .nparray;
439     LogSoftmax!float hfunc;
440     auto _hx = xs.variable;
441     auto _hy = hfunc.forward(_hx);
442     assert(approxEqual(_hy.sliced, xs - e));
443 
444     auto hx = uniform!float(2, 2).slice.variable;
445     auto hy = hfunc.forward(hx);
446     auto hgy = uniform!float(2, 2).slice.variable;
447     auto hgx = hfunc.backward(hgy);
448     gradCheck(hfunc, hx, hgy, 1e-3, 1e-3, 1e-3);
449 
450     version (grain_cuda) {
451         alias Storage = DeviceStorage;
452         auto func = LogSoftmax!float();
453         auto dx = hx.to!Storage;
454         auto dy = func.forward(dx);
455         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
456         auto dgy = hgy.to!Storage;
457         auto dgx = func.backward(dgy);
458         assert(approxEqual(dgx.to!HostStorage.sliced, hgx.sliced));
459     }
460 }
461 
462 /// wrapper of CUDA kernel unary functions
463 void unaryFunc(alias K, size_t dim)(Variable!(float, dim, DeviceStorage) x) {
464     auto shape = CuPtr!uint(x.shape[0 .. $]);
465     auto strides = CuPtr!int(x.strides[0 .. $]);
466     auto ndim = cast(uint) dim;
467     auto len = cast(uint) x.data.length;
468     Global.kernel!K.call(x.data.ptr, len, ndim, shape.ptr, strides.ptr).launch(len);
469 }
470 
471 /// test neg kernel
472 version (grain_cuda) unittest {
473     import numir;
474     import grain.kernel;
475 
476     auto x = [[1f, 2f, 3f], [4f, 5f, 6f]].variable.to!DeviceStorage;
477     unaryFunc!neg(x);
478     assert(x.to!HostStorage.sliced == -[[1f, 2f, 3f], [4f, 5f, 6f]].nparray);
479 }
480 
481 /// test reciprocal kernel
482 version (grain_cuda) unittest {
483     import grain.kernel;
484 
485     auto x = [[1f, 2f, 3f], [4f, 5f, 6f]].variable.to!DeviceStorage;
486     unaryFunc!reciprocal(x);
487     assert(x.to!HostStorage.sliced == [[1f, 1f / 2f, 1f / 3f], [1f / 4f, 1f / 5f, 1f / 6f]]);
488 }
489 
490 /**
491    test math function kernels. these functions are available in mir.math (as LDC intrinsic) or CUDA fast math
492 
493    See_also:
494    - http://mir.dlang.io/mir_math_common.html
495    - https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#intrinsic-functions
496 */
497 version (grain_cuda) unittest {
498     // FIXME mir.math.log will exit 1
499     import std.math : log, tan; // , log2, log10, exp, exp2, cos, sin, tan;
500     import mir.math : log2, log10, exp, exp2, cos, sin;
501     import grain.kernel : log, log2, log10, exp, exp2, cos, sin, tan;
502     import std.format : format;
503     import numir : approxEqual;
504     import mir.ndslice : iota, as, slice, map;
505 
506     static foreach (name; ["log", "log2", "log10", "exp", "exp2", "cos", "sin", "tan"]) {
507         {
508             auto x = iota([2, 3], 1).as!float
509                 .slice
510                 .variable
511                 .to!DeviceStorage;
512             mixin(format!q{  unaryFunc!(grain.kernel.%s)(x);  }(name));
513             mixin(format!q{  alias func = %s;  }(name));
514             assert(approxEqual(x.to!HostStorage.sliced, iota([2, 3], 1).as!float
515                     .map!func));
516         }
517     }
518 }
519 
520 /// wrapper of CUDA kernel pow function
521 void unaryPow(size_t dim)(Variable!(float, dim, DeviceStorage) x, float power) {
522     import grain.kernel : pow;
523 
524     auto shape = CuPtr!uint(x.shape[0 .. $]);
525     auto strides = CuPtr!int(x.strides[0 .. $]);
526     auto ndim = cast(uint) dim;
527     auto len = cast(uint) x.data.length;
528     Global.kernel!pow.call(power, x.data.ptr, len, ndim, shape.ptr, strides.ptr)
529         .launch(len);
530 }
531 
532 /// test pow kernel
533 version (grain_cuda) unittest {
534     import numir;
535     import mir.ndslice;
536     import grain.kernel;
537     import mir.math : pow;
538 
539     auto x = iota([2, 3], 1).as!float
540         .slice
541         .variable
542         .to!DeviceStorage;
543     unaryPow(x, 2f);
544     assert(approxEqual(x.to!HostStorage.sliced, iota([2, 3], 1).as!float
545             .map!(x => pow(x, 2))));
546 }
547 
548 /// y = 1 / x
549 struct Reciprocal(T, size_t dim) {
550     mixin FunctionCommon;
551 
552     Variable!(T, dim, HostStorage) hy;
553 
554     auto forward(Variable!(T, dim, HostStorage) x) {
555         import mir.ndslice : map, slice;
556 
557         auto y = x.sliced.map!(a => T(1) / a).slice.variable(x.requiresGrad);
558         this.hy = y; // TODO if train
559         return y;
560     }
561 
562     auto backward(Variable!(T, dim, HostStorage) gy) {
563         auto gx = this.hy.dup;
564         gx.sliced[] *= gx.sliced;
565         gx.sliced[] *= -gy.sliced;
566         return gx;
567     }
568 
569     version (grain_cuda) {
570         Variable!(T, dim, DeviceStorage) dy;
571 
572         auto forward(Variable!(T, dim, DeviceStorage) x) {
573             import grain.kernel : reciprocal;
574 
575             auto y = x.dup;
576             unaryFunc!reciprocal(y);
577             this.dy = y; // TODO if train
578             return y;
579         }
580 
581         auto backward(Variable!(T, dim, DeviceStorage) gy) {
582             import grain.cudnn : tensorOp, CUDNN_OP_TENSOR_MUL;
583 
584             auto gx = this.dy.dup;
585             tensorOp!CUDNN_OP_TENSOR_MUL(gx, gx, this.dy, T(-1));
586             tensorOp!CUDNN_OP_TENSOR_MUL(gx, gx, gy);
587             return gx;
588         }
589     }
590 }
591 
592 /// test reciprocal simple case, gradcheck and cpu/cuda equality
593 unittest {
594     import grain.testing;
595     import std.typecons;
596     import numir;
597     import mir.ndslice;
598 
599     // simple case
600     auto e = [[-1.0f, 1f / 2f, 1f / 3f], [1f, 10f, 1f / 3f]].nparray;
601     auto xs = [[-1.0f, 2.0f, 3.0f], [1.0f, 0.1f, 3.0f]].nparray;
602     Reciprocal!(float, 2) hfunc;
603     auto _hx = xs.variable;
604     auto _hy = hfunc.forward(_hx);
605     assert(approxEqual(_hy.sliced, e));
606 
607     auto hx = uniform!float(2, 2).slice.variable;
608     auto hy = hfunc.forward(hx);
609     auto hgy = uniform!float(2, 2).slice.variable;
610     auto hgx = hfunc.backward(hgy);
611     gradCheck(hfunc, hx, hgy, 1e-3, 5e-2, 5e-2);
612 
613     version (grain_cuda) {
614         Reciprocal!(float, 2) dfunc;
615         auto dy = dfunc.forward(hx.to!DeviceStorage);
616         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
617         auto dgx = dfunc.backward(hgy.to!DeviceStorage);
618         assert(approxEqual(dgx.to!HostStorage.sliced, hgx.sliced));
619     }
620 }
621 
622 /// y = exp x
623 struct Exp(T, size_t dim) {
624     import mir.ndslice : slice, map;
625 
626     mixin FunctionCommon;
627 
628     Variable!(T, dim, HostStorage) hy;
629 
630     auto forward(Variable!(T, dim, HostStorage) x) {
631         import mir.math.common : exp;
632 
633         auto y = slice(x.sliced.map!exp).variable(x.requiresGrad);
634         this.hy = y; // TODO if train
635         return y;
636     }
637 
638     auto backward(Variable!(T, dim, HostStorage) gy) {
639         auto gx = gy.dup;
640         gx.sliced[] *= this.hy.sliced;
641         return gx;
642     }
643 
644     version (grain_cuda) {
645         Variable!(T, dim, DeviceStorage) dy;
646 
647         auto forward(Variable!(T, dim, DeviceStorage) x) {
648             import grain.kernel : exp;
649 
650             auto y = x.dup;
651             unaryFunc!exp(y);
652             this.dy = y;
653             return y;
654         }
655 
656         auto backward(Variable!(T, dim, DeviceStorage) gy) {
657             return this.dy * gy;
658         }
659     }
660 }
661 
662 ///
663 unittest {
664     import grain.testing;
665     import std.typecons;
666     import numir;
667     import mir.ndslice;
668     import mir.math : exp;
669 
670     Exp!(float, 2) hfunc;
671     auto hx = uniform!float(2, 3).slice.variable;
672     auto hy = hfunc.forward(hx);
673     auto hgy = uniform!float(2, 3).slice.variable;
674     auto hgx = hfunc.backward(hgy);
675     gradCheck(hfunc, hx, hgy);
676     assert(approxEqual(hy.sliced, hx.sliced.map!exp));
677 
678     version (grain_cuda) {
679         Exp!(float, 2) dfunc;
680         auto dy = dfunc.forward(hx.to!DeviceStorage);
681         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
682         auto dgx = dfunc.backward(hgy.to!DeviceStorage);
683         assert(approxEqual(dgx.to!HostStorage.sliced, hgx.sliced));
684     }
685 }
686 
687 /// y = exp x
688 struct Log(T, size_t dim) {
689     import mir.ndslice : slice, map;
690 
691     mixin FunctionCommon;
692 
693     Variable!(T, dim, HostStorage) hx;
694 
695     auto forward(Variable!(T, dim, HostStorage) x) {
696         import mir.math.common : log;
697 
698         auto y = slice(x.sliced.map!log).variable(x.requiresGrad);
699         this.hx = x; // TODO if train
700         return y;
701     }
702 
703     auto backward(Variable!(T, dim, HostStorage) gy) {
704         auto gx = gy.dup;
705         gx.sliced[] /= this.hx.sliced;
706         return gx;
707     }
708 
709     version (grain_cuda) {
710         Variable!(T, dim, DeviceStorage) dx;
711 
712         auto forward(Variable!(T, dim, DeviceStorage) x) {
713             import grain.kernel : log;
714 
715             auto y = x.dup;
716             unaryFunc!log(y);
717             this.dx = x;
718             return y;
719         }
720 
721         auto backward(Variable!(T, dim, DeviceStorage) gy) {
722             return gy / this.dx;
723         }
724     }
725 }
726 
727 ///
728 unittest {
729     import grain.testing;
730     import std.typecons;
731     import numir;
732     import mir.ndslice;
733     import mir.math : log;
734 
735     Log!(float, 2) hfunc;
736     auto hx = uniform!float(2, 3).slice.variable;
737     auto hy = hfunc.forward(hx);
738     auto hgy = uniform!float(2, 3).slice.variable;
739     auto hgx = hfunc.backward(hgy);
740     gradCheck(hfunc, hx, hgy, 1e-3, 5e-2, 5e-2);
741     assert(approxEqual(hy.sliced, hx.sliced.map!log));
742 
743     version (grain_cuda) {
744         Log!(float, 2) dfunc;
745         auto dy = dfunc.forward(hx.to!DeviceStorage);
746         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
747         auto dgx = dfunc.backward(hgy.to!DeviceStorage);
748         assert(approxEqual(dgx.to!HostStorage.sliced, hgx.sliced));
749     }
750 }
751 
752 // TODO implement these autograd fast-math functions
753 // struct Log2
754 // struct Log10
755 
756 // struct Exp2
757 // struct Exp10
758 
759 /// y = sin x
760 struct Sin(T, size_t dim) {
761     import mir.ndslice : slice, map;
762 
763     mixin FunctionCommon;
764 
765     Variable!(T, dim, HostStorage) hx;
766 
767     auto forward(Variable!(T, dim, HostStorage) x) {
768         import mir.math.common : sin;
769 
770         auto y = slice(x.sliced.map!sin).variable(x.requiresGrad);
771         this.hx = x; // TODO if train
772         return y;
773     }
774 
775     auto backward(Variable!(T, dim, HostStorage) gy) {
776         import mir.math.common : cos;
777 
778         auto gx = gy.dup;
779         gx.sliced[] *= this.hx.sliced.map!cos;
780         return gx;
781     }
782 
783     version (grain_cuda) {
784         Variable!(T, dim, DeviceStorage) dx;
785 
786         auto forward(Variable!(T, dim, DeviceStorage) x) {
787             import grain.kernel : sin;
788 
789             auto y = x.dup;
790             unaryFunc!sin(y);
791             this.dx = x;
792             return y;
793         }
794 
795         auto backward(Variable!(T, dim, DeviceStorage) gy) {
796             import grain.cudnn;
797             import grain.kernel : cos;
798 
799             auto gx = this.dx.dup;
800             unaryFunc!cos(gx);
801             tensorOp!CUDNN_OP_TENSOR_MUL(gx, gx, gy);
802             return gx;
803         }
804     }
805 }
806 
807 ///
808 unittest {
809     import grain.testing;
810     import std.typecons;
811     import numir;
812     import mir.ndslice;
813     import mir.math : sin;
814 
815     Sin!(float, 2) hfunc;
816     auto hx = uniform!float(2, 3).slice.variable;
817     auto hy = hfunc.forward(hx);
818     auto hgy = uniform!float(2, 3).slice.variable;
819     auto hgx = hfunc.backward(hgy);
820     gradCheck(hfunc, hx, hgy);
821     assert(approxEqual(hy.sliced, hx.sliced.map!sin));
822 
823     version (grain_cuda) {
824         Sin!(float, 2) dfunc;
825         auto dy = dfunc.forward(hx.to!DeviceStorage);
826         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
827         auto dgx = dfunc.backward(hgy.to!DeviceStorage);
828         assert(approxEqual(dgx.to!HostStorage.sliced, hgx.sliced));
829     }
830 }
831 
832 /// y = cos x
833 struct Cos(T, size_t dim) {
834     import mir.ndslice : slice, map;
835 
836     mixin FunctionCommon;
837 
838     Variable!(T, dim, HostStorage) hx;
839 
840     auto forward(Variable!(T, dim, HostStorage) x) {
841         import mir.math.common : cos;
842 
843         auto y = slice(x.sliced.map!cos).variable(x.requiresGrad);
844         this.hx = x; // TODO if train
845         return y;
846     }
847 
848     auto backward(Variable!(T, dim, HostStorage) gy) {
849         import mir.math.common : sin;
850 
851         auto gx = gy.dup;
852         gx.sliced[] *= -this.hx.sliced.map!sin;
853         return gx;
854     }
855 
856     version (grain_cuda) {
857         Variable!(T, dim, DeviceStorage) dx;
858 
859         auto forward(Variable!(T, dim, DeviceStorage) x) {
860             import grain.kernel : cos;
861 
862             auto y = x.dup;
863             unaryFunc!cos(y);
864             this.dx = x;
865             return y;
866         }
867 
868         auto backward(Variable!(T, dim, DeviceStorage) gy) {
869             import grain.cudnn;
870             import grain.kernel : sin;
871 
872             auto gx = this.dx.dup;
873             unaryFunc!sin(gx);
874             tensorOp!CUDNN_OP_TENSOR_MUL(gx, gx, gy, -1);
875             return gx;
876         }
877     }
878 }
879 
880 ///
881 unittest {
882     import grain.testing;
883     import std.typecons;
884     import numir;
885     import mir.ndslice;
886     import mir.math : cos;
887 
888     Cos!(float, 2) hfunc;
889     auto hx = uniform!float(2, 3).slice.variable;
890     auto hy = hfunc.forward(hx);
891     auto hgy = uniform!float(2, 3).slice.variable;
892     auto hgx = hfunc.backward(hgy);
893     gradCheck(hfunc, hx, hgy, 1e-3, 1e-3, 1e-3);
894     assert(approxEqual(hy.sliced, hx.sliced.map!cos));
895 
896     version (grain_cuda) {
897         Cos!(float, 2) dfunc;
898         auto dy = dfunc.forward(hx.to!DeviceStorage);
899         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
900         auto dgx = dfunc.backward(hgy.to!DeviceStorage);
901         assert(approxEqual(dgx.to!HostStorage.sliced, hgx.sliced));
902     }
903 }
904 
905 /// y = tan x
906 struct Tan(T, size_t dim) {
907     import mir.ndslice : slice, map, as;
908 
909     mixin FunctionCommon;
910 
911     Variable!(T, dim, HostStorage) hx;
912 
913     auto forward(Variable!(T, dim, HostStorage) x) {
914         import std.math : tan;
915 
916         auto y = slice(x.sliced
917                 .map!tan
918                 .as!T).variable(x.requiresGrad);
919         this.hx = x; // TODO if train
920         return y;
921     }
922 
923     auto backward(Variable!(T, dim, HostStorage) gy) {
924         import mir.math.common : cos;
925 
926         auto gx = gy.dup;
927         auto c = this.hx
928             .sliced
929             .map!cos
930             .map!"a * a";
931         gx.sliced[] /= c;
932         return gx;
933     }
934 
935     version (grain_cuda) {
936         Variable!(T, dim, DeviceStorage) dx;
937 
938         auto forward(Variable!(T, dim, DeviceStorage) x) {
939             import grain.kernel : tan;
940 
941             auto y = x.dup;
942             unaryFunc!tan(y);
943             this.dx = x;
944             return y;
945         }
946 
947         auto backward(Variable!(T, dim, DeviceStorage) gy) {
948             import grain.cudnn;
949             import grain.kernel : cos;
950 
951             auto gx = this.dx.dup;
952             unaryFunc!cos(gx);
953             tensorOp!CUDNN_OP_TENSOR_MUL(gx, gx, gx); // cos^2 x
954             return gy / gx;
955         }
956     }
957 }
958 
959 ///
960 unittest {
961     import grain.testing;
962     import std.typecons;
963     import numir;
964     import mir.ndslice;
965     import std.math : tan;
966 
967     Tan!(float, 2) hfunc;
968     auto hx = uniform!float(2, 3).slice.variable;
969     auto hy = hfunc.forward(hx);
970     auto hgy = uniform!float(2, 3).slice.variable;
971     auto hgx = hfunc.backward(hgy);
972     gradCheck(hfunc, hx, hgy);
973     assert(approxEqual(hy.sliced, hx.sliced.map!tan));
974 
975     version (grain_cuda) {
976         Tan!(float, 2) dfunc;
977         auto dy = dfunc.forward(hx.to!DeviceStorage);
978         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
979         auto dgx = dfunc.backward(hgy.to!DeviceStorage);
980         assert(approxEqual(dgx.to!HostStorage.sliced, hgx.sliced));
981     }
982 }
983 
984 /// y = alpha * x
985 struct Scale(T, size_t dim) {
986     import mir.ndslice : slice;
987 
988     T alpha = 1.0;
989 
990     auto forward(Variable!(T, dim, HostStorage) x) {
991         return slice(this.alpha * x.sliced).variable(x.requiresGrad);
992     }
993 
994     auto backward(Variable!(T, dim, HostStorage) gy) {
995         return slice(this.alpha * gy.sliced).variable(gy.requiresGrad);
996     }
997 
998     version (grain_cuda) {
999         import grain.cudnn : scale;
1000 
1001         auto forward(Variable!(T, dim, DeviceStorage) x) {
1002             auto y = x.dup;
1003             scale(y, this.alpha);
1004             return y;
1005         }
1006 
1007         auto backward(Variable!(T, dim, DeviceStorage) gy) {
1008             auto gx = gy.dup;
1009             scale(gx, this.alpha);
1010             return gx;
1011         }
1012     }
1013 
1014     mixin FunctionCommon;
1015 }
1016 
1017 /// test scale in simple case, gradcheck and cpu/cuda equality
1018 unittest {
1019     import grain.testing;
1020     import std.typecons;
1021     import numir;
1022     import mir.ndslice;
1023 
1024     // simple case: 2.0 * x
1025     auto e = [[-2.0f, 4.0f, 6.0f], [2.0f, 0.2f, 0.0f]].nparray;
1026     auto xs = [[-1.0f, 2.0f, 3.0f], [1.0f, 0.1f, 0.0f]].nparray;
1027     auto hfunc = Scale!(float, 2)(2f);
1028     auto _hx = xs.variable;
1029     auto _hy = hfunc.forward(_hx);
1030     assert(approxEqual(_hy.sliced, e));
1031 
1032     auto hx = uniform!float(2, 2).slice.variable;
1033     auto hy = hfunc.forward(hx);
1034     auto hgy = uniform!float(2, 2).slice.variable;
1035     auto hgx = hfunc.backward(hgy);
1036     gradCheck(hfunc, hx, hgy); // , 1e-3, 1e-3, 1e-3);
1037 
1038     version (grain_cuda) {
1039         auto dfunc = Scale!(float, 2)(2f);
1040         auto dy = dfunc.forward(hx.to!DeviceStorage);
1041         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
1042         auto dgx = dfunc.backward(hgy.to!DeviceStorage);
1043         assert(approxEqual(dgx.to!HostStorage.sliced, hgx.sliced));
1044     }
1045 }
1046 
1047 /// y = -x
1048 struct Neg(T, size_t dim) {
1049     import mir.ndslice : slice;
1050 
1051     auto forward(Variable!(T, dim, HostStorage) x) {
1052         return slice(-x.sliced).variable(x.requiresGrad);
1053     }
1054 
1055     auto backward(Variable!(T, dim, HostStorage) gy) {
1056         return slice(-gy.sliced).variable(gy.requiresGrad);
1057     }
1058 
1059     version (grain_cuda) {
1060         import grain.kernel : neg;
1061 
1062         auto forward(Variable!(T, dim, DeviceStorage) x) {
1063             auto y = x.dup;
1064             unaryFunc!neg(y);
1065             return y;
1066         }
1067 
1068         auto backward(Variable!(T, dim, DeviceStorage) gy) {
1069             auto gx = gy.dup;
1070             unaryFunc!neg(gx);
1071             return gx;
1072         }
1073     }
1074 
1075     mixin FunctionCommon;
1076 }
1077 
1078 /// test neg simple case, gradcheck and cpu/cuda equality
1079 unittest {
1080     import grain.testing;
1081     import std.typecons;
1082     import numir;
1083     import mir.ndslice;
1084 
1085     // simple case: 2.0 * x
1086     auto xs = [[-1.0f, 2.0f, 3.0f], [1.0f, 0.1f, 0.0f]].nparray;
1087     auto hfunc = Neg!(float, 2)();
1088     auto _hx = xs.variable;
1089     auto _hy = hfunc.forward(_hx);
1090     assert(approxEqual(_hy.sliced, -xs));
1091 
1092     auto hx = uniform!float(2, 2).slice.variable;
1093     auto hy = hfunc.forward(hx);
1094     auto hgy = uniform!float(2, 2).slice.variable;
1095     auto hgx = hfunc.backward(hgy);
1096     gradCheck(hfunc, hx, hgy); // , 1e-3, 1e-3, 1e-3);
1097 
1098     version (grain_cuda) {
1099         auto dfunc = Neg!(float, 2)();
1100         auto dy = dfunc.forward(hx.to!DeviceStorage);
1101         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
1102         auto dgx = dfunc.backward(hgy.to!DeviceStorage);
1103         assert(approxEqual(dgx.to!HostStorage.sliced, hgx.sliced));
1104     }
1105 }
1106 
1107 /// y = abs x
1108 struct Abs(T, size_t dim) {
1109     import mir.ndslice : slice, map;
1110 
1111     mixin FunctionCommon;
1112     Variable!(T, dim, HostStorage) hx;
1113 
1114     auto forward(Variable!(T, dim, HostStorage) x) {
1115         import mir.math : fabs;
1116 
1117         this.hx = x; // if train
1118         return slice(x.sliced.map!fabs).variable(x.requiresGrad);
1119     }
1120 
1121     auto backward(Variable!(T, dim, HostStorage) gy) {
1122         auto gx = gy.dup;
1123         gx.sliced[] *= this.hx.sliced.map!(a => a == 0f ? 0f : (a > 0f ? 1f : -1f));
1124         return gx;
1125     }
1126 
1127     version (grain_cuda) {
1128         Variable!(T, dim, DeviceStorage) dx;
1129 
1130         auto forward(Variable!(T, dim, DeviceStorage) x) {
1131             import grain.kernel : abs;
1132 
1133             auto y = x.dup;
1134             unaryFunc!abs(y);
1135             this.dx = x; // if train
1136             return y;
1137         }
1138 
1139         auto backward(Variable!(T, dim, DeviceStorage) gy) {
1140             import grain.kernel : absGrad;
1141 
1142             auto gx = this.dx.dup;
1143             unaryFunc!absGrad(gx);
1144             return gy * gx;
1145         }
1146     }
1147 }
1148 
1149 /// test abs simple case, gradcheck and cpu/cuda equality
1150 unittest {
1151     import grain.testing;
1152     import std.typecons;
1153     import numir;
1154     import mir.ndslice;
1155 
1156     auto xs = [[-1.0f, 2.0f, -3.0f], [1.0f, 0.0f, 0.0f]].nparray;
1157     auto ys = [[1.0f, 2.0f, 3.0f], [1.0f, 0.0f, 0.0f]].nparray;
1158     auto hfunc = Abs!(float, 2)();
1159     auto hx = xs.variable;
1160     auto hy = hfunc.forward(hx);
1161     assert(approxEqual(hy.sliced, ys));
1162 
1163     auto gxs = [[-0.1f, 0.2f, -0.3f], [0.5f, 0.0f, 0.0f]].nparray;
1164     auto gys = [[0.1f, 0.2f, 0.3f], [0.5f, 0.6f, 0.7f]].nparray;
1165     auto hgy = gys.variable;
1166     auto hgx = hfunc.backward(hgy);
1167     assert(approxEqual(hgx.sliced, gxs));
1168 
1169     version (grain_cuda) {
1170         auto dfunc = Abs!(float, 2)();
1171         auto dy = dfunc.forward(hx.to!DeviceStorage);
1172         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
1173         auto dgx = dfunc.backward(hgy.to!DeviceStorage);
1174         assert(approxEqual(dgx.to!HostStorage.sliced, hgx.sliced));
1175     }
1176 }
1177 
1178 /// y = pow x
1179 struct Pow(T, size_t dim) {
1180     import mir.ndslice : slice, map;
1181 
1182     mixin FunctionCommon;
1183 
1184     T power;
1185     Variable!(T, dim, HostStorage) hx;
1186 
1187     this(T power) {
1188         this.power = power;
1189     }
1190 
1191     auto forward(Variable!(T, dim, HostStorage) x) {
1192         import mir.math.common : pow;
1193 
1194         auto y = slice(x.sliced.map!(a => pow(a, this.power))).variable(x
1195                 .requiresGrad);
1196         this.hx = x; // TODO if train
1197         return y;
1198     }
1199 
1200     auto backward(Variable!(T, dim, HostStorage) gy) {
1201         import mir.math.common : pow;
1202 
1203         auto gx = gy.dup;
1204         gx.sliced[] *= this.hx.sliced.map!(a => this.power * pow(a, this.power - 1));
1205         return gx;
1206     }
1207 
1208     version (grain_cuda) {
1209         Variable!(T, dim, DeviceStorage) dx;
1210 
1211         auto forward(Variable!(T, dim, DeviceStorage) _x) {
1212             this.dx = _x;
1213             auto x = _x.dup;
1214             import grain.kernel : pow;
1215 
1216             auto shape = CuPtr!uint(x.shape[0 .. $]);
1217             auto strides = CuPtr!int(x.strides[0 .. $]);
1218             auto ndim = cast(uint) dim;
1219             auto len = cast(uint) x.data.length;
1220 
1221             Global.kernel!pow.call(this.power, x.data.ptr, len, ndim,
1222                     shape.ptr, strides.ptr).launch(len);
1223             return x;
1224         }
1225 
1226         auto backward(Variable!(T, dim, DeviceStorage) gy) {
1227             auto x = this.dx.dup;
1228             import grain.kernel : powGrad;
1229 
1230             auto shape = CuPtr!uint(x.shape[0 .. $]);
1231             auto strides = CuPtr!int(x.strides[0 .. $]);
1232             auto ndim = cast(uint) dim;
1233             auto len = cast(uint) x.data.length;
1234 
1235             Global.kernel!powGrad.call(this.power, x.data.ptr, len, ndim,
1236                     shape.ptr, strides.ptr).launch(len);
1237             return gy * x;
1238         }
1239     }
1240 }
1241 
1242 ///
1243 unittest {
1244     import grain.testing;
1245     import std.typecons;
1246     import numir;
1247     import mir.ndslice;
1248     import mir.math : pow;
1249 
1250     auto p = 2.0f;
1251     auto hfunc = Pow!(float, 2)(p);
1252     auto hx = uniform!float(2, 3).slice.variable;
1253     auto hy = hfunc.forward(hx);
1254     auto hgy = uniform!float(2, 3).slice.variable;
1255     auto hgx = hfunc.backward(hgy);
1256     gradCheck(hfunc, hx, hgy, 1e-3, 1e-3, 1e-3);
1257     assert(approxEqual(hy.sliced, hx.sliced.map!(a => pow(a, p))));
1258 
1259     version (grain_cuda) {
1260         auto dfunc = Pow!(float, 2)(p);
1261         auto dy = dfunc.forward(hx.to!DeviceStorage);
1262         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
1263         auto dgx = dfunc.backward(hgy.to!DeviceStorage);
1264         assert(approxEqual(dgx.to!HostStorage.sliced, hgx.sliced));
1265     }
1266 }
1267 
1268 /// n-dimensional strided
1269 auto ndStrided(size_t d = 0, S, size_t dim)(S s, ptrdiff_t[dim] strides...)
1270         if (isSlice!S && DimensionCount!S >= dim) {
1271     static if (d == dim) {
1272         return s;
1273     }
1274     else {
1275         import mir.ndslice.dynamic : strided;
1276 
1277         return ndStrided!(d + 1)(s.strided!d(strides[d]), strides);
1278     }
1279 }
1280 
1281 ///
1282 unittest {
1283     import mir.ndslice;
1284 
1285     auto s = iota(3, 4);
1286     assert(s.ndStrided(2, 3) == s.strided(0, 2).strided(1, 3));
1287 }
1288 
1289 /// only both is supported
1290 auto unpad(size_t d = 0, S, size_t N)(S s, size_t[N] lengths...)
1291         if (isSlice!S && DimensionCount!S == N) {
1292     import mir.ndslice;
1293 
1294     static if (d == N) {
1295         return s;
1296     }
1297     else {
1298         immutable p = lengths[d];
1299         auto s_ = s.swapped!(0, d)[p .. $ - p].swapped!(0, d);
1300         return unpad!(d + 1)(s_, lengths);
1301     }
1302 }
1303 
1304 ///
1305 unittest {
1306     import mir.ndslice;
1307 
1308     auto s = iota(3, 4);
1309     assert(s.pad(0, [2, 1]).slicedNdField.unpad(2, 1) == s);
1310 }
1311 
1312 void sumNdStrided(size_t d = 0, S, D, size_t dim)(S src, D dst, ptrdiff_t[dim] strides...)
1313         if (isSlice!S && isSlice!D && DimensionCount!S == DimensionCount!D) {
1314     static if (d == 0) {
1315         static assert(dim == DimensionCount!S);
1316     }
1317     foreach (i; 0 .. src.length!0) {
1318         immutable j = i * strides[d];
1319         foreach (k; i .. i + strides[d]) {
1320             if (j >= dst.length!0 || k >= src.length!0) {
1321                 break;
1322             }
1323             static if (d + 1 == dim) {
1324                 dst[j] += src[k];
1325             }
1326             else {
1327                 sumNdStrided!(d + 1)(src[k], dst[j], strides);
1328             }
1329         }
1330     }
1331 }
1332 
1333 ///
1334 unittest {
1335     import numir;
1336     import mir.ndslice;
1337 
1338     auto s = iota(3, 4);
1339     auto t = s.ndStrided(2, 3); // [[[[1, 2, 2]], [[4, 5, 5]]], [[[7, 8, 8]], [[10, 11, 11]]]]
1340     auto d = s.slice.zeros_like;
1341     sumNdStrided(s.ndStrided(2, 3).slice, d, [2, 3]);
1342     assert(d == [[22, 0, 0, 14], [0, 0, 0, 0], [19, 0, 0, 11]]);
1343 }
1344 
1345 /// reference implementaion of pooling function
1346 struct PoolRefImpl(alias poolFun) {
1347     import mir.ndslice;
1348     import numir;
1349 
1350     static auto forward(T, size_t poolDims, size_t tensorDims)(Variable!(T,
1351             tensorDims, HostStorage) x,
1352             // ref Variable!(T, poolDims+2, HostStorage) y,
1353             int[poolDims] windowA, int[poolDims] padA, int[poolDims] strideA) {
1354         static assert(poolDims + 2 == tensorDims);
1355         size_t[tensorDims] p, w;
1356         p[2 .. $] = padA.castArray!size_t[];
1357         w[0 .. 2] = x.shape[0 .. 2].castArray!size_t[];
1358         w[2 .. $] = windowA.castArray!size_t[];
1359         ptrdiff_t[tensorDims] s;
1360         s[0 .. 2] = 1;
1361         s[2 .. $] = strideA.castArray!ptrdiff_t[];
1362         auto expanded = x.sliced.pad(0, p).slicedNdField.windows(w) // [wb=1, wc=1, wx, wy, ..., [b, c, kx, ky, ...]]
1363         .ndStrided(s) // [wb=1, wc=1, sx, sy, ..., [b, c, kx, ky, ...]]
1364         .unpack; // [wb=1, wc=1, sx, sy, ..., b, c, kx, ky, ...]
1365 
1366         ptrdiff_t[poolDims + 3] e;
1367         static foreach (i; 0 .. poolDims) {
1368             e[i] = cast(ptrdiff_t) expanded.length!(i + 2);
1369         }
1370         e[poolDims] = x.shape[0];
1371         e[poolDims + 1] = x.shape[1];
1372         e[poolDims + 2] = -1;
1373         return expanded.slice // TODO how can i avoid this allocation
1374         .view(e).alongDim!(-1) // [wx, wy, b, c, kx * ky]
1375 
1376         
1377 
1378             .map!poolFun
1379             .transposed!(tensorDims - 1)
1380             .transposed!(tensorDims - 1)
1381             .slice
1382             .variable(x.requiresGrad);
1383     }
1384 
1385     static void backward(T, size_t poolDims, size_t tensorDims)(
1386             ref Variable!(T, tensorDims, HostStorage) gx, Variable!(T,
1387             tensorDims, HostStorage) x, Variable!(T, tensorDims,
1388             HostStorage) gy, Variable!(T,
1389             tensorDims, HostStorage) y, int[poolDims] windowA,
1390             int[poolDims] padA, int[poolDims] strideA) {
1391 
1392         static assert(poolDims + 2 == tensorDims);
1393         size_t[tensorDims] p, w;
1394         p[2 .. $] = padA.castArray!size_t[];
1395         w[0 .. 2] = x.shape[0 .. 2].castArray!size_t[];
1396         w[2 .. $] = windowA.castArray!size_t[];
1397         ptrdiff_t[tensorDims] s;
1398         s[0 .. 2] = 1;
1399         s[2 .. $] = strideA.castArray!ptrdiff_t[];
1400         auto expanded = x.sliced.pad(0, p).slice // TODO how can i avoid this allocation
1401         .windows(w) // [wb=1, wc=1, wx, wy, ..., [b, c, kx, ky, ...]]
1402         .ndStrided(s) // [wb=1, wc=1, sx, sy, ..., [b, c, kx, ky, ...]]
1403         .unpack; // [wb=1, wc=1, sx, sy, ..., b, c, kx, ky, ...]
1404 
1405         ptrdiff_t[poolDims + 3] e;
1406         size_t kernelSize = 1;
1407         static foreach (i; 0 .. poolDims) {
1408             e[i] = cast(ptrdiff_t) expanded.length!(i + 2);
1409             kernelSize *= windowA[i];
1410         }
1411         e[poolDims] = x.shape[0];
1412         e[poolDims + 1] = x.shape[1];
1413         e[poolDims + 2] = -1;
1414         auto gxpad = x.uninit.sliced.pad(0, p).slice;
1415         gxpad[] = 0;
1416         auto xe = expanded.view(e); // [wx, wy, ..., b, c, kx * ky * ...]
1417     }
1418 
1419 }
1420 
1421 /// max pooling function
1422 struct Pool(bool isMax, T, size_t poolDims) {
1423     static assert(poolDims > 0);
1424     enum dim = poolDims + 2;
1425     int[poolDims] window, pad, stride;
1426 
1427     Variable!(T, dim, HostStorage) hx, hy;
1428 
1429     static if (isMax) {
1430         // import std.algorithm : max;
1431         import mir.ndslice;
1432 
1433         static auto pool(S)(S s) if (isSlice!S) {
1434             return maxPos(s).first;
1435         }
1436 
1437         alias CpuImpl = PoolRefImpl!pool;
1438     }
1439 
1440     auto forward(Variable!(T, dim, HostStorage) x) {
1441         auto y = CpuImpl.forward(x, this.window, this.pad, this.stride);
1442         // TODO if train
1443         this.hx = x;
1444         this.hy = y;
1445         return y;
1446     }
1447 
1448     auto backward(Variable!(T, dim, HostStorage) gy) {
1449         auto gx = this.hx.uninit;
1450         CpuImpl.backward(gx, this.hx, gy, this.hy, this.window, this.pad, this
1451                 .stride);
1452         return gx;
1453     }
1454 
1455     version (grain_cuda) {
1456         Variable!(T, dim, DeviceStorage) dx, dy;
1457 
1458         import grain.cudnn;
1459 
1460         auto forward(Variable!(T, dim, DeviceStorage) x) {
1461             auto y = grain.cudnn.poolForward(x, this.window, this.pad, this
1462                     .stride);
1463             // TODO if train
1464             this.dx = x;
1465             this.dy = y;
1466             return y;
1467         }
1468 
1469         auto backward(Variable!(T, dim, DeviceStorage) gy) {
1470             auto gx = this.dx.uninit;
1471             grain.cudnn.poolBackward(gx, this.dx, gy, this.dy, this.window, this
1472                     .pad, this.stride);
1473             return gx;
1474         }
1475     }
1476 }
1477 
1478 alias MaxPool(T, size_t poolDims) = Pool!(true, T, poolDims);
1479 alias AvgPool(T, size_t poolDims) = Pool!(false, T, poolDims);
1480 
1481 ///
1482 unittest {
1483     import std.stdio;
1484     import mir.ndslice;
1485     import numir;
1486 
1487     auto f = MaxPool!(float, 2)([3, 3], [1, 1], [1, 1]);
1488     auto x = iota(2, 2, 1, 3).as!float.slice.variable;
1489     // [[[[0, 1, 2]], [[3, 4, 5]]],
1490     //  [[[6, 7, 8]], [[9, 10, 11]]]]
1491     enum yex = [[[[1, 2, 2]], [[4, 5, 5]]], [[[7, 8, 8]], [[10, 11, 11]]]];
1492 
1493     // mir implementation
1494     auto hy = f.forward(x);
1495     assert(hy.sliced == yex);
1496 
1497     version (grain_cuda) {
1498         auto y = f.forward(x.to!DeviceStorage);
1499         assert(y.to!HostStorage.sliced == yex);
1500         auto gx = f.backward(y);
1501         enum gxex = [[[[0, 1, 4]], [[0, 4, 10]]], [[[0, 7, 16]], [[0, 10, 22]]]];
1502         assert(gx.to!HostStorage.sliced == gxex);
1503     }
1504 }