Compute negative log-likelihood: -logP(y=t)
store grain.autograd.BackProp object in returned variables from forward function
type-erased version of backward function used in grain.autograd.BackProp object
test nll simple case, gradcheck and cpu/cuda equality
/++ equivalent torch v0.4 code >>> x = torch.FloatTensor([[0.2, 0.4, 0.4], [0.1,0.5,0.4]]) >>> x.requires_grad = True >>> t = torch.LongTensor([1, 0]) >>> l = torch.nn.functional.nll_loss(x, t) >>> print(l) tensor(-0.2500) >>> l.backward() >>> print(x.grad) tensor([[0.0, -0.5, 0.0], [-0.5, 0.0, 0.0]]) +/ import std.typecons; import grain.testing; NegativeLogLikelihood!(float, int) func; auto hx = [[0.2f, 0.4f, 0.4f], [0.1f, 0.5f, 0.4f], [0.1f, 0.5f, 0.4f]] .variable; auto ht = [1, 0, func.ignoreIndex].variable; auto hl = func.forward(hx, ht); assert(func._normalize == 0.5); assert(hl.sliced == [-(0.4f + 0.1f + 0.0f) / 2]); auto hgx = func.backward(1.0f.variable); assert(hgx[0].sliced == [[0.0, -0.5, 0.0], [-0.5, 0.0, 0.0], [0.0, 0.0, 0.0]]); assert(!hgx[1].defined); gradCheck(func, tuple(hx, ht), 1.0f.variable); version (grain_cuda) { auto dx = hx.to!DeviceStorage; auto dt = ht.to!DeviceStorage; auto dl = func.forward(dx, dt); assert(func._normalize == 0.5); assert(dl.to!HostStorage.sliced == [-(0.4f + 0.1f + 0.0f) / 2]); auto dgx = func.backward(1.0f.variable.to!DeviceStorage); assert(dgx[0].to!HostStorage.sliced == [[0.0, -0.5, 0.0], [-0.5, 0.0, 0.0], [0.0, 0.0, 0.0]]); assert(!dgx[1].defined); }
test variable.backward
import std.typecons; import grain.testing; import mir.ndslice; static import grain.config; grain.config.backprop = true; NegativeLogLikelihood!(float, int) func; auto hx = [[0.2f, 0.4f, 0.4f], [0.1f, 0.5f, 0.4f], [0.1f, 0.5f, 0.4f]] .variable; hx.requiresGrad = true; auto ht = [1, 0, func.ignoreIndex].variable; auto hl = func.applyForward(hx, ht); assert(func._normalize == 0.5); assert(hl.sliced == [-(0.4f + 0.1f + 0.0f) / 2]); auto u = UntypedVariable(1.0f.variable); hl.backward(&u); assert(hx.grad[].sliced(3, 3) == [[0.0, -0.5, 0.0], [-0.5, 0.0, 0.0], [0.0, 0.0, 0.0]]); // TODO assert(!ht.grad.defined);