lib/omniauth-ldap/adaptor.rb in gitlab_omniauth-ldap-1.0.3 vs lib/omniauth-ldap/adaptor.rb in gitlab_omniauth-ldap-1.0.4

- old
+ new

@@ -1,37 +1,41 @@ #this code borrowed pieces from activeldap and net-ldap require 'rack' require 'net/ldap' require 'net/ntlm' -require 'uri' require 'sasl' require 'kconv' module OmniAuth module LDAP class Adaptor class LdapError < StandardError; end class ConfigurationError < StandardError; end class AuthenticationError < StandardError; end class ConnectionError < StandardError; end - VALID_ADAPTER_CONFIGURATION_KEYS = [:host, :port, :method, :bind_dn, :password, :try_sasl, :sasl_mechanisms, :uid, :base, :allow_anonymous] + VALID_ADAPTER_CONFIGURATION_KEYS = [:host, :port, :method, :bind_dn, :password, :try_sasl, :sasl_mechanisms, :uid, :base, :allow_anonymous, :filter] - MUST_HAVE_KEYS = [:host, :port, :method, :uid, :base] + # A list of needed keys. Possible alternatives are specified using sub-lists. + MUST_HAVE_KEYS = [:host, :port, :method, [:uid, :filter], :base] METHOD = { :ssl => :simple_tls, :tls => :start_tls, :plain => nil, } attr_accessor :bind_dn, :password - attr_reader :connection, :uid, :base, :auth + attr_reader :connection, :uid, :base, :auth, :filter def self.validate(configuration={}) message = [] - MUST_HAVE_KEYS.each do |name| - message << name if configuration[name].nil? + MUST_HAVE_KEYS.each do |names| + names = [names].flatten + missing_keys = names.select{|name| configuration[name].nil?} + if missing_keys == names + message << names.join(' or ') + end end raise ArgumentError.new(message.join(",") +" MUST be provided") unless message.empty? end def initialize(configuration={}) Adaptor.validate(configuration) @@ -46,11 +50,10 @@ :host => @host, :port => @port, :encryption => method, :base => @base } - @uri = construct_uri(@host, @port, @method != :plain) @bind_method = @try_sasl ? :sasl : (@allow_anonymous||!@bind_dn||!@password ? :anonymous : :simple) @auth = sasl_auths({:username => @bind_dn, :password => @password}).first if @bind_method == :sasl @@ -138,12 +141,8 @@ t3_msg.serialize } [Net::NTLM::Message::Type1.new.serialize, nego] end - def construct_uri(host, port, ssl) - protocol = ssl ? "ldaps" : "ldap" - URI.parse("#{protocol}://#{host}:#{port}").to_s - end end end end