lib/dnn/core/embedding.rb in ruby-dnn-0.13.4 vs lib/dnn/core/embedding.rb in ruby-dnn-0.14.0
- old
+ new
@@ -19,13 +19,13 @@
@input_length = input_length
@weight_initializer = weight_initializer
@weight_regularizer = weight_regularizer
end
- def call(input)
+ def call(input_tensor)
build unless built?
- [forward(input), Link.new(nil, self)]
+ Tensor.new(forward(input_tensor.data), Link.new(nil, self))
end
def build
@built = true
@weight = Param.new(Xumo::SFloat.new(@input_length), Xumo::SFloat[0])
@@ -52,9 +52,27 @@
nil
end
def regularizers
@weight_regularizer ? [@weight_regularizer] : []
+ end
+
+ def to_proc
+ method(:call).to_proc
+ end
+
+ def >>(layer)
+ if RUBY_VERSION < "2.6.0"
+ raise DNN_Error, "Function composition is not supported before ruby version 2.6.0."
+ end
+ to_proc >> layer
+ end
+
+ def <<(layer)
+ if RUBY_VERSION < "2.6.0"
+ raise DNN_Error, "Function composition is not supported before ruby version 2.6.0."
+ end
+ to_proc << layer
end
def to_hash
super(input_shape: @input_shape, input_length: @input_length,
weight_initializer: @weight_initializer.to_hash, weight_regularizer: @weight_regularizer&.to_hash)