ext/nmatrix/math/scal.h in nmatrix-0.1.0.rc3 vs ext/nmatrix/math/scal.h in nmatrix-0.1.0.rc4

- old
+ new

@@ -21,11 +21,11 @@ // // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement // // == scal.h // -// LAPACK scal function in native C. +// BLAS scal function. // #ifndef SCAL_H #define SCAL_H @@ -45,28 +45,48 @@ /* modified 12/3/93, array(1) declarations changed to array(*) */ /* ===================================================================== */ template <typename DType> -inline void scal(const int n, const DType da, DType* dx, const int incx) { +inline void scal(const int n, const DType scalar, DType* x, const int incx) { - // This used to have unrolled loops, like dswap. They were in the way. + if (n <= 0 || incx <= 0) { + return; + } - if (n <= 0 || incx <= 0) return; - for (int i = 0; incx < 0 ? i > n*incx : i < n*incx; i += incx) { - dx[i] = da * dx[i]; + x[i] = scalar * x[i]; } -} /* scal */ +} +#if defined HAVE_CBLAS_H || defined HAVE_ATLAS_CBLAS_H +template <> +inline void scal(const int n, const float scalar, float* x, const int incx) { + cblas_sscal(n, scalar, x, incx); +} +template <> +inline void scal(const int n, const double scalar, double* x, const int incx) { + cblas_dscal(n, scalar, x, incx); +} + +template <> +inline void scal(const int n, const Complex64 scalar, Complex64* x, const int incx) { + cblas_cscal(n, (const void*)(&scalar), (void*)(x), incx); +} + +template <> +inline void scal(const int n, const Complex128 scalar, Complex128* x, const int incx) { + cblas_zscal(n, (const void*)(&scalar), (void*)(x), incx); +} +#endif + /* * Function signature conversion for LAPACK's scal function. */ template <typename DType> -inline void clapack_scal(const int n, const void* da, void* dx, const int incx) { - // FIXME: See if we can call the clapack version instead of our C++ version. - scal<DType>(n, *reinterpret_cast<const DType*>(da), reinterpret_cast<DType*>(dx), incx); +inline void cblas_scal(const int n, const void* scalar, void* x, const int incx) { + scal<DType>(n, *reinterpret_cast<const DType*>(scalar), reinterpret_cast<DType*>(x), incx); } }} // end of nm::math #endif