lib/arel/middleware/chain.rb in arel_toolkit-0.4.0 vs lib/arel/middleware/chain.rb in arel_toolkit-0.4.1

- old
+ new

@@ -1,74 +1,109 @@ module Arel module Middleware class Chain - def initialize(internal_middleware = [], internal_context = {}) + attr_reader :executing_middleware + attr_reader :executor + + def initialize( + internal_middleware = [], + internal_context = {}, + executor_class = Arel::Middleware::DatabaseExecutor + ) @internal_middleware = internal_middleware @internal_context = internal_context + @executor = executor_class.new(internal_middleware) + @executing_middleware = false end - def execute(sql, binds = []) - return sql if internal_middleware.length.zero? + def execute(sql, binds = [], &execute_sql) + return execute_sql.call(sql, binds).to_casted_result if internal_middleware.length.zero? - result = Arel.sql_to_arel(sql, binds: binds) + check_middleware_recursion(sql) + updated_context = context.merge(original_sql: sql) + enhanced_arel = Arel.enhance(Arel.sql_to_arel(sql, binds: binds)) - internal_middleware.each do |middleware_item| - result = result.map do |arel| - middleware_item.call(arel, updated_context.dup) - end - end + result = executor.run(enhanced_arel, updated_context, execute_sql) - result.to_sql + result.to_casted_result + rescue ::PgQuery::ParseError + execute_sql.call(sql, binds) + ensure + @executing_middleware = false end def current internal_middleware.dup end def apply(middleware, &block) - continue_chain(middleware, internal_context, &block) + new_middleware = Array.wrap(middleware) + continue_chain(new_middleware, internal_context, &block) end + alias only apply - def only(middleware, &block) - continue_chain(middleware, internal_context, &block) - end - def none(&block) continue_chain([], internal_context, &block) end def except(without_middleware, &block) - new_middleware = internal_middleware.reject do |middleware| - middleware == without_middleware - end - + without_middleware = Array.wrap(without_middleware) + new_middleware = internal_middleware - without_middleware continue_chain(new_middleware, internal_context, &block) end def insert_before(new_middleware, existing_middleware, &block) + new_middleware = Array.wrap(new_middleware) index = internal_middleware.index(existing_middleware) - updated_middleware = internal_middleware.insert(index, new_middleware) + updated_middleware = internal_middleware.insert(index, *new_middleware) continue_chain(updated_middleware, internal_context, &block) end + def prepend(new_middleware, &block) + new_middleware = Array.wrap(new_middleware) + updated_middleware = new_middleware + internal_middleware + continue_chain(updated_middleware, internal_context, &block) + end + def insert_after(new_middleware, existing_middleware, &block) + new_middleware = Array.wrap(new_middleware) index = internal_middleware.index(existing_middleware) - updated_middleware = internal_middleware.insert(index + 1, new_middleware) + updated_middleware = internal_middleware.insert(index + 1, *new_middleware) continue_chain(updated_middleware, internal_context, &block) end + def append(new_middleware, &block) + new_middleware = Array.wrap(new_middleware) + updated_middleware = internal_middleware + new_middleware + continue_chain(updated_middleware, internal_context, &block) + end + def context(new_context = nil, &block) if new_context.nil? && !block.nil? raise 'You cannot do a block statement while calling context without arguments' end return internal_context if new_context.nil? continue_chain(internal_middleware, new_context, &block) end + def to_sql(type, &block) + middleware = Arel::Middleware::ToSqlMiddleware.new(type) + + new_chain = Arel::Middleware::Chain.new( + internal_middleware + [middleware], + internal_context, + Arel::Middleware::ToSqlExecutor, + ) + + maybe_execute_block(new_chain, &block) + + middleware.sql + end + protected attr_reader :internal_middleware attr_reader :internal_context @@ -85,9 +120,29 @@ previous_chain = Middleware.current_chain Arel::Middleware.current_chain = new_chain yield block ensure Arel::Middleware.current_chain = previous_chain + end + + def check_middleware_recursion(sql) + if executing_middleware + message = <<~ERROR + Middleware is being called from within middleware, aborting execution + to prevent endless recursion. You can do the following if you want to execute SQL + inside middleware: + + - Set middleware context before entering the middleware + - Use `Arel.middleware.none { ... }` to temporarily disable middleware + + SQL that triggered the error: + #{sql} + ERROR + + raise message + else + @executing_middleware = true + end end end end end