Sha256: af93ce821118b1cf2f1297a4f4a7c0f17a51eb5f827b1f1f7e42134363897660

Contents?: true

Size: 1.66 KB

Versions: 1

Compression:

Stored size: 1.66 KB

Contents

module TensorFlow
  module Keras
    module Models
      class Sequential
        def initialize(layers = [])
          @layers = []

          layers.each do |layer|
            add(layer)
          end
        end

        def add(layer)
          @layers << layer
        end

        def compile(optimizer: nil, loss: nil, metrics: nil)
          raise "Not implemented"
        end

        def fit(x, y, epochs: nil)
          raise "Not implemented"
        end

        def evaluate(x, y)
          raise "Not implemented"
        end

        def summary
          sep = "_________________________________________________________________\n"

          output_shape = nil
          @layers.each do |layer|
            layer.build(output_shape) if layer.respond_to?(:build)
            output_shape = layer.output_shape
          end

          total_params = @layers.map(&:count_params).sum

          summary = String.new("")
          summary << "Model: \"sequential\"\n"
          summary << sep
          summary << "Layer (type)                 Output Shape              Param #   \n"
          summary << "=================================================================\n"
          summary << @layers.map { |l| "%-28s %-25s %-10s\n" % [l.class.name.split("::").last, ([nil] + l.output_shape[1..-1]).inspect, l.count_params] }.join(sep)
          summary << "=================================================================\n"
          summary << "Total params: #{total_params}\n"
          summary << "Trainable params: #{total_params}\n"
          summary << "Non-trainable params: 0\n"
          summary << sep
          puts summary
        end
      end
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
tensorflow-0.2.0 lib/tensorflow/keras/models/sequential.rb