Sha256: 4771cb484895eab800f1da509c0305f8a6440f0ad27da314241d536e801072c1
Contents?: true
Size: 440 Bytes
Versions: 55
Compression:
Stored size: 440 Bytes
Contents
module Torch module NN class BCEWithLogitsLoss < Loss def initialize(weight: nil, reduction: "mean", pos_weight: nil) super(reduction) register_buffer("weight", weight) register_buffer("pos_weight", pos_weight) end def forward(input, target) F.binary_cross_entropy_with_logits(input, target, weight: weight, pos_weight: pos_weight, reduction: @reduction) end end end end
Version data entries
55 entries across 55 versions & 1 rubygems