Sha256: caca8dbe978a458c884b787756abc48e4c0145114162abda1ee1194982e94c1a
Contents?: true
Size: 512 Bytes
Versions: 1
Compression:
Stored size: 512 Bytes
Contents
module TensorFlow module Keras module Layers class Flatten def initialize(input_shape: nil) @input_shape = input_shape end def output_shape flattened_dim = @input_shape.inject(&:*) [-1, flattened_dim] end def count_params 0 end def call(inputs) flattened_dim = inputs.shape[1..-1].inject(&:*) TensorFlow.reshape(inputs, [-1, flattened_dim]) end end end end end
Version data entries
1 entries across 1 versions & 1 rubygems
Version | Path |
---|---|
tensorflow-0.2.0 | lib/tensorflow/keras/layers/flatten.rb |