1 /++
2 Chain means autograd operators in grain that is equivalent to
3 - pytorch: torch.nn.Module
4 - chainer: chainer.Chain or chainer.Link
5 
6 Users cannot apply grain.functions to Variable without new or applyForward.
7 Instead of that, you can apply grain.chains to Variable with opCall.
8  +/
9 module grain.chain;
10 
11 import numir : normal;
12 
13 import grain.autograd; // : Variable, variable, to;
14 
15 
16 // enum isChain(T) = {
17 //     import std.traits;
18 //     import std.meta;
19 //     alias R = ReturnType!(T.init);
20 //     if (isVariable!R) return true;
21 //     if (isTuple!() AllSatisfy!(isVariable, ReturnType!(T.init));
22 // }();
23 
24 /// linear operator
25 struct Linear(T, alias Storage) {
26     import mir.ndslice : slice;
27     import std.traits : isFloatingPoint;
28     import grain.functions : MatMul, AddBias;
29     static assert(isFloatingPoint!T);
30     Variable!(T, 2, Storage) weight;
31     Variable!(T, 1, Storage) bias;
32 
33     this(int ninput, int noutput) {
34         import numir;
35         import mir.random.variable;
36         auto stdv = 1.0 / (cast(T) noutput ^^ 0.5);
37         this.weight = UniformVariable!T(-stdv, stdv).generate(ninput, noutput).slice.variable(true).to!Storage;
38         this.bias = UniformVariable!T(-stdv, stdv).generate(noutput).slice.variable(true).to!Storage;
39     }
40 
41     auto opCall(Variable!(T, 2, Storage) x) {
42         auto matmul = new MatMul!T;
43         auto wx = matmul.applyForward(x, this.weight);
44         auto addbias = new AddBias!T;
45         return addbias.applyForward(wx, this.bias);
46     }
47 }
48 
49 // rectified linear unit nonlinearity
50 auto relu(T, size_t dim, alias Storage)(Variable!(T, dim, Storage) x) {
51     import grain.functions : ReLU;
52     auto func = new ReLU!(T, dim);
53     return func.applyForward(x);
54 }
55 
56 // cross entropy loss (logsoftmax -> negative loglikelihood function)
57 auto crossEntropy(alias Storage)(Variable!(float, 2, Storage) x, Variable!(int, 1, Storage) t, int ignoreIndex=-100) {
58     import grain.functions : LogSoftmax, NegativeLogLikelihood;
59     auto lsmax = new LogSoftmax!(float, 2);
60     auto y = lsmax.applyForward(x);
61     auto nll = new NegativeLogLikelihood!(float, int);
62     nll.ignoreIndex = ignoreIndex;
63     return nll.applyForward(y, t);
64 }
65 
66 
67 /// test variable.backward
68 unittest {
69     /* pytorch equivalent
70        >>> x = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], requires_grad=True)
71        >>> t = torch.tensor([1, 0, -100], dtype=torch.long)
72        >>> l = torch.nn.functional.cross_entropy(x, t)
73        >>> l
74        tensor(0.6944)
75        >>> l.backward()
76        >>> x.grad
77        tensor([[ 0.2375, -0.2375],
78                [-0.2625,  0.2625],
79                [ 0.0000,  0.0000]])
80      */
81     import std.stdio;
82     import std.typecons;
83     import mir.ndslice;
84     import grain.autograd;
85     import numir;
86 
87     grain.autograd.backprop = true;
88 
89     auto hx = [[0.1f, 0.2f], [0.3f, 0.4f], [0.5f, 0.6f]].variable(true);
90     auto ht = [1, 0, -100].variable;
91     auto hl = crossEntropy(hx, ht);
92     hl.backward();
93     assert(approxEqual(hx.gradSliced,
94                        [[ 0.2375, -0.2375],
95                         [-0.2625,  0.2625],
96                         [ 0.0000,  0.0000]].nparray));
97 
98     version (grain_cuda) {
99         auto dx = hx.to!DeviceStorage;
100         dx.grad.zero_();
101         auto dt = ht.to!DeviceStorage;
102         auto dl = crossEntropy(dx, dt);
103         assert(approxEqual(hl.sliced, dl.to!HostStorage.sliced));
104         dl.backward();
105         assert(approxEqual(dx.to!HostStorage.gradSliced, hx.gradSliced));
106     }
107 }
108