1 module grain.autograd;
2 
3 import std.traits : isArray;
4 import std.typecons : RefCounted, RefCountedAutoInitialize;
5 import mir.ndslice : isSlice;
6 import std.range : ElementType;
7 
8 import grain.cuda;
9 import grain.utility : castArray;
10 
11 alias HostStorage(T) = T[];
12 
13 auto zero_(T)(T[] s) {
14     import std.algorithm.mutation : fill;
15     fill(s, 0);
16     return s;
17 }
18 
19 auto zeros(T)(size_t n) if (isArray!T) {
20     auto s = new ElementType!T[n];
21     return s.zero_();
22 }
23 
24 unittest {
25     float[] h = [1f, 2f, 3f];
26     h.zero_();
27     assert(h == [0f, 0f, 0f]);
28     assert(zeros!(HostStorage!float)(3) == [0f, 0f, 0f]);
29 }
30 
31 version(grain_cuda) {
32     alias DeviceStorage(T) = CuPtr!T;
33 
34     enum bool isDevice(T) = is(typeof({T.init.toHost();}));
35 
36 
37     auto to(alias S : DeviceStorage, T)(T[] src) {
38         import std.array : empty;
39         return src.empty ? DeviceStorage!T() : DeviceStorage!T(src);
40     }
41 
42     auto to(alias S : HostStorage, Src)(Src src) if (isDevice!Src) {
43         return src.toHost();
44     }
45 
46     // auto to(alias S : HostStorage, Src)(Src src) if (!isDevice!Src) {
47     //     return src;
48     // }
49 
50     // auto to(alias S : DeviceStorage, Src)(Src src) if (isDevice!Src) {
51     //     return src;
52     // }
53 
54     unittest {
55         auto h = [[0.1f, 0.2f, 0.3f], [0.4f, 0.5f, 0.6f]].variable;
56         auto d = h.to!DeviceStorage;
57         assert(h.data == d.to!HostStorage.data);
58     }
59 }
60 
61 
62 /// type-erased variable
63 struct UntypedVariable {
64     import std.variant;
65     bool requiresGrad;
66     size_t dim;
67     // size_t[]
68     int[] shape;
69     // ptrdiff_t[]
70     int[] strides;
71     TypeInfo elem;
72     Variant data, grad;
73     size_t outPosition = 0;
74     // RefCounted!
75     BackProp bprop;
76 
77     this(T, size_t dim, alias Storage)(Variable!(T, dim, Storage) v) {
78         this.elem = typeid(T);
79         this.requiresGrad = v.requiresGrad;
80         this.shape = v.shape.dup;
81         this.strides = v.strides.dup;
82         this.dim = dim;
83         this.data = v.data;
84         this.grad = v.grad;
85     }
86 
87     auto get(T)() {
88         return this.data.get!(RefCounted!T);
89     }
90 
91     auto to(V : Variable!(T, dim, Storage), T, size_t dim, alias Storage)() {
92         auto d = this.data.get!(RefCounted!(Storage!T));
93         return Variable!(T, dim, Storage)(
94             this.requiresGrad, this.shape[0..dim], this.strides[0..dim], d);
95     }
96 
97     string toString() const {
98         import std.format : format;
99         return "UntypedVariable(%s, dim=%d, data=%s, shape=%s, strides=%s)".format(
100             elem, dim, data, shape, strides);
101     }
102 
103     auto gradSlice(V)() if (isVariable!V && isHost!V) {
104         import mir.ndslice.slice : sliced;
105         return grad.get!(typeof(V.init.data)).ptr.sliced(this.shape[0 .. Ndim!V].castArray!size_t);
106     }
107 }
108 
109 auto gradSlice(V)(V v) if (isVariable!V && isHost!V) {
110     import mir.ndslice.slice : sliced;
111     return v.grad.ptr.sliced(v.shape.castArray!size_t);
112 }
113 
114 
115 
116 /// FIXME maybe singleton?
117 shared bool backprop = false;
118 
119 /// Informations for backpropagation
120 struct BackProp {
121     alias Proc = void delegate(UntypedVariable[], UntypedVariable[]);
122     Proc proc;
123     UntypedVariable[] inputs;
124     UntypedVariable[] gradOutputs;
125     size_t nGrad = 0;
126 
127     void backward(UntypedVariable* grad=null, size_t pos=0) {
128         import std.exception : enforce;
129         import std.range : empty;
130         // enforce(!this.inputs.empty, "nothing to backprop");
131         if (this.inputs.empty) return;
132         ++this.nGrad;
133         if (grad is null) {
134             enforce(this.gradOutputs.length == 1, "this variable is not loss");
135         } else {
136             this.gradOutputs[pos] = *grad; // FIXME??
137         }
138         if (grad is null || this.nGrad == this.gradOutputs.length) {
139             proc(this.gradOutputs, this.inputs);
140         }
141 
142         // FIXME: reconsider this maybe
143         // import core.memory : GC;
144         // destroy(gradOutputs);
145         // GC.free(&gradOutputs);
146         // destroy(this);
147         // GC.free(&this);
148     }
149 }
150 
151 ///
152 unittest {
153     import std.stdio;
154     UntypedVariable u;
155     {
156         auto v = [[0f, 1f], [2f, 3f]].variable;
157         u = UntypedVariable(v);
158     }
159     assert(u.get!(HostStorage!float) == [0, 1, 2, 3]);
160 }
161 
162 // TODO add SliceKind
163 struct Variable(T, size_t dim, alias Storage = HostStorage) {
164     bool requiresGrad = true;
165     // size_t[dim]
166     int[dim] shape;
167     // ptrdiff_t[dim]
168     int[dim] strides;
169     RefCounted!(Storage!T) data;
170     RefCounted!(Storage!T) grad;
171     // RefCounted!
172     BackProp bprop;
173     enum isHost = is(Storage!T == HostStorage!T);
174 
175     this(bool requiresGrad, int[dim] shape, int[dim] strides, RefCounted!(Storage!T) data) {
176         this.requiresGrad = requiresGrad;
177         this.shape = shape;
178         this.strides = strides;
179         this.data = data;
180         // this.grad.isHost = is(Storage!T == HostStorage!T);
181         // if (this.requiresGrad) { // TODO enable this
182             static if (is(Storage!T == HostStorage!T)) {
183                 this.grad = zeros!(Storage!T)(this.data.length);
184             } else version (grain_cuda) {
185                 // TODO why is grain.cuda. required?
186                 this.grad = grain.cuda.zeros!(CuPtr!T)(this.data.length);
187             }
188         // }
189     }
190 
191     @property
192     bool defined() { return cast(size_t) data.ptr != 0; }
193 
194     auto dup() {
195         static if (is(Storage!T == HostStorage!T)) {
196             RefCounted!(Storage!T) d = new T[data.length];
197             d[] = data[];
198         } else {
199             RefCounted!(Storage!T) d = data.dup;
200         }
201         auto y = Variable(this.requiresGrad, this.shape, this.strides, d);
202         return y;
203     }
204 
205     static if (is(Storage!T == HostStorage!T)) {
206         auto sliced() {
207             import mir.ndslice; // .slice : Slice, Universal;
208             static if (dim == 0) {
209                 return [this.data[0]].sliced.universal;
210             } else {
211                 return Slice!(Universal, [dim], T*)(
212                     this.shape.castArray!size_t,
213                     this.strides.castArray!ptrdiff_t, data.ptr);
214             }
215         }
216 
217         auto gradSliced() {
218             import mir.ndslice; // .slice : Slice, Universal;
219             static if (dim == 0) {
220                 return [this.grad[0]].sliced.universal;
221             } else {
222                 return Slice!(Universal, [dim], T*)(
223                     this.shape.castArray!size_t,
224                     this.strides.castArray!ptrdiff_t, grad.ptr);
225             }
226         }
227     }
228 
229     // TODO pass gradOutput
230     void backward(UntypedVariable* grad, size_t pos=0) {
231         this.bprop.backward(grad, pos);
232     }
233 
234     void backward() {
235         auto grad = UntypedVariable(1.0f.variable.to!Storage);
236         this.bprop.backward(&grad, 0);
237     }
238 
239 
240     string toString() const {
241         import std.format : format;
242         return "Variable!(%s, dim=%d, %s)(data=%s, shape=%s, strides=%s)"
243             .format(T.stringof, dim, Storage.stringof,
244                     data, shape, strides);
245     }
246 }
247 
248 /// test Variable.defined
249 unittest {
250     Variable!(float, 1, HostStorage) h;
251     assert(!h.defined);
252     assert(0.variable.defined);
253     assert(0.1f.variable.defined);
254     assert([0].variable.defined);
255     assert([0.1f].variable.defined);
256 
257     version (grain_cuda) {
258         Variable!(float, 1, DeviceStorage) d;
259         assert(!d.defined);
260         assert(!h.to!DeviceStorage.defined);
261         assert(0.variable.to!DeviceStorage.defined);
262         assert(0.1f.variable.to!DeviceStorage.defined);
263         assert([0].variable.to!DeviceStorage.defined);
264         assert([0.1f].variable.to!DeviceStorage.defined);
265     }
266 }
267 
268 enum bool isVariable(T) = is(T : Variable!(Elem, dim, Storage), Elem, size_t dim, alias Storage);
269 enum bool isHost(V : Variable!(Elem, dim, Storage), Elem, size_t dim, alias Storage) = is(Storage!Elem == HostStorage!Elem);
270 enum size_t Ndim(V : Variable!(Elem, dim, Storage), Elem, size_t dim, alias Storage) = dim;
271 alias ElementType(V : Variable!(Elem, dim, Storage), Elem, size_t dim, alias Storage) = Elem;
272 
273 
274 auto variable(Sl)(Sl sl, bool requiresGrad = false) if (isSlice!Sl) {
275     import mir.ndslice : universal, DeepElementType;
276     import std.algorithm : reduce;
277 
278     import numir : Ndim;
279     auto s = sl.universal;
280     alias S = typeof(s);
281     alias E = DeepElementType!S;
282     auto size = s._lengths.reduce!"a * b";
283     RefCounted!(E[]) data = s._iterator[0..size];
284     int[Ndim!S] shape, strides;
285     static foreach (i; 0 .. Ndim!S) {
286         assert(s._lengths[i] < int.max);
287         assert(s._strides[i] < int.max);
288         shape[i] = cast(int) s.length!i;
289         strides[i] = cast(int) s._strides[i];
290     }
291     return Variable!(E, Ndim!S, HostStorage)(
292         requiresGrad, shape, strides, data);
293 }
294 
295 import std.traits : isNumeric;
296 auto variable(alias Storage=HostStorage, bool requiresGrad=false, T)(T x) if (isNumeric!T) {
297     RefCounted!(T[]) data = [x];
298     return Variable!(T, 0, Storage)(requiresGrad, [], [], data);
299 }
300 
301 auto variable(A)(A a, bool requiresGrad=false) if (isArray!A) {
302     import numir.core : nparray;
303     return a.nparray.variable(requiresGrad);
304 }
305 
306 ///
307 version (grain_cuda) unittest {
308     auto h = 0.5f.variable;
309     auto d = h.to!DeviceStorage;
310     assert(d.to!HostStorage.data == h.data);
311 }
312 
313 Variable!(T, dim, Dst) to(alias Dst, T, size_t dim, alias Src)(Variable!(T, dim, Src) src) {
314     static if (is(Dst!T == Src!T)) return src;
315     else {
316         import std.range :empty;
317         RefCounted!(Dst!T) d = src.data.to!Dst;
318         RefCounted!(Dst!T) g = src.grad.to!Dst;
319         // FIXME: consider grad
320         auto ret = typeof(return)(src.requiresGrad, src.shape, src.strides, d);
321         ret.grad = g;
322         return ret;
323     }
324 }
325 
326 
327 ///
328 unittest {
329     import std.stdio;
330     {
331         // Variable!(float, 1) x;
332         auto x = [-1f, -2f, -3f].variable;
333         auto y = x.dup;
334         x.data[0] = 1.0;
335         static assert(isVariable!(typeof(x)));
336         static assert(!isVariable!void);
337         static assert(isHost!(typeof(x)));
338         assert(y.data[0] == -1);
339     }
340     version (grain_cuda) {
341         {
342             auto x = [[1f, 3f],
343                       [5f, 7f],
344                       [9f, 11f]].variable;
345 
346             assert(x.data.length == 6);
347             static assert(!isHost!(typeof(x.to!DeviceStorage)));
348             auto xx = x.dup;
349             assert(x.to!DeviceStorage.to!HostStorage.sliced == x.sliced);
350         }
351     }
352 }
353 
354