Sha256: 33a3546d43f310f07948f40113f2a7c3dd418eb37aee6a27141fcf6fa4a84d39
Contents?: true
Size: 326 Bytes
Versions: 55
Compression:
Stored size: 326 Bytes
Contents
module Torch module NN class MultiLabelSoftMarginLoss < WeightedLoss def initialize(weight: nil, reduction: "mean") super(weight, reduction) end def forward(input, target) F.multilabel_soft_margin_loss(input, target, weight: @weight, reduction: @reduction) end end end end
Version data entries
55 entries across 55 versions & 1 rubygems