Sha256: e07ef686b57523a6e831512c081e76544ee62f37e6b5841180e35f31aae00a34
Contents?: true
Size: 722 Bytes
Versions: 13
Compression:
Stored size: 722 Bytes
Contents
% c_dtype = dtype_to_c_type(dtype) // same dimension add floating point op __kernel void apply_momentum_<%= dtype %>(__global const <%= c_dtype %> *grad, __global const <%= c_dtype %> *learning_rate, __global const <%= c_dtype %> *momentum, __global <%= c_dtype %> *output, __global <%= c_dtype %> *acc) { // Get the index of the current element to be processed const int index = get_global_id(0); <%= c_dtype %> acc_m = acc[index]; acc[index] = acc_m * momentum[0] + grad[index]; <% if nesterov %> output[index] -= grad[index] * learning_rate[0] + acc_m * momentum[0] * learning_rate[0]; <% else %> output[index] -= acc_m * learning_rate[0]; <% end %> }
Version data entries
13 entries across 13 versions & 1 rubygems