Sha256: c86da446a112f631e66acd1f77c46ad58f3ce82a64d40636a5d1ab7079992a19
Contents?: true
Size: 770 Bytes
Versions: 1
Compression:
Stored size: 770 Bytes
Contents
require 'rbbt/vector/model' require 'rbbt/util/python' RbbtPython.add_path Rbbt.python.find(:lib) RbbtPython.init_rbbt class TorchModel < VectorModel attr_accessor :model def self.get_layer(model, layer) layer.split(".").inject(model){|acc,l| PyCall.getattr(acc, l.to_sym) } end def self.get_weights(model, layer) PyCall.getattr(get_layer(model, layer), :weight) end def self.freeze(layer) begin PyCall.getattr(layer, :weight).requires_grad = false rescue end RbbtPython.iterate(layer.children) do |layer| freeze(layer) end end def self.freeze_layer(model, layer) layer = get_layer(model, layer) freeze(layer) end def initialize(dir, model_options = {}) super(dir, model_options) end end
Version data entries
1 entries across 1 versions & 1 rubygems
Version | Path |
---|---|
rbbt-dm-1.2.9 | lib/rbbt/vector/model/torch.rb |