Sha256: 8cb5bdda7066d90f87508153bacf05b227904711a3a274c436ddb4dc76275af0

Contents?: true

Size: 1.68 KB

Versions: 5

Compression:

Stored size: 1.68 KB

Contents

require 'erb'
class OpenclTemplateHelper
  def initialize(source)
    @source = source
  end

  def generate(args = {})
    current_scope = binding

    args.each do |k, v|
      current_scope.local_variable_set(k.to_sym, v)
    end

    ERB.new(@source, nil, '%').result(current_scope)
  end

  def floating_point?(dtype)
    TensorStream::Ops::FLOATING_POINT_TYPES.include?(dtype)
  end

  def render(template, locals = {})
    filename = File.join(File.dirname(__FILE__), 'kernels', "_#{template}")
    source = File.read(filename)
    current_scope = binding
    locals.each do |k, v|
      current_scope.local_variable_set(k.to_sym, v)
    end
    ERB.new(source, nil, '%').result(current_scope)
  end

  def dtype_to_c_type(dtype)
    case dtype.to_s
    when 'float64'
      'double'
    when 'float32', 'float'
      'float'
    when 'int32', 'int'
      'int'
    when 'int16'
      'short'
    when 'boolean'
      'short'
    else
      raise "unknown dtype #{dtype}"
    end
  end

  def min_value_for(dtype)
    case dtype.to_s
    when 'float64'
      'DBL_MIN'
    when 'float32', 'float'
      'FLT_MIN'
    when 'int32', 'int'
      'INT_MIN'
    when 'int16'
      'SHRT_MIN'
    when 'boolean'
      '0'
    else
      raise "unknown dtype #{dtype}"
    end
  end

  def operator_to_c(op)
    case op
    when 'less'
      '<'
    when 'less_equal'
      '<='
    when 'equal'
      '=='
    when 'greater'
      '>'
    when 'greater_equal'
      '>='
    when 'not_equal'
      '!='
    when 'logical_and'
      '&&'
    when 'div'
      '/'
    when 'add'
      '+'
    when 'sub'
      '-'
    when 'mul'
      '*'
    when 'mod'
      '%'
    else
      raise "unsupported op #{op}"
    end
  end
end

Version data entries

5 entries across 5 versions & 1 rubygems

Version Path
tensor_stream-0.8.1 lib/tensor_stream/evaluator/opencl/opencl_template_helper.rb
tensor_stream-0.8.0 lib/tensor_stream/evaluator/opencl/opencl_template_helper.rb
tensor_stream-0.7.0 lib/tensor_stream/evaluator/opencl/opencl_template_helper.rb
tensor_stream-0.6.1 lib/tensor_stream/evaluator/opencl/opencl_template_helper.rb
tensor_stream-0.6.0 lib/tensor_stream/evaluator/opencl/opencl_template_helper.rb