1 /**
2 A module for serialization. HDF5 read/write is supported
3  */
4 module grain.serializer;
5 
6 import std.stdio;
7 import grain.autograd;
8 
9 import grain.hdf5;
10 import std..string : toStringz;
11 
12 
13 /// enumerate the parameter names inside chain C
14 enum variableNames(C) = {
15     string[] ret;
16     void register(V)(string k, V v) if (isVariable!V) {
17         ret ~= [k];
18     }
19 
20     C chain;
21     iterVariables!((k, v) { register(k, v); })(&chain, "");
22     return ret;
23 }();
24 
25 ///
26 unittest {
27     import std.traits;
28 
29     auto mlp = MLP!(float, HostStorage)(3);
30     static assert(variableNames!(typeof(mlp)) == [".fc1.weight", ".fc1.bias",
31             ".fc2.weight", ".fc2.bias", ".fc3.weight", ".fc3.bias"]);
32 }
33 
34 // test .slice makes slice contiguous
35 unittest {
36     import numir;
37     import mir.ndslice;
38 
39     auto i = iota(3, 4, 5).transposed(1);
40     assert(i.universal._strides == [5, 20, 1]);
41     assert(i.slice.universal._strides == [15, 5, 1]);
42 }
43 
44 version (unittest) {
45     struct MLP(T, alias Storage) {
46         import grain.autograd : Variable;
47         import grain.chain : Linear, relu;
48 
49         alias L = Linear!(T, Storage);
50         L fc1, fc2, fc3;
51 
52         this(int nhidden) {
53             this.fc1 = L(2, nhidden);
54             this.fc2 = L(nhidden, nhidden);
55             this.fc3 = L(nhidden, 10);
56         }
57 
58         auto opCall(Variable!(T, 2, Storage) x) {
59             auto h1 = relu(this.fc1(x));
60             auto h2 = relu(this.fc2(h1));
61             auto h3 = this.fc3(h2);
62             return h1;
63         }
64     }
65 }
66 
67 /// convert D type into HDF5 type-id https://support.hdfgroup.org/HDF5/doc1.8/RM/PredefDTypes.html
68 auto toH5Type(T)() {
69     import std.traits;
70     import std.format;
71 
72     static assert(isBasicType!T);
73     mixin("return H5T_%s%dLE;".format(isFloatingPoint!T ? "IEEE_F"
74             : (isSigned!T ? "STD_I" : "STD_U"), T.sizeof * 8));
75 }
76 
77 /// save chain parameters to HDF5 path
78 void save(bool verbose = true, C)(C chain, string path) {
79     import std.file : exists;
80     import std..string : replace, endsWith;
81     import mir.ndslice : slice;
82     import grain.utility : castArray;
83 
84     auto file = H5Fcreate(path.toStringz, // path.exists ? H5F_ACC_TRUNC :
85                           // H5F_ACC_RDWR, //
86                           H5F_ACC_TRUNC,
87                           H5P_DEFAULT, H5P_DEFAULT);
88     scope (exit)
89         H5Fclose(file);
90 
91     void register(T, size_t dim, alias Storage)(string k, Variable!(T, dim, Storage) v) {
92         auto h = v.to!HostStorage;
93         // FIXME support check contiguous
94         // auto s = h.sliced.slice;
95         auto data = v.to!HostStorage.data;
96         auto dims = h.shape.castArray!hsize_t;
97         auto space = H5Screate_simple(cast(int) dims.length, dims.ptr, dims.ptr);
98         scope (exit)
99             H5Sclose(space);
100         auto dataset = H5Dcreate2(file, toStringz("/" ~ k), toH5Type!T, space,
101                 H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);
102         scope (exit)
103             H5Dclose(dataset);
104         H5Dwrite(dataset, toH5Type!T, H5S_ALL, H5S_ALL, H5P_DEFAULT, cast(void*) data
105                 .ptr);
106     }
107 
108     iterVariables!((k, v) { register(k, v); })(&chain, "");
109 }
110 
111 /// load chain parameters from HDF5 path
112 void load(C)(ref C chain, string path) {
113     import std..string : replace, endsWith;
114     import mir.ndslice : slice, sliced;
115 
116     import grain.utility : castArray;
117 
118     auto file = H5Fopen(path.toStringz, // path.exists ? H5F_ACC_RDWR :
119             H5F_ACC_RDONLY, //
120             // H5F_ACC_RDWR,
121             H5P_DEFAULT);
122     scope (exit)
123         H5Fclose(file);
124 
125     void register(T, size_t dim, alias Storage)(string k, ref Variable!(T, dim, Storage) v) {
126         auto dataset = H5Dopen2(file, toStringz("/" ~ k), H5P_DEFAULT);
127         scope (exit)
128             H5Dclose(dataset);
129 
130         auto raw = new T[v.data.length];
131         H5Dread(dataset, toH5Type!T, H5S_ALL, H5S_ALL, H5P_DEFAULT, cast(void*)&raw[
132                 0]);
133 
134         auto src = raw.sliced(v.shape.castArray!size_t).variable;
135         static if (is(Storage!T == HostStorage!T)) {
136             v.sliced[] = src.sliced;
137         }
138         else {
139             import grain.cudnn : transform;
140 
141             transform(src.to!Storage, v);
142         }
143     }
144 
145     refIterVariables!((k, ref v) { register(k, v); })(chain, "");
146 }
147 
148 ///
149 unittest {
150     auto model1 = MLP!(float, HostStorage)(3);
151     model1.save("/tmp/test_grain0.h5");
152 
153     auto model2 = MLP!(float, HostStorage)(3);
154     model2.load("/tmp/test_grain0.h5");
155     assert(model1.fc1.bias.sliced == model2.fc1.bias.sliced);
156 
157     import numir;
158     import mir.ndslice;
159 
160     auto x = uniform!float(3, 2).slice.variable;
161     assert(model1(x).sliced == model2(x).sliced);
162 }
163 
164 ///
165 version (grain_cuda) unittest {
166     auto model1 = MLP!(float, DeviceStorage)(3);
167     model1.save("/tmp/test_grain1.h5");
168 
169     auto model2 = MLP!(float, DeviceStorage)(3);
170     model2.load("/tmp/test_grain1.h5");
171     assert(model1.fc1.bias.to!HostStorage.sliced == model2.fc1.bias.to!HostStorage
172             .sliced);
173 
174     import numir;
175     import mir.ndslice;
176 
177     auto x = uniform!float(3, 2).slice.variable.to!DeviceStorage;
178     assert(model1(x).to!HostStorage.sliced == model2(x).to!HostStorage.sliced);
179 }
180 
181 ///
182 version (grain_cuda) unittest {
183     auto model1 = MLP!(float, HostStorage)(3);
184     model1.save("/tmp/test_grain2.h5");
185 
186     auto model2 = MLP!(float, DeviceStorage)(3);
187     model2.load("/tmp/test_grain2.h5");
188     assert(model1.fc1.bias.to!HostStorage.sliced == model2.fc1.bias.to!HostStorage
189             .sliced);
190 }
191 
192 ///
193 version (grain_cuda) unittest {
194     auto model1 = MLP!(float, DeviceStorage)(3);
195     model1.save("/tmp/test_grain3.h5");
196 
197     auto model2 = MLP!(float, HostStorage)(3);
198     model2.load("/tmp/test_grain3.h5");
199     assert(model1.fc1.bias.to!HostStorage.sliced == model2.fc1.bias.to!HostStorage
200             .sliced);
201 }