1 module grain.functions.topology;
2 
3 import std.stdio;
4 
5 import numir;
6 import mir.ndslice;
7 
8 import grain.autograd;
9 import grain.utility;
10 import grain.testing : gradCheck;
11 import grain.functions.common;
12 
13 /*
14   mir.ndslice shape/strides test
15  */
16 unittest {
17     auto info(S, P)(S s, P base) {
18         writefln!"Slice(ptr: %s, shape: %s, strides: %s)"(s._iterator - base, s._lengths, s._strides);
19     }
20 
21     auto s = iota(3, 4, 5).slice.universal;
22     auto ptr = s._iterator;
23 
24     assert(s._iterator == ptr);
25     assert(s._lengths == [3, 4, 5]);
26     assert(s._strides == [20, 5, 1]);
27 
28     auto t = s.swapped(0, 1);
29     assert(t._iterator == ptr);
30     assert(t._lengths == [4, 3, 5]);
31     assert(t._strides == [5, 20, 1]);
32 
33     auto s0 = s[0];
34     assert(s0._iterator == ptr);
35     assert(s0._lengths == [4, 5]);
36     assert(s0._strides == [5, 1]);
37 
38     auto s1 = s[0..$, 0];
39     assert(s1._iterator == ptr);
40     assert(s1._lengths == [3, 5]);
41     assert(s1._strides == [20, 1]);
42 
43     auto s2 = s[0..$, 0..$, 0];
44     assert(s2._iterator == ptr);
45     assert(s2._lengths == [3, 4]);
46     assert(s2._strides == [20, 5]);
47 
48     auto r0 = s.reversed!0;
49     assert(r0._iterator == ptr + 40);
50     assert(r0._lengths == [3, 4, 5]);
51     assert(r0._strides == [-20, 5, 1]);
52 
53     auto ra = s.allReversed;
54     assert(ra._iterator == ptr + 59);
55     assert(ra._lengths == [3, 4, 5]);
56     assert(ra._strides == [-20, -5, -1]);
57 
58     auto v = s.view(3, -1);
59     assert(v._iterator == ptr);
60     assert(v._lengths == [3, 20]);
61     assert(v._strides == [20, 1]);
62 }
63 
64 auto prod(T)(T x) {
65     return reduce!"a * b"(1L, x.sliced);
66 }
67 
68 
69 // Reshaping or viewing to the other shape
70 struct View(T, size_t sourceDim, size_t targetDim, alias Storage) {
71     import numir : view;
72     ptrdiff_t[targetDim] targetShape;
73     ptrdiff_t[sourceDim] sourceShape;
74 
75     auto forward(Variable!(T, sourceDim, Storage) x) {
76         // assert(x.shape[].prod == targetShape[].prod);
77         this.sourceShape = x.shape.castArray!ptrdiff_t; // TODO if train
78         auto y = x.sliced.view(targetShape);
79         return Variable!(T, targetDim, Storage)(
80             x.requiresGrad,
81             y.shape.castArray!uint,
82             y.strides.castArray!int,
83             x.data
84             );
85     }
86 
87     auto backward(Variable!(T, targetDim, Storage) gy) {
88         auto gx = gy.sliced.view(this.sourceShape);
89         return Variable!(T, sourceDim, Storage)(
90             gy.requiresGrad,
91             gx.shape.castArray!uint,
92             gx.strides.castArray!int,
93             gy.data
94             );
95     }
96 
97     mixin FunctionCommon;
98 }
99 
100 ///
101 unittest {
102     auto f = View!(float, 3, 2, HostStorage)([3, -1]);
103     auto x = iota(3, 4, 5).as!float.slice.variable;
104     auto y = f.forward(x);
105     assert(y.sliced == iota(3, 20));
106     auto hgy = uniform!float(3, 20).slice.variable;
107     auto hgx = f.backward(hgy);
108     assert(hgy.sliced.view(3, 4, 5) == hgx.sliced);
109     // gradCheck(f, x, hgy);
110 
111     version (grain_cuda) {
112         auto df = View!(float, 3, 2, DeviceStorage)([3, -1]);
113         auto dy = df.forward(x.to!DeviceStorage);
114         assert(dy.to!HostStorage.sliced == iota(3, 20));
115         auto dgx = df.backward(hgy.to!DeviceStorage);
116         assert(dgx.to!HostStorage.sliced.view(3, 4, 5) == hgx.sliced);
117     }
118 }
119 
120