lib/active_median/model.rb in active_median-0.4.1 vs lib/active_median/model.rb in active_median-0.5.0

- old
+ new

@@ -1,28 +1,32 @@ module ActiveMedian module Model def median(column) - calculate_percentile(column, 0.5, "median") + connection_pool.with_connection do |connection| + calculate_percentile(column, 0.5, "median", connection) + end end def percentile(column, percentile) - calculate_percentile(column, percentile, "percentile") + connection_pool.with_connection do |connection| + calculate_percentile(column, percentile, "percentile", connection) + end end private - def calculate_percentile(column, percentile, operation) + def calculate_percentile(column, percentile, operation, connection) percentile = Float(percentile, exception: false) raise ArgumentError, "invalid percentile" if percentile.nil? raise ArgumentError, "percentile is not between 0 and 1" if percentile < 0 || percentile > 1 # basic version of Active Record disallow_raw_sql! # symbol = column (safe), Arel node = SQL (safe), other = untrusted # matches table.column and column unless column.is_a?(Symbol) || column.is_a?(Arel::Nodes::SqlLiteral) column = column.to_s - unless /\A\w+(\.\w+)?\z/i.match(column) + unless /\A\w+(\.\w+)?\z/i.match?(column) raise ActiveRecord::UnknownAttributeReference, "Query method called with non-attribute argument(s): #{column.inspect}. Use Arel.sql() for known-safe values." end end column_alias = @@ -32,15 +36,15 @@ # Active Record 7.0.5+ ActiveRecord::Calculations::ColumnAliasTracker.new(connection).alias_for("#{operation} #{column.to_s.downcase}") end # safety check # could quote, but want to keep consistent with Active Record - raise "Bad column alias: #{column_alias}. Please report a bug." unless column_alias =~ /\A[a-z0-9_]+\z/ + raise "Bad column alias: #{column_alias}. Please report a bug." unless /\A[a-z0-9_]+\z/.match?(column_alias) # column resolution node = relation.send(:arel_columns, [column]).first node = Arel::Nodes::SqlLiteral.new(node) if node.is_a?(String) - column = relation.connection.visitor.accept(node, Arel::Collectors::SQLString.new).value + column = connection.visitor.accept(node, Arel::Collectors::SQLString.new).value # prevent SQL injection percentile = connection.quote(percentile) group_values = all.group_values