1 /// CIFAR10/100 image recognition datasets
2 module grain.dataset.cifar;
3 
4 import std.exception : enforce;
5 
6 ///
7 struct Dataset {
8     import mir.ndslice : Slice, Contiguous;
9 
10     /// shape is [data, rgb, width, height]
11     Slice!(float*, 4) inputs;
12     Slice!(int*, 1) targets;
13     Slice!(int*, 1) coarses;
14 
15     ///
16     this(string[] paths) {
17         import std.algorithm : sum, map, canFind;
18         import mir.ndslice :  sliced, flattened, universal;
19         import numir : empty;
20         import std.stdio : File;
21 
22         immutable isCIFAR100 = paths[0].canFind("cifar-100");
23         immutable dataSize = isCIFAR100 ? 3074 : 3073;
24         auto fileSize = paths.map!(p => File(p, "rb").size).sum;
25         enforce(fileSize % dataSize == 0);
26         auto numData = fileSize / dataSize;
27         this.inputs = empty!float(numData, 3, 32, 32);
28         this.targets = empty!int(numData);
29         if (isCIFAR100) {
30             this.coarses = empty!int(numData);
31         }
32 
33         size_t i = 0;
34         immutable imageOffset = isCIFAR100 ? 2 : 1;
35         foreach (p; paths) {
36             foreach (chunk; File(p).byChunk(dataSize)) {
37                 if (isCIFAR100) {
38                     this.coarses[i] = cast(int) chunk[0];
39                     this.targets[i] = cast(int) chunk[1];
40                 } else {
41                     this.targets[i] = cast(int) chunk[0];
42                 }
43 
44                 size_t n = 0;
45                 foreach (ref inp; this.inputs[i].flattened) {
46                     inp = cast(float) chunk[imageOffset + n];
47                     ++n;
48                 }
49                 ++i;
50             }
51         }
52     }
53 
54     ///
55     auto makeBatch(size_t batchSize) {
56         import numir : view;
57         import std.typecons : tuple;
58         auto niter = this.inputs.length!0 / batchSize; // omit last
59         auto ndata = batchSize * niter;
60         return tuple!("niter", "inputs", "targets")(
61             niter,
62             this.inputs[0..ndata].view(niter, batchSize, 3, 32, 32),
63             this.targets[0..ndata].view(niter, batchSize)
64             );
65     }
66 }
67 
68 ///
69 auto prepareDataset(string dataset, string dir = "data") {
70     import std.stdio : writeln;
71     import std.typecons : tuple;
72     import std.array : array;
73     import std.format : format;
74     import std.algorithm : filter, canFind, map;
75     import std..string : replace, split;
76     import std.path : baseName, extension, dirName;
77     import std.file : exists, read, mkdirRecurse, dirEntries, SpanMode, readText;
78     import std.net.curl : download, HTTP;
79     import std.range : chunks;
80     import std.process : executeShell;
81 
82     immutable url = "https://www.cs.toronto.edu/~kriz/%s-binary.tar.gz".format(dataset);
83     immutable root = dir ~ "/" ~ url.baseName.replace(".tar.gz", "");
84     if (!root.exists) {
85         immutable gz = dir ~ "/" ~ url.baseName;
86         if (!gz.exists) {
87             writeln("downloading ", url);
88             auto conn = HTTP(url);
89             download(url, gz, conn);
90             auto code = conn.statusLine().code;
91             // FIXME: does not work?
92             enforce(code == 200, "status code: %s".format(code));
93         }
94         mkdirRecurse(root);
95         auto cmd = "tar -xvf " ~ gz ~ " -C " ~ root;
96         writeln("uncompressing ", cmd);
97         auto ret = executeShell(cmd);
98         enforce(ret.status == 0, ret.output);
99     }
100 
101     auto bins = root
102         .dirEntries("*", SpanMode.depth)
103         .filter!(a => a.name.extension == ".bin")
104         .map!"a.name".array;
105 
106     auto train = Dataset(bins.filter!(a => !a.canFind("test")).array);
107     auto test = Dataset(bins.filter!(a => a.canFind("test")).array);
108     auto meta = dataset == "cifar-10" ? "batches.meta.txt" : "fine_label_names.txt";
109     auto labels = readText(bins[0].dirName ~ "/" ~ meta).split;
110 
111     string[] coarseLabels;
112     if (dataset == "cifar-100") {
113         coarseLabels = readText(bins[0].dirName ~ "/" ~ "coarse_label_names.txt").split;
114     }
115     return tuple!("train", "test", "labels", "coarses")(train, test, labels, coarseLabels);
116 }