require 'dionysus'
require 'digest'

##
# Convenience methods for the Digest module.
#
#     require 'dionysus/digest'
# 
# The <tt>Digest::DEFAULT_DIGESTS</tt> are automatically registered, if they
# exist.  You can register additional digests with Digest.register_digest.
# The given class must give the digest with the <tt>digest(string)</tt>.
#
# TODO add digest detection -- by length and by proc on the digests hash
module Digest
  DEFAULT_DIGESTS = [:md5, :sha1, :sha2, :sha256, :sha384, :sha512]
  @digests = {}
  
  ##
  # Register a digest.  Raises an error if the interpreted class doesn't exist.
  # It will interpret the klass as <tt>Digest::SYM</tt> if it's <tt>nil</tt>, 
  # and it will run the <tt>digest</tt> method on the klass to determine the 
  # digests bit length if bits is <tt>nil</tt>.
  # 
  # This will register <tt>:my_digest</tt> and automatically determine the bit 
  # length by executing the class's <tt>digest</tt> method on the string 
  # <tt>'1'</tt>:
  #
  #     Digest.register_digest!( :my_digest, :klass => MyDigestClass )
  #
  # Options:
  # [klass]       The digest class (also can be an arbitrary object).  Default:
  #               <tt>Digest::#{sym.to_s.upcase}</tt>
  # [bit_length]  The bit length of the digest.  Default: calculated by 
  #               running the digest on the string <tt>'1'</tt>.
  # [method]      The calculation method for the digest.  Default: 
  #               <tt>:digest</tt>
  def self.register_digest!( sym, options = {} )
    options = options.with_indifferent_access
    options[:method]      ||= :digest
    options[:klass]       ||= "Digest::#{sym.to_s.upcase}".constantize
    options[:bit_length]  ||= options[:klass].send(options[:method], '1').length * 8
    @digests[sym.to_sym] = options
  end
  
  ##
  # Register a digest.  Returns nil if an error occurs.
  def self.register_digest( sym, options = {} )
    self.register_digest!(sym, options)
  rescue LoadError
    nil
  end
  
  ##
  # The hash of registered digests.
  def self.digests
    @digests
  end
  
  ##
  # The available digests.
  def self.available_digests
    self.digests.keys
  end
  
  ##
  # The lengths of the registered digests in the given encoding.
  def self.digest_lengths( encoding = :binary )
    if encoding == :bit or encoding == 1
      _digest_lengths(1)
    elsif encoding.is_a?(Symbol) and String::ENCODING_BITS_PER_CHAR[encoding]
      _digest_lengths(String::ENCODING_BITS_PER_CHAR[encoding])
    elsif encoding.is_a?(Integer) and encoding > 0
      _digest_lengths(encoding)
    else
      raise ArgumentError, "Invalid encoding"
    end
  end
  
  ##
  # Calculate the given digest of the given string.
  #
  # Examples:
  #
  #     Digest.digest(:sha512, 'foobar') #=> binary digest
  #     Digest.digest(Digest::SHA512, 'foobar') #=> binary digest
  def self.digest( sym, str )
    Digest.const_get(sym.to_s.upcase).digest(str)
  end
  
  ##
  # Detect the digest of the string.  Returns nil if the digest cannot be 
  # determined.
  #
  # Example:
  #     Digest.detect_digest("wxeCFXPVXePFcpwuFDjonyn1G/w=", :base64) #=> :sha1
  #     Digest.detect_digest("foobar", :hex) #=> nil
  def self.detect_digest( string, encoding = :binary )
    string = string.strip unless encoding == :binary
    dig = self.digest_lengths(encoding).invert[string.length]
    dig = :sha256 if dig == :sha2
    dig
  end
  
  ##
  # Detect the digest of the string.  Returns nil if the digest cannot be 
  # determined.
  #
  # Example:
  #     Digest.detect_digest!("wxeCFXPVXePFcpwuFDjonyn1G/w=", :base64) #=> :sha1
  #     Digest.detect_digest!("foobar", :hex) #=> RuntimeError
  def self.detect_digest!( string, encoding = :binary )
    self.detect_digest(string, encoding) or raise("Unknown digest")
  end
  
  private
  
  def self._digest_lengths( bits_per_char ) # :nodoc:
    padding_factor = (bits_per_char.lcm(8) / bits_per_char)
    
    {}.tap do |result|
      self.digests.each do |dig, info|
        result[dig] = len = info[:bit_length] / bits_per_char
        if (t_ = len % padding_factor) != 0
          result[dig] = len + (padding_factor - t_)
        end
      end
    end
  end
end

# Register some default digests
Digest::DEFAULT_DIGESTS.each do |dig|
  Digest.register_digest(dig)
end