Sha256: 56688830fe88e90dccccb0eddd89b0fc6fc2030c41678f387399b590b0f1935d
Contents?: true
Size: 376 Bytes
Versions: 56
Compression:
Stored size: 376 Bytes
Contents
module Torch module NN class NLLLoss < WeightedLoss def initialize(weight: nil, ignore_index: -100, reduction: "mean") super(weight, reduction) @ignore_index = ignore_index end def forward(input, target) F.nll_loss(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction) end end end end
Version data entries
56 entries across 56 versions & 1 rubygems