lib/torch/hub.rb in torch-rb-0.2.4 vs lib/torch/hub.rb in torch-rb-0.2.5

- old
+ new

@@ -3,15 +3,59 @@ class << self def list(github, force_reload: false) raise NotImplementedYet end - def download_url_to_file(url) - raise NotImplementedYet + def download_url_to_file(url, dst) + uri = URI(url) + tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name + location = nil + + Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http| + request = Net::HTTP::Get.new(uri) + + puts "Downloading #{url}..." + File.open(tmp, "wb") do |f| + http.request(request) do |response| + case response + when Net::HTTPRedirection + location = response["location"] + when Net::HTTPSuccess + response.read_body do |chunk| + f.write(chunk) + end + else + raise Error, "Bad response" + end + end + end + end + + if location + download_url_to_file(location, dst) + else + FileUtils.mv(tmp, dst) + nil + end end - def load_state_dict_from_url(url) - raise NotImplementedYet + def load_state_dict_from_url(url, model_dir: nil) + unless model_dir + torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch" + model_dir = File.join(torch_home, "checkpoints") + end + + FileUtils.mkdir_p(model_dir) + + parts = URI(url) + filename = File.basename(parts.path) + cached_file = File.join(model_dir, filename) + unless File.exist?(cached_file) + # TODO support hash_prefix + download_url_to_file(url, cached_file) + end + + Torch.load(cached_file) end end end end