Sha256: db45eee4c425567dc4dbb980cf76eeec47a7ad1f72794ff8d786fd0b1c58cb20

Contents?: true

Size: 1.39 KB

Versions: 15

Compression:

Stored size: 1.39 KB

Contents

require 'rbbt/vector/model'
require 'rbbt/tensorflow'

class TensorFlowModel < VectorModel
  attr_accessor :graph, :epochs, :compile_options

  def tensorflow(&block)
    RbbtPython.run "tensorflow" do 
      RbbtPython.module_eval(&block)
    end
  end

  def keras(&block)
    RbbtPython.run "tensorflow.keras", as: 'keras' do 
      RbbtPython.run "tensorflow" do 
        RbbtPython.module_eval(&block)
      end
    end
  end
  
  def initialize(dir, graph = nil, epochs = 3, **compile_options)
    @graph = graph
    @epochs = epochs
    @compile_options = compile_options

    super(dir)

    @train_model = Proc.new do |file, features, labels|
      tensorflow do 
        features = tensorflow.convert_to_tensor(features)
        labels = tensorflow.convert_to_tensor(labels)
      end
      @graph ||= keras_graph
      @graph.compile(**@compile_options)
      @graph.fit(features, labels, :epochs => @epochs, :verbose => true)
      @graph.save(file)
    end
 
    @eval_model = Proc.new do |file, features|
      tensorflow do 
        features = tensorflow.convert_to_tensor(features)
      end
      keras do
        @graph ||= keras.models.load_model(file)
        indices = @graph.predict(features, :verbose => false).tolist()
        labels = indices.collect{|p| p.length > 1 ? p.index(p.max): p.first }
        labels
      end
    end
  end

  def keras_graph(&block)
    @graph = keras(&block)
  end
end

Version data entries

15 entries across 15 versions & 1 rubygems

Version Path
rbbt-dm-1.2.7 lib/rbbt/vector/model/tensorflow.rb
rbbt-dm-1.2.6 lib/rbbt/vector/model/tensorflow.rb
rbbt-dm-1.2.4 lib/rbbt/vector/model/tensorflow.rb
rbbt-dm-1.2.3 lib/rbbt/vector/model/tensorflow.rb
rbbt-dm-1.2.1 lib/rbbt/vector/model/tensorflow.rb
rbbt-dm-1.1.63 lib/rbbt/vector/model/tensorflow.rb
rbbt-dm-1.1.62 lib/rbbt/vector/model/tensorflow.rb
rbbt-dm-1.1.61 lib/rbbt/vector/model/tensorflow.rb
rbbt-dm-1.1.60 lib/rbbt/vector/model/tensorflow.rb
rbbt-dm-1.1.59 lib/rbbt/vector/model/tensorflow.rb
rbbt-dm-1.1.58 lib/rbbt/vector/model/tensorflow.rb
rbbt-dm-1.1.57 lib/rbbt/vector/model/tensorflow.rb
rbbt-dm-1.1.56 lib/rbbt/vector/model/tensorflow.rb
rbbt-dm-1.1.55 lib/rbbt/vector/model/tensorflow.rb
rbbt-dm-1.1.54 lib/rbbt/vector/model/tensorflow.rb