1 /**
2    A module for gradient descent optimizer
3  */
4 module grain.optim;
5 
6 import std.stdio;
7 import grain.autograd : isVariable, zero_, isHost, UntypedVariable, Variable,
8     HostStorage, iterVariables;
9 import std.traits : hasMember;
10 import std.stdio;
11 
12 version (grain_cuda) {
13     import grain.cuda : zero_;
14     import grain.cudnn : transform;
15 }
16 
17 /// fill gradient arrays with zero
18 void zeroGrad(C)(ref C chain) {
19     foreach (ref field; chain.tupleof) {
20         alias F = typeof(field);
21         static if (isVariable!F) {
22             field.grad.zero_();
23         }
24         else static if (hasMember!(F, "tupleof")) { // static if (isChain!F) {
25             field.zeroGrad();
26         }
27     }
28 }
29 
30 /// trait to identify optimizer
31 enum bool isOptimizer(T) = is(typeof({
32             import grain.autograd;
33 
34             Variable!(float, 2) v;
35             T.init.step("", v);
36         }));
37 
38 /// structure to memorize the stats e.g., momentum
39 alias StateDict = UntypedVariable[string];
40 
41 /// public api to update a target model
42 void update(O)(ref O optimizer) { // if (isOptimizer!O) {
43     iterVariables!((k, v) { optimizer.step(k, v); })(optimizer.target, "");
44 }
45 
46 /// CPU version of cudnn.transform
47 void transform(T, size_t dim)(Variable!(T, dim, HostStorage) src,
48         ref Variable!(T, dim, HostStorage) dst, T alpha = 1, T beta = 0) {
49     if (beta == 0) {
50         dst.sliced[] = alpha * src.sliced;
51         return;
52     }
53     if (beta != 1)
54         dst.sliced[] = beta * dst.sliced;
55     dst.sliced[] += alpha * src.sliced;
56 }
57 
58 version (unittest) {
59     struct MLP(T, alias Storage) {
60         import grain.autograd : Variable;
61         import grain.chain : Linear, relu;
62 
63         alias L = Linear!(T, Storage);
64         L fc1, fc2, fc3;
65 
66         this(int nhidden) {
67             this.fc1 = L(2, nhidden);
68             this.fc2 = L(nhidden, nhidden);
69             this.fc3 = L(nhidden, 10);
70         }
71 
72         auto opCall(Variable!(T, 2, Storage) x) {
73             auto h1 = relu(this.fc1(x));
74             auto h2 = relu(this.fc2(h1));
75             auto h3 = this.fc3(h2);
76             return h1;
77         }
78     }
79 }
80 
81 /// stochastic gradient descent optimizer
82 struct SGD(Chain) {
83     Chain* target;
84     float lr = 1.0;
85     // float momentum = 0.0;
86     // float weightDecay = 0.0;
87     this(ref Chain target, float lr = 1.0) {
88         this.target = ⌖
89         this.lr = lr;
90     }
91 
92     ///
93     void step(V)(string name, ref V field) if (isVariable!V) {
94         // transform(field.gradVariable, field, -this.lr, 1.0);
95 
96         // FIXME : this code is much faster than above (250fps -> 300fps in example/mnist.d)
97         static if (isHost!V) {
98             // if (field.data.ptr != null && field.grad.ptr != null)
99             // writeln(field);
100             // writeln(field.sliced.shape);
101             // writeln(field.gradSliced.shape);
102             field.sliced[] -= this.lr * field.gradSliced[];
103         }
104         else {
105             import grain.cuda : axpy;
106 
107             axpy(field.grad, field.data, -this.lr);
108         }
109     }
110 }
111 
112 ///
113 unittest {
114     import std.stdio;
115     import numir;
116     import grain.autograd; // : Variable, HostStorage;
117 
118     {
119         auto mlp = MLP!(float, HostStorage)(3);
120         mlp.fc1.weight.grad[0] = 1.0;
121         mlp.zeroGrad();
122         assert(mlp.fc1.weight.grad[0] == 0.0);
123 
124         auto sgd = SGD!(typeof(mlp))(mlp, 0.5);
125         mlp.fc1.weight.data.zero_();
126         mlp.fc1.weight.grad = [[1.0f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable
127             .data;
128         sgd.update();
129         assert(mlp.fc1.weight.sliced == [[-0.5, 0.0, 0.0], [0.0, 0.0, 0.0]]);
130     }
131     version (grain_cuda) {
132         auto mlp = MLP!(float, DeviceStorage)(3);
133         mlp.fc1.weight.grad = [[1.0f, 0.0f, 0.0f], [0.0f, 0.0f,
134             0.0f]].variable.to!DeviceStorage.data;
135         mlp.zeroGrad();
136         assert(mlp.fc1.weight.to!HostStorage.gradSliced == [[0.0, 0.0, 0.0], [0.0,
137                 0.0, 0.0]]);
138 
139         auto sgd = SGD!(typeof(mlp))(mlp, 0.5);
140         mlp.fc1.weight.data.zero_();
141         mlp.fc1.weight.grad = [[1.0f, 0.0f, 0.0f], [0.0f, 0.0f,
142             0.0f]].variable.to!DeviceStorage.data;
143         sgd.update();
144         assert(mlp.fc1.weight.to!HostStorage.sliced == [[-0.5, 0.0, 0.0], [0.0, 0.0,
145                 0.0]]);
146     }
147 }
148 
149 /// http://jmlr.org/papers/v12/duchi11a.html
150 struct AdaGrad(Chain) {
151     import grain.autograd;
152 
153     Chain* target;
154     float lr = 1.0;
155     float eps = 1e-8;
156     StateDict memory;
157 
158     ///
159     this(ref Chain target, float lr = 1e-3, float eps = 1e-8) {
160         this.target = ⌖
161         this.lr = lr;
162         this.eps = eps;
163         iterVariables!((k, v) { this.initStates(k, v); })(this.target);
164     }
165 
166     ///
167     void initStates(V)(string name, ref V field) if (isVariable!V) {
168         if (name !in this.memory) {
169             auto m = field.uninit();
170             m.data.zero_();
171             this.memory[name] = UntypedVariable(m);
172         }
173     }
174 
175     ///
176     void step(V)(string name, ref V field) if (isVariable!V) {
177         import grain.chain : pow;
178 
179         auto m = memory[name].to!V;
180         auto g = field.gradVariable;
181         auto mn = m + g * g;
182         auto diff = g / pow(mn + this.eps, 0.5); // TODO implement sqrt
183         memory[name] = UntypedVariable(mn);
184         transform(diff, field, -this.lr, 1.0);
185     }
186 }
187 
188 ///
189 unittest {
190     import grain.autograd;
191     import numir;
192 
193     {
194         float lr = 0.1;
195         float eps = 1e-8;
196         auto model = MLP!(float, HostStorage)(3);
197         auto optim = AdaGrad!(typeof(model))(model, lr, eps);
198         static assert(isOptimizer!(typeof(optim)));
199         model.fc1.weight.data.zero_();
200         model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable
201             .data;
202         optim.update();
203         auto w = model.fc1.weight;
204         assert(approxEqual(w.sliced, [[-lr * 0.2 / (0.2 * 0.2 + eps) ^^ 0.5, 0.0,
205                 0.0], [0.0, 0.0, 0.0]].nparray));
206         auto m = optim.memory[".fc1.weight"].to!(typeof(w));
207         assert(approxEqual(m.sliced, [[0.2 * 0.2, 0.0, 0.0], [0.0, 0.0, 0.0]]
208                 .nparray));
209     }
210     version (grain_cuda) {
211         auto model = MLP!(float, DeviceStorage)(3);
212         auto optim = AdaGrad!(typeof(model))(model, 0.1);
213         optim.update();
214     }
215 }
216 
217 /// https://arxiv.org/pdf/1412.6980v8.pdf
218 struct Adam(Chain) {
219     import grain.autograd;
220 
221     Chain* target;
222     float lr = 1.0;
223     float beta1 = 0.9;
224     float beta2 = 0.999;
225     float eps = 1e-8;
226     StateDict moment1, moment2;
227 
228     ///
229     this(ref Chain target, float lr, float eps = 1e-8) {
230         this.target = ⌖
231         this.lr = lr;
232         this.eps = eps;
233         iterVariables!((k, v) { this.initStates(k, v); })(this.target);
234     }
235 
236     ///
237     void initStates(V)(string name, ref V field) if (isVariable!V) {
238         if (name !in this.moment1) {
239             auto m = field.uninit();
240             m.data.zero_();
241             this.moment1[name] = UntypedVariable(m);
242         }
243         if (name !in this.moment2) {
244             auto m = field.uninit();
245             m.data.zero_();
246             this.moment2[name] = UntypedVariable(m);
247         }
248     }
249 
250     ///
251     void step(V)(string name, ref V field) if (isVariable!V) {
252         import grain.chain : pow;
253 
254         auto g = field.gradVariable;
255         auto m1 = this.moment1[name].to!V;
256         auto m2 = this.moment2[name].to!V;
257         auto nextM1 = (1.0 - this.beta1) * (g - m1) + m1;
258         auto nextM2 = (1.0 - this.beta2) * (g * g - m2) + m2;
259         auto diff = nextM1 / pow(nextM2 + this.eps, 0.5); // TODO implement sqrt
260         this.moment1[name] = UntypedVariable(nextM1);
261         this.moment2[name] = UntypedVariable(nextM2);
262         transform(diff, field, -this.lr, 1.0);
263     }
264 }
265 
266 ///
267 unittest {
268     import grain.autograd;
269     import numir;
270 
271     {
272         auto model = MLP!(float, HostStorage)(3);
273         auto optim = Adam!(typeof(model))(model, 1e-3);
274         static assert(isOptimizer!(typeof(optim)));
275         model.fc1.weight.data.zero_();
276         model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable
277             .data;
278         optim.update();
279         auto w = model.fc1.weight;
280         auto m1 = (1.0 - optim.beta1) * (0.2 - 0.0) + 0.0;
281         auto m2 = (1.0 - optim.beta2) * (0.2 * 0.2 - 0.0) + 0.0;
282         assert(approxEqual(w.sliced, [[-optim.lr * m1 / (m2 + optim.eps) ^^ 0.5,
283                 0.0, 0.0], [0.0, 0.0, 0.0]].nparray));
284         auto m1_ = optim.moment1[".fc1.weight"].to!(typeof(w));
285         assert(approxEqual(m1_.sliced, [[m1, 0.0, 0.0], [0.0, 0.0, 0.0]].nparray));
286         auto m2_ = optim.moment2[".fc1.weight"].to!(typeof(w));
287         assert(approxEqual(m2_.sliced, [[m2, 0.0, 0.0], [0.0, 0.0, 0.0]].nparray));
288     }
289     version (grain_cuda) {
290         auto model = MLP!(float, DeviceStorage)(3);
291         auto optim = Adam!(typeof(model))(model, 0.1);
292         optim.update();
293     }
294 }
295 
296 /// http://www.matthewzeiler.com/pubs/googleTR2012/googleTR2012.pdf
297 struct AdaDelta(Chain) {
298     import grain.autograd;
299 
300     Chain* target;
301     float lr = 1.0;
302     float rho = 0.95;
303     float eps = 1e-6;
304 
305     StateDict den, num;
306 
307     ///
308     this(ref Chain target, float lr = 1.0, float rho = 0.95, float eps = 1e-8) {
309         this.target = ⌖
310         this.lr = lr;
311         this.rho = rho;
312         this.eps = eps;
313         iterVariables!((k, v) { this.initStates(k, v); })(this.target);
314     }
315 
316     ///
317     void initStates(V)(string name, ref V field) if (isVariable!V) {
318         if (name !in this.den) {
319             auto m = field.uninit();
320             m.data.zero_();
321             this.den[name] = UntypedVariable(m);
322         }
323         if (name !in this.num) {
324             auto m = field.uninit();
325             m.data.zero_();
326             this.num[name] = UntypedVariable(m);
327         }
328     }
329 
330     ///
331     void step(V)(string name, ref V field) if (isVariable!V) {
332         import grain.chain : pow;
333 
334         auto g = field.gradVariable;
335         auto d = this.den[name].to!V;
336         auto n = this.num[name].to!V;
337         auto nextDen = (1.0 - this.rho) * g * g + this.rho * d;
338         auto diff = pow((n + this.eps) / (nextDen + this.eps), 0.5); // TODO implement sqrt
339         auto nextNum = (1.0 - this.rho) * diff * diff + this.rho * n;
340         this.den[name] = UntypedVariable(nextDen);
341         this.num[name] = UntypedVariable(nextNum);
342         transform(diff, field, -this.lr, 1.0);
343     }
344 }
345 
346 auto make(alias O, C, Args ...)(ref C chain, Args args) if (isOptimizer!(O!C)) {
347     return O!C(chain, args);
348 }
349 
350 ///
351 unittest {
352     import grain.autograd;
353     import numir;
354 
355     {
356         auto model = MLP!(float, HostStorage)(3);
357         // auto optim = AdaDelta!(typeof(model))(model);
358         auto optim = make!AdaDelta(model);
359         // static assert(isOptimizer!(typeof(optim)));
360         model.fc1.weight.data.zero_();
361         model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable
362             .data;
363         optim.update();
364         auto w = model.fc1.weight;
365         auto d = (1.0 - optim.rho) * 0.2 * 0.2;
366         auto diff = cast(float)((0.0 + optim.eps) / (d + optim.eps)) ^^ 0.5;
367         auto n = (1.0 - optim.rho) * diff * diff;
368         assert(approxEqual(w.sliced, [[-optim.lr * diff, -optim.lr,
369                 -optim.lr], [-optim.lr, -optim.lr, -optim.lr]].nparray));
370         auto d_ = optim.den[".fc1.weight"].to!(typeof(w));
371         auto n_ = optim.num[".fc1.weight"].to!(typeof(w));
372         assert(approxEqual(d_.sliced, [[d, 0.0, 0.0], [0.0, 0.0, 0.0]].nparray));
373         assert(approxEqual(n_.sliced[0, 0 .. 1], [n].nparray));
374     }
375     version (grain_cuda) {
376         auto model = MLP!(float, DeviceStorage)(3);
377         auto optim = AdaDelta!(typeof(model))(model);
378         optim.update();
379     }
380 }