require "active_record" require "join_dependency/version" module JoinDependency class << self def from_relation(relation, &block) build(relation, collect_joins(relation, &block)) end private def collect_joins(relation, &block) joins = [] joins += relation.joins_values joins += relation.left_outer_joins_values if at_least?(5) buckets = joins.group_by do |join| case join when String :string_join when Hash, Symbol, Array :association_join when Arel::Nodes::Join :join_node else (block_given? && yield(join)) || raise("unknown class: %s" % join.class.name) end end end def build(relation, buckets) buckets.default = [] association_joins = buckets[:association_join] stashed_association_joins = buckets[:stashed_join] join_nodes = buckets[:join_node].uniq string_joins = buckets[:string_join].map(&:strip).uniq join_list = if at_least?(5, 2) join_nodes + relation.send(:convert_join_strings_to_ast, string_joins) elsif at_least?(5) join_nodes + relation.send(:convert_join_strings_to_ast, relation.table, string_joins) else relation.send(:custom_join_ast, relation.table.from(relation.table), string_joins) end if at_least?(5, 2) alias_tracker = ::ActiveRecord::Associations::AliasTracker.create(relation.klass.connection, relation.table.name, join_list) join_dependency = ::ActiveRecord::Associations::JoinDependency.new(relation.klass, relation.table, association_joins, alias_tracker) join_nodes.each do |join| join_dependency.send(:alias_tracker).aliases[join.left.name.downcase] = 1 end else join_dependency = ::ActiveRecord::Associations::JoinDependency.new(relation.klass, association_joins, join_list) join_nodes.each do |join| join_dependency.send(:alias_tracker).aliases[join.left.name.downcase] = 1 end end join_dependency end def at_least?(major, minor = 0) ActiveRecord::VERSION::MAJOR >= major && ActiveRecord::VERSION::MINOR >= minor end end end