1 /** 2 A module for gradient descent optimizer 3 */ 4 module grain.optim; 5 6 import std.stdio; 7 import grain.autograd : isVariable, zero_, isHost, UntypedVariable, Variable, 8 HostStorage, iterVariables; 9 import std.traits : hasMember; 10 import std.stdio; 11 12 version (grain_cuda) { 13 import grain.cuda : zero_; 14 import grain.cudnn : transform; 15 } 16 17 /// fill gradient arrays with zero 18 void zeroGrad(C)(ref C chain) { 19 foreach (ref field; chain.tupleof) { 20 alias F = typeof(field); 21 static if (isVariable!F) { 22 field.grad.zero_(); 23 } 24 else static if (hasMember!(F, "tupleof")) { // static if (isChain!F) { 25 field.zeroGrad(); 26 } 27 } 28 } 29 30 /// trait to identify optimizer 31 enum bool isOptimizer(T) = is(typeof({ 32 import grain.autograd; 33 34 Variable!(float, 2) v; 35 T.init.step("", v); 36 })); 37 38 /// structure to memorize the stats e.g., momentum 39 alias StateDict = UntypedVariable[string]; 40 41 /// public api to update a target model 42 void update(O)(ref O optimizer) { // if (isOptimizer!O) { 43 iterVariables!((k, v) { optimizer.step(k, v); })(optimizer.target, ""); 44 } 45 46 /// CPU version of cudnn.transform 47 void transform(T, size_t dim)(Variable!(T, dim, HostStorage) src, 48 ref Variable!(T, dim, HostStorage) dst, T alpha = 1, T beta = 0) { 49 if (beta == 0) { 50 dst.sliced[] = alpha * src.sliced; 51 return; 52 } 53 if (beta != 1) 54 dst.sliced[] = beta * dst.sliced; 55 dst.sliced[] += alpha * src.sliced; 56 } 57 58 version (unittest) { 59 struct MLP(T, alias Storage) { 60 import grain.autograd : Variable; 61 import grain.chain : Linear, relu; 62 63 alias L = Linear!(T, Storage); 64 L fc1, fc2, fc3; 65 66 this(int nhidden) { 67 this.fc1 = L(2, nhidden); 68 this.fc2 = L(nhidden, nhidden); 69 this.fc3 = L(nhidden, 10); 70 } 71 72 auto opCall(Variable!(T, 2, Storage) x) { 73 auto h1 = relu(this.fc1(x)); 74 auto h2 = relu(this.fc2(h1)); 75 auto h3 = this.fc3(h2); 76 return h1; 77 } 78 } 79 } 80 81 /// stochastic gradient descent optimizer 82 struct SGD(Chain) { 83 Chain* target; 84 float lr = 1.0; 85 // float momentum = 0.0; 86 // float weightDecay = 0.0; 87 this(ref Chain target, float lr = 1.0) { 88 this.target = ⌖ 89 this.lr = lr; 90 } 91 92 /// 93 void step(V)(string name, ref V field) if (isVariable!V) { 94 // transform(field.gradVariable, field, -this.lr, 1.0); 95 96 // FIXME : this code is much faster than above (250fps -> 300fps in example/mnist.d) 97 static if (isHost!V) { 98 // if (field.data.ptr != null && field.grad.ptr != null) 99 // writeln(field); 100 // writeln(field.sliced.shape); 101 // writeln(field.gradSliced.shape); 102 field.sliced[] -= this.lr * field.gradSliced[]; 103 } 104 else { 105 import grain.cuda : axpy; 106 107 axpy(field.grad, field.data, -this.lr); 108 } 109 } 110 } 111 112 /// 113 unittest { 114 import std.stdio; 115 import numir; 116 import grain.autograd; // : Variable, HostStorage; 117 118 { 119 auto mlp = MLP!(float, HostStorage)(3); 120 mlp.fc1.weight.grad[0] = 1.0; 121 mlp.zeroGrad(); 122 assert(mlp.fc1.weight.grad[0] == 0.0); 123 124 auto sgd = SGD!(typeof(mlp))(mlp, 0.5); 125 mlp.fc1.weight.data.zero_(); 126 mlp.fc1.weight.grad = [[1.0f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable 127 .data; 128 sgd.update(); 129 assert(mlp.fc1.weight.sliced == [[-0.5, 0.0, 0.0], [0.0, 0.0, 0.0]]); 130 } 131 version (grain_cuda) { 132 auto mlp = MLP!(float, DeviceStorage)(3); 133 mlp.fc1.weight.grad = [[1.0f, 0.0f, 0.0f], [0.0f, 0.0f, 134 0.0f]].variable.to!DeviceStorage.data; 135 mlp.zeroGrad(); 136 assert(mlp.fc1.weight.to!HostStorage.gradSliced == [[0.0, 0.0, 0.0], [0.0, 137 0.0, 0.0]]); 138 139 auto sgd = SGD!(typeof(mlp))(mlp, 0.5); 140 mlp.fc1.weight.data.zero_(); 141 mlp.fc1.weight.grad = [[1.0f, 0.0f, 0.0f], [0.0f, 0.0f, 142 0.0f]].variable.to!DeviceStorage.data; 143 sgd.update(); 144 assert(mlp.fc1.weight.to!HostStorage.sliced == [[-0.5, 0.0, 0.0], [0.0, 0.0, 145 0.0]]); 146 } 147 } 148 149 /// http://jmlr.org/papers/v12/duchi11a.html 150 struct AdaGrad(Chain) { 151 import grain.autograd; 152 153 Chain* target; 154 float lr = 1.0; 155 float eps = 1e-8; 156 StateDict memory; 157 158 /// 159 this(ref Chain target, float lr = 1e-3, float eps = 1e-8) { 160 this.target = ⌖ 161 this.lr = lr; 162 this.eps = eps; 163 iterVariables!((k, v) { this.initStates(k, v); })(this.target); 164 } 165 166 /// 167 void initStates(V)(string name, ref V field) if (isVariable!V) { 168 if (name !in this.memory) { 169 auto m = field.uninit(); 170 m.data.zero_(); 171 this.memory[name] = UntypedVariable(m); 172 } 173 } 174 175 /// 176 void step(V)(string name, ref V field) if (isVariable!V) { 177 import grain.chain : pow; 178 179 auto m = memory[name].to!V; 180 auto g = field.gradVariable; 181 auto mn = m + g * g; 182 auto diff = g / pow(mn + this.eps, 0.5); // TODO implement sqrt 183 memory[name] = UntypedVariable(mn); 184 transform(diff, field, -this.lr, 1.0); 185 } 186 } 187 188 /// 189 unittest { 190 import grain.autograd; 191 import numir; 192 193 { 194 float lr = 0.1; 195 float eps = 1e-8; 196 auto model = MLP!(float, HostStorage)(3); 197 auto optim = AdaGrad!(typeof(model))(model, lr, eps); 198 static assert(isOptimizer!(typeof(optim))); 199 model.fc1.weight.data.zero_(); 200 model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable 201 .data; 202 optim.update(); 203 auto w = model.fc1.weight; 204 assert(approxEqual(w.sliced, [[-lr * 0.2 / (0.2 * 0.2 + eps) ^^ 0.5, 0.0, 205 0.0], [0.0, 0.0, 0.0]].nparray)); 206 auto m = optim.memory[".fc1.weight"].to!(typeof(w)); 207 assert(approxEqual(m.sliced, [[0.2 * 0.2, 0.0, 0.0], [0.0, 0.0, 0.0]] 208 .nparray)); 209 } 210 version (grain_cuda) { 211 auto model = MLP!(float, DeviceStorage)(3); 212 auto optim = AdaGrad!(typeof(model))(model, 0.1); 213 optim.update(); 214 } 215 } 216 217 /// https://arxiv.org/pdf/1412.6980v8.pdf 218 struct Adam(Chain) { 219 import grain.autograd; 220 221 Chain* target; 222 float lr = 1.0; 223 float beta1 = 0.9; 224 float beta2 = 0.999; 225 float eps = 1e-8; 226 StateDict moment1, moment2; 227 228 /// 229 this(ref Chain target, float lr, float eps = 1e-8) { 230 this.target = ⌖ 231 this.lr = lr; 232 this.eps = eps; 233 iterVariables!((k, v) { this.initStates(k, v); })(this.target); 234 } 235 236 /// 237 void initStates(V)(string name, ref V field) if (isVariable!V) { 238 if (name !in this.moment1) { 239 auto m = field.uninit(); 240 m.data.zero_(); 241 this.moment1[name] = UntypedVariable(m); 242 } 243 if (name !in this.moment2) { 244 auto m = field.uninit(); 245 m.data.zero_(); 246 this.moment2[name] = UntypedVariable(m); 247 } 248 } 249 250 /// 251 void step(V)(string name, ref V field) if (isVariable!V) { 252 import grain.chain : pow; 253 254 auto g = field.gradVariable; 255 auto m1 = this.moment1[name].to!V; 256 auto m2 = this.moment2[name].to!V; 257 auto nextM1 = (1.0 - this.beta1) * (g - m1) + m1; 258 auto nextM2 = (1.0 - this.beta2) * (g * g - m2) + m2; 259 auto diff = nextM1 / pow(nextM2 + this.eps, 0.5); // TODO implement sqrt 260 this.moment1[name] = UntypedVariable(nextM1); 261 this.moment2[name] = UntypedVariable(nextM2); 262 transform(diff, field, -this.lr, 1.0); 263 } 264 } 265 266 /// 267 unittest { 268 import grain.autograd; 269 import numir; 270 271 { 272 auto model = MLP!(float, HostStorage)(3); 273 auto optim = Adam!(typeof(model))(model, 1e-3); 274 static assert(isOptimizer!(typeof(optim))); 275 model.fc1.weight.data.zero_(); 276 model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable 277 .data; 278 optim.update(); 279 auto w = model.fc1.weight; 280 auto m1 = (1.0 - optim.beta1) * (0.2 - 0.0) + 0.0; 281 auto m2 = (1.0 - optim.beta2) * (0.2 * 0.2 - 0.0) + 0.0; 282 assert(approxEqual(w.sliced, [[-optim.lr * m1 / (m2 + optim.eps) ^^ 0.5, 283 0.0, 0.0], [0.0, 0.0, 0.0]].nparray)); 284 auto m1_ = optim.moment1[".fc1.weight"].to!(typeof(w)); 285 assert(approxEqual(m1_.sliced, [[m1, 0.0, 0.0], [0.0, 0.0, 0.0]].nparray)); 286 auto m2_ = optim.moment2[".fc1.weight"].to!(typeof(w)); 287 assert(approxEqual(m2_.sliced, [[m2, 0.0, 0.0], [0.0, 0.0, 0.0]].nparray)); 288 } 289 version (grain_cuda) { 290 auto model = MLP!(float, DeviceStorage)(3); 291 auto optim = Adam!(typeof(model))(model, 0.1); 292 optim.update(); 293 } 294 } 295 296 /// http://www.matthewzeiler.com/pubs/googleTR2012/googleTR2012.pdf 297 struct AdaDelta(Chain) { 298 import grain.autograd; 299 300 Chain* target; 301 float lr = 1.0; 302 float rho = 0.95; 303 float eps = 1e-6; 304 305 StateDict den, num; 306 307 /// 308 this(ref Chain target, float lr = 1.0, float rho = 0.95, float eps = 1e-8) { 309 this.target = ⌖ 310 this.lr = lr; 311 this.rho = rho; 312 this.eps = eps; 313 iterVariables!((k, v) { this.initStates(k, v); })(this.target); 314 } 315 316 /// 317 void initStates(V)(string name, ref V field) if (isVariable!V) { 318 if (name !in this.den) { 319 auto m = field.uninit(); 320 m.data.zero_(); 321 this.den[name] = UntypedVariable(m); 322 } 323 if (name !in this.num) { 324 auto m = field.uninit(); 325 m.data.zero_(); 326 this.num[name] = UntypedVariable(m); 327 } 328 } 329 330 /// 331 void step(V)(string name, ref V field) if (isVariable!V) { 332 import grain.chain : pow; 333 334 auto g = field.gradVariable; 335 auto d = this.den[name].to!V; 336 auto n = this.num[name].to!V; 337 auto nextDen = (1.0 - this.rho) * g * g + this.rho * d; 338 auto diff = pow((n + this.eps) / (nextDen + this.eps), 0.5); // TODO implement sqrt 339 auto nextNum = (1.0 - this.rho) * diff * diff + this.rho * n; 340 this.den[name] = UntypedVariable(nextDen); 341 this.num[name] = UntypedVariable(nextNum); 342 transform(diff, field, -this.lr, 1.0); 343 } 344 } 345 346 auto make(alias O, C, Args ...)(ref C chain, Args args) if (isOptimizer!(O!C)) { 347 return O!C(chain, args); 348 } 349 350 /// 351 unittest { 352 import grain.autograd; 353 import numir; 354 355 { 356 auto model = MLP!(float, HostStorage)(3); 357 // auto optim = AdaDelta!(typeof(model))(model); 358 auto optim = make!AdaDelta(model); 359 // static assert(isOptimizer!(typeof(optim))); 360 model.fc1.weight.data.zero_(); 361 model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable 362 .data; 363 optim.update(); 364 auto w = model.fc1.weight; 365 auto d = (1.0 - optim.rho) * 0.2 * 0.2; 366 auto diff = cast(float)((0.0 + optim.eps) / (d + optim.eps)) ^^ 0.5; 367 auto n = (1.0 - optim.rho) * diff * diff; 368 assert(approxEqual(w.sliced, [[-optim.lr * diff, -optim.lr, 369 -optim.lr], [-optim.lr, -optim.lr, -optim.lr]].nparray)); 370 auto d_ = optim.den[".fc1.weight"].to!(typeof(w)); 371 auto n_ = optim.num[".fc1.weight"].to!(typeof(w)); 372 assert(approxEqual(d_.sliced, [[d, 0.0, 0.0], [0.0, 0.0, 0.0]].nparray)); 373 assert(approxEqual(n_.sliced[0, 0 .. 1], [n].nparray)); 374 } 375 version (grain_cuda) { 376 auto model = MLP!(float, DeviceStorage)(3); 377 auto optim = AdaDelta!(typeof(model))(model); 378 optim.update(); 379 } 380 }