1 /// PTB language modeling dataset
2 module grain.dataset.ptb;
3 
4 import std.array;
5 import std.algorithm : splitter;
6 import std.stdio : File;
7 import mir.ndslice;
8 import mir.random.variable : UniformVariable;
9 
10 import numir;
11 
12 ///
13 auto prepareDataset() {
14     import std.format;
15     import std.file : exists, read, mkdir;
16     import std.net.curl : download;
17 
18     if (!exists("data")) {
19         mkdir("data");
20     }
21     enum root = "https://github.com/tomsercu/lstm/raw/master/";
22     foreach (f; ["train", "valid", "test"]) {
23         auto dst = format!"data/ptb.%s.txt"(f);
24         if (!exists(dst))download(root ~ dst, dst);
25     }
26 
27     Corpus corpus;
28     foreach (name; ["train", "valid", "test"]) {
29         corpus.register("data/ptb." ~ name ~ ".txt", name);
30     }
31     return corpus;
32 }
33 
34 ///
35 struct Dictionary {
36     enum eos = "<eos>";
37     enum eosId = 0;
38     string[] idx2word;
39     int[string] word2idx;
40 
41     void register(string word) {
42         assert(int.max > this.idx2word.length);
43         if (this.idx2word.empty) { // init
44             this.idx2word = [eos];
45             this.word2idx[eos] = 0;
46         }
47         if (word !in this.word2idx) {
48             this.word2idx[word] = cast(int) this.idx2word.length;
49             this.idx2word ~= word;
50         }
51     }
52 }
53 
54 ///
55 struct Corpus {
56     Dictionary dict;
57     int[][string] dataset;
58     size_t batchSize = 20;
59 
60     void register(string path, string name) {
61         import std..string : strip;
62         int[] data;
63         foreach (line; File(path).byLine) {
64             foreach (word; line.strip.splitter(' ')) {
65                 this.dict.register(word.idup);
66                 data ~= this.dict.word2idx[word];
67             }
68             data ~= Dictionary.eosId;
69         }
70         dataset[name] = data;
71     }
72 
73     /// returns word-id 2d slice shaped (seqlen, batchsize)
74     auto batchfy(string name) {
75         import numir;
76         auto data = this.dataset[name];
77         const len = data.length / this.batchSize;
78         return data[0 .. len * this.batchSize].sliced.view(this.batchSize, len).transposed.slice;
79     }
80 
81     auto vocabSize() {
82         return cast(int) this.dict.idx2word.length;
83     }
84 }