1 /// MNIST hand-written digit recognition dataset 2 module grain.dataset.mnist; 3 4 import std.typecons : tuple; 5 import mir.ndslice; 6 7 /// 8 enum files = [ 9 "train-images-idx3-ubyte", 10 "train-labels-idx1-ubyte", 11 "t10k-images-idx3-ubyte", 12 "t10k-labels-idx1-ubyte" 13 ]; 14 15 /// 16 struct Dataset { 17 Slice!(float*, 3) inputs; 18 Slice!(int*, 1) targets; 19 } 20 21 /// 22 auto prepareDataset() { 23 import std.stdio : writeln; 24 import std.file : exists, read, mkdir; 25 import std.algorithm : canFind; 26 import std.zlib : UnCompress; 27 import std.net.curl : download; 28 29 // Dataset train, test; 30 auto train = new Dataset; 31 auto test = new Dataset; 32 if (!exists("data")) { 33 mkdir("data"); 34 } 35 foreach (f; files) { 36 auto gz = "data/" ~ f ~ ".gz"; 37 writeln("loading " ~ gz); 38 if (!exists(gz)) { 39 auto url = "http://yann.lecun.com/exdb/mnist/" ~ f ~ ".gz"; 40 download(url, gz); 41 } 42 auto unc = new UnCompress; 43 auto decomp = cast(ubyte[]) unc.uncompress(gz.read); 44 auto dataset = f.canFind("train") ? train : test; 45 if (f.canFind("images")) { 46 // skip header 47 decomp = decomp[16..$]; 48 auto ndata = decomp.length / (28 * 28); 49 auto imgs = decomp.sliced(ndata, 28, 28); 50 // normalize 0 .. 255 to 0.0 .. 1.0 51 dataset.inputs = imgs.map!(i => 1.0f * i / 255).slice; 52 } else { // labels 53 decomp = decomp[8..$]; 54 dataset.targets = decomp.sliced.as!int.slice; 55 } 56 } 57 return tuple!("train", "test")(train, test); 58 } 59 60 /// 61 auto makeBatch(Dataset* d, size_t batchSize) { 62 import numir.core : view; 63 auto niter = d.inputs.shape[0] / batchSize; // omit last 64 auto inSize = d.inputs[0].view(-1).length!0; 65 return tuple!("niter", "inputs", "targets")( 66 niter, 67 d.inputs.view(-1, inSize)[0..$ - ($ % batchSize)].view(-1, batchSize, inSize), 68 d.targets[0..$ - ($ % batchSize)].view(-1, batchSize) 69 ); 70 }