1 /**
2    Common components for autograd function object
3  */
4 module grain.functions.common;
5 
6 import grain.autograd;
7 import grain.cuda;
8 import grain.utility : fromTuple, castArray;
9 import mir.ndslice : isSlice;
10 
11 import std.stdio;
12 import std.traits : hasMember;
13 
14 version (grain_cuda) {
15     import cudnn = derelict.cudnn7;
16 }
17 
18 /// a simple type check of forward/backward functions compatibility
19 mixin template TypeChecker(alias forward, alias backward) {
20     static assert(allSatisfy!(isVariable, Parameters!forward),
21                   "all the forward function args should be variable.");
22     static assert(allSatisfy!(isVariable, Parameters!backward),
23                   "all the backward function args should be variable.");
24     static if (arity!forward == 1 && arity!backward == 1) {
25         static assert(is(ReturnType!backward == Parameters!forward[0]));
26         static assert(is(ReturnType!forward == Parameters!backward[0]));
27     } else static if (arity!backward == 1) {
28         static assert(is(ReturnType!backward == Tuple!(Parameters!forward)));
29         static assert(is(ReturnType!forward == Parameters!backward[0]));
30     } else static if (arity!forward == 1) {
31         static assert(is(ReturnType!backward == Parameters!forward[0]));
32         static assert(is(ReturnType!forward == Tuple!(Parameters!backward)));
33     } else {
34         static assert(is(ReturnType!backward == Tuple!(Parameters!forward)));
35         static assert(is(ReturnType!forward == Tuple!(Parameters!backward)));
36     }
37 }
38 
39 /// a trait to identify autograd functions
40 enum bool isFunction(T) = hasMember!(T, "forward") && hasMember!(T, "backward");
41 
42 /// common components (typecheck and backprop wrappers) for autograd functions
43 mixin template FunctionCommon() {
44     import std.meta : allSatisfy;
45     import std.typecons : isTuple, tuple, Tuple, RefCounted;
46     import std.traits : arity, Parameters, ReturnType;
47     import grain.utility : toTuple;
48     
49     @disable this(this); // no copyable
50 
51     static foreach (i, forward; __traits(getOverloads, typeof(this), "forward")) {
52         static foreach (i, backward; __traits(getOverloads, typeof(this), "backward")) {
53             static if (!allSatisfy!(isHost, Parameters!forward) &&
54                        !allSatisfy!(isHost, Parameters!backward)) {
55                 alias DeviceRets = Tuple!(Parameters!backward);
56                 alias DeviceArgs = Tuple!(Parameters!forward);
57                 mixin TypeChecker!(forward, backward);
58                 DeviceArgs _mixin_dargs;
59             }
60             static if (allSatisfy!(isHost, Parameters!forward) &&
61                        allSatisfy!(isHost, Parameters!backward)) {
62                 alias HostRets = Tuple!(Parameters!backward);
63                 alias HostArgs = Tuple!(Parameters!forward);
64                 mixin TypeChecker!(forward, backward);
65                 HostArgs _mixin_hargs;
66             }
67         }
68     }
69     static assert(isFunction!(typeof(this)));
70 
71     /// store grain.autograd.BackProp object in returned variables from forward function
72     auto applyForward(Args...)(Args args) {
73         static import grain.config;
74 
75         enum isHost = allSatisfy!(isHost, Args);
76         static foreach (i, a; args) {
77             static if (isHost) _mixin_hargs[i] = a;
78             else _mixin_dargs[i] = a;
79         }
80         auto rets = this.forward(args).toTuple;
81         auto bp = BackProp(&this.applyBackward!isHost,
82                            new UntypedVariable[rets.length]);
83         if (grain.config.backprop) {
84             foreach (ref r; rets) {
85                 r.bprop = bp;
86             }
87         }
88         static if (rets.length > 1) {
89             return rets;
90         } else {
91             return rets[0];
92         }
93     }
94 
95     /// type-erased version of backward function used in grain.autograd.BackProp object
96     void applyBackward(bool isHost_)(UntypedVariable[] ugradOutputs) {
97         static if (isHost_) {
98             HostRets vgradOutputs;
99             alias args = _mixin_hargs;
100         } else {
101             DeviceRets vgradOutputs;
102             alias args = _mixin_dargs;
103         }
104         static foreach (i; 0 .. vgradOutputs.length) {
105             vgradOutputs[i] = ugradOutputs[i].to!(typeof(vgradOutputs[i]));
106         }
107         auto vgradInputs = this.backward(vgradOutputs.expand).toTuple;
108         static assert(typeof(vgradInputs).length == args.length,
109                       "invalid number of input gradients");
110 
111         UntypedVariable[vgradInputs.length] ugradInputs;
112         foreach (i, v; vgradInputs) {
113             ugradInputs[i] = UntypedVariable(v);
114         }
115 
116         foreach (i, vgi; vgradInputs) {
117             // TODO reconsider this condition
118             if (args[i].requiresGrad) {
119                 auto data = args[i].grad;
120                 static if (vgi.isHost) {
121                     import mir.ndslice.slice : sliced;
122                     auto shape = vgradInputs[i].shape.castArray!size_t;
123                     data[] += vgradInputs[i].data[]; // .sliced(shape); FIXME use shape
124                 } else version (grain_cuda) {
125                     import std.traits : isFloatingPoint;
126                     // TODO support integral types
127                     static if (isFloatingPoint!(ElementType!(typeof(vgi)))) {
128                         // FIXME if contiguous
129                         import grain.cuda;
130                         grain.cuda.axpy(vgradInputs[i].data, data);
131                         /*
132                         import grain.cudnn;
133                         auto shape = vgradInputs[i].shape;
134                         auto strides = vgradInputs[i].strides;
135                         auto x = V(false, shape, strides, data);
136                         auto gx = V(false, shape, strides, vgradInputs[i].data);
137                         grain.cudnn.tensorOp!CUDNN_OP_TENSOR_ADD(
138                             // FIXME grad can have different strides
139                             gx,
140                             x,
141                             x);
142                         */
143                     }
144                 }
145             }
146             args[i].bprop.backward(&ugradInputs[i]); // FIXME support pos
147         }
148     }
149 }
150 
151 
152 /// test NG functions
153 unittest {
154     alias F1H = Variable!(float, 1, HostStorage);
155     version (grain_cuda) alias F1D = Variable!(float, 1, HostStorage);
156     struct A(DelayInstantiation) {
157         mixin FunctionCommon;
158         // mismatch of args
159         F1H forward(F1H x) { return x; };
160         F1H backward(F1H x, F1H y) { return x; };
161     }
162     static assert(!__traits(compiles, A!void));
163 
164     version (grain_cuda) {
165         struct B(DelayInstantiation) {
166             mixin FunctionCommon;
167             F1H forward(F1H x) { return x; };
168             F1H backward(F1H x) { return x; };
169             // mismatch of args in device
170             version (grain_cuda) {
171                 F1D forward(F1D x) { return x; };
172                 F1D backward(F1D x, F1D y) { return x; };
173             }
174         }
175         static assert(!__traits(compiles, B!void));
176     }
177 }
178 
179 
180 /// check if broadcastable
181 auto broadcastable(T, size_t dim, alias Storage)(Variable!(T, dim, Storage) a, Variable!(T, dim, Storage) b) {
182     import std.typecons : tuple;
183     import std.algorithm : max;
184     int[dim] resultShape;
185     bool ok = false;
186     foreach (i; 0 .. dim) {
187         ok = a.shape[i] == b.shape[i] || a.shape[i] == 1 || b.shape[i] == 1;
188         if (ok) {
189             resultShape[i] = max(a.shape[i], b.shape[i]);
190         } else break;
191     }
192     return tuple!("ok", "shape")(ok, resultShape);
193 }
194 
195 /// expand dimension i.e. repeat n time on dim
196 auto expand(size_t dim, S)(S s, size_t n) if (isSlice!S) {
197     import mir.primitives : DimensionCount;
198     static assert(dim < DimensionCount!S, format!"acessing invalid dim %d (should be < %d)"(dim, Ndim!S));
199     assert(s.length!dim == 1);
200 
201     import mir.ndslice : repeat, swapped, transposed, unpack;
202     /// [a, 1, b] -> repeat [n, a, 1, b] -> swapped [1, a, n, b]
203     return s.repeat(n).unpack.swapped!(0, dim+1)[0];
204 }
205 
206 ///
207 nothrow pure @safe
208 unittest {
209     import mir.ndslice;
210     assert(iota(1, 1, 3).expand!1(3) ==
211            [[[0,1,2],[0,1,2],[0,1,2]]]);
212     assert(iota(1, 1, 3).expand!0(2).expand!1(3) ==
213            [[[0,1,2],[0,1,2],[0,1,2]],
214             [[0,1,2],[0,1,2],[0,1,2]]]);
215     assert(iota(1, 3, 2).expand!0(2) == iota(3, 2).repeat(2).unpack);
216 }
217 
218 /// exapand dimension if s.length!dim == 1 else do nothing but type in the same expressions of repeat/unpack/swapped/index[0]
219 auto maybeExpand(size_t dim, S)(S s, size_t n) if (isSlice!S) {
220     import mir.ndslice;
221     import mir.ndslice : repeat, swapped, transposed, unpack;
222     return s.length!dim == 1 ? s.expand!dim(n) :
223         /// [a, c, b] -> repeat [1, a, c, b] -> swapped [1, a, c, b]
224         s.repeat(1).unpack.swapped!(dim+1, dim+1)[0];
225 }
226 
227 ///
228 @nogc nothrow pure @safe
229 unittest {
230     import mir.ndslice;
231     assert(iota(1, 3, 2).maybeExpand!0(2) == iota(3, 2).repeat(2).unpack);
232     assert(iota(3, 2).maybeExpand!0(2) == iota(3, 2));
233 }
234 
235 /**
236    Returns:
237       broadcasted slice.
238       For example, when a has its shape [a, 1] and b has [1, b],
239       this function returns expanded a and b with a broadcasted shape [a, b].
240 
241    See_also:
242       https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html
243 */
244 auto broadcast(S1, S2)(S1 a0, S2 b0) if (isSlice!S1 && isSlice!S2) {
245     import std.format : format;
246     import std.typecons : tuple;
247     import mir.primitives : DimensionCount;
248     static assert(DimensionCount!S1 == DimensionCount!S2); // TODO support dim mismatched slices by unsqueezing like numpy
249     enum dim = DimensionCount!S1;
250     static foreach (d; 1 .. dim+1) {
251         mixin(format!q{auto a%d = a%d.maybeExpand!(d-1)(b0.length!(d-1));}(d, d-1));
252         mixin(format!q{auto b%d = b%d.maybeExpand!(d-1)(a0.length!(d-1));}(d, d-1));
253     }
254     mixin(format!q{auto ax = a%d;}(dim));
255     mixin(format!q{auto bx = b%d;}(dim));
256     return tuple(ax, bx);
257 }
258 
259 ///
260 @nogc nothrow pure @safe
261 unittest {
262     import mir.ndslice;
263     auto a = iota(1, 3, 1);
264     auto b = iota(1, 1, 2);
265     auto x = broadcast(a, b);
266     assert(broadcast(a, b)[0] == a.expand!2(2));
267     assert(broadcast(a, b)[1] == b.expand!1(3));
268 }
269 
270 /// reduce slice into targetShape, TODO @nogc
271 auto reduceShape(alias fun, S, size_t N)(S s0, size_t[N] targetShape) {
272     import numir;
273     import mir.ndslice;
274     import mir.math : sum;
275     import std.format : format;
276     import std.exception : assumeWontThrow; // TODO unsqueeze can be assumeWontThrow
277 
278     static if (N == 1) {
279         return s0;
280     } else {
281         auto rec(size_t n, T)(T t) {
282             static if (n == N) return t;
283             else {
284                 return
285                     rec!(n+1)(
286                         targetShape[n] == 1
287                         ? assumeWontThrow(t.alongDim!n.map!fun.slice.unsqueeze!n).slice
288                         : t.slice);
289             }
290         }
291         return rec!0(s0);
292     }
293 }
294 
295 nothrow pure @safe unittest {
296     import mir.ndslice;
297     import mir.math;
298     import numir;
299     import std.exception : assumeWontThrow; // TODO unsqueeze can be assumeWontThrow
300     assert(iota(2, 3).reduceShape!sum([2, 1]) == assumeWontThrow(iota(2, 3).alongDim!1.map!sum.slice.unsqueeze!1));
301 }