lib/cas/client/middleware.rb in cas-client-0.1.3 vs lib/cas/client/middleware.rb in cas-client-0.2.0

- old
+ new

@@ -1,57 +1,61 @@ require 'net/http' module Cas module Client class Middleware - def initialize(app, config={}) + def initialize(app, &block) @app = app - @config = config - @config[:extra_attributes] = [] if config[:extra_attributes].nil? - @request = nil + + Cas::Client.configure(&block) if block_given? end def call(env) @request = Rack::Request.new(env) - server = Cas::Client::Server.new(@config[:server_url]) status, headers, rack_body = @app.call(env) - log(env, "Middleware called. Status: #{status}, Headers: #{headers}") if ticket_validation? - attributes = server.validate_service(self_url(@request), ticket_param, {extra_attributes: @config[:extra_attributes]}) - set_session(@request, attributes) - return redirect_to(self_url(@request)) + attributes = server.validate_service(self_url, ticket_param) + set_session(attributes) + + return redirect_to(self_url) elsif status == 401 - return redirect_to(server.login_url({service_url: self_url(@request)})) + log(env, "Cas::Client::Middleware detected 401, Status: #{status}, Headers: #{headers}\n") + + return redirect_to(server.login_url({ service_url: self_url })) else return [status, headers, rack_body] end end private - def set_session(req, attributes) - req.session['cas'] = attributes + def server + @_server ||= Cas::Client::Server.new end + def set_session(attributes) + @request.session['cas'] = attributes + end + def redirect_to(url, status=302) [ status, { 'Location' => url, 'Content-Type' => 'text/plain' }, ["Redirecting you to #{url}"] ] end - def self_url(req) - req.url.split('?')[0] + def self_url + @request.url.split('?')[0] end def ticket_validation? - !!(@request.get? && ticket_param && ticket_param.to_s =~ /\AST\-[^\s]{1,253}\Z/) + @request.get? && param_service_ticket? end def ticket_param @request.params['ticket'] end - def xml_namespace - @config[:cas_namespace] || 'cas' + def param_service_ticket? + ticket_param.to_s =~ /\AST\-[^\s]{1,253}\Z/ end def log(env, message, level = :info) if env['rack.logger'] env['rack.logger'].send(level, message)