lib/tensorflow/keras/layers/flatten.rb in tensorflow-0.1.2 vs lib/tensorflow/keras/layers/flatten.rb in tensorflow-0.2.0
- old
+ new
@@ -1,9 +1,24 @@
module TensorFlow
module Keras
module Layers
class Flatten
def initialize(input_shape: nil)
+ @input_shape = input_shape
+ end
+
+ def output_shape
+ flattened_dim = @input_shape.inject(&:*)
+ [-1, flattened_dim]
+ end
+
+ def count_params
+ 0
+ end
+
+ def call(inputs)
+ flattened_dim = inputs.shape[1..-1].inject(&:*)
+ TensorFlow.reshape(inputs, [-1, flattened_dim])
end
end
end
end
end