Sha256: 083ac55597184f3973aa20fbaf61659e1702533b9d31876c7ce109e3a3d012eb

Contents?: true

Size: 1.64 KB

Versions: 5

Compression:

Stored size: 1.64 KB

Contents

module Torch
  module Hub
    class << self
      def list(github, force_reload: false)
        raise NotImplementedYet
      end

      def download_url_to_file(url, dst)
        uri = URI(url)
        tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
        location = nil

        Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
          request = Net::HTTP::Get.new(uri)

          puts "Downloading #{url}..."
          File.open(tmp, "wb") do |f|
            http.request(request) do |response|
              case response
              when Net::HTTPRedirection
                location = response["location"]
              when Net::HTTPSuccess
                response.read_body do |chunk|
                  f.write(chunk)
                end
              else
                raise Error, "Bad response"
              end
            end
          end
        end

        if location
          download_url_to_file(location, dst)
        else
          FileUtils.mv(tmp, dst)
          nil
        end
      end

      def load_state_dict_from_url(url, model_dir: nil)
        unless model_dir
          torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
          model_dir = File.join(torch_home, "checkpoints")
        end

        FileUtils.mkdir_p(model_dir)

        parts = URI(url)
        filename = File.basename(parts.path)
        cached_file = File.join(model_dir, filename)
        unless File.exist?(cached_file)
          # TODO support hash_prefix
          download_url_to_file(url, cached_file)
        end

        Torch.load(cached_file)
      end
    end
  end
end

Version data entries

5 entries across 5 versions & 1 rubygems

Version Path
torch-rb-0.3.1 lib/torch/hub.rb
torch-rb-0.3.0 lib/torch/hub.rb
torch-rb-0.2.7 lib/torch/hub.rb
torch-rb-0.2.6 lib/torch/hub.rb
torch-rb-0.2.5 lib/torch/hub.rb