1 /**
2    A module for testing
3  */
4 module grain.testing;
5 
6 import std.traits : isIntegral;
7 import std.typecons  : isTuple, tuple;
8 
9 import grain.autograd : variable, ElementType;
10 import grain.utility : castArray, toTuple;
11 import grain.functions.common : isFunction;
12 
13 
14 /// computes numeric grad that emulates analytical differential in eps range
15 auto numericGrad(F, In, Out)(ref F func, In inputs, Out gradOutputs, float eps) {
16     import numir; // : zeros_like, view;
17     import mir.ndslice;
18     import mir.math : sum;
19     In gradInputs;
20     foreach (n, x; inputs.toTuple) {
21         static if (isIntegral!(ElementType!(typeof(x)))) {
22             continue;
23         } else {
24             auto xFlat = x.sliced.view(-1);
25             auto gxFlat = zeros_like(xFlat);
26             foreach (i; 0 .. xFlat.length) {
27                 auto origin = xFlat[i];
28                 xFlat[i] = origin + eps;
29                 static if (isFunction!F) {
30                     auto a = func.forward(inputs.toTuple.expand).toTuple;
31                 } else {
32                     auto a = func(inputs.toTuple.expand).toTuple;
33                 }
34                 xFlat[i] = origin - eps;
35 
36                 static if (isFunction!F) {
37                     auto b = func.forward(inputs.toTuple.expand).toTuple;
38                 } else {
39                    auto b = func(inputs.toTuple.expand).toTuple;
40                 }
41 
42                 xFlat[i] = origin;
43                 foreach (m, gy; gradOutputs.toTuple) {
44                     auto sa = a[m].sliced; // copy?
45                     auto sb = b[m].sliced;
46                     sa[] -= sb;
47                     sa[] *= gy.sliced;
48                     gxFlat[i] += sum!"fast"(sa) / (2.0 * eps);
49                 }
50             }
51             auto gx = gxFlat.universal.view(x.shape.castArray!ptrdiff_t);
52             gradInputs[n] = gx.variable;
53         }
54     }
55     return gradInputs;
56 }
57 
58 
59 /// computes numeric grad that emulates analytical differential in eps range
60 auto numericGradChain(alias func, In, Out)(In inputs, Out gradOutputs, float eps) {
61     import numir; // : zeros_like, view;
62     import mir.ndslice;
63     import mir.math : sum;
64     In gradInputs;
65     foreach (n, x; inputs.toTuple) {
66         static if (isIntegral!(ElementType!(typeof(x)))) {
67             continue;
68         } else {
69             auto xFlat = x.sliced.view(-1);
70             auto gxFlat = zeros_like(xFlat);
71             foreach (i; 0 .. xFlat.length) {
72                 auto origin = xFlat[i];
73                 xFlat[i] = origin + eps;
74                 auto a = func(inputs.toTuple.expand).toTuple;
75                 xFlat[i] = origin - eps;
76                 auto b = func(inputs.toTuple.expand).toTuple;
77 
78                 xFlat[i] = origin;
79                 foreach (m, gy; gradOutputs.toTuple) {
80                     auto sa = a[m].sliced; // copy?
81                     auto sb = b[m].sliced;
82                     sa[] -= sb;
83                     sa[] *= gy.sliced;
84                     gxFlat[i] += sum!"fast"(sa) / (2.0 * eps);
85                 }
86             }
87             auto gx = gxFlat.universal.view(x.shape.castArray!ptrdiff_t);
88             gradInputs[n] = gx.variable;
89         }
90     }
91     return gradInputs;
92 }
93 
94 /// gradient check function to compare numeric grad and autograd
95 auto gradCheck(F, In, Out, string file = __FILE__, size_t line = __LINE__)(
96     ref F func, In inputs, Out gradOutputs,
97     float eps=1e-3, float rtol=1e-3, float atol=1e-5) {
98     import std.format : format;
99     import numir.testing : approxEqual;
100     static if (isFunction!F) {
101         auto ys = func.forward(inputs.toTuple.expand).toTuple;
102         auto agrad = func.backward(gradOutputs.toTuple.expand).toTuple;
103     } else {
104         auto xs = inputs.toTuple;
105         auto ys = func(xs.expand).toTuple;
106         auto gys = gradOutputs.toTuple;
107         foreach (o, y; ys) {
108             import grain.autograd;
109             auto u = UntypedVariable(gys[o]);
110             y.backward(&u);
111         }
112         typeof(inputs.toTuple) agrad;
113         foreach (i, a; agrad) {
114             agrad[i] = xs[i].gradSlice.variable; // TODO support CUDA
115         }
116     }
117     // FIXME transfer device variable to host before computing numericGrad
118     auto ngrad = numericGrad(func, inputs.toTuple, gradOutputs.toTuple, eps).toTuple;
119     static foreach (i; 0 .. inputs.toTuple.length) {
120         static if (!isIntegral!(ElementType!(typeof(inputs.toTuple[i])))) {
121             assert(approxEqual(agrad[i].sliced, ngrad[i].sliced, rtol, atol),
122                    format!"%d th input grad %s (backprop) != %s (numeric) from %s %d"(i, agrad[i].sliced, ngrad[i].sliced, file , line));
123         }
124     }
125 }
126 
127 
128 /// gradient check function to compare numeric grad and autograd
129 auto gradCheckChain(alias func, In, Out, string file = __FILE__, size_t line = __LINE__)
130     (In inputs, Out gradOutputs, float eps=1e-3, float rtol=1e-3, float atol=1e-5) {
131     import std.format : format;
132     import numir.testing : approxEqual;
133     auto xs = inputs.toTuple;
134     auto ys = func(xs.expand).toTuple;
135     auto gys = gradOutputs.toTuple;
136     foreach (o, ref y; ys) {
137         import grain.autograd;
138         auto u = UntypedVariable(gys[o]);
139         y.backward(&u);
140     }
141     typeof(inputs.toTuple) agrad;
142     foreach (i, ref a; agrad) {
143         agrad[i] = xs[i].gradSliced.variable; // TODO support CUDA
144     }
145 
146     // FIXME transfer device variable to host before computing numericGrad
147     auto ngrad = numericGradChain!func(inputs.toTuple, gradOutputs.toTuple, eps).toTuple;
148     static foreach (i; 0 .. inputs.toTuple.length) {
149         static if (!isIntegral!(ElementType!(typeof(inputs.toTuple[i])))) {
150             assert(approxEqual(agrad[i].sliced, ngrad[i].sliced, rtol, atol),
151                    format!"%d th input grad %s != %s from %s %d"(i, agrad[i].sliced, ngrad[i].sliced, file , line));
152         }
153     }
154 }
155 
156 
157 // TODO CPU-CUDA comparison function