1 /// MNIST hand-written digit recognition dataset
2 module grain.dataset.mnist;
3 
4 import std.typecons : tuple;
5 import mir.ndslice;
6 
7 ///
8 enum files = [
9     "train-images-idx3-ubyte",
10     "train-labels-idx1-ubyte",
11     "t10k-images-idx3-ubyte",
12     "t10k-labels-idx1-ubyte"
13     ];
14 
15 ///
16 struct Dataset {
17     Slice!(float*, 3) inputs;
18     Slice!(int*, 1) targets;
19 }
20 
21 ///
22 auto prepareDataset() {
23     import std.stdio : writeln;
24     import std.file : exists, read, mkdir;
25     import std.algorithm : canFind;
26     import std.zlib : UnCompress;
27     import std.net.curl : download;
28 
29     // Dataset train, test;
30     auto train = new Dataset;
31     auto test = new Dataset;
32     if (!exists("data")) {
33         mkdir("data");
34     }
35     foreach (f; files) {
36         auto gz = "data/" ~ f ~ ".gz";
37         writeln("loading " ~ gz);
38         if (!exists(gz)) {
39             auto url = "http://yann.lecun.com/exdb/mnist/" ~ f ~ ".gz";
40             download(url, gz);
41         }
42         auto unc = new UnCompress;
43         auto decomp = cast(ubyte[]) unc.uncompress(gz.read);
44         auto dataset = f.canFind("train") ? train : test;
45         if (f.canFind("images")) {
46             // skip header
47             decomp = decomp[16..$];
48             auto ndata = decomp.length / (28 * 28);
49             auto imgs = decomp.sliced(ndata, 28, 28);
50             // normalize 0 .. 255 to 0.0 .. 1.0
51             dataset.inputs = imgs.map!(i => 1.0f * i / 255).slice;
52         } else { // labels
53             decomp = decomp[8..$];
54             dataset.targets = decomp.sliced.as!int.slice;
55         }
56     }
57     return tuple!("train", "test")(train, test);
58 }
59 
60 ///
61 auto makeBatch(Dataset* d, size_t batchSize) {
62     import numir.core : view;
63     auto niter = d.inputs.shape[0] / batchSize; // omit last
64     auto inSize = d.inputs[0].view(-1).length!0;
65     return tuple!("niter", "inputs", "targets")(
66         niter,
67         d.inputs.view(-1, inSize)[0..$ - ($ % batchSize)].view(-1, batchSize, inSize),
68         d.targets[0..$ - ($ % batchSize)].view(-1, batchSize)
69         );
70 }