Sha256: a761f32d80508d1825832a03abb379883c2b4c0099086ce813ed24331bce73c8
Contents?: true
Size: 1.13 KB
Versions: 13
Compression:
Stored size: 1.13 KB
Contents
% c_dtype = dtype_to_c_type(dtype) // same dimension add floating point op __kernel void apply_centered_rms_prop_<%= dtype %>(__global const <%= c_dtype %> *lr, __global const <%= c_dtype %> *rho, __global const <%= c_dtype %> *momentum, __global const <%= c_dtype %> *epsilon, __global const <%= c_dtype %> *grad, __global <%= c_dtype %> *output, __global <%= c_dtype %> *ms, __global <%= c_dtype %> *mg, __global <%= c_dtype %> *mom) { // Get the index of the current element to be processed const int id = get_global_id(0); ms[id] += (grad[id] * grad[id] - ms[id]) * (1.0 - rho[0]); <%= c_dtype %> denom = ms[id] - mg[id] * mg[id] + epsilon[0]; mg[id] = (grad[id] - mg[id]) * (1.0 - rho[0]); mom[id] = mom[id] * momentum[0] + (grad[id] * lr[0]) / sqrt(denom); output[id] -= mom[id]; }
Version data entries
13 entries across 13 versions & 1 rubygems