1 /// CIFAR10/100 image recognition datasets 2 module grain.dataset.cifar; 3 4 import std.exception : enforce; 5 6 /// 7 struct Dataset { 8 import mir.ndslice : Slice, Contiguous; 9 10 /// shape is [data, rgb, width, height] 11 Slice!(float*, 4) inputs; 12 Slice!(int*, 1) targets; 13 Slice!(int*, 1) coarses; 14 15 /// 16 this(string[] paths) { 17 import std.algorithm : sum, map, canFind; 18 import mir.ndslice : sliced, flattened, universal; 19 import numir : empty; 20 import std.stdio : File; 21 22 immutable isCIFAR100 = paths[0].canFind("cifar-100"); 23 immutable dataSize = isCIFAR100 ? 3074 : 3073; 24 auto fileSize = paths.map!(p => File(p, "rb").size).sum; 25 enforce(fileSize % dataSize == 0); 26 auto numData = fileSize / dataSize; 27 this.inputs = empty!float(numData, 3, 32, 32); 28 this.targets = empty!int(numData); 29 if (isCIFAR100) { 30 this.coarses = empty!int(numData); 31 } 32 33 size_t i = 0; 34 immutable imageOffset = isCIFAR100 ? 2 : 1; 35 foreach (p; paths) { 36 foreach (chunk; File(p).byChunk(dataSize)) { 37 if (isCIFAR100) { 38 this.coarses[i] = cast(int) chunk[0]; 39 this.targets[i] = cast(int) chunk[1]; 40 } else { 41 this.targets[i] = cast(int) chunk[0]; 42 } 43 44 size_t n = 0; 45 foreach (ref inp; this.inputs[i].flattened) { 46 inp = cast(float) chunk[imageOffset + n]; 47 ++n; 48 } 49 ++i; 50 } 51 } 52 } 53 54 /// 55 auto makeBatch(size_t batchSize) { 56 import numir : view; 57 import std.typecons : tuple; 58 auto niter = this.inputs.length!0 / batchSize; // omit last 59 auto ndata = batchSize * niter; 60 return tuple!("niter", "inputs", "targets")( 61 niter, 62 this.inputs[0..ndata].view(niter, batchSize, 3, 32, 32), 63 this.targets[0..ndata].view(niter, batchSize) 64 ); 65 } 66 } 67 68 /// 69 auto prepareDataset(string dataset, string dir = "data") { 70 import std.stdio : writeln; 71 import std.typecons : tuple; 72 import std.array : array; 73 import std.format : format; 74 import std.algorithm : filter, canFind, map; 75 import std..string : replace, split; 76 import std.path : baseName, extension, dirName; 77 import std.file : exists, read, mkdirRecurse, dirEntries, SpanMode, readText; 78 import std.net.curl : download, HTTP; 79 import std.range : chunks; 80 import std.process : executeShell; 81 82 immutable url = "https://www.cs.toronto.edu/~kriz/%s-binary.tar.gz".format(dataset); 83 immutable root = dir ~ "/" ~ url.baseName.replace(".tar.gz", ""); 84 if (!root.exists) { 85 immutable gz = dir ~ "/" ~ url.baseName; 86 if (!gz.exists) { 87 writeln("downloading ", url); 88 auto conn = HTTP(url); 89 download(url, gz, conn); 90 auto code = conn.statusLine().code; 91 // FIXME: does not work? 92 enforce(code == 200, "status code: %s".format(code)); 93 } 94 mkdirRecurse(root); 95 auto cmd = "tar -xvf " ~ gz ~ " -C " ~ root; 96 writeln("uncompressing ", cmd); 97 auto ret = executeShell(cmd); 98 enforce(ret.status == 0, ret.output); 99 } 100 101 auto bins = root 102 .dirEntries("*", SpanMode.depth) 103 .filter!(a => a.name.extension == ".bin") 104 .map!"a.name".array; 105 106 auto train = Dataset(bins.filter!(a => !a.canFind("test")).array); 107 auto test = Dataset(bins.filter!(a => a.canFind("test")).array); 108 auto meta = dataset == "cifar-10" ? "batches.meta.txt" : "fine_label_names.txt"; 109 auto labels = readText(bins[0].dirName ~ "/" ~ meta).split; 110 111 string[] coarseLabels; 112 if (dataset == "cifar-100") { 113 coarseLabels = readText(bins[0].dirName ~ "/" ~ "coarse_label_names.txt").split; 114 } 115 return tuple!("train", "test", "labels", "coarses")(train, test, labels, coarseLabels); 116 }