lib/cas/client/middleware.rb in cas-client-0.1.3 vs lib/cas/client/middleware.rb in cas-client-0.2.0
- old
+ new
@@ -1,57 +1,61 @@
require 'net/http'
module Cas
module Client
class Middleware
- def initialize(app, config={})
+ def initialize(app, &block)
@app = app
- @config = config
- @config[:extra_attributes] = [] if config[:extra_attributes].nil?
- @request = nil
+
+ Cas::Client.configure(&block) if block_given?
end
def call(env)
@request = Rack::Request.new(env)
- server = Cas::Client::Server.new(@config[:server_url])
status, headers, rack_body = @app.call(env)
- log(env, "Middleware called. Status: #{status}, Headers: #{headers}")
if ticket_validation?
- attributes = server.validate_service(self_url(@request), ticket_param, {extra_attributes: @config[:extra_attributes]})
- set_session(@request, attributes)
- return redirect_to(self_url(@request))
+ attributes = server.validate_service(self_url, ticket_param)
+ set_session(attributes)
+
+ return redirect_to(self_url)
elsif status == 401
- return redirect_to(server.login_url({service_url: self_url(@request)}))
+ log(env, "Cas::Client::Middleware detected 401, Status: #{status}, Headers: #{headers}\n")
+
+ return redirect_to(server.login_url({ service_url: self_url }))
else
return [status, headers, rack_body]
end
end
private
- def set_session(req, attributes)
- req.session['cas'] = attributes
+ def server
+ @_server ||= Cas::Client::Server.new
end
+ def set_session(attributes)
+ @request.session['cas'] = attributes
+ end
+
def redirect_to(url, status=302)
[ status, { 'Location' => url, 'Content-Type' => 'text/plain' }, ["Redirecting you to #{url}"] ]
end
- def self_url(req)
- req.url.split('?')[0]
+ def self_url
+ @request.url.split('?')[0]
end
def ticket_validation?
- !!(@request.get? && ticket_param && ticket_param.to_s =~ /\AST\-[^\s]{1,253}\Z/)
+ @request.get? && param_service_ticket?
end
def ticket_param
@request.params['ticket']
end
- def xml_namespace
- @config[:cas_namespace] || 'cas'
+ def param_service_ticket?
+ ticket_param.to_s =~ /\AST\-[^\s]{1,253}\Z/
end
def log(env, message, level = :info)
if env['rack.logger']
env['rack.logger'].send(level, message)