Sha256: 186d8bd445877d06db9f9055cf93c844a391c246a98d417481951172ea364979
Contents?: true
Size: 494 Bytes
Versions: 55
Compression:
Stored size: 494 Bytes
Contents
module Torch module NN class TripletMarginLoss < Loss def initialize(margin: 1.0, p: 2.0, eps: 1e-6, swap: false, reduction: "mean") super(reduction) @margin = margin @p = p @eps = eps @swap = swap end def forward(anchor, positive, negative) F.triplet_margin_loss(anchor, positive, negative, margin: @margin, p: @p, eps: @eps, swap: @swap, reduction: @reduction) end end end end
Version data entries
55 entries across 55 versions & 1 rubygems