Sha256: a4f645309e4c7ac9c50f2ffc6a3a57ae89f41f03781489204fd9e6faac27ee43
Contents?: true
Size: 390 Bytes
Versions: 56
Compression:
Stored size: 390 Bytes
Contents
module Torch module NN class CrossEntropyLoss < WeightedLoss def initialize(weight: nil, ignore_index: -100, reduction: "mean") super(weight, reduction) @ignore_index = ignore_index end def forward(input, target) F.cross_entropy(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction) end end end end
Version data entries
56 entries across 56 versions & 1 rubygems