Sha256: de5fb251c92be267c56c48142eb95bd3e8e18081b78c7e3ba50bffac6f21d28f
Contents?: true
Size: 1.91 KB
Versions: 2
Compression:
Stored size: 1.91 KB
Contents
module TensorStream class Session def initialize(evaluator = :ruby_evaluator, thread_pool_class: Concurrent::ImmediateExecutor) @evaluator_class = Object.const_get("TensorStream::Evaluator::#{camelize(evaluator.to_s)}") @thread_pool = thread_pool_class.new end def self.default_session @session ||= Session.new end def last_session_context @last_session_context end def run(*args) options = if args.last.is_a?(Hash) args.pop else {} end context = {} # scan for placeholders and assign value options[:feed_dict].keys.each do |k| if k.is_a?(Placeholder) context[k.name.to_sym] = options[:feed_dict][k] end end if options[:feed_dict] evaluator = @evaluator_class.new(self, context.merge!(retain: options[:retain]), thread_pool: @thread_pool) execution_context = {} result = args.collect { |e| evaluator.run(e, execution_context) } @last_session_context = context result.size == 1 ? result.first : result end def dump_internal_ops(tensor) dump_ops(tensor, ->(k, n) { n.is_a?(Tensor) && n.internal? } ) end def dump_user_ops(tensor) dump_ops(tensor, ->(k, n) { n.is_a?(Tensor) && !n.internal? } ) end def dump_ops(tensor, selector) graph = tensor.graph graph.nodes.select { |k,v| selector.call(k, v) }.collect do |k, node| next unless @last_session_context[node.name] "#{k} #{node.to_math(true, 1)} = #{@last_session_context[node.name]}" end.compact end private def camelize(string, uppercase_first_letter = true) if uppercase_first_letter string = string.sub(/^[a-z\d]*/) { $&.capitalize } else string = string.sub(/^(?:(?=\b|[A-Z_])|\w)/) { $&.downcase } end string.gsub(/(?:_|(\/))([a-z\d]*)/) { "#{$1}#{$2.capitalize}" }.gsub('/', '::') end end end
Version data entries
2 entries across 2 versions & 1 rubygems
Version | Path |
---|---|
tensor_stream-0.1.1 | lib/tensor_stream/session.rb |
tensor_stream-0.1.0 | lib/tensor_stream/session.rb |