1 module grain.functions.random; 2 3 import grain.autograd; 4 import grain.utility : castArray; 5 import grain.functions.common; 6 7 /// TODO create cuRAND wrappers 8 struct Dropout(T, size_t dim) { 9 import numir.random : generate; 10 import mir.ndslice : as, slice; 11 import mir.random.variable : BernoulliVariable; 12 13 float ratio = 0.5; 14 Variable!(T, dim, HostStorage) hostMask; 15 16 this(double ratio) { 17 assert(0.0 <= ratio && ratio <= 1.0); 18 this.ratio = ratio; 19 } 20 21 auto forward(Variable!(T, dim, HostStorage) x) { 22 if (this.ratio == 0.0) return x; 23 24 import mir.ndslice; // : universal; 25 const shape = x.shape.castArray!size_t; 26 const float survived = 1.0 - this.ratio; 27 const float scale = 1.0f / (1.0f - survived); 28 auto mask = BernoulliVariable!T(survived).generate(shape).as!T.slice.universal; 29 mask[] *= scale; 30 this.hostMask = mask.variable; 31 return this.hostMask * x; 32 } 33 34 auto backward(Variable!(T, dim, HostStorage) gy) { 35 assert(gy.shape == this.hostMask.shape); 36 return this.hostMask * gy; 37 } 38 39 version (grain_cuda) { 40 import grain.cudnn : CudnnDropout; 41 CudnnDropout impl; 42 43 auto forward(Variable!(T, dim, DeviceStorage) x) { 44 return this.impl.forward(x, this.ratio); 45 } 46 47 auto backward(Variable!(T, dim, DeviceStorage) gy) { 48 return this.impl.backward(gy); 49 } 50 } 51 52 mixin FunctionCommon; 53 } 54 55 56 unittest { 57 Dropout!(float, 2) func; 58 auto x = [[1f, 2f, 3f], [4f, 5f, 6f]].variable; 59 auto y = func.forward(x); 60 auto gx = func.backward(x); 61 foreach (i; 0 .. x.shape[0]) { 62 foreach (j; 0 .. x.shape[1]) { 63 auto yij = y.sliced[i, j]; 64 assert(yij == 0 || yij == 2.0 * x.sliced[i, j]); 65 assert(yij == gx.sliced[i, j]); 66 } 67 } 68 69 version (grain_cuda) { 70 auto cx = x.to!DeviceStorage; 71 auto cy = func.forward(cx).to!HostStorage; 72 auto cgx = func.backward(cx).to!HostStorage; 73 74 foreach (i; 0 .. x.shape[0]) { 75 foreach (j; 0 .. x.shape[1]) { 76 auto yij = cy.sliced[i, j]; 77 assert(yij == 0 || yij == 2.0 * x.sliced[i, j]); 78 assert(yij == cgx.sliced[i, j]); 79 } 80 } 81 } 82 }