1 /** 2 A module for reductions functions 3 */ 4 module grain.functions.reduction; 5 6 import grain.autograd; 7 import grain.cuda; 8 import grain.functions.common; 9 import grain.utility : toTuple, fromTuple, castArray; 10 11 /// sum to scalar. mode is similar to mir.math.sum 12 struct Sum(string mode = "fast", T, size_t dim) { 13 import std.traits : isFloatingPoint; 14 static assert(isFloatingPoint!T, "currently only float point is supported."); 15 16 uint[dim] shape; 17 18 auto forward(Variable!(T, dim, HostStorage) x) { 19 import mir.math : sum; 20 21 // TODO if train 22 this.shape = x.shape; 23 auto result = x.sliced.sum!mode; 24 return result.variable; 25 } 26 27 auto backward(Variable!(T, 0, HostStorage) y) { 28 auto gx = uninitVariable!T(this.shape, y.requiresGrad); 29 import std.algorithm : fill; 30 fill(gx.data, y.data[0]); 31 return gx; 32 } 33 34 version (grain_cuda) { 35 auto forward(Variable!(T, dim, DeviceStorage) x) { 36 import std.algorithm : reduce; 37 import grain.cuda : sum, sumNaive; 38 39 this.shape = x.shape; 40 // auto y = CuPtr!float([0]); 41 // Global.kernel!sum.call(x.data.ptr, y.ptr, cast(int) x.data.length) 42 // .launch(cast(uint[3]) [1U,1,1], cast(uint[3]) [1U,1,1], 0U); 43 // checkCudaErrors(cuCtxSynchronize()); 44 return x.data.sumNaive.variable.to!DeviceStorage; 45 } 46 47 auto backward(Variable!(T, 0, DeviceStorage) y) { 48 auto gx = uninitVariable!(T, DeviceStorage)(this.shape, y.requiresGrad); 49 gx.data.fill_(y.data.toHost[0]); 50 return gx; 51 } 52 } 53 54 mixin FunctionCommon; 55 } 56 57 /// 58 unittest { 59 import mir.ndslice; 60 import mir.math; 61 auto x = [1f, 2f, 3f, 4f].sliced(2, 2).variable; 62 Sum!("fast", float, 2) func; 63 auto y = func.forward(x); 64 assert(y == 10f.variable); 65 assert(func.backward(1.2f.variable) == [1.2f, 1.2f, 1.2f, 1.2f].sliced(2, 2).variable); 66 67 version (grain_cuda) { 68 auto cx = x.to!DeviceStorage; 69 auto cy = func.forward(cx).to!HostStorage; 70 assert(cy == 10f.variable); 71 auto cgx = func.backward(1.2f.variable.to!DeviceStorage).to!HostStorage; 72 assert(cgx.sliced == [1.2f, 1.2f, 1.2f, 1.2f].sliced(2, 2)); 73 } 74 }