1 module grain.functions;
2 
3 import grain.autograd;
4 import grain.cuda;
5 import grain.utility : toTuple, fromTuple, castArray;
6 
7 import std.stdio;
8 
9 version (grain_cuda) {
10     import cudnn = derelict.cudnn7;
11 }
12 
13 
14 mixin template TypeChecker(alias forward, alias backward) {
15     static assert(allSatisfy!(isVariable, Parameters!forward),
16                   "all the forward function args should be variable.");
17     static assert(allSatisfy!(isVariable, Parameters!backward),
18                   "all the backward function args should be variable.");
19     static if (arity!forward == 1 && arity!backward == 1) {
20         static assert(is(ReturnType!backward == Parameters!forward[0]));
21         static assert(is(ReturnType!forward == Parameters!backward[0]));
22     } else static if (arity!backward == 1) {
23         static assert(is(ReturnType!backward == Tuple!(Parameters!forward)));
24         static assert(is(ReturnType!forward == Parameters!backward[0]));
25     } else static if (arity!forward == 1) {
26         static assert(is(ReturnType!backward == Parameters!forward[0]));
27         static assert(is(ReturnType!forward == Tuple!(Parameters!backward)));
28     } else {
29         static assert(is(ReturnType!backward == Tuple!(Parameters!forward)));
30         static assert(is(ReturnType!forward == Tuple!(Parameters!backward)));
31     }
32 }
33 
34 
35 
36 enum bool isFunction(T) = {
37     import std.meta : allSatisfy;
38     import std.typecons : isTuple, tuple, Tuple, RefCounted;
39     import std.traits : arity, Parameters, ReturnType;
40     static foreach (i, forward; __traits(getOverloads, T, "forward")) {
41         static foreach (i, backward; __traits(getOverloads, T, "backward")) {
42             static if (!allSatisfy!(isHost, Parameters!forward) &&
43                        !allSatisfy!(isHost, Parameters!backward)) {
44                 mixin TypeChecker!(forward, backward);
45             }
46             static if (allSatisfy!(isHost, Parameters!forward) &&
47                        allSatisfy!(isHost, Parameters!backward)) {
48                 mixin TypeChecker!(forward, backward);
49             }
50         }
51     }
52     return true;
53         }();
54 
55 mixin template FunctionCommon() {
56     import std.meta : allSatisfy;
57     import std.typecons : isTuple, tuple, Tuple, RefCounted;
58     import std.traits : arity, Parameters, ReturnType;
59 
60     @disable this(this); // no copyable
61 
62     static foreach (i, forward; __traits(getOverloads, typeof(this), "forward")) {
63         static foreach (i, backward; __traits(getOverloads, typeof(this), "backward")) {
64             static if (!allSatisfy!(isHost, Parameters!forward) &&
65                        !allSatisfy!(isHost, Parameters!backward)) {
66                 alias DeviceRets = Tuple!(Parameters!backward);
67                 alias DeviceArgs = Tuple!(Parameters!forward);
68                 mixin TypeChecker!(forward, backward);
69             }
70             static if (allSatisfy!(isHost, Parameters!forward) &&
71                        allSatisfy!(isHost, Parameters!backward)) {
72                 alias HostRets = Tuple!(Parameters!backward);
73                 alias HostArgs = Tuple!(Parameters!forward);
74                 mixin TypeChecker!(forward, backward);
75             }
76         }
77     }
78     static assert(isFunction!(typeof(this)));
79 
80     auto applyForward(Args...)(Args args) {
81         import std.algorithm : each;
82         RefCounted!(UntypedVariable[]) uargs;
83         uargs.length = args.length;
84         foreach (i, a; args) {
85             uargs[i] = UntypedVariable(a);
86             uargs[i].bprop = a.bprop; // pass the chain to backprop
87         }
88         auto rets = this.forward(args).toTuple;
89         enum isHost = allSatisfy!(isHost, Args);
90         foreach (i, r; rets) {
91             auto u = UntypedVariable(r);
92             if (grain.autograd.backprop) {
93                 // RefCounted!
94                 BackProp bp = BackProp(&this.applyBackward!isHost,
95                                        uargs);
96                 bp.gradOutputs.length = rets.length;
97                 u.bprop = bp;
98                 u.outPosition = i;
99                 rets[i].bprop = bp;
100             }
101         }
102         static if (rets.length > 1) {
103             return rets;
104         } else {
105             return rets[0];
106         }
107     }
108 
109     void applyBackward(bool isHost)(UntypedVariable[] ugradOutputs, UntypedVariable[] uinputs) {
110         static if (isHost) {
111             HostRets vgradOutputs;
112         } else {
113             DeviceRets vgradOutputs;
114         }
115         static foreach (i; 0 .. vgradOutputs.length) {
116             vgradOutputs[i] = ugradOutputs[i].to!(typeof(vgradOutputs[i]));
117         }
118         auto vgradInputs = this.backward(vgradOutputs.expand).toTuple;
119         assert(vgradInputs.length == uinputs.length, "invalid number of input gradients");
120         UntypedVariable[vgradInputs.length] ugradInputs; // TODO use refcounted?
121         foreach (i, v; vgradInputs) {
122             ugradInputs[i] = UntypedVariable(v);
123         }
124 
125         foreach (i, vgi; vgradInputs) {
126             // TODO reconsider this condition
127             if (uinputs[i].requiresGrad) {
128                 alias Storage = typeof(vgradInputs[i].data);
129                 alias V = typeof(vgradInputs[i]);
130                 auto data = uinputs[i].grad.get!Storage;
131                 static if (vgradInputs[i].isHost) {
132                     import mir.ndslice.slice : sliced;
133                     auto shape = vgradInputs[i].shape.castArray!size_t;
134                     data[] += vgradInputs[i].data[]; // .sliced(shape); FIXME use shape
135                 } else version (grain_cuda) {
136                     import std.traits : isFloatingPoint;
137                     // TODO support integral types
138                     static if (isFloatingPoint!(ElementType!V)) {
139                         axpy(vgradInputs[i].data, data);
140                     }
141                 }
142             }
143             uinputs[i].bprop.backward(&ugradInputs[i], uinputs[i].outPosition);
144         }
145     }
146 }
147 
148 
149 // forward two functions parallel
150 unittest {
151     import std.typecons;
152     grain.autograd.backprop = true;
153     // scope (exit) grain.autograd.backprop = false;
154     {
155         auto x = [-1.0f, 2.0f, 3.0f].variable(true);
156         x.requiresGrad = true;
157         Variable!(float, 1) y, h;
158         y.requiresGrad = true;
159         h.requiresGrad = true;
160         // bprop will survive even if deeper scope
161         {
162             // FIXME cannot use RefCounted instead of new here
163             // RefCounted!(ReLU!(float, 1)) func0 = ReLU!(float, 1)();
164             auto func0 = new ReLU!(float, 1);
165             h = func0.applyForward(x);
166             assert(h.bprop.inputs[0].data == x.data);
167             auto func1 = new ReLU!(float, 1);
168             y = func1.applyForward(h);
169             // test the chain to backprop
170             assert(y.bprop.inputs[0].data == h.data);
171             assert(y.bprop.inputs[0].bprop.inputs[0].data == x.data);
172         }
173         auto gy = [1.0f, 2.0f, 3.0f].variable;
174         auto ugy = UntypedVariable(gy);
175         y.backward(&ugy);
176         assert(x.grad == [0, 2, 3]);
177 
178         auto func2 = new ReLU!(float, 1);
179         auto y2 = func2.applyForward(x);
180         y2.backward(&ugy);
181         assert(x.grad == [0, 4, 6]); // summation
182     }
183     version (grain_cuda) {
184         auto func = new ReLU!(float, 1);
185         auto x = [-1.0f, 2.0f, 3.0f].variable(true).to!DeviceStorage;
186         auto y = func.applyForward(x);
187         auto gy = [1.0f, 2.0f, 3.0f].variable.to!DeviceStorage;
188         auto ugy = UntypedVariable(gy);
189         y.backward(&ugy);
190         assert(x.grad.toHost() == [0, 2, 3]);
191 
192         auto func2 = new ReLU!(float, 1);
193         auto y2 = func.applyForward(x);
194         y2.backward(&ugy);
195         assert(x.grad.toHost() == [0, 4, 6]); // summation
196     }
197 }
198 
199 
200 /// test NG functions
201 unittest {
202     alias F1H = Variable!(float, 1, HostStorage);
203     version (grain_cuda) alias F1D = Variable!(float, 1, HostStorage);
204     struct A(DelayInstantiation) {
205         mixin FunctionCommon;
206         // mismatch of args
207         F1H forward(F1H x) { return x; };
208         F1H backward(F1H x, F1H y) { return x; };
209     }
210     static assert(!__traits(compiles, A!void));
211 
212     version (grain_cuda) {
213         struct B(DelayInstantiation) {
214             mixin FunctionCommon;
215             F1H forward(F1H x) { return x; };
216             F1H backward(F1H x) { return x; };
217             // mismatch of args in device
218             version (grain_cuda) {
219                 F1D forward(F1D x) { return x; };
220                 F1D backward(F1D x, F1D y) { return x; };
221             }
222         }
223         static assert(!__traits(compiles, B!void));
224     }
225 }
226 
227 struct ReLU(T, size_t dim) {
228     mixin FunctionCommon;
229     bool inplace = false;
230     bool useCuDNN = true;
231     Variable!(T, dim, HostStorage) hx;
232 
233     auto forward(Variable!(T, dim, HostStorage) x) {
234         import mir.ndslice : each;
235         // FIXME if train
236         this.hx = x.dup;
237         auto y = this.inplace ? x : x.dup;
238         y.sliced.each!((ref a) { if (a < 0) a = 0; });
239         return y;
240     }
241 
242     auto backward(Variable!(T, dim, HostStorage) gy) {
243         auto gx = gy.dup;
244         foreach (i; 0..gx.data.length) {
245             if (this.hx.data[i] < 0.0) gx.data[i] = 0.0;
246         }
247         return gx;
248     }
249 
250     // TODO use cudnn
251     version(grain_cuda) {
252         import grain.cudnn;
253         Variable!(T, dim, DeviceStorage) dx, dy;
254 
255         auto forward(Variable!(T, dim, DeviceStorage) x) {
256             // FIXME if train
257             this.dx = x.dup;
258             auto y = this.inplace ? x : x.dup;
259 
260             if (this.useCuDNN) {
261                 this.dy = y;
262                 activationForward!CUDNN_ACTIVATION_RELU(x, y);
263             } else {
264                 import grain.kernel : relu;
265                 auto n = cast(uint) y.data.length; // FIXME use y.nElement
266                 Global.kernel!relu
267                     .call(y.data.ptr, n).launch(n);
268             }
269             return y;
270         }
271 
272         auto backward(Variable!(T, dim, DeviceStorage) gy) {
273             auto gx = gy.dup; // TODO: create empty
274             if (this.useCuDNN) {
275                 activationBackward!CUDNN_ACTIVATION_RELU(gx, gy, dx, dy);
276             } else {
277                 import grain.kernel : reluGrad;
278                 auto n = cast(uint) gy.data.length;
279                 Global.kernel!reluGrad
280                     .call(gx.data.ptr, gy.data.ptr, this.dx.data.ptr, n).launch(n);
281             }
282             return gx;
283         }
284     }
285 }
286 
287 // forward 2-in 1-out function
288 unittest {
289     import std.typecons;
290     import numir;
291     import mir.ndslice;
292     grain.autograd.backprop = true;
293     scope (exit) grain.autograd.backprop = false;
294     {
295         auto func = new MatMul!float;
296         auto a = uniform!float(3, 4).slice.variable(true);
297         auto b = uniform!float(4, 2).slice.variable(true);
298         auto c = func.applyForward(a, b);
299         auto gc = uniform!float(3, 2).slice.variable;
300         auto ugc = UntypedVariable(gc);
301         c.backward(&ugc);
302 
303         auto gab = func.backward(gc);
304         assert(a.gradSlice == gab[0].sliced);
305         assert(b.gradSlice == gab[1].sliced);
306     }
307 }
308 
309 
310 
311 /// test relu
312 unittest {
313     import grain.testing : gradCheck;
314     foreach (inplace; [true, false]) {
315         foreach (useCuDNN; [true, false]) {
316             auto func = new ReLU!(float, 1);
317             func.inplace = inplace;
318             func.useCuDNN = useCuDNN;
319 
320             // test CPU
321             {
322                 auto x = [-1.0f, 1.0f, 0.0f].variable;
323                 // fail because of non-smooth function?
324                 // gradCheck(func, x, [0.1f, 0.1f, 0.1f].variable);
325 
326                 auto y = func.forward(x);
327                 assert(x.data == (inplace ? y.data : [-1.0f, 1.0f, 0.0f]));
328                 assert(y.data == [0.0f, 1.0f, 0.0f]);
329 
330                 auto gy = [1.0f, 2.0f, 3.0f].variable;
331                 auto gx = func.backward(gy);
332                 assert(gx.data == [0.0f, 2.0f, 3.0f]);
333             }
334 
335             // test CUDA
336             version(grain_cuda) {
337                 auto x = [-1.0f, 1.0f, 0.0f].variable;
338                 auto xd = x.to!DeviceStorage;
339                 auto yd = func.forward(xd);
340                 x = xd.to!HostStorage;
341                 auto y = yd.to!HostStorage;
342                 assert(x.data == (inplace ? y.data : [-1.0f, 1.0f, 0.0f]));
343                 assert(y.data == [0.0f, 1.0f, 0.0f]);
344 
345                 x = [-1.0f, 1.0f, 0.0f].variable;
346                 auto gy = [1.0f, 2.0f, 3.0f].variable;
347                 auto gxd = func.backward(gy.to!DeviceStorage);
348                 auto gx = gxd.to!HostStorage;
349                 assert(gx.data == [0.0, 2.0, 0.0]);
350             }
351         }
352     }
353 }
354 
355 /++
356  Matrix-Matrix multiplication
357 
358  See_Also: https://github.com/chainer/chainer/blob/v1/chainer/functions/connection/linear.py#L11
359  +/
360 struct MatMul(T) {
361     import mir.ndslice : transposed, universal;
362     import std.typecons : tuple;
363     import lubeck : mtimes;
364     T alpha = 1;
365     T beta = 0;
366     Variable!(T, 2, HostStorage) ha, hb;
367 
368     // TODO uncomment this line
369     mixin FunctionCommon;
370 
371     auto forward(Variable!(T, 2, HostStorage) a, Variable!(T, 2, HostStorage) b) {
372         // TODO if training
373         this.ha = a;
374         this.hb = b;
375         return mtimes(a.sliced, b.sliced).variable(a.requiresGrad || b.requiresGrad);
376     }
377 
378     auto backward(Variable!(T, 2, HostStorage) gy) {
379         auto ga = mtimes(gy.sliced, this.hb.sliced.transposed).variable;
380         auto gb = mtimes(this.ha.sliced.transposed, gy.sliced).variable;
381         return tuple(ga, gb);
382     }
383 
384     version(grain_cuda) {
385         Variable!(T, 2, DeviceStorage) da, db;
386 
387         auto forward(Variable!(T, 2, DeviceStorage) a, Variable!(T, 2, DeviceStorage) b) {
388             import grain.cublas;
389             static if (is(T == float)) {
390                 alias gemm = cublasSgemm_v2;
391             } else static if (is(T == double)) {
392                 alias gemm = cublasDgemm_v2;
393             } else {
394                 static assert(false, "unsupported type");
395             }
396 
397             import std.typecons : RefCounted;
398             assert(a.shape[1] == b.shape[0]);
399             auto cdata = RefCounted!(CuPtr!T)(a.shape[0] * b.shape[1]);
400             auto c = Variable!(T, 2, DeviceStorage)(
401                 a.requiresGrad || b.requiresGrad, [a.shape[0], b.shape[1]], [b.shape[1], 1], cdata);
402             // C = A x B = (BT x AT)T
403             // TODO support transposed (CUBLAS_OP_T)
404             // see https://github.com/libmir/mir-blas/blob/master/source/mir/blas.d#L299
405             // TODO if train
406             this.da = a;
407             this.db = b;
408             checkCublasErrors(gemm(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N,
409                                    cast(int) b.shape[1],
410                                    cast(int) a.shape[0], cast(int) a.shape[1],
411                                    &alpha,
412                                    cast(const T*) b.data.ptr, cast(int) b.strides[0],
413                                    cast(const T*) a.data.ptr, cast(int) a.strides[0],
414                                    &beta,
415                                    cast(T*) c.data.ptr, cast(int) c.strides[0]));
416             return c;
417         }
418 
419         auto backward(Variable!(T, 2, DeviceStorage) gc) {
420             import grain.cublas;
421             static if (is(T == float)) {
422                 alias gemm = cublasSgemm_v2;
423             } else static if (is(T == double)) {
424                 alias gemm = cublasDgemm_v2;
425             } else {
426                 static assert(false, "unsupported type");
427             }
428             auto ga = this.da.dup;
429             auto gb = this.db.dup;
430             // auto ga = mtimes(gc.sliced, this.hb.sliced.transposed).variable;
431             checkCublasErrors(gemm(cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N,
432                                    cast(int) ga.shape[1],
433                                    cast(int) ga.shape[0], cast(int) gc.shape[1],
434                                    &alpha,
435                                    cast(const T*) db.data.ptr, cast(int) db.strides[0],
436                                    cast(const T*) gc.data.ptr, cast(int) gc.strides[0],
437                                    &beta,
438                                    cast(T*) ga.data.ptr, cast(int) ga.strides[0]));
439             // auto gb = mtimes(this.ha.sliced.transposed, gc.sliced).variable;
440             checkCublasErrors(gemm(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T,
441                                    cast(int) gb.shape[1],
442                                    cast(int) gb.shape[0], cast(int) da.shape[0],
443                                    &alpha,
444                                    cast(const T*) gc.data.ptr, cast(int) gc.strides[0],
445                                    cast(const T*) da.data.ptr, cast(int) da.strides[0],
446                                    &beta,
447                                    cast(T*) gb.data.ptr, cast(int) gb.strides[0]));
448             return tuple(ga, gb);
449         }
450     }
451 }
452 
453 /// test matmul gradcheck and cpu/cuda equality
454 unittest {
455     foreach (i; [2, 3, 4]) {
456         foreach (j; [2, 3, 4]) {
457             import std.typecons : tuple;
458             import numir : uniform;
459             import mir.ndslice : slice;
460             import grain.testing;
461 
462             auto k = 3;
463             auto a = uniform!float(i, k).slice.variable;
464             auto b = uniform!float(k, j).slice.variable;
465             auto gc = uniform!float(i, j).slice.variable;
466             MatMul!float func;
467             gradCheck(func, tuple(a, b), gc, 1e-3, 1e-3, 1e-3);
468 
469             version (grain_cuda) {
470                 import numir.testing;
471                 MatMul!float func2;
472                 auto hc = func.forward(a, b);
473                 auto dc = func2.forward(a.to!DeviceStorage, b.to!DeviceStorage);
474                 assert(approxEqual(dc.to!HostStorage.sliced, hc.sliced));
475                 auto hgab = func.backward(gc);
476                 auto dgab = func2.backward(gc.to!DeviceStorage);
477                 // writefln!"%s vs %s"(dgab[0].to!HostStorage.sliced, hgab[0].sliced);
478                 assert(approxEqual(dgab[0].to!HostStorage.sliced, hgab[0].sliced));
479                 assert(approxEqual(dgab[1].to!HostStorage.sliced, hgab[1].sliced));
480             }
481         }
482     }
483 }
484 
485 // TODO add to numir
486 import mir.ndslice : isSlice;
487 import numir : Ndim;
488 pure nothrow @nogc
489 logsumexp(S)(S x) if (isSlice!S && Ndim!S == 1) {
490     import mir.ndslice : map, maxIndex;
491     import mir.math : log, sum, exp;
492     auto m = x[x.maxIndex];
493     auto s = map!exp(x - m).sum!"fast".log;
494     return m + s;
495 }
496 
497 ///
498 pure nothrow @nogc
499 unittest {
500     import numir;
501     import mir.ndslice;
502     // import mir.math;
503     import std.math;
504     static immutable x = [-1.0, 2.0, 3.0];
505     static immutable e = log(exp(-1.0) + exp(2.0) + exp(3.0));
506     assert(approxEqual(x.sliced.logsumexp, e));
507     static immutable xs = [-1.0, 2.0, 3.0,
508                            -1.0, 2.0, 3.0,
509                            -1.0, 2.0, 3.0];
510     static immutable es = [e, e, e];
511     assert(approxEqual(xs.sliced(3, 3).alongDim!1.map!logsumexp, es));
512 }
513 
514 /++
515 See_also: https://github.com/chainer/chainer/blob/v1/chainer/functions/activation/log_softmax.py
516  +/
517 struct LogSoftmax(T, size_t dim=2) {
518     // TODO support custom dim to compute softmax over (now only dim=1)
519      mixin FunctionCommon;
520 
521     Variable!(T, dim, HostStorage) hy;
522 
523     auto forward(Variable!(T, dim, HostStorage) x) {
524         import mir.ndslice;
525         import numir;
526         // return slice(x.sliced.alongDim!0.map!(e => e - e.logsumexp)).variable;
527         auto y = x.dup;
528         foreach (i; 0 .. y.shape[0]) {
529             y.sliced[i][] -= x.sliced[i].logsumexp;
530         }
531         // TODO if train
532         this.hy = y;
533         return y;
534     }
535 
536     auto backward(Variable!(T, dim, HostStorage) gy) {
537         import mir.math;
538         import numir;
539         import mir.ndslice;
540         auto gx = gy.dup;
541         auto m = gy.sliced.alongDim!1.map!(sum!"fast");
542         foreach (i; 0 .. gx.shape[0]) {
543             gx.sliced[i][] -= this.hy.sliced[i].map!exp * m[i];
544         }
545         return gx;
546     }
547 
548     version (grain_cuda) {
549         import grain.cudnn;
550         Variable!(T, dim, DeviceStorage) dy;
551 
552         auto forward(Variable!(T, dim, DeviceStorage) x) {
553             auto y = x.dup;
554             softmaxForward!CUDNN_SOFTMAX_LOG(x, y);
555             // TODO if train
556             this.dy = y;
557             return y;
558         }
559 
560         auto backward(Variable!(T, dim, DeviceStorage) gy) {
561             auto gx = gy.dup;
562             softmaxBackward!CUDNN_SOFTMAX_LOG(gx, gy, this.dy);
563             return gx;
564         }
565     }
566 }
567 
568 /// test logsoftmax simple case, gradcheck and cpu/cuda equality
569 unittest {
570     import grain.testing;
571     import std.typecons;
572     import numir;
573     import mir.ndslice;
574     import mir.math;
575     auto e = log(exp(-1.0) + exp(2.0) + exp(3.0));
576     auto xs = [[-1.0f, 2.0f, 3.0f], [-1.0f, 2.0f, 3.0f], [-1.0f, 2.0f, 3.0f]].nparray;
577     LogSoftmax!float hfunc;
578     auto _hx = xs.variable;
579     auto _hy = hfunc.forward(_hx);
580     assert(approxEqual(_hy.sliced, xs - e));
581 
582     auto hx = uniform!float(2, 2).slice.variable;
583     auto hy = hfunc.forward(hx);
584     auto hgy = uniform!float(2, 2).slice.variable;
585     auto hgx = hfunc.backward(hgy);
586     gradCheck(hfunc, hx, hgy, 1e-3, 1e-3, 1e-3);
587 
588     version (grain_cuda) {
589         alias Storage = DeviceStorage;
590         auto func = LogSoftmax!float();
591         auto dx = hx.to!Storage;
592         auto dy = func.forward(dx);
593         assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
594         auto dgy = hgy.to!Storage;
595         auto dgx = func.backward(dgy);
596         assert(approxEqual(dgx.to!HostStorage.sliced, hgx.sliced));
597     }
598 }
599 
600 
601 struct NegativeLogLikelihood(F, I=long) {
602     /++
603     Compute negative log-likelihood: -logP(y=t)
604     Params:
605       logP: log softmax output as prediction. shape: (nBatch, nClass)
606       targetId: target integer id of class. shape: (nBatch)
607       +/
608 
609     mixin FunctionCommon;
610 
611     bool sizeAverage = true;
612     int ignoreIndex = -100;
613     // TODO: bool reduce = true;
614 
615     // cache for backward
616     Variable!(I, 1, HostStorage) _htargetId;
617     F _normalize;
618     int _nClass;
619 
620     auto forward(Variable!(F, 2, HostStorage) logP, Variable!(I, 1, HostStorage) targetId) {
621         import mir.math;
622         import mir.ndslice;
623         F result = 0.0;
624         size_t count = 0;
625         foreach (i; 0 .. targetId.sliced.length) {
626             auto t = targetId.sliced[i];
627             if (t != this.ignoreIndex) {
628                 result -= logP.sliced[i, t];
629                 ++count;
630             }
631         }
632         if (this.sizeAverage && count > 0) {
633             result /= count;
634         }
635         // TODO if train
636         this._nClass = logP.shape[1];
637         this._htargetId = targetId;
638         this._normalize = this.sizeAverage && count > 0 ? 1.0 / count : 1.0;
639         return result.variable;
640     }
641 
642     auto backward(Variable!(F, 0, HostStorage) gy) {
643         import std.typecons;
644         import mir.math;
645         import mir.ndslice;
646         import numir;
647 
648         auto nBatch = this._htargetId.shape[0];
649         auto glogP = zeros!F(nBatch, this._nClass);
650         auto coeff = gy.data[0] * this._normalize;
651         foreach (i; 0 .. nBatch) {
652             auto t = this._htargetId.sliced[i];
653             if (t != this.ignoreIndex) {
654                 glogP[i][t] = -coeff;
655             }
656         }
657         return tuple(glogP.variable, typeof(this._htargetId)());
658     }
659 
660     version (grain_cuda) {
661         Variable!(I, 1, DeviceStorage) _dtargetId;
662         auto forward(Variable!(F, 2, DeviceStorage) logP, Variable!(I, 1, DeviceStorage) targetId) {
663             static assert(is(F == float), "only float is supported now");
664             static assert(is(I == int), "only int is supported now");
665 
666             import grain.kernel : nll;
667             this._nClass = logP.shape[1];
668             auto dresult = CuPtr!F([0]); // [result].variable.to!DeviceStorage; <- FIXME
669             auto dcount = CuPtr!int([0]); // [count].variable.to!DeviceStorage;
670 
671             auto batchSize = targetId.shape[0];
672             Global.kernel!nll
673                 .call(dresult.ptr, dcount.ptr, logP.data.ptr,
674                       targetId.data.ptr, this.ignoreIndex, batchSize, this._nClass).launch(batchSize);
675 
676             F result = 0.0;
677             int count = 0;
678             dresult.toHost(&result);
679             dcount.toHost(&count);
680 
681             if (this.sizeAverage && count > 0) {
682                 result /= count;
683             }
684             // TODO if train
685             this._nClass = logP.shape[1];
686             this._dtargetId = targetId;
687             this._normalize = this.sizeAverage && count > 0 ? 1.0 / count : 1.0;
688             return result.variable.to!DeviceStorage;
689         }
690 
691         auto backward(Variable!(F, 0, DeviceStorage) gy) {
692             static assert(is(F == float), "only float is supported now");
693             static assert(is(I == int), "only int is supported now");
694 
695             import grain.kernel;
696             import std.typecons : tuple, RefCounted;
697             auto nBatch = this._dtargetId.shape[0];
698             RefCounted!(CuPtr!F) glogP = CuPtr!F(nBatch * this._nClass);
699             glogP.zero_();
700             auto coeff = gy.to!HostStorage.data[0] * this._normalize;
701             Global.kernel!nllGrad
702                 .call(glogP.ptr, -coeff, this._dtargetId.data.ptr, this.ignoreIndex, nBatch, this._nClass).launch(nBatch);
703             auto v = Variable!(F, 2, DeviceStorage)(false, [nBatch, this._nClass], [this._nClass, 1], glogP);
704             return tuple(v, typeof(this._dtargetId)());
705         }
706 
707     }
708 }
709 
710 /// test nll simple case, gradcheck and cpu/cuda equality
711 unittest {
712     /++ equivalent torch v0.4 code
713      x = torch.FloatTensor([[0.2, 0.4, 0.4], [0.1,0.5,0.4]])
714      x.requires_grad = True
715      t = torch.LongTensor([1, 0])
716      l = torch.nn.functional.nll_loss(x, t)
717      print(l)       # tensor(-0.2500)
718      l.backward()
719      print(x.grad)  # tensor([[0.0, -0.5, 0.0], [-0.5, 0.0, 0.0]])
720      +/
721     import std.typecons;
722     import grain.testing;
723     NegativeLogLikelihood!(float, int) func;
724     auto hx = [[0.2f, 0.4f, 0.4f], [0.1f, 0.5f, 0.4f], [0.1f, 0.5f, 0.4f]].variable;
725     auto ht = [1, 0, func.ignoreIndex].variable;
726     auto hl = func.forward(hx, ht);
727     assert(func._normalize == 0.5);
728     assert(hl.sliced == [-(0.4f + 0.1f + 0.0f) / 2]);
729     auto hgx = func.backward(1.0f.variable);
730     assert(hgx[0].sliced == [[0.0, -0.5, 0.0], [-0.5, 0.0, 0.0], [0.0, 0.0, 0.0]]);
731     assert(!hgx[1].defined);
732     gradCheck(func, tuple(hx, ht), 1.0f.variable);
733 
734     version (grain_cuda) {
735         auto dx = hx.to!DeviceStorage;
736         auto dt = ht.to!DeviceStorage;
737         auto dl = func.forward(dx, dt);
738         assert(func._normalize == 0.5);
739         assert(dl.to!HostStorage.sliced == [-(0.4f + 0.1f + 0.0f) / 2]);
740         auto dgx = func.backward(1.0f.variable.to!DeviceStorage);
741         assert(dgx[0].to!HostStorage.sliced == [[0.0, -0.5, 0.0], [-0.5, 0.0, 0.0], [0.0, 0.0, 0.0]]);
742         assert(!dgx[1].defined);
743     }
744 }
745 
746 
747 auto broadcastable(T, size_t dim, alias Storage)(Variable!(T, dim, Storage) a, Variable!(T, dim, Storage) b) {
748     int[dim] resultShape;
749     bool ok = false;
750     foreach (i; 0 .. dim) {
751         ok = a.shape[i] == b.shape[i] || a.shape[i] == 1 || b.shape[i] == 1;
752         if (ok) {
753             resultShape[i] = max(a.shape[i], b.shape[i]);
754         } else break;
755     }
756     return tuple!("ok", "shape")(ok, resultShape);
757 }
758 
759 
760 /// TODO generalize to broadcastable addition
761 struct AddBias(T) {
762     mixin FunctionCommon;
763 
764     import mir.ndslice : map, slice;
765     import std.typecons : tuple, RefCounted;
766     auto forward(Variable!(T, 2, HostStorage) a, Variable!(T, 1, HostStorage) b) {
767         assert(a.shape[1] == b.shape[0]);
768         auto ret = a.dup;
769         foreach (i; 0 .. a.shape[0]) {
770             ret.sliced[i][] += b.sliced;
771         }
772         return ret;
773     }
774 
775     auto backward(Variable!(T, 2, HostStorage) gy) {
776         import numir : alongDim;
777         import mir.math : sum;
778         auto gb = gy.sliced.alongDim!0.map!sum.slice.variable;
779         return tuple(gy, gb);
780     }
781 
782     version (grain_cuda) {
783         import grain.kernel : addBias, addBiasGrad;
784 
785         auto forward(Variable!(T, 2, DeviceStorage) a, Variable!(T, 1, DeviceStorage) b) {
786             assert(a.shape[1] == b.shape[0]);
787             auto y = a.dup;
788             auto n = cast(uint) y.data.length;
789             auto blen = cast(uint) b.data.length;
790             Global.kernel!addBias
791                 .call(y.data.ptr, b.data.ptr, blen, n).launch(n);
792             return y;
793         }
794 
795         auto backward(Variable!(T, 2, DeviceStorage) gy) {
796             RefCounted!(CuPtr!T) gb = CuPtr!T(gy.shape[1]);
797             gb.zero_();
798             auto n = cast(uint) gy.data.length;
799             auto blen = cast(uint) gb.length;
800             Global.kernel!addBiasGrad
801                 .call(gy.data.ptr, gb.ptr, blen, n).launch(n);
802             return tuple(gy, Variable!(T, 1, DeviceStorage)(false, [cast(int) blen], [1], gb));
803         }
804     }
805 }
806 
807 
808 unittest {
809     import std.typecons;
810     import grain.testing;
811     import numir;
812     import mir.ndslice;
813 
814     AddBias!float func;
815     auto hx = [[0f, 1f], [2f, 3f], [4f, 5f]].variable; // 3x2
816     auto hb = [-1f, 1f].variable; // 2
817     auto hy = func.forward(hx, hb);
818     assert(hy.sliced == [[-1f, 2f], [1f, 4f], [3f, 6f]]);
819 
820     auto hgy = uniform!float(hy.shape.castArray!size_t).slice.variable;
821     auto hgxb = func.backward(hgy);
822     assert(hgxb[0].sliced == hgy.sliced);
823     assert(hgxb[1].sliced == [hgy.sliced[0, 0] + hgy.sliced[1, 0] + hgy.sliced[2, 0],
824                               hgy.sliced[0, 1] + hgy.sliced[1, 1] + hgy.sliced[2, 1]]);
825     gradCheck(func, tuple(hx, hb), hgy);
826 
827     version (grain_cuda) {
828         auto dx = hx.to!DeviceStorage;
829         auto db = hb.to!DeviceStorage;
830         auto dy = func.forward(dx, db);
831         assert(dy.to!HostStorage.sliced == [[-1f, 2f], [1f, 4f], [3f, 6f]]);
832         auto dgy = hgy.to!DeviceStorage;
833         auto dgxb = func.backward(dgy);
834         assert(dgxb[0].to!HostStorage.sliced == hgxb[0].sliced);
835         assert(dgxb[1].to!HostStorage.sliced == hgxb[1].sliced);
836     }
837 }
838 
839 
840 
841 /// test variable.backward
842 unittest {
843     import std.typecons;
844     import grain.testing;
845     import mir.ndslice;
846 
847     grain.autograd.backprop = true;
848 
849     NegativeLogLikelihood!(float, int) func;
850     auto hx = [[0.2f, 0.4f, 0.4f], [0.1f, 0.5f, 0.4f], [0.1f, 0.5f, 0.4f]].variable;
851     hx.requiresGrad = true;
852     auto ht = [1, 0, func.ignoreIndex].variable;
853     auto hl = func.applyForward(hx, ht);
854     // hl.bprop.writeln;
855     assert(func._normalize == 0.5);
856     assert(hl.sliced == [-(0.4f + 0.1f + 0.0f) / 2]);
857     auto u = UntypedVariable(1.0f.variable);
858     hl.backward(&u);
859     // hl.bprop.inputs[0].writeln;
860     assert(hx.grad[].sliced(3, 3) == [[0.0, -0.5, 0.0], [-0.5, 0.0, 0.0], [0.0, 0.0, 0.0]]);
861     // assert(!hgx[1].defined);
862 }