1 /** 2 Common components for autograd function object 3 */ 4 module grain.functions.common; 5 6 import grain.autograd; 7 import grain.cuda; 8 import grain.utility : fromTuple, castArray; 9 import mir.ndslice : isSlice; 10 11 import std.stdio; 12 import std.traits : hasMember; 13 14 version (grain_cuda) { 15 import cudnn = derelict.cudnn7; 16 } 17 18 /// a simple type check of forward/backward functions compatibility 19 mixin template TypeChecker(alias forward, alias backward) { 20 static assert(allSatisfy!(isVariable, Parameters!forward), 21 "all the forward function args should be variable."); 22 static assert(allSatisfy!(isVariable, Parameters!backward), 23 "all the backward function args should be variable."); 24 static if (arity!forward == 1 && arity!backward == 1) { 25 static assert(is(ReturnType!backward == Parameters!forward[0])); 26 static assert(is(ReturnType!forward == Parameters!backward[0])); 27 } else static if (arity!backward == 1) { 28 static assert(is(ReturnType!backward == Tuple!(Parameters!forward))); 29 static assert(is(ReturnType!forward == Parameters!backward[0])); 30 } else static if (arity!forward == 1) { 31 static assert(is(ReturnType!backward == Parameters!forward[0])); 32 static assert(is(ReturnType!forward == Tuple!(Parameters!backward))); 33 } else { 34 static assert(is(ReturnType!backward == Tuple!(Parameters!forward))); 35 static assert(is(ReturnType!forward == Tuple!(Parameters!backward))); 36 } 37 } 38 39 /// a trait to identify autograd functions 40 enum bool isFunction(T) = hasMember!(T, "forward") && hasMember!(T, "backward"); 41 42 /// common components (typecheck and backprop wrappers) for autograd functions 43 mixin template FunctionCommon() { 44 import std.meta : allSatisfy; 45 import std.typecons : isTuple, tuple, Tuple, RefCounted; 46 import std.traits : arity, Parameters, ReturnType; 47 import grain.utility : toTuple; 48 49 @disable this(this); // no copyable 50 51 static foreach (i, forward; __traits(getOverloads, typeof(this), "forward")) { 52 static foreach (i, backward; __traits(getOverloads, typeof(this), "backward")) { 53 static if (!allSatisfy!(isHost, Parameters!forward) && 54 !allSatisfy!(isHost, Parameters!backward)) { 55 alias DeviceRets = Tuple!(Parameters!backward); 56 alias DeviceArgs = Tuple!(Parameters!forward); 57 mixin TypeChecker!(forward, backward); 58 DeviceArgs _mixin_dargs; 59 } 60 static if (allSatisfy!(isHost, Parameters!forward) && 61 allSatisfy!(isHost, Parameters!backward)) { 62 alias HostRets = Tuple!(Parameters!backward); 63 alias HostArgs = Tuple!(Parameters!forward); 64 mixin TypeChecker!(forward, backward); 65 HostArgs _mixin_hargs; 66 } 67 } 68 } 69 static assert(isFunction!(typeof(this))); 70 71 /// store grain.autograd.BackProp object in returned variables from forward function 72 auto applyForward(Args...)(Args args) { 73 static import grain.config; 74 75 enum isHost = allSatisfy!(isHost, Args); 76 static foreach (i, a; args) { 77 static if (isHost) _mixin_hargs[i] = a; 78 else _mixin_dargs[i] = a; 79 } 80 auto rets = this.forward(args).toTuple; 81 auto bp = BackProp(&this.applyBackward!isHost, 82 new UntypedVariable[rets.length]); 83 if (grain.config.backprop) { 84 foreach (ref r; rets) { 85 r.bprop = bp; 86 } 87 } 88 static if (rets.length > 1) { 89 return rets; 90 } else { 91 return rets[0]; 92 } 93 } 94 95 /// type-erased version of backward function used in grain.autograd.BackProp object 96 void applyBackward(bool isHost_)(UntypedVariable[] ugradOutputs) { 97 static if (isHost_) { 98 HostRets vgradOutputs; 99 alias args = _mixin_hargs; 100 } else { 101 DeviceRets vgradOutputs; 102 alias args = _mixin_dargs; 103 } 104 static foreach (i; 0 .. vgradOutputs.length) { 105 vgradOutputs[i] = ugradOutputs[i].to!(typeof(vgradOutputs[i])); 106 } 107 auto vgradInputs = this.backward(vgradOutputs.expand).toTuple; 108 static assert(typeof(vgradInputs).length == args.length, 109 "invalid number of input gradients"); 110 111 UntypedVariable[vgradInputs.length] ugradInputs; 112 foreach (i, v; vgradInputs) { 113 ugradInputs[i] = UntypedVariable(v); 114 } 115 116 foreach (i, vgi; vgradInputs) { 117 // TODO reconsider this condition 118 if (args[i].requiresGrad) { 119 auto data = args[i].grad; 120 static if (vgi.isHost) { 121 import mir.ndslice.slice : sliced; 122 auto shape = vgradInputs[i].shape.castArray!size_t; 123 data[] += vgradInputs[i].data[]; // .sliced(shape); FIXME use shape 124 } else version (grain_cuda) { 125 import std.traits : isFloatingPoint; 126 // TODO support integral types 127 static if (isFloatingPoint!(ElementType!(typeof(vgi)))) { 128 // FIXME if contiguous 129 import grain.cuda; 130 grain.cuda.axpy(vgradInputs[i].data, data); 131 /* 132 import grain.cudnn; 133 auto shape = vgradInputs[i].shape; 134 auto strides = vgradInputs[i].strides; 135 auto x = V(false, shape, strides, data); 136 auto gx = V(false, shape, strides, vgradInputs[i].data); 137 grain.cudnn.tensorOp!CUDNN_OP_TENSOR_ADD( 138 // FIXME grad can have different strides 139 gx, 140 x, 141 x); 142 */ 143 } 144 } 145 } 146 args[i].bprop.backward(&ugradInputs[i]); // FIXME support pos 147 } 148 } 149 } 150 151 152 /// test NG functions 153 unittest { 154 alias F1H = Variable!(float, 1, HostStorage); 155 version (grain_cuda) alias F1D = Variable!(float, 1, HostStorage); 156 struct A(DelayInstantiation) { 157 mixin FunctionCommon; 158 // mismatch of args 159 F1H forward(F1H x) { return x; }; 160 F1H backward(F1H x, F1H y) { return x; }; 161 } 162 static assert(!__traits(compiles, A!void)); 163 164 version (grain_cuda) { 165 struct B(DelayInstantiation) { 166 mixin FunctionCommon; 167 F1H forward(F1H x) { return x; }; 168 F1H backward(F1H x) { return x; }; 169 // mismatch of args in device 170 version (grain_cuda) { 171 F1D forward(F1D x) { return x; }; 172 F1D backward(F1D x, F1D y) { return x; }; 173 } 174 } 175 static assert(!__traits(compiles, B!void)); 176 } 177 } 178 179 180 /// check if broadcastable 181 auto broadcastable(T, size_t dim, alias Storage)(Variable!(T, dim, Storage) a, Variable!(T, dim, Storage) b) { 182 import std.typecons : tuple; 183 import std.algorithm : max; 184 int[dim] resultShape; 185 bool ok = false; 186 foreach (i; 0 .. dim) { 187 ok = a.shape[i] == b.shape[i] || a.shape[i] == 1 || b.shape[i] == 1; 188 if (ok) { 189 resultShape[i] = max(a.shape[i], b.shape[i]); 190 } else break; 191 } 192 return tuple!("ok", "shape")(ok, resultShape); 193 } 194 195 /// expand dimension i.e. repeat n time on dim 196 auto expand(size_t dim, S)(S s, size_t n) if (isSlice!S) { 197 import mir.primitives : DimensionCount; 198 static assert(dim < DimensionCount!S, format!"acessing invalid dim %d (should be < %d)"(dim, Ndim!S)); 199 assert(s.length!dim == 1); 200 201 import mir.ndslice : repeat, swapped, transposed, unpack; 202 /// [a, 1, b] -> repeat [n, a, 1, b] -> swapped [1, a, n, b] 203 return s.repeat(n).unpack.swapped!(0, dim+1)[0]; 204 } 205 206 /// 207 nothrow pure @safe 208 unittest { 209 import mir.ndslice; 210 assert(iota(1, 1, 3).expand!1(3) == 211 [[[0,1,2],[0,1,2],[0,1,2]]]); 212 assert(iota(1, 1, 3).expand!0(2).expand!1(3) == 213 [[[0,1,2],[0,1,2],[0,1,2]], 214 [[0,1,2],[0,1,2],[0,1,2]]]); 215 assert(iota(1, 3, 2).expand!0(2) == iota(3, 2).repeat(2).unpack); 216 } 217 218 /// exapand dimension if s.length!dim == 1 else do nothing but type in the same expressions of repeat/unpack/swapped/index[0] 219 auto maybeExpand(size_t dim, S)(S s, size_t n) if (isSlice!S) { 220 import mir.ndslice; 221 import mir.ndslice : repeat, swapped, transposed, unpack; 222 return s.length!dim == 1 ? s.expand!dim(n) : 223 /// [a, c, b] -> repeat [1, a, c, b] -> swapped [1, a, c, b] 224 s.repeat(1).unpack.swapped!(dim+1, dim+1)[0]; 225 } 226 227 /// 228 @nogc nothrow pure @safe 229 unittest { 230 import mir.ndslice; 231 assert(iota(1, 3, 2).maybeExpand!0(2) == iota(3, 2).repeat(2).unpack); 232 assert(iota(3, 2).maybeExpand!0(2) == iota(3, 2)); 233 } 234 235 /** 236 Returns: 237 broadcasted slice. 238 For example, when a has its shape [a, 1] and b has [1, b], 239 this function returns expanded a and b with a broadcasted shape [a, b]. 240 241 See_also: 242 https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html 243 */ 244 auto broadcast(S1, S2)(S1 a0, S2 b0) if (isSlice!S1 && isSlice!S2) { 245 import std.format : format; 246 import std.typecons : tuple; 247 import mir.primitives : DimensionCount; 248 static assert(DimensionCount!S1 == DimensionCount!S2); // TODO support dim mismatched slices by unsqueezing like numpy 249 enum dim = DimensionCount!S1; 250 static foreach (d; 1 .. dim+1) { 251 mixin(format!q{auto a%d = a%d.maybeExpand!(d-1)(b0.length!(d-1));}(d, d-1)); 252 mixin(format!q{auto b%d = b%d.maybeExpand!(d-1)(a0.length!(d-1));}(d, d-1)); 253 } 254 mixin(format!q{auto ax = a%d;}(dim)); 255 mixin(format!q{auto bx = b%d;}(dim)); 256 return tuple(ax, bx); 257 } 258 259 /// 260 @nogc nothrow pure @safe 261 unittest { 262 import mir.ndslice; 263 auto a = iota(1, 3, 1); 264 auto b = iota(1, 1, 2); 265 auto x = broadcast(a, b); 266 assert(broadcast(a, b)[0] == a.expand!2(2)); 267 assert(broadcast(a, b)[1] == b.expand!1(3)); 268 } 269 270 /// reduce slice into targetShape, TODO @nogc 271 auto reduceShape(alias fun, S, size_t N)(S s0, size_t[N] targetShape) { 272 import numir; 273 import mir.ndslice; 274 import mir.math : sum; 275 import std.format : format; 276 import std.exception : assumeWontThrow; // TODO unsqueeze can be assumeWontThrow 277 278 static if (N == 1) { 279 return s0; 280 } else { 281 auto rec(size_t n, T)(T t) { 282 static if (n == N) return t; 283 else { 284 return 285 rec!(n+1)( 286 targetShape[n] == 1 287 ? assumeWontThrow(t.alongDim!n.map!fun.slice.unsqueeze!n).slice 288 : t.slice); 289 } 290 } 291 return rec!0(s0); 292 } 293 } 294 295 nothrow pure @safe unittest { 296 import mir.ndslice; 297 import mir.math; 298 import numir; 299 import std.exception : assumeWontThrow; // TODO unsqueeze can be assumeWontThrow 300 assert(iota(2, 3).reduceShape!sum([2, 1]) == assumeWontThrow(iota(2, 3).alongDim!1.map!sum.slice.unsqueeze!1)); 301 }