1 /**
2    A module for a variable used as a node in autograd computation graph
3 
4    TODO:
5    - support shape ops
6  */
7 module grain.autograd;
8 
9 import std.traits : isArray, isBasicType;
10 import std.typecons : RefCounted, RefCountedAutoInitialize;
11 import mir.ndslice : isSlice, SliceKind, Contiguous, Universal;
12 import mir.primitives : DimensionCount;
13 import std.range : ElementType;
14 
15 import grain.cuda;
16 import grain.utility : castArray;
17 
18 // import std.algorithm : fill;
19 // alias fill_ = fill;
20 
21 /// CPU storage (i.e., GC dynamic array)
22 alias HostStorage(T) = T[];
23 
24 /// fill CPU array with zero
25 auto zero_(T)(T[] s) { // if (!isBasicType!T) {
26     import std.algorithm.mutation : fill;
27 
28     fill(s, 0);
29     return s;
30 }
31 
32 /// create new CPU array filled with zero
33 auto zeros(T)(size_t n) if (isArray!T) {
34     auto s = new ElementType!T[n];
35     return s.zero_();
36 }
37 
38 ///
39 unittest {
40     float[] h = [1f, 2f, 3f];
41     h.zero_();
42     assert(h == [0f, 0f, 0f]);
43     assert(zeros!(HostStorage!float)(3) == [0f, 0f, 0f]);
44 }
45 
46 /// create new variable with uninitialized array and the same shape/strides to v on CPU
47 auto uninit(T, size_t dim)(Variable!(T, dim, HostStorage) v) {
48     auto data = new T[v.length];
49     return Variable!(T, dim, HostStorage)(v.requiresGrad, v.shape, v.strides, data);
50 }
51 
52 /// create new variable with uninitialized array of shape on CPU/CUDA
53 auto uninitVariable(T, alias S = HostStorage, size_t dim)(uint[dim] shape, bool requiresGrad = false) {
54     import std.algorithm : reduce;
55 
56     const length = shape.reduce!"a * b";
57     static if (is(S!T == HostStorage!T)) {
58         auto data = new T[length];
59     }
60     version (grain_cuda) {
61         static if (is(S!T == DeviceStorage!T)) {
62             auto data = CuArray!T(CuPtr!T(length));
63         }
64     }
65     int[dim] strides;
66     strides[dim - 1] = 1;
67     foreach_reverse (i; 0 .. dim - 1) {
68         assert(shape[i + 1] < int.max);
69         strides[i] = cast(int) shape[i + 1] * strides[i + 1];
70     }
71     return Variable!(T, dim, S)(requiresGrad, shape, strides, data);
72 }
73 
74 ///
75 unittest {
76     import std.stdio;
77     import numir;
78     import mir.ndslice;
79 
80     auto x = numir.zeros(2, 3, 4).universal;
81     auto y = uninitVariable!float([2, 3, 4]);
82     assert(x.strides == y.strides);
83 }
84 
85 version (grain_cuda) {
86     /// create new variable with uninitialized array and the same shape/strides to v on CUDA
87     auto uninit(T, size_t dim)(Variable!(T, dim, DeviceStorage) v) {
88         return uninitVariable!(T, DeviceStorage, dim)(v.shape, v.requiresGrad);
89     }
90 
91     alias DeviceStorage(T) = CuArray!T;
92 
93     // enum bool isDevice(T) = isDeviceMemory(typeof(T.data)); // is(typeof({T.init.toHost();}));
94     alias isDevice = isDeviceMemory;
95 
96     /// CUDA -> CPU memory conversion
97     auto to(alias S : DeviceStorage, T)(T[] src) {
98         import std.array : empty;
99 
100         return src.empty ? DeviceStorage!T() : DeviceStorage!T(src);
101     }
102 
103     /// CPU -> CUDA memory conversion
104     auto to(alias S : HostStorage, Src)(Src src) if (isDevice!Src) {
105         return src.toHost();
106     }
107 
108     ///
109     unittest {
110         auto h = [[0.1f, 0.2f, 0.3f], [0.4f, 0.5f, 0.6f]].variable;
111         auto d = h.to!DeviceStorage;
112         assert(h.data == d.to!HostStorage.data);
113     }
114 }
115 
116 /// type-erased variable used in BackProp object
117 struct UntypedVariable {
118     import std.variant;
119 
120     bool requiresGrad;
121     size_t dim;
122     uint[] shape;
123     int[] strides;
124     TypeInfo elem;
125     Variant data, grad;
126     // size_t outPosition = 0;
127 
128     ///
129     this(T, size_t dim, alias Storage)(Variable!(T, dim, Storage) v) {
130         this.elem = typeid(T);
131         this.requiresGrad = v.requiresGrad;
132         this.shape = v.shape.dup;
133         this.strides = v.strides.dup;
134         this.dim = dim;
135         this.data = v.data;
136         this.grad = v.grad;
137     }
138 
139     /// variant.get
140     auto get(T)() {
141         return this.data.get!T;
142     }
143 
144     /// untyped to typed
145     auto to(V : Variable!(T, dim, Storage), T, size_t dim, alias Storage)() {
146         auto d = this.data.get!(Storage!T);
147         return Variable!(T, dim, Storage)(this.requiresGrad,
148                 this.shape[0 .. dim], this.strides[0 .. dim], d);
149     }
150 
151     /// untyped grad to typed
152     auto gradTo(V : Variable!(T, dim, Storage), T, size_t dim, alias Storage)() {
153         auto d = this.data.get!(Storage!T);
154         return Variable!(T, dim, Storage)(this.requiresGrad,
155                 this.shape[0 .. dim], this.strides[0 .. dim], d);
156     }
157 
158     ///
159     string toString() const {
160         import std.format : format;
161 
162         return "UntypedVariable(%s, dim=%d, data=%s, shape=%s, strides=%s)"
163             .format(elem, dim, data, shape, strides);
164     }
165 
166     ///
167     auto gradSlice(V)() if (isVariable!V && isHost!V) {
168         import mir.ndslice.slice : sliced;
169 
170         return grad.get!(typeof(V.init.data)).ptr.sliced(this.shape[0 .. DimensionCount!V]
171                 .castArray!size_t);
172     }
173 
174     ///
175     auto dataSlice(V)() if (isVariable!V && isHost!V) {
176         import mir.ndslice.slice : sliced;
177 
178         return data.get!(typeof(V.init.data)).ptr.sliced(this.shape[0 .. DimensionCount!V]
179                 .castArray!size_t);
180     }
181 }
182 
183 ///
184 auto gradSlice(V)(V v) if (isVariable!V && isHost!V) {
185     import mir.ndslice.slice : sliced;
186 
187     return v.grad.ptr.sliced(v.shape.castArray!size_t);
188 }
189 
190 
191 /// stores information for backpropagation
192 struct BackProp {
193     alias Proc = void delegate(UntypedVariable[]);
194     Proc proc;
195     UntypedVariable[] gradOutputs;
196     size_t nGrad = 0;
197 
198     /// error backward propagation
199     void backward(UntypedVariable* grad = null, size_t pos = 0) {
200         import std.exception : enforce;
201         import std.range : empty;
202 
203         if (this.gradOutputs.empty) return;
204         ++this.nGrad;
205         if (grad is null) {
206             enforce(this.gradOutputs.length == 1, "this variable is not loss");
207         }
208         else {
209             this.gradOutputs[pos] = *grad; // FIXME currently multi-output functions is not supported??
210         }
211         if (grad is null || this.nGrad == this.gradOutputs.length) {
212             proc(this.gradOutputs);
213         }
214     }
215 }
216 
217 ///
218 unittest {
219     import std.stdio;
220 
221     UntypedVariable u;
222     {
223         auto v = [[0f, 1f], [2f, 3f]].variable;
224         u = UntypedVariable(v);
225     }
226     assert(u.get!(HostStorage!float) == [0, 1, 2, 3]);
227 }
228 
229 /**
230    A variable has autograd ability with mir.ndslice.Slice like data
231 
232    TODO: add SliceKind
233 */
234 struct Variable(T, size_t dim, alias Storage = HostStorage, SliceKind kind = Contiguous) {
235     bool requiresGrad = true;
236     // size_t[dim]
237     uint[dim] shape;
238     // ptrdiff_t[dim]
239     int[dim] strides;
240     Storage!T data;
241     Storage!T grad;
242     BackProp bprop;
243     enum isHost = is(Storage!T == HostStorage!T);
244     uint offset = 0;
245 
246     // void opAssign(Variable!(T, dim, Storage) rhs) {
247     // }
248 
249     ///
250     this(bool requiresGrad, uint[dim] shape, int[dim] strides, Storage!T data) {
251         this.requiresGrad = requiresGrad;
252         this.shape = shape;
253         this.strides = strides;
254         this.data = data;
255         // if (this.requiresGrad) { // TODO enable this
256         static if (is(Storage!T == HostStorage!T)) {
257             this.grad = zeros!(Storage!T)(this.data.length);
258         }
259         else version (grain_cuda) {
260             this.grad = grain.cuda.zeros!(CuPtr!T)(this.data.length);
261         }
262     }
263 
264     /// get gradient as variable
265     auto gradVariable(bool requiresGrad = false) {
266         return Variable(requiresGrad, this.shape, this.strides, this.grad);
267     }
268 
269     /// detach the computation graph used in backward
270     ref detach() {
271         this.bprop = BackProp();
272         return this;
273     }
274 
275     /// data pointer
276     @property auto ptr() {
277         return this.data.ptr + offset;
278     }
279 
280     /// check data is not null
281     @property bool defined() {
282         return cast(size_t) data.ptr != 0;
283     }
284 
285     /// duplicate (deep copy) variable
286     auto dup() {
287         static if (is(Storage!T == HostStorage!T)) {
288             auto d = new T[data.length];
289             d[] = data[];
290         }
291         else {
292             auto d = CuArray!T(data.dup);
293         }
294         auto y = Variable(this.requiresGrad, this.shape, this.strides, d);
295         return y;
296     }
297 
298     static if (is(Storage!T == HostStorage!T)) {
299         ///
300         auto sliced() {
301             import mir.ndslice; // .slice : Slice, Universal;
302             static if (dim == 0) {
303                 return [this.data[0]].sliced.universal;
304             }
305             else {
306                 return Slice!(T*, dim, Universal)(
307                         this.shape.castArray!size_t,
308                         this.strides.castArray!ptrdiff_t, data.ptr);
309             }
310         }
311 
312         ///
313         auto gradSliced() {
314             import mir.ndslice; // .slice : Slice, Universal;
315             static if (dim == 0) {
316                 return [this.grad[0]].sliced.universal;
317             }
318             else {
319                 return Slice!(T*, dim, Universal)(
320                         this.shape.castArray!size_t,
321                         this.strides.castArray!ptrdiff_t, grad.ptr);
322             }
323         }
324     }
325     else {
326         ///
327         auto sliced() {
328             import mir.ndslice; // .slice : Slice, Universal;
329             static if (dim == 0) {
330                 return Slice!(T*, 1, Universal)([1], [1], cast(T*) data.ptr);
331             }
332             else {
333                 return Slice!(T*, dim, Universal)(
334                         this.shape.castArray!size_t,
335                         this.strides.castArray!ptrdiff_t, cast(T*) data.ptr);
336             }
337         }
338 
339         // TODO gradSliced?
340     }
341 
342     /// computes gradients of creator variables w.r.t. the arg grad
343     void backward(UntypedVariable* grad, size_t pos = 0) {
344         this.bprop.backward(grad, pos);
345     }
346 
347     /// computes gradients of creator variables w.r.t. this variable
348     static if (dim == 0) {
349         void backward() {
350             auto grad = UntypedVariable(1.0f.variable.to!Storage);
351             this.bprop.backward(&grad);
352         }
353     }
354 
355     ///
356     string toString() const {
357         import std.format : format;
358 
359         return "Variable!(%s, dim=%d, %s)(data=%s, shape=%s, strides=%s)"
360             .format(T.stringof, dim, Storage.stringof, data, shape, strides);
361     }
362 
363     /// binary ops: b * this
364     /// TODO implement contiguous with mir.ndslice and cudnnTransformTensor
365     auto opBinary(string op)(Variable!(T, dim, Storage) b) {
366         import grain.chain : opBinaryFunc, reciprocal;
367 
368         static if (op == "+" || op == "*") {
369             return opBinaryFunc!op(this, b);
370         }
371         else static if (op == "-") {
372             return opBinaryFunc!"+"(this, b, 1, -1);
373         }
374         else static if (op == "/") {
375             return opBinaryFunc!"*"(this, reciprocal(b));
376         }
377         else {
378             static assert(false, "unsupported op: " ~ op);
379         }
380     }
381 
382     /// binary ops with primitive scalar value (e.g., float, double)
383     auto opBinary(string op)(T b) {
384         uint[dim] shape;
385         shape[] = 1;
386         auto v = uninitVariable!(T, Storage, dim)(shape, false);
387         static if (is(Storage!T == HostStorage!T)) {
388             import std.algorithm : fill;
389 
390             fill(v.data, b);
391         }
392         else {
393             fill_(v.data, b);
394         }
395         return this.opBinary!op(v);
396     }
397 
398     /// binary ops: this op b
399     auto opBinaryRight(string op)(T b) {
400         static if (op == "+" || op == "*") {
401             return this.opBinary!op(b);
402         }
403         else static if (op == "-") {
404             return this.opBinary!"+"(-b);
405         }
406         else static if (op == "/") {
407             uint[dim] shape;
408             shape[] = 1;
409             auto v = uninitVariable!(T, Storage, dim)(shape, false);
410             static if (is(Storage!T == HostStorage!T)) {
411                 import std.algorithm : fill;
412 
413                 fill(v.data, b);
414             }
415             else {
416                 fill_(v.data, b);
417             }
418             return v.opBinary!op(this);
419         }
420         else {
421             static assert(false, "unsupported op: " ~ op);
422         }
423     }
424 }
425 
426 /// test opBinary(string op)(Variable ...)
427 unittest {
428     import mir.ndslice;
429     import numir;
430     import std.stdio;
431 
432     static foreach (op; ["+", "*", "-", "/"]) {
433         {
434             auto a = uniform!float(3, 2).slice.variable(true);
435             auto b = uniform!float(3, 2).slice.variable(true);
436             // this is equivalent to `a + b` if op == "+"
437             auto c = a.opBinary!op(b);
438             // this is equivalent to `a.sliced.slice + b.sliced.slice` if op == "+"
439             auto e = a.sliced.slice.opBinary!op(b.sliced.slice);
440             assert(approxEqual(c.sliced, e));
441 
442             auto gc = uniform!float(3, 2).slice.variable(true);
443             auto ugc = UntypedVariable(gc);
444             c.backward(&ugc);
445 
446             version (grain_cuda) {
447                 auto da = a.to!DeviceStorage;
448                 auto db = b.to!DeviceStorage;
449                 auto dc = da.opBinary!op(db);
450                 assert(approxEqual(dc.to!HostStorage.sliced, c.sliced));
451 
452                 import grain.cuda : zero_;
453 
454                 da.grad.zero_();
455                 db.grad.zero_();
456                 auto dugc = UntypedVariable(gc.to!DeviceStorage);
457                 dc.backward(&dugc);
458                 assert(approxEqual(da.to!HostStorage.gradSliced, a.gradSliced));
459             }
460         }
461     }
462 }
463 
464 /// test multiple addition
465 unittest {
466     static import grain.config;
467     grain.config.backprop = true;
468     auto x = [1f, 2f].variable(true);
469     auto y = x + x; // x = 2 x
470     auto z = y + y; // x = 4 x
471     auto g = [0f, 1f].variable;
472     auto u = UntypedVariable(g);
473     z.backward(&u);
474     assert(x.gradSliced == [0f, 4f]);
475 }
476 
477 // /// FIXME: test multiple addition with assign
478 // unittest {
479 //     import std.stdio;
480 //     grain.config.backprop = true;
481 //     auto x = [1f, 2f].variable(true);
482 //     x = x + x; // x = 2 x
483 //     x = x + x; // x = 4 x
484 //     auto g = [0f, 1f].variable;
485 //     auto u = UntypedVariable(g);
486 //     x.backward(&u);
487 //     x.gradSliced.writeln;
488 //     assert(x.gradSliced == [0f, 4f]);
489 // }
490 
491 /// test Variable.defined
492 unittest {
493     Variable!(float, 1, HostStorage) h;
494     assert(!h.defined);
495     assert(0.variable.defined);
496     assert(0.1f.variable.defined);
497     assert([0].variable.defined);
498     assert([0.1f].variable.defined);
499 
500     version (grain_cuda) {
501         Variable!(float, 1, DeviceStorage) d;
502         assert(!d.defined);
503         assert(!h.to!DeviceStorage.defined);
504         assert(0.variable.to!DeviceStorage.defined);
505         assert(0.1f.variable.to!DeviceStorage.defined);
506         assert([0].variable.to!DeviceStorage.defined);
507         assert([0.1f].variable.to!DeviceStorage.defined);
508     }
509 }
510 
511 /// a trait to identify variable object
512 enum bool isVariable(T) = is(T : Variable!(Elem, dim, Storage), Elem, size_t dim, alias Storage);
513 
514 /// a trait to identify variable stored in CPU memory
515 enum bool isHost(V : Variable!(Elem, dim, Storage), Elem, size_t dim, alias Storage) = is(
516             Storage!Elem == HostStorage!Elem);
517 
518 /// a function to get the number of dimensions of variable
519 enum size_t DimensionCount(V : Variable!(Elem, dim, Storage), Elem, size_t dim, alias Storage) = dim;
520 
521 /// an alias of element type (e.g., float, double and int) of variable
522 alias ElementType(V : Variable!(Elem, dim, Storage), Elem, size_t dim, alias Storage) = Elem;
523 
524 /// total number of elements in variable
525 auto length(V)(V v) if (isVariable!V) {
526     import std.algorithm : reduce;
527 
528     return v.shape.reduce!"a * b";
529 }
530 
531 /// a helper function to create variable object from slice
532 auto variable(Sl)(Sl sl, bool requiresGrad = false) if (isSlice!Sl) {
533     import mir.ndslice : universal, DeepElementType;
534     import std.algorithm : reduce;
535 
536     auto s = sl.universal;
537     alias S = typeof(s);
538     alias E = DeepElementType!S;
539     auto size = s._lengths.reduce!"a * b";
540     auto data = s._iterator[0 .. size];
541     uint[DimensionCount!S] shape;
542     int[DimensionCount!S] strides;
543     static foreach (i; 0 .. DimensionCount!S) {
544         assert(s._lengths[i] < int.max);
545         assert(s._strides[i] < int.max);
546         shape[i] = cast(uint) s.length!i;
547         strides[i] = cast(int) s._strides[i];
548     }
549     return Variable!(E, DimensionCount!S, HostStorage)(requiresGrad, shape, strides, data);
550 }
551 
552 import std.traits : isNumeric;
553 
554 /// a helper function to create variable object from CPU/CUDA array
555 auto variable(alias Storage = HostStorage, bool requiresGrad = false, T)(T x)
556         if (isNumeric!T) {
557     return Variable!(T, 0, Storage)(requiresGrad, [], [], [x]);
558 }
559 
560 /// ditto
561 auto variable(A)(A a, bool requiresGrad = false) if (isArray!A) {
562     import numir.core : nparray;
563 
564     return a.nparray.variable(requiresGrad);
565 }
566 
567 ///
568 version (grain_cuda) unittest {
569     auto h = 0.5f.variable;
570     auto d = h.to!DeviceStorage;
571     assert(d.to!HostStorage.data == h.data);
572 }
573 
574 /// copy variable into the other device (e.g., CPU -> CUDA or CUDA -> CPU)
575 Variable!(T, dim, Dst) to(alias Dst, T, size_t dim, alias Src)(Variable!(T, dim, Src) src) {
576     static if (is(Dst!T == Src!T))
577         return src;
578     else {
579         import std.range : empty;
580 
581         auto d = src.data.to!Dst;
582         auto g = src.grad.to!Dst;
583         // FIXME: consider grad
584         auto ret = typeof(return)(src.requiresGrad, src.shape, src.strides, d);
585         ret.grad = g;
586         return ret;
587     }
588 }
589 
590 ///
591 unittest {
592     import std.stdio;
593 
594     {
595         // Variable!(float, 1) x;
596         auto x = [-1f, -2f, -3f].variable;
597         auto y = x.dup;
598         x.data[0] = 1.0;
599         static assert(isVariable!(typeof(x)));
600         static assert(!isVariable!void);
601         static assert(isHost!(typeof(x)));
602         assert(y.data[0] == -1);
603     }
604     version (grain_cuda) {
605         {
606             auto x = [[1f, 3f], [5f, 7f], [9f, 11f]].variable;
607 
608             assert(x.data.length == 6);
609             static assert(!isHost!(typeof(x.to!DeviceStorage)));
610             auto xx = x.dup;
611             assert(x.to!DeviceStorage
612                     .to!HostStorage
613                     .sliced == x.sliced);
614         }
615     }
616 }
617 
618 /// kind of std.algorithm.each for iterating variables inside a chain
619 void iterVariables(alias proc, C)(C* chain, string prefix = "") {
620     import std.traits;
621     import grain.autograd;
622 
623     foreach (name; FieldNameTuple!C) {
624         auto fullName = prefix ~ "." ~ name;
625         auto value = __traits(getMember, chain, name);
626         alias V = typeof(value);
627         static if (isVariable!V) {
628             proc(fullName, value);
629         }
630         else static if (hasMember!(V, "tupleof")) {
631             iterVariables!proc(&value, fullName);
632         }
633     }
634 }
635 
636 /// ditto
637 void refIterVariables(alias proc, C)(ref C chain, string prefix = "") {
638     import std.traits;
639     import grain.autograd;
640 
641     foreach (name; FieldNameTuple!C) {
642         auto fullName = prefix ~ "." ~ name;
643         auto value = __traits(getMember, chain, name);
644         alias V = typeof(value);
645         static if (isVariable!V) {
646             proc(fullName, value);
647         }
648         else static if (hasMember!(V, "tupleof")) {
649             refIterVariables!proc(value, fullName);
650         }
651     }
652 }