Sha256: f84a02b4fe336aa3f9faba841fa6861a824cccd1f6484e0704b668ef91777fa0

Contents?: true

Size: 1.62 KB

Versions: 3

Compression:

Stored size: 1.62 KB

Contents

% c_dtype = dtype_to_c_type(dtype)
% out_c_dtype = dtype_to_c_type(out_dtype)
% o_multipliers = o_shape.dup.drop(1).reverse.inject([1]) { |a, s| a << s * a.last }.reverse
% i_multipliers = shape.dup.drop(1).reverse.inject([1]) { |a, s| a << s * a.last }.reverse
% out_ops = o_multipliers.map.with_index { |m, index| "id_#{index} * #{m}"}.join(' + ')
% axis = axis[0]
% in_axis_multipliers = i_multipliers.select.with_index { |m, index| axis == index }
% in_axis_ops =  in_axis_multipliers.map.with_index { |m| "i * #{m}"}.join(' + ')
% in_output_multipliers = i_multipliers.reject.with_index { |m, index| axis == index }
% in_output_ops =  in_output_multipliers.map.with_index { |m, index| "id_#{index} * #{m}"}.join(' + ')
__kernel void arg_axis_<%= dtype %>(__global const <%= c_dtype %> *value, __global <%= out_c_dtype %> *output) {
    // Get the index of the current element to be processed
<% o_multipliers.size.times.each_with_index do |s, index| %>
  const int id_<%= index %> = get_global_id(<%= index %>);
<% end %>

<%= c_dtype %> min_or_max_value = <%= f == :argmax ? min_value_for(dtype) : max_value_for(dtype) %>;
int min_or_max_index = 0;

for (int i = 0; i < <%= shape[axis] %>; i++) {

  int index = <%= in_axis_ops %>;

  <% unless in_output_ops.empty? %>
  index += <%= in_output_ops %>;
  <% end %>
  <%= case(f)
    when :argmax
      "if (value[index] > min_or_max_value) {"
    when :argmin
      "if (value[index] < min_or_max_value) {"
    else
    raise "unkown redunction func #{f}"
    end
  %>
     min_or_max_index = i;
     min_or_max_value = value[index];
  }
}

  output[<%= out_ops %>] = (<%= out_c_dtype %>)min_or_max_index;
}

Version data entries

3 entries across 3 versions & 1 rubygems

Version Path
tensor_stream-opencl-0.3.2 lib/tensor_stream/opencl/kernels/arg_axis.cl
tensor_stream-opencl-0.3.1 lib/tensor_stream/opencl/kernels/arg_axis.cl
tensor_stream-opencl-0.3.0 lib/tensor_stream/opencl/kernels/arg_axis.cl