Sha256: e54e2dde329a35daf6bab8ce16a8f1588e5040f265734beb4e69aa2f8281bbd4

Contents?: true

Size: 1.02 KB

Versions: 9

Compression:

Stored size: 1.02 KB

Contents

require "tensor_stream/train/slot_creator"
require "tensor_stream/train/optimizer"
require "tensor_stream/train/gradient_descent_optimizer"
require "tensor_stream/train/momentum_optimizer"
require "tensor_stream/train/adam_optimizer"
require "tensor_stream/train/adadelta_optimizer"
require "tensor_stream/train/adagrad_optimizer"
require "tensor_stream/train/rmsprop_optimizer"
require "tensor_stream/train/saver"
require "tensor_stream/train/learning_rate_decay"

module TensorStream
  module Trainer
    extend TensorStream::Train::Utils
    extend TensorStream::Train::LearningRateDecay
    extend TensorStream::StringHelper

    def self.write_graph(graph, path, filename, as_text: true, serializer: :yaml)
      raise "only supports as_text=true for now" unless as_text

      serializer = constantize("TensorStream::#{camelize(serializer.to_s)}") if serializer.is_a?(Symbol)

      new_filename = File.join(path, filename)
      serializer.new.get_string(graph).tap do |str|
        File.write(new_filename, str)
      end
    end
  end
end

Version data entries

9 entries across 9 versions & 1 rubygems

Version Path
tensor_stream-1.0.9 lib/tensor_stream/trainer.rb
tensor_stream-1.0.8 lib/tensor_stream/trainer.rb
tensor_stream-1.0.7 lib/tensor_stream/trainer.rb
tensor_stream-1.0.6 lib/tensor_stream/trainer.rb
tensor_stream-1.0.5 lib/tensor_stream/trainer.rb
tensor_stream-1.0.4 lib/tensor_stream/trainer.rb
tensor_stream-1.0.3 lib/tensor_stream/trainer.rb
tensor_stream-1.0.2 lib/tensor_stream/trainer.rb
tensor_stream-1.0.1 lib/tensor_stream/trainer.rb