Sha256: c86da446a112f631e66acd1f77c46ad58f3ce82a64d40636a5d1ab7079992a19

Contents?: true

Size: 770 Bytes

Versions: 1

Compression:

Stored size: 770 Bytes

Contents

require 'rbbt/vector/model'
require 'rbbt/util/python'

RbbtPython.add_path Rbbt.python.find(:lib)
RbbtPython.init_rbbt

class TorchModel < VectorModel

  attr_accessor :model

  def self.get_layer(model, layer)
    layer.split(".").inject(model){|acc,l| PyCall.getattr(acc, l.to_sym) }
  end

  def self.get_weights(model, layer)
    PyCall.getattr(get_layer(model, layer), :weight)
  end

  def self.freeze(layer)
    begin
      PyCall.getattr(layer, :weight).requires_grad = false
    rescue
    end
    RbbtPython.iterate(layer.children) do |layer|
      freeze(layer)
    end
  end

  def self.freeze_layer(model, layer)
    layer = get_layer(model, layer)
    freeze(layer)
  end

  def initialize(dir, model_options = {})
    super(dir, model_options)
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
rbbt-dm-1.2.9 lib/rbbt/vector/model/torch.rb