lib/rbbt/vector/model/pytorch_lightning.rb in rbbt-dm-1.2.9 vs lib/rbbt/vector/model/pytorch_lightning.rb in rbbt-dm-1.2.10
- old
+ new
@@ -1,35 +1,31 @@
require 'rbbt/vector/model/torch'
class PytorchLightningModel < TorchModel
attr_accessor :loader, :val_loader, :trainer
- def initialize(module_name, class_name, dir = nil, model_options = {})
- super(dir, model_options)
- @module_name = module_name
- @class_name = class_name
+ def initialize(...)
+ super(...)
- init_model do
- RbbtPython.pyimport @module_name
- RbbtPython.class_new_obj(@module_name, @class_name, @model_options[:model_args] || {})
- end
-
train_model do |features,labels|
model = init
- raise "Use the loader" if @loader.nil?
- raise "Use the trainer" if @trainer.nil?
-
- trainer.fit(model, @loader, @val_loader)
- end
-
- eval_model do |features,list|
- if list
- model.call(RbbtPython.call_method(:torch, :tensor, features))
- else
- model.call(RbbtPython.call_method(:torch, :tensor, [features]))
+ loader = self.loader
+ val_loader = self.val_loader
+ if (features && features.any?) && loader.nil?
+ TmpFile.with_file do |tsv_dataset_file|
+ TorchModel.feature_dataset(tsv_dataset_file, features, labels)
+ RbbtPython.pyimport :rbbt_dm
+ loader = RbbtPython.rbbt_dm.tsv(tsv_dataset_file)
+ end
end
+ trainer.fit(model, loader, val_loader)
+ TorchModel.save_architecture(model, model_path) if @directory
+ TorchModel.save_state(model, model_path) if @directory
end
-
end
-end
-if __FILE__ == $0
+ def trainer
+ @trainer ||= begin
+ options = @model_options[:training_args] || @model_options[:trainer_args]
+ RbbtPython.class_new_obj("pytorch_lightning", "Trainer", options || {})
+ end
+ end
end