Sha256: 083ac55597184f3973aa20fbaf61659e1702533b9d31876c7ce109e3a3d012eb
Contents?: true
Size: 1.64 KB
Versions: 5
Compression:
Stored size: 1.64 KB
Contents
module Torch module Hub class << self def list(github, force_reload: false) raise NotImplementedYet end def download_url_to_file(url, dst) uri = URI(url) tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name location = nil Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http| request = Net::HTTP::Get.new(uri) puts "Downloading #{url}..." File.open(tmp, "wb") do |f| http.request(request) do |response| case response when Net::HTTPRedirection location = response["location"] when Net::HTTPSuccess response.read_body do |chunk| f.write(chunk) end else raise Error, "Bad response" end end end end if location download_url_to_file(location, dst) else FileUtils.mv(tmp, dst) nil end end def load_state_dict_from_url(url, model_dir: nil) unless model_dir torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch" model_dir = File.join(torch_home, "checkpoints") end FileUtils.mkdir_p(model_dir) parts = URI(url) filename = File.basename(parts.path) cached_file = File.join(model_dir, filename) unless File.exist?(cached_file) # TODO support hash_prefix download_url_to_file(url, cached_file) end Torch.load(cached_file) end end end end
Version data entries
5 entries across 5 versions & 1 rubygems
Version | Path |
---|---|
torch-rb-0.3.1 | lib/torch/hub.rb |
torch-rb-0.3.0 | lib/torch/hub.rb |
torch-rb-0.2.7 | lib/torch/hub.rb |
torch-rb-0.2.6 | lib/torch/hub.rb |
torch-rb-0.2.5 | lib/torch/hub.rb |