lib/dnn/core/layers/embedding.rb in ruby-dnn-1.0.0 vs lib/dnn/core/layers/embedding.rb in ruby-dnn-1.1.0
- old
+ new
@@ -6,24 +6,27 @@
attr_reader :input_length
attr_reader :weight
attr_reader :weight_initializer
attr_reader :weight_regularizer
+ attr_reader :mask_zero
# @param [Integer | Array] input_dim_or_shape Set input data dimension or shape.
# @param [Integer] input_length Set the time series length of input data.
# @param [DNN::Initializers::Initializer] weight_initializer Weight initializer.
# @param [DNN::Regularizers::Regularizer | NilClass] weight_regularizer Weight regularizer.
def initialize(input_dim_or_shape, input_length,
weight_initializer: Initializers::RandomUniform.new,
- weight_regularizer: nil)
+ weight_regularizer: nil,
+ mask_zero: false)
super()
@input_shape = input_dim_or_shape.is_a?(Array) ? input_dim_or_shape : [input_dim_or_shape]
@input_length = input_length
@weight_initializer = weight_initializer
@weight_regularizer = weight_regularizer
@weight = Param.new(nil, Xumo::SFloat[0])
+ @mask_zero = mask_zero
end
def build(input_shape)
super(@input_shape)
@weight.data = Xumo::SFloat.new(@input_length)
@@ -33,20 +36,32 @@
def forward_node(x)
@x = x
y = Xumo::SFloat.zeros(*x.shape)
x.shape[0].times do |i|
- y[i, false] = @weight.data[x[i, false]]
+ if @mask_zero
+ x.shape[1].times do |j|
+ index = x[i, j]
+ y[i, j] = index == 0 ? 0 : @weight.data[index]
+ end
+ else
+ y[i, false] = @weight.data[x[i, false]]
+ end
end
y
end
def backward_node(dy)
@weight.grad += Xumo::SFloat.zeros(*@weight.data.shape)
@x.shape[0].times do |i|
@x.shape[1].times do |j|
- @weight.grad[@x[i, j]] += dy[i, j]
+ index = @x[i, j]
+ if @mask_zero
+ @weight.grad[index] += dy[i, j] unless index == 0
+ else
+ @weight.grad[index] += dy[i, j]
+ end
end
end
nil
end
@@ -54,16 +69,18 @@
@weight_regularizer ? [@weight_regularizer] : []
end
def to_hash
super(input_shape: @input_shape, input_length: @input_length,
- weight_initializer: @weight_initializer.to_hash, weight_regularizer: @weight_regularizer&.to_hash)
+ weight_initializer: @weight_initializer.to_hash, weight_regularizer: @weight_regularizer&.to_hash,
+ mask_zero: @mask_zero)
end
def load_hash(hash)
initialize(hash[:input_shape], hash[:input_length],
weight_initializer: Initializers::Initializer.from_hash(hash[:weight_initializer]),
- weight_regularizer: Regularizers::Regularizer.from_hash(hash[:weight_regularizer]))
+ weight_regularizer: Regularizers::Regularizer.from_hash(hash[:weight_regularizer]),
+ mask_zero: hash[:mask_zero])
end
def get_params
{ weight: @weight }
end