1 module grain.functions.topology; 2 3 import std.stdio; 4 5 import numir; 6 import mir.ndslice; 7 8 import grain.autograd; 9 import grain.utility; 10 import grain.testing : gradCheck; 11 import grain.functions.common; 12 13 /* 14 mir.ndslice shape/strides test 15 */ 16 unittest { 17 auto info(S, P)(S s, P base) { 18 writefln!"Slice(ptr: %s, shape: %s, strides: %s)"(s._iterator - base, s._lengths, s._strides); 19 } 20 21 auto s = iota(3, 4, 5).slice.universal; 22 auto ptr = s._iterator; 23 24 assert(s._iterator == ptr); 25 assert(s._lengths == [3, 4, 5]); 26 assert(s._strides == [20, 5, 1]); 27 28 auto t = s.swapped(0, 1); 29 assert(t._iterator == ptr); 30 assert(t._lengths == [4, 3, 5]); 31 assert(t._strides == [5, 20, 1]); 32 33 auto s0 = s[0]; 34 assert(s0._iterator == ptr); 35 assert(s0._lengths == [4, 5]); 36 assert(s0._strides == [5, 1]); 37 38 auto s1 = s[0..$, 0]; 39 assert(s1._iterator == ptr); 40 assert(s1._lengths == [3, 5]); 41 assert(s1._strides == [20, 1]); 42 43 auto s2 = s[0..$, 0..$, 0]; 44 assert(s2._iterator == ptr); 45 assert(s2._lengths == [3, 4]); 46 assert(s2._strides == [20, 5]); 47 48 auto r0 = s.reversed!0; 49 assert(r0._iterator == ptr + 40); 50 assert(r0._lengths == [3, 4, 5]); 51 assert(r0._strides == [-20, 5, 1]); 52 53 auto ra = s.allReversed; 54 assert(ra._iterator == ptr + 59); 55 assert(ra._lengths == [3, 4, 5]); 56 assert(ra._strides == [-20, -5, -1]); 57 58 auto v = s.view(3, -1); 59 assert(v._iterator == ptr); 60 assert(v._lengths == [3, 20]); 61 assert(v._strides == [20, 1]); 62 } 63 64 auto prod(T)(T x) { 65 return reduce!"a * b"(1L, x.sliced); 66 } 67 68 69 // Reshaping or viewing to the other shape 70 struct View(T, size_t sourceDim, size_t targetDim, alias Storage) { 71 import numir : view; 72 ptrdiff_t[targetDim] targetShape; 73 ptrdiff_t[sourceDim] sourceShape; 74 75 auto forward(Variable!(T, sourceDim, Storage) x) { 76 // assert(x.shape[].prod == targetShape[].prod); 77 this.sourceShape = x.shape.castArray!ptrdiff_t; // TODO if train 78 auto y = x.sliced.view(targetShape); 79 return Variable!(T, targetDim, Storage)( 80 x.requiresGrad, 81 y.shape.castArray!uint, 82 y.strides.castArray!int, 83 x.data 84 ); 85 } 86 87 auto backward(Variable!(T, targetDim, Storage) gy) { 88 auto gx = gy.sliced.view(this.sourceShape); 89 return Variable!(T, sourceDim, Storage)( 90 gy.requiresGrad, 91 gx.shape.castArray!uint, 92 gx.strides.castArray!int, 93 gy.data 94 ); 95 } 96 97 mixin FunctionCommon; 98 } 99 100 /// 101 unittest { 102 auto f = View!(float, 3, 2, HostStorage)([3, -1]); 103 auto x = iota(3, 4, 5).as!float.slice.variable; 104 auto y = f.forward(x); 105 assert(y.sliced == iota(3, 20)); 106 auto hgy = uniform!float(3, 20).slice.variable; 107 auto hgx = f.backward(hgy); 108 assert(hgy.sliced.view(3, 4, 5) == hgx.sliced); 109 // gradCheck(f, x, hgy); 110 111 version (grain_cuda) { 112 auto df = View!(float, 3, 2, DeviceStorage)([3, -1]); 113 auto dy = df.forward(x.to!DeviceStorage); 114 assert(dy.to!HostStorage.sliced == iota(3, 20)); 115 auto dgx = df.backward(hgy.to!DeviceStorage); 116 assert(dgx.to!HostStorage.sliced.view(3, 4, 5) == hgx.sliced); 117 } 118 } 119 120