1 module grain.testing;
2 
3 import std.typecons  : isTuple, tuple;
4 
5 import grain.utility : castArray, toTuple;
6 
7 import std.traits : isIntegral;
8 import grain.autograd : variable, ElementType;
9 
10 
11 auto numericGrad(F, In, Out)(ref F func, In inputs, Out gradOutputs, float eps) {
12     import numir; // : zeros_like, view;
13     import mir.ndslice;
14     import mir.math : sum;
15     In gradInputs;
16     foreach (n, x; inputs.toTuple) {
17         static if (isIntegral!(ElementType!(typeof(x)))) {
18             continue;
19         } else {
20             auto xFlat = x.sliced.view(-1);
21             auto gxFlat = zeros_like(xFlat);
22             foreach (i; 0 .. xFlat.length) {
23                 auto origin = xFlat[i];
24                 xFlat[i] = origin + eps;
25                 auto a = func.forward(inputs.toTuple.expand).toTuple;
26                 xFlat[i] = origin - eps;
27                 auto b = func.forward(inputs.toTuple.expand).toTuple;
28                 xFlat[i] = origin;
29                 foreach (m, gy; gradOutputs.toTuple) {
30                     auto sa = a[m].sliced; // copy?
31                     auto sb = b[m].sliced;
32                     sa[] -= sb;
33                     sa[] *= gy.sliced;
34                     gxFlat[i] += sum!"fast"(sa) / (2.0 * eps);
35                 }
36             }
37             auto gx = gxFlat.universal.view(x.shape.castArray!ptrdiff_t);
38             gradInputs[n] = gx.variable;
39         }
40     }
41     return gradInputs;
42 }
43 
44 /// gradient check function to compare numeric grad and autograd
45 auto gradCheck(F, In, Out, string file = __FILE__, size_t line = __LINE__)(
46     ref F func, In inputs, Out gradOutputs,
47     float eps=1e-3, float rtol=1e-3, float atol=1e-5) {
48     import std.format : format;
49     import numir.testing : approxEqual;
50     auto ys = func.forward(inputs.toTuple.expand).toTuple;
51     auto agrad = func.backward(gradOutputs.toTuple.expand).toTuple;
52     // FIXME transfer device variable to host before computing numericGrad
53     auto ngrad = numericGrad(func, inputs.toTuple, gradOutputs.toTuple, eps).toTuple;
54     static foreach (i; 0 .. inputs.toTuple.length) {
55         static if (!isIntegral!(ElementType!(typeof(inputs.toTuple[i])))) {
56             assert(approxEqual(agrad[i].sliced, ngrad[i].sliced, rtol, atol),
57                    format!"%d th input grad %s != %s from %s %d"(i, agrad[i].sliced, ngrad[i].sliced, file , line));
58         }
59     }
60 }
61 
62 
63 // TODO CPU-CUDA comparison function