Sha256: 5fb6d3ab1c5e7b6c4ea5f077abbcb7a35536d94940a1e06a7b091cb5ddf6e3e7
Contents?: true
Size: 911 Bytes
Versions: 1
Compression:
Stored size: 911 Bytes
Contents
require 'rbbt/vector/model/torch' class PytorchLightningModel < TorchModel attr_accessor :loader, :val_loader, :trainer def initialize(module_name, class_name, dir = nil, model_options = {}) super(dir, model_options) @module_name = module_name @class_name = class_name init_model do RbbtPython.pyimport @module_name RbbtPython.class_new_obj(@module_name, @class_name, @model_options[:model_args] || {}) end train_model do |features,labels| model = init raise "Use the loader" if @loader.nil? raise "Use the trainer" if @trainer.nil? trainer.fit(model, @loader, @val_loader) end eval_model do |features,list| if list model.call(RbbtPython.call_method(:torch, :tensor, features)) else model.call(RbbtPython.call_method(:torch, :tensor, [features])) end end end end if __FILE__ == $0 end
Version data entries
1 entries across 1 versions & 1 rubygems
Version | Path |
---|---|
rbbt-dm-1.2.9 | lib/rbbt/vector/model/pytorch_lightning.rb |