Sha256: caca8dbe978a458c884b787756abc48e4c0145114162abda1ee1194982e94c1a

Contents?: true

Size: 512 Bytes

Versions: 1

Compression:

Stored size: 512 Bytes

Contents

module TensorFlow
  module Keras
    module Layers
      class Flatten
        def initialize(input_shape: nil)
          @input_shape = input_shape
        end

        def output_shape
          flattened_dim = @input_shape.inject(&:*)
          [-1, flattened_dim]
        end

        def count_params
          0
        end

        def call(inputs)
          flattened_dim = inputs.shape[1..-1].inject(&:*)
          TensorFlow.reshape(inputs, [-1, flattened_dim])
        end
      end
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
tensorflow-0.2.0 lib/tensorflow/keras/layers/flatten.rb