1 /**
2    A module for reductions functions
3 */
4 module grain.functions.reduction;
5 
6 import grain.autograd;
7 import grain.cuda;
8 import grain.functions.common;
9 import grain.utility : toTuple, fromTuple, castArray;
10 
11 /// sum to scalar. mode is similar to mir.math.sum
12 struct Sum(string mode = "fast", T, size_t dim) {
13     import std.traits : isFloatingPoint;
14     static assert(isFloatingPoint!T, "currently only float point is supported.");
15 
16     uint[dim] shape;
17 
18     auto forward(Variable!(T, dim, HostStorage) x) {
19         import mir.math : sum;
20 
21         // TODO if train
22         this.shape = x.shape;
23         auto result = x.sliced.sum!mode;
24         return result.variable;
25     }
26 
27     auto backward(Variable!(T, 0, HostStorage) y) {
28         auto gx = uninitVariable!T(this.shape, y.requiresGrad);
29         import std.algorithm : fill;
30         fill(gx.data, y.data[0]);
31         return gx;
32     }
33 
34     version (grain_cuda) {
35         auto forward(Variable!(T, dim, DeviceStorage) x) {
36             import std.algorithm : reduce;
37             import grain.cuda : sum, sumNaive;
38 
39             this.shape = x.shape;
40             // auto y = CuPtr!float([0]);
41             // Global.kernel!sum.call(x.data.ptr, y.ptr, cast(int) x.data.length)
42             //     .launch(cast(uint[3]) [1U,1,1], cast(uint[3]) [1U,1,1], 0U);
43             // checkCudaErrors(cuCtxSynchronize());
44             return x.data.sumNaive.variable.to!DeviceStorage;
45         }
46 
47         auto backward(Variable!(T, 0, DeviceStorage) y) {
48             auto gx = uninitVariable!(T, DeviceStorage)(this.shape, y.requiresGrad);
49             gx.data.fill_(y.data.toHost[0]);
50             return gx;
51         }
52     }
53 
54     mixin FunctionCommon;
55 }
56 
57 ///
58 unittest {
59     import mir.ndslice;
60     import mir.math;
61     auto x = [1f, 2f, 3f, 4f].sliced(2, 2).variable;
62     Sum!("fast", float, 2) func;
63     auto y = func.forward(x);
64     assert(y == 10f.variable);
65     assert(func.backward(1.2f.variable) == [1.2f, 1.2f, 1.2f, 1.2f].sliced(2, 2).variable);
66 
67     version (grain_cuda) {
68         auto cx = x.to!DeviceStorage;
69         auto cy = func.forward(cx).to!HostStorage;
70         assert(cy == 10f.variable);
71         auto cgx = func.backward(1.2f.variable.to!DeviceStorage).to!HostStorage;
72         assert(cgx.sliced == [1.2f, 1.2f, 1.2f, 1.2f].sliced(2, 2));
73     }
74 }