lib/dnn/core/layers/embedding.rb in ruby-dnn-0.16.2 vs lib/dnn/core/layers/embedding.rb in ruby-dnn-1.0.0
- old
+ new
@@ -1,9 +1,11 @@
module DNN
module Layers
class Embedding < TrainableLayer
+ include LayerNode
+
attr_reader :input_length
attr_reader :weight
attr_reader :weight_initializer
attr_reader :weight_regularizer
@@ -20,17 +22,12 @@
@weight_initializer = weight_initializer
@weight_regularizer = weight_regularizer
@weight = Param.new(nil, Xumo::SFloat[0])
end
- def call(input_tensor)
- build(@input_shape) unless built?
- Tensor.new(forward_node(input_tensor.data), Link.new(nil, self))
- end
-
def build(input_shape)
- @built = true
+ super(@input_shape)
@weight.data = Xumo::SFloat.new(@input_length)
@weight_initializer.init_param(self, @weight)
@weight_regularizer.param = @weight if @weight_regularizer
end
@@ -53,27 +50,9 @@
nil
end
def regularizers
@weight_regularizer ? [@weight_regularizer] : []
- end
-
- def to_proc
- method(:call).to_proc
- end
-
- def >>(layer)
- if RUBY_VERSION < "2.6.0"
- raise DNN_Error, "Function composition is not supported before ruby version 2.6.0."
- end
- to_proc >> layer
- end
-
- def <<(layer)
- if RUBY_VERSION < "2.6.0"
- raise DNN_Error, "Function composition is not supported before ruby version 2.6.0."
- end
- to_proc << layer
end
def to_hash
super(input_shape: @input_shape, input_length: @input_length,
weight_initializer: @weight_initializer.to_hash, weight_regularizer: @weight_regularizer&.to_hash)