lib/tensorflow/keras/layers/flatten.rb in tensorflow-0.1.2 vs lib/tensorflow/keras/layers/flatten.rb in tensorflow-0.2.0

- old
+ new

@@ -1,9 +1,24 @@ module TensorFlow module Keras module Layers class Flatten def initialize(input_shape: nil) + @input_shape = input_shape + end + + def output_shape + flattened_dim = @input_shape.inject(&:*) + [-1, flattened_dim] + end + + def count_params + 0 + end + + def call(inputs) + flattened_dim = inputs.shape[1..-1].inject(&:*) + TensorFlow.reshape(inputs, [-1, flattened_dim]) end end end end end