/** Metric (e.g., accuracy) TODO: perplexity, AUC, F1, BLEU, edit distance */ module grain.metric; import grain.autograd : isVariable; /// compute accuracy comparing prediction y (histgram) to target t (id) auto accuracy(Vy, Vt)(Vy y, Vt t) if (isVariable!Vy && isVariable!Vt) { import mir.ndslice : maxIndex; import grain.autograd : to, HostStorage; auto nbatch = t.shape[0]; auto hy = y.to!HostStorage.sliced; auto ht = t.to!HostStorage.sliced; double acc = 0.0; foreach (i; 0 .. nbatch) { auto maxid = hy[i].maxIndex[0]; if (maxid == ht[i]) { ++acc; } } return acc / nbatch; }