require 'odbc' require 'marjoree/odbc' require 'marjoree/result_set' require 'marjoree/odbc_connection_wrapper' class Hash def symbolize_keys self.inject({}) do |result, (key, value)| result[key.to_sym] = value result end end end # This is the main Marjoree mixin. # Before you call anything you will need to call 'connect' or no messages will reach your DB :o) # You might like to stick something like this in your test code to ensure the connection is dropped when you are finished. # Kernel.at_exit { disconnect } module Marjoree # Connect to server def connect_me_to( odbc_name ) $connection.connect_me_to( odbc_name ) end # Establish a connection to dataserver. def establish_connection_to( connection ) $connection = connection end # Remove connection to dataserver. def disconnect_me $connection.disconnect_me end # Run Store Procedure 'proc_name' # # The input_param_map takes the form of # { :procedure_input_param_name => procedure_input_param_value } # eg: # result_set = EXEC 'proc_name' @procedure_input_param_name = procedure_input_param_value # # If a proc has output parameters these are automatically bound onto the ResultSet object. # eg: # assert_equal( expected_output_parameter_value, result_set.output_parameter_name ) def run_sproc( proc_name, map={} ) if has_output_params?( proc_name ) return run_complex_sproc( proc_name, map ) else return run_simple_sproc( proc_name, map ) end end # Performs a select * on table_name def select(table_name, where_map = nil) where_clause = build_where_statement( where_map ) execute "SELECT * FROM #{table_name} #{where_clause}" end # Provides the number of rows in table 'table_name'. # # The where_map entries are AND'ed together eg: # { :table_column_name_one => 1, # :table_column_name_two => 2 } # would produce # SELECT # * # FROM # table_name # WHERE # table_column_name_one = 1 # AND table_column_name_two = 2 def num_rows( table_name, where_map = nil ) where_clause = build_where_statement( where_map ) dbCall = execute( "SELECT COUNT(*) AS number_of_rows FROM #{table_name} #{where_clause}" ) return dbCall.hashes.first[:number_of_rows] end # Does the same as num_rows def count(table_name, where_map = nil) return num_rows(table_name, where_map) end # Returns true if table 'table_name' contains the data specified in the where_map. # # 'where_map' supplies column_name_as_sym => value # # These are AND'ed together. def contains?(table_name, map={}) return count(table_name, map) > 0 end # Performs an INSERT INTO 'table_name'. # 'value_map' supplies column_name_as_sym => value def insert(table_name, map={}) columns = build_column_headers(map) values = build_column_values(map) execute "INSERT INTO #{table_name} (#{columns}) VALUES (#{values})" end # Performs an UPDATE 'table_name' # # 'where_map' supplies column_name_as_sym => value for the WHERE section of the UPDATE statement. # # These are AND'ed together. # # 'set_map' supplies column_name_as_sym => value for the SET section of the UPDATE statement. def update(table_name, where_map={}, map={}) where_clause = build_where_statement(where_map) assignments = build_set_statement(map) execute "UPDATE #{table_name} SET #{assignments} #{where_clause}" end # Performs an DELETE FROM 'table_name' # # 'where_map' supplies column_name_as_sym => value for the WHERE section of the UPDATE statement. # # These are AND'ed together. def delete(table_name, map={}) where_clause = build_where_statement(map) execute "DELETE FROM #{table_name} #{where_clause}" end # Performs an TRUNCATE TABLE 'table_name' def truncate(table_name) execute "TRUNCATE TABLE #{table_name}" end def get_user_tables return get_objects( 'U' ) end def get_user_views return get_objects( 'V' ) end def get_user_sprocs return get_objects( 'P' ) end # DROPs all user tables def drop_user_tables drop_objects('TABLE') do get_user_tables end end # DROPs all user views def drop_user_views drop_objects('VIEW') do get_user_views end end # DROPs all user sprocs def drop_user_sprocs drop_objects('PROCEDURE') do get_user_sprocs end end # Asserts that the ResultSet is empty def assert_empty(result_set) assert_equal 0, result_set.hashes.size end # Asserts that the values in the ExpectedResultSet are contained within the actual ResultSet. def assert_results(expected, result_set) assert_column_headers expected, result_set assert_row_data expected, result_set end # Asserts that the values in the ExpectedResultSet are NOT contained within the actual ResultSet. def assert_not_equal_results(expected, result_set) flag1 = !has_column_headers?(expected, result_set) flag2 = !has_correct_data?(expected, result_set) assert( !flag1 || !has_correct_data?(expected, result_set) ) end # Asserts that an expected_error_code and error_message are returned when running the block. def assert_error_thrown( expected_error_code, expected_error_message ) begin yield fail rescue ODBC::Error exception = $! assert_equal( expected_error_code, exception.error_code ) if expected_error_message != nil assert_equal( expected_error_message, exception.error_message ) end end end # Asserts when the DB throws an error. # # ie: Inserting into a table that does not exist. def assert_db_error begin yield fail rescue ODBC::Error end end def has_column_headers?(expected, result_set) begin assert_column_headers expected, result_set return true rescue return false end end def has_correct_data?(expected, result_set) begin assert_row_data expected, result_set return true rescue return false end end # Helpful Assert that wraps the contains? method. def assert_contains(table_name, value_map) assert contains?(table_name, value_map) end # Helpful Assert that wraps the contains? method. def assert_does_not_contain(table_name, value_map) assert !contains?(table_name, value_map) end # Executes raw sql against the db. def execute(sql, output_params = []) return $connection.execute( sql, output_params ) end private def build_set_statement(map) pairs = [] map.each { |key, value| pairs << "#{key} = #{display_value_for( value )}" } return pairs.join(", ") end def build_column_values(map) return map.values.collect {|v| display_value_for v }.join(", ") end def build_input_parameter_statement(map) pairs = [] map.each { |key, value| pairs << "@#{key.to_s} = #{display_value_for( value )}" } return pairs.join(", ") end def build_column_headers(map) return map.keys.join(", ") end def build_where_statement(map) return '' if map.nil? pairs = [] map.each { |key, value| pairs << "#{key} = #{display_value_for( value ) }" } expression = pairs.join(" AND ") if map.size > 0 expression = " WHERE #{expression}" end return expression end def display_value_for(value) return 'NULL' if value.nil? return value.is_a?(Numeric) ? value : "'#{value}'" end def assert_column_headers(expected, result_set) expected.columns.each do |expected_column_header| actual_column_headers = result_set.columns.keys formatted_actual_headers = actual_column_headers.join( " \n\t" ) errorMessage = "\nColumn Header: '#{expected_column_header}' does not exist.\nActual Column Headers available are \n\t#{formatted_actual_headers}\n" assert actual_column_headers.include?(expected_column_header), errorMessage end end def assert_row_data(expected, result_set) result = result_set.has?(expected) possibleErrorMessage = "\nExpected: #{expected.row_hashes.map{|row|row.symbolize_keys.rehash}.inspect}\nActual: [#{actual_values( result_set, expected )}]\n" assert result, possibleErrorMessage end def actual_values(result_set, expected) actual_values = result_set.hashes.collect { |hash| hash.reject{|k, v| !expected.columns.include?(k.to_s)}.rehash.inspect } return actual_values.join(", ") end def has_output_params?(proc_name) return $connection.has_output_params?(proc_name) end def build_output_parameter_statement(list) result = list.collect { |hash| "#{hash[:COLUMN_NAME]} = #{hash[:COLUMN_NAME]} output" } return result.join(", ") end def run_complex_sproc(proc_name, input_param_map) output_param_data = get_output_params(proc_name) output_params = output_param_names(output_param_data) sql = build_sql_for_sproc_with_output_params(proc_name, input_param_map, output_param_data) execute sql, output_params end def build_sql_for_sproc_with_output_params(proc_name, input_param_map, output_params) sproc_input_param_text = build_input_parameter_statement(input_param_map) sproc_output_param_text = build_output_parameter_statement(output_params) sproc_input_param_text << "," unless input_param_map.empty? sql = build_output_declarations(output_params) sql << "EXEC #{proc_name} #{sproc_input_param_text} #{sproc_output_param_text}\n" output_params.each { |hash| sql << "SELECT #{hash[:COLUMN_NAME]}\n" } return sql end def output_param_names(output_params) return output_params.collect{ |hash| hash[:COLUMN_NAME] } end def build_output_declarations(output_params) result = "" output_params.each do |hash| type_declaration = format_type_declaration( hash ) result << "DECLARE #{hash[:COLUMN_NAME]} #{type_declaration}\n" end return result end def format_type_declaration(hash) column_type = hash[:TYPE_NAME] if column_type == 'varchar' return "#{column_type}(#{hash[:COLUMN_SIZE]})" end return column_type end def get_output_params(proc_name) return $connection.get_output_params(proc_name) end def run_simple_sproc(proc_name, map={}) input_param_text = build_input_parameter_statement(map) execute "EXEC #{proc_name} #{input_param_text }" end def get_objects( type ) result_set = select( "sysobjects", {:type => type} ) object_names = [] result_set.hashes.each do |row| object_names << row[:name] end return object_names end def drop_objects( object_type ) puts "Dropping #{yield.size} #{object_type}(s)." while (yield.size != 0) yield.each do |object_name| begin execute "DROP #{object_type} #{object_name}" rescue end end end puts "All user #{object_type}s dropped." end end