Sha256: 7bcd24c8269b399c9c8458c299b152b1336fa23d94ccbcbf16c176454cfd5755
Contents?: true
Size: 340 Bytes
Versions: 55
Compression:
Stored size: 340 Bytes
Contents
module Torch module NN class CosineEmbeddingLoss < Loss def initialize(margin: 0, reduction: "mean") super(reduction) @margin = margin end def forward(input1, input2, target) F.cosine_embedding_loss(input1, input2, target, margin: @margin, reduction: @reduction) end end end end
Version data entries
55 entries across 55 versions & 1 rubygems