1 module grain.functions.random;
2 
3 import grain.autograd;
4 import grain.utility : castArray;
5 import grain.functions.common;
6 
7 /// TODO create cuRAND wrappers
8 struct Dropout(T, size_t dim) {
9     import numir.random : generate;
10     import mir.ndslice : as, slice;
11     import mir.random.variable : BernoulliVariable;
12 
13     float ratio = 0.5;
14     Variable!(T, dim, HostStorage) hostMask;
15 
16     this(double ratio) {
17         assert(0.0 <= ratio && ratio <= 1.0);
18         this.ratio = ratio;
19     }
20 
21     auto forward(Variable!(T, dim, HostStorage) x) {
22         if (this.ratio == 0.0) return x;
23 
24         import mir.ndslice; //  : universal;
25         const shape = x.shape.castArray!size_t;
26         const float survived = 1.0 - this.ratio;
27         const float scale = 1.0f / (1.0f - survived);
28         auto mask = BernoulliVariable!T(survived).generate(shape).as!T.slice.universal;
29         mask[] *= scale;
30         this.hostMask = mask.variable;
31         return this.hostMask * x;
32     }
33 
34     auto backward(Variable!(T, dim, HostStorage) gy) {
35         assert(gy.shape == this.hostMask.shape);
36         return this.hostMask * gy;
37     }
38 
39     version (grain_cuda) {
40         import grain.cudnn : CudnnDropout;
41         CudnnDropout impl;
42 
43         auto forward(Variable!(T, dim, DeviceStorage) x) {
44             return this.impl.forward(x, this.ratio);
45         }
46 
47         auto backward(Variable!(T, dim, DeviceStorage) gy) {
48             return this.impl.backward(gy);
49         }
50     }
51 
52     mixin FunctionCommon;
53 }
54 
55 
56 unittest {
57     Dropout!(float, 2) func;
58     auto x = [[1f, 2f, 3f], [4f, 5f, 6f]].variable;
59     auto y = func.forward(x);
60     auto gx = func.backward(x);
61     foreach (i; 0 .. x.shape[0]) {
62         foreach (j; 0 .. x.shape[1]) {
63             auto yij = y.sliced[i, j];
64             assert(yij == 0 || yij == 2.0 * x.sliced[i, j]);
65             assert(yij == gx.sliced[i, j]);
66         }
67     }
68 
69     version (grain_cuda) {
70         auto cx = x.to!DeviceStorage;
71         auto cy = func.forward(cx).to!HostStorage;
72         auto cgx = func.backward(cx).to!HostStorage;
73 
74         foreach (i; 0 .. x.shape[0]) {
75             foreach (j; 0 .. x.shape[1]) {
76                 auto yij = cy.sliced[i, j];
77                 assert(yij == 0 || yij == 2.0 * x.sliced[i, j]);
78                 assert(yij == cgx.sliced[i, j]);
79             }
80         }
81     }
82 }