Sha256: 5fb6d3ab1c5e7b6c4ea5f077abbcb7a35536d94940a1e06a7b091cb5ddf6e3e7

Contents?: true

Size: 911 Bytes

Versions: 1

Compression:

Stored size: 911 Bytes

Contents

require 'rbbt/vector/model/torch'

class PytorchLightningModel < TorchModel
  attr_accessor :loader, :val_loader, :trainer
  def initialize(module_name, class_name, dir = nil, model_options = {})
    super(dir, model_options)
    @module_name = module_name
    @class_name = class_name

    init_model do 
      RbbtPython.pyimport @module_name
      RbbtPython.class_new_obj(@module_name, @class_name, @model_options[:model_args] || {})
    end

    train_model do |features,labels|
      model = init
      raise "Use the loader" if @loader.nil?
      raise "Use the trainer" if @trainer.nil?
      
      trainer.fit(model, @loader, @val_loader)
    end

    eval_model do |features,list|
      if list
        model.call(RbbtPython.call_method(:torch, :tensor, features))
      else
        model.call(RbbtPython.call_method(:torch, :tensor, [features]))
      end
    end

  end
end

if __FILE__ == $0
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
rbbt-dm-1.2.9 lib/rbbt/vector/model/pytorch_lightning.rb