1 /** 2 A module for serialization. HDF5 read/write is supported 3 */ 4 module grain.serializer; 5 6 import std.stdio; 7 import grain.autograd; 8 9 import grain.hdf5; 10 import std..string : toStringz; 11 12 13 /// enumerate the parameter names inside chain C 14 enum variableNames(C) = { 15 string[] ret; 16 void register(V)(string k, V v) if (isVariable!V) { 17 ret ~= [k]; 18 } 19 20 C chain; 21 iterVariables!((k, v) { register(k, v); })(&chain, ""); 22 return ret; 23 }(); 24 25 /// 26 unittest { 27 import std.traits; 28 29 auto mlp = MLP!(float, HostStorage)(3); 30 static assert(variableNames!(typeof(mlp)) == [".fc1.weight", ".fc1.bias", 31 ".fc2.weight", ".fc2.bias", ".fc3.weight", ".fc3.bias"]); 32 } 33 34 // test .slice makes slice contiguous 35 unittest { 36 import numir; 37 import mir.ndslice; 38 39 auto i = iota(3, 4, 5).transposed(1); 40 assert(i.universal._strides == [5, 20, 1]); 41 assert(i.slice.universal._strides == [15, 5, 1]); 42 } 43 44 version (unittest) { 45 struct MLP(T, alias Storage) { 46 import grain.autograd : Variable; 47 import grain.chain : Linear, relu; 48 49 alias L = Linear!(T, Storage); 50 L fc1, fc2, fc3; 51 52 this(int nhidden) { 53 this.fc1 = L(2, nhidden); 54 this.fc2 = L(nhidden, nhidden); 55 this.fc3 = L(nhidden, 10); 56 } 57 58 auto opCall(Variable!(T, 2, Storage) x) { 59 auto h1 = relu(this.fc1(x)); 60 auto h2 = relu(this.fc2(h1)); 61 auto h3 = this.fc3(h2); 62 return h1; 63 } 64 } 65 } 66 67 /// convert D type into HDF5 type-id https://support.hdfgroup.org/HDF5/doc1.8/RM/PredefDTypes.html 68 auto toH5Type(T)() { 69 import std.traits; 70 import std.format; 71 72 static assert(isBasicType!T); 73 mixin("return H5T_%s%dLE;".format(isFloatingPoint!T ? "IEEE_F" 74 : (isSigned!T ? "STD_I" : "STD_U"), T.sizeof * 8)); 75 } 76 77 /// save chain parameters to HDF5 path 78 void save(bool verbose = true, C)(C chain, string path) { 79 import std.file : exists; 80 import std..string : replace, endsWith; 81 import mir.ndslice : slice; 82 import grain.utility : castArray; 83 84 auto file = H5Fcreate(path.toStringz, // path.exists ? H5F_ACC_TRUNC : 85 // H5F_ACC_RDWR, // 86 H5F_ACC_TRUNC, 87 H5P_DEFAULT, H5P_DEFAULT); 88 scope (exit) 89 H5Fclose(file); 90 91 void register(T, size_t dim, alias Storage)(string k, Variable!(T, dim, Storage) v) { 92 auto h = v.to!HostStorage; 93 // FIXME support check contiguous 94 // auto s = h.sliced.slice; 95 auto data = v.to!HostStorage.data; 96 auto dims = h.shape.castArray!hsize_t; 97 auto space = H5Screate_simple(cast(int) dims.length, dims.ptr, dims.ptr); 98 scope (exit) 99 H5Sclose(space); 100 auto dataset = H5Dcreate2(file, toStringz("/" ~ k), toH5Type!T, space, 101 H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); 102 scope (exit) 103 H5Dclose(dataset); 104 H5Dwrite(dataset, toH5Type!T, H5S_ALL, H5S_ALL, H5P_DEFAULT, cast(void*) data 105 .ptr); 106 } 107 108 iterVariables!((k, v) { register(k, v); })(&chain, ""); 109 } 110 111 /// load chain parameters from HDF5 path 112 void load(C)(ref C chain, string path) { 113 import std..string : replace, endsWith; 114 import mir.ndslice : slice, sliced; 115 116 import grain.utility : castArray; 117 118 auto file = H5Fopen(path.toStringz, // path.exists ? H5F_ACC_RDWR : 119 H5F_ACC_RDONLY, // 120 // H5F_ACC_RDWR, 121 H5P_DEFAULT); 122 scope (exit) 123 H5Fclose(file); 124 125 void register(T, size_t dim, alias Storage)(string k, ref Variable!(T, dim, Storage) v) { 126 auto dataset = H5Dopen2(file, toStringz("/" ~ k), H5P_DEFAULT); 127 scope (exit) 128 H5Dclose(dataset); 129 130 auto raw = new T[v.data.length]; 131 H5Dread(dataset, toH5Type!T, H5S_ALL, H5S_ALL, H5P_DEFAULT, cast(void*)&raw[ 132 0]); 133 134 auto src = raw.sliced(v.shape.castArray!size_t).variable; 135 static if (is(Storage!T == HostStorage!T)) { 136 v.sliced[] = src.sliced; 137 } 138 else { 139 import grain.cudnn : transform; 140 141 transform(src.to!Storage, v); 142 } 143 } 144 145 refIterVariables!((k, ref v) { register(k, v); })(chain, ""); 146 } 147 148 /// 149 unittest { 150 auto model1 = MLP!(float, HostStorage)(3); 151 model1.save("/tmp/test_grain0.h5"); 152 153 auto model2 = MLP!(float, HostStorage)(3); 154 model2.load("/tmp/test_grain0.h5"); 155 assert(model1.fc1.bias.sliced == model2.fc1.bias.sliced); 156 157 import numir; 158 import mir.ndslice; 159 160 auto x = uniform!float(3, 2).slice.variable; 161 assert(model1(x).sliced == model2(x).sliced); 162 } 163 164 /// 165 version (grain_cuda) unittest { 166 auto model1 = MLP!(float, DeviceStorage)(3); 167 model1.save("/tmp/test_grain1.h5"); 168 169 auto model2 = MLP!(float, DeviceStorage)(3); 170 model2.load("/tmp/test_grain1.h5"); 171 assert(model1.fc1.bias.to!HostStorage.sliced == model2.fc1.bias.to!HostStorage 172 .sliced); 173 174 import numir; 175 import mir.ndslice; 176 177 auto x = uniform!float(3, 2).slice.variable.to!DeviceStorage; 178 assert(model1(x).to!HostStorage.sliced == model2(x).to!HostStorage.sliced); 179 } 180 181 /// 182 version (grain_cuda) unittest { 183 auto model1 = MLP!(float, HostStorage)(3); 184 model1.save("/tmp/test_grain2.h5"); 185 186 auto model2 = MLP!(float, DeviceStorage)(3); 187 model2.load("/tmp/test_grain2.h5"); 188 assert(model1.fc1.bias.to!HostStorage.sliced == model2.fc1.bias.to!HostStorage 189 .sliced); 190 } 191 192 /// 193 version (grain_cuda) unittest { 194 auto model1 = MLP!(float, DeviceStorage)(3); 195 model1.save("/tmp/test_grain3.h5"); 196 197 auto model2 = MLP!(float, HostStorage)(3); 198 model2.load("/tmp/test_grain3.h5"); 199 assert(model1.fc1.bias.to!HostStorage.sliced == model2.fc1.bias.to!HostStorage 200 .sliced); 201 }