lib/rbbt/vector/model/pytorch_lightning.rb in rbbt-dm-1.2.9 vs lib/rbbt/vector/model/pytorch_lightning.rb in rbbt-dm-1.2.10

- old
+ new

@@ -1,35 +1,31 @@ require 'rbbt/vector/model/torch' class PytorchLightningModel < TorchModel attr_accessor :loader, :val_loader, :trainer - def initialize(module_name, class_name, dir = nil, model_options = {}) - super(dir, model_options) - @module_name = module_name - @class_name = class_name + def initialize(...) + super(...) - init_model do - RbbtPython.pyimport @module_name - RbbtPython.class_new_obj(@module_name, @class_name, @model_options[:model_args] || {}) - end - train_model do |features,labels| model = init - raise "Use the loader" if @loader.nil? - raise "Use the trainer" if @trainer.nil? - - trainer.fit(model, @loader, @val_loader) - end - - eval_model do |features,list| - if list - model.call(RbbtPython.call_method(:torch, :tensor, features)) - else - model.call(RbbtPython.call_method(:torch, :tensor, [features])) + loader = self.loader + val_loader = self.val_loader + if (features && features.any?) && loader.nil? + TmpFile.with_file do |tsv_dataset_file| + TorchModel.feature_dataset(tsv_dataset_file, features, labels) + RbbtPython.pyimport :rbbt_dm + loader = RbbtPython.rbbt_dm.tsv(tsv_dataset_file) + end end + trainer.fit(model, loader, val_loader) + TorchModel.save_architecture(model, model_path) if @directory + TorchModel.save_state(model, model_path) if @directory end - end -end -if __FILE__ == $0 + def trainer + @trainer ||= begin + options = @model_options[:training_args] || @model_options[:trainer_args] + RbbtPython.class_new_obj("pytorch_lightning", "Trainer", options || {}) + end + end end