1 /**
2 A module for gradient descent optimizer
3 */
4 module grain.optim;
5
6 import std.stdio;
7 import grain.autograd : isVariable, zero_, isHost, UntypedVariable, Variable,
8 HostStorage, iterVariables;
9 import std.traits : hasMember;
10 import std.stdio;
11
12 version (grain_cuda) {
13 import grain.cuda : zero_;
14 import grain.cudnn : transform;
15 }
16
17 /// fill gradient arrays with zero
18 void zeroGrad(C)(ref C chain) {
19 foreach (ref field; chain.tupleof) {
20 alias F = typeof(field);
21 static if (isVariable!F) {
22 field.grad.zero_();
23 }
24 else static if (hasMember!(F, "tupleof")) { // static if (isChain!F) {
25 field.zeroGrad();
26 }
27 }
28 }
29
30 /// trait to identify optimizer
31 enum bool isOptimizer(T) = is(typeof({
32 import grain.autograd;
33
34 Variable!(float, 2) v;
35 T.init.step("", v);
36 }));
37
38 /// structure to memorize the stats e.g., momentum
39 alias StateDict = UntypedVariable[string];
40
41 /// public api to update a target model
42 void update(O)(ref O optimizer) { // if (isOptimizer!O) {
43 iterVariables!((k, v) { optimizer.step(k, v); })(optimizer.target, "");
44 }
45
46 /// CPU version of cudnn.transform
47 void transform(T, size_t dim)(Variable!(T, dim, HostStorage) src,
48 ref Variable!(T, dim, HostStorage) dst, T alpha = 1, T beta = 0) {
49 if (beta == 0) {
50 dst.sliced[] = alpha * src.sliced;
51 return;
52 }
53 if (beta != 1)
54 dst.sliced[] = beta * dst.sliced;
55 dst.sliced[] += alpha * src.sliced;
56 }
57
58 version (unittest) {
59 struct MLP(T, alias Storage) {
60 import grain.autograd : Variable;
61 import grain.chain : Linear, relu;
62
63 alias L = Linear!(T, Storage);
64 L fc1, fc2, fc3;
65
66 this(int nhidden) {
67 this.fc1 = L(2, nhidden);
68 this.fc2 = L(nhidden, nhidden);
69 this.fc3 = L(nhidden, 10);
70 }
71
72 auto opCall(Variable!(T, 2, Storage) x) {
73 auto h1 = relu(this.fc1(x));
74 auto h2 = relu(this.fc2(h1));
75 auto h3 = this.fc3(h2);
76 return h1;
77 }
78 }
79 }
80
81 /// stochastic gradient descent optimizer
82 struct SGD(Chain) {
83 Chain* target;
84 float lr = 1.0;
85 // float momentum = 0.0;
86 // float weightDecay = 0.0;
87 this(ref Chain target, float lr = 1.0) {
88 this.target = ⌖
89 this.lr = lr;
90 }
91
92 ///
93 void step(V)(string name, ref V field) if (isVariable!V) {
94 // transform(field.gradVariable, field, -this.lr, 1.0);
95
96 // FIXME : this code is much faster than above (250fps -> 300fps in example/mnist.d)
97 static if (isHost!V) {
98 // if (field.data.ptr != null && field.grad.ptr != null)
99 // writeln(field);
100 // writeln(field.sliced.shape);
101 // writeln(field.gradSliced.shape);
102 field.sliced[] -= this.lr * field.gradSliced[];
103 }
104 else {
105 import grain.cuda : axpy;
106
107 axpy(field.grad, field.data, -this.lr);
108 }
109 }
110 }
111
112 ///
113 unittest {
114 import std.stdio;
115 import numir;
116 import grain.autograd; // : Variable, HostStorage;
117
118 {
119 auto mlp = MLP!(float, HostStorage)(3);
120 mlp.fc1.weight.grad[0] = 1.0;
121 mlp.zeroGrad();
122 assert(mlp.fc1.weight.grad[0] == 0.0);
123
124 auto sgd = SGD!(typeof(mlp))(mlp, 0.5);
125 mlp.fc1.weight.data.zero_();
126 mlp.fc1.weight.grad = [[1.0f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable
127 .data;
128 sgd.update();
129 assert(mlp.fc1.weight.sliced == [[-0.5, 0.0, 0.0], [0.0, 0.0, 0.0]]);
130 }
131 version (grain_cuda) {
132 auto mlp = MLP!(float, DeviceStorage)(3);
133 mlp.fc1.weight.grad = [[1.0f, 0.0f, 0.0f], [0.0f, 0.0f,
134 0.0f]].variable.to!DeviceStorage.data;
135 mlp.zeroGrad();
136 assert(mlp.fc1.weight.to!HostStorage.gradSliced == [[0.0, 0.0, 0.0], [0.0,
137 0.0, 0.0]]);
138
139 auto sgd = SGD!(typeof(mlp))(mlp, 0.5);
140 mlp.fc1.weight.data.zero_();
141 mlp.fc1.weight.grad = [[1.0f, 0.0f, 0.0f], [0.0f, 0.0f,
142 0.0f]].variable.to!DeviceStorage.data;
143 sgd.update();
144 assert(mlp.fc1.weight.to!HostStorage.sliced == [[-0.5, 0.0, 0.0], [0.0, 0.0,
145 0.0]]);
146 }
147 }
148
149 /// http://jmlr.org/papers/v12/duchi11a.html
150 struct AdaGrad(Chain) {
151 import grain.autograd;
152
153 Chain* target;
154 float lr = 1.0;
155 float eps = 1e-8;
156 StateDict memory;
157
158 ///
159 this(ref Chain target, float lr = 1e-3, float eps = 1e-8) {
160 this.target = ⌖
161 this.lr = lr;
162 this.eps = eps;
163 iterVariables!((k, v) { this.initStates(k, v); })(this.target);
164 }
165
166 ///
167 void initStates(V)(string name, ref V field) if (isVariable!V) {
168 if (name !in this.memory) {
169 auto m = field.uninit();
170 m.data.zero_();
171 this.memory[name] = UntypedVariable(m);
172 }
173 }
174
175 ///
176 void step(V)(string name, ref V field) if (isVariable!V) {
177 import grain.chain : pow;
178
179 auto m = memory[name].to!V;
180 auto g = field.gradVariable;
181 auto mn = m + g * g;
182 auto diff = g / pow(mn + this.eps, 0.5); // TODO implement sqrt
183 memory[name] = UntypedVariable(mn);
184 transform(diff, field, -this.lr, 1.0);
185 }
186 }
187
188 ///
189 unittest {
190 import grain.autograd;
191 import numir;
192
193 {
194 float lr = 0.1;
195 float eps = 1e-8;
196 auto model = MLP!(float, HostStorage)(3);
197 auto optim = AdaGrad!(typeof(model))(model, lr, eps);
198 static assert(isOptimizer!(typeof(optim)));
199 model.fc1.weight.data.zero_();
200 model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable
201 .data;
202 optim.update();
203 auto w = model.fc1.weight;
204 assert(approxEqual(w.sliced, [[-lr * 0.2 / (0.2 * 0.2 + eps) ^^ 0.5, 0.0,
205 0.0], [0.0, 0.0, 0.0]].nparray));
206 auto m = optim.memory[".fc1.weight"].to!(typeof(w));
207 assert(approxEqual(m.sliced, [[0.2 * 0.2, 0.0, 0.0], [0.0, 0.0, 0.0]]
208 .nparray));
209 }
210 version (grain_cuda) {
211 auto model = MLP!(float, DeviceStorage)(3);
212 auto optim = AdaGrad!(typeof(model))(model, 0.1);
213 optim.update();
214 }
215 }
216
217 /// https://arxiv.org/pdf/1412.6980v8.pdf
218 struct Adam(Chain) {
219 import grain.autograd;
220
221 Chain* target;
222 float lr = 1.0;
223 float beta1 = 0.9;
224 float beta2 = 0.999;
225 float eps = 1e-8;
226 StateDict moment1, moment2;
227
228 ///
229 this(ref Chain target, float lr, float eps = 1e-8) {
230 this.target = ⌖
231 this.lr = lr;
232 this.eps = eps;
233 iterVariables!((k, v) { this.initStates(k, v); })(this.target);
234 }
235
236 ///
237 void initStates(V)(string name, ref V field) if (isVariable!V) {
238 if (name !in this.moment1) {
239 auto m = field.uninit();
240 m.data.zero_();
241 this.moment1[name] = UntypedVariable(m);
242 }
243 if (name !in this.moment2) {
244 auto m = field.uninit();
245 m.data.zero_();
246 this.moment2[name] = UntypedVariable(m);
247 }
248 }
249
250 ///
251 void step(V)(string name, ref V field) if (isVariable!V) {
252 import grain.chain : pow;
253
254 auto g = field.gradVariable;
255 auto m1 = this.moment1[name].to!V;
256 auto m2 = this.moment2[name].to!V;
257 auto nextM1 = (1.0 - this.beta1) * (g - m1) + m1;
258 auto nextM2 = (1.0 - this.beta2) * (g * g - m2) + m2;
259 auto diff = nextM1 / pow(nextM2 + this.eps, 0.5); // TODO implement sqrt
260 this.moment1[name] = UntypedVariable(nextM1);
261 this.moment2[name] = UntypedVariable(nextM2);
262 transform(diff, field, -this.lr, 1.0);
263 }
264 }
265
266 ///
267 unittest {
268 import grain.autograd;
269 import numir;
270
271 {
272 auto model = MLP!(float, HostStorage)(3);
273 auto optim = Adam!(typeof(model))(model, 1e-3);
274 static assert(isOptimizer!(typeof(optim)));
275 model.fc1.weight.data.zero_();
276 model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable
277 .data;
278 optim.update();
279 auto w = model.fc1.weight;
280 auto m1 = (1.0 - optim.beta1) * (0.2 - 0.0) + 0.0;
281 auto m2 = (1.0 - optim.beta2) * (0.2 * 0.2 - 0.0) + 0.0;
282 assert(approxEqual(w.sliced, [[-optim.lr * m1 / (m2 + optim.eps) ^^ 0.5,
283 0.0, 0.0], [0.0, 0.0, 0.0]].nparray));
284 auto m1_ = optim.moment1[".fc1.weight"].to!(typeof(w));
285 assert(approxEqual(m1_.sliced, [[m1, 0.0, 0.0], [0.0, 0.0, 0.0]].nparray));
286 auto m2_ = optim.moment2[".fc1.weight"].to!(typeof(w));
287 assert(approxEqual(m2_.sliced, [[m2, 0.0, 0.0], [0.0, 0.0, 0.0]].nparray));
288 }
289 version (grain_cuda) {
290 auto model = MLP!(float, DeviceStorage)(3);
291 auto optim = Adam!(typeof(model))(model, 0.1);
292 optim.update();
293 }
294 }
295
296 /// http://www.matthewzeiler.com/pubs/googleTR2012/googleTR2012.pdf
297 struct AdaDelta(Chain) {
298 import grain.autograd;
299
300 Chain* target;
301 float lr = 1.0;
302 float rho = 0.95;
303 float eps = 1e-6;
304
305 StateDict den, num;
306
307 ///
308 this(ref Chain target, float lr = 1.0, float rho = 0.95, float eps = 1e-8) {
309 this.target = ⌖
310 this.lr = lr;
311 this.rho = rho;
312 this.eps = eps;
313 iterVariables!((k, v) { this.initStates(k, v); })(this.target);
314 }
315
316 ///
317 void initStates(V)(string name, ref V field) if (isVariable!V) {
318 if (name !in this.den) {
319 auto m = field.uninit();
320 m.data.zero_();
321 this.den[name] = UntypedVariable(m);
322 }
323 if (name !in this.num) {
324 auto m = field.uninit();
325 m.data.zero_();
326 this.num[name] = UntypedVariable(m);
327 }
328 }
329
330 ///
331 void step(V)(string name, ref V field) if (isVariable!V) {
332 import grain.chain : pow;
333
334 auto g = field.gradVariable;
335 auto d = this.den[name].to!V;
336 auto n = this.num[name].to!V;
337 auto nextDen = (1.0 - this.rho) * g * g + this.rho * d;
338 auto diff = pow((n + this.eps) / (nextDen + this.eps), 0.5); // TODO implement sqrt
339 auto nextNum = (1.0 - this.rho) * diff * diff + this.rho * n;
340 this.den[name] = UntypedVariable(nextDen);
341 this.num[name] = UntypedVariable(nextNum);
342 transform(diff, field, -this.lr, 1.0);
343 }
344 }
345
346 auto make(alias O, C, Args ...)(ref C chain, Args args) if (isOptimizer!(O!C)) {
347 return O!C(chain, args);
348 }
349
350 ///
351 unittest {
352 import grain.autograd;
353 import numir;
354
355 {
356 auto model = MLP!(float, HostStorage)(3);
357 // auto optim = AdaDelta!(typeof(model))(model);
358 auto optim = make!AdaDelta(model);
359 // static assert(isOptimizer!(typeof(optim)));
360 model.fc1.weight.data.zero_();
361 model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable
362 .data;
363 optim.update();
364 auto w = model.fc1.weight;
365 auto d = (1.0 - optim.rho) * 0.2 * 0.2;
366 auto diff = cast(float)((0.0 + optim.eps) / (d + optim.eps)) ^^ 0.5;
367 auto n = (1.0 - optim.rho) * diff * diff;
368 assert(approxEqual(w.sliced, [[-optim.lr * diff, -optim.lr,
369 -optim.lr], [-optim.lr, -optim.lr, -optim.lr]].nparray));
370 auto d_ = optim.den[".fc1.weight"].to!(typeof(w));
371 auto n_ = optim.num[".fc1.weight"].to!(typeof(w));
372 assert(approxEqual(d_.sliced, [[d, 0.0, 0.0], [0.0, 0.0, 0.0]].nparray));
373 assert(approxEqual(n_.sliced[0, 0 .. 1], [n].nparray));
374 }
375 version (grain_cuda) {
376 auto model = MLP!(float, DeviceStorage)(3);
377 auto optim = AdaDelta!(typeof(model))(model);
378 optim.update();
379 }
380 }