ext/nmatrix/math/scal.h in nmatrix-0.1.0.rc3 vs ext/nmatrix/math/scal.h in nmatrix-0.1.0.rc4
- old
+ new
@@ -21,11 +21,11 @@
//
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
//
// == scal.h
//
-// LAPACK scal function in native C.
+// BLAS scal function.
//
#ifndef SCAL_H
#define SCAL_H
@@ -45,28 +45,48 @@
/* modified 12/3/93, array(1) declarations changed to array(*) */
/* ===================================================================== */
template <typename DType>
-inline void scal(const int n, const DType da, DType* dx, const int incx) {
+inline void scal(const int n, const DType scalar, DType* x, const int incx) {
- // This used to have unrolled loops, like dswap. They were in the way.
+ if (n <= 0 || incx <= 0) {
+ return;
+ }
- if (n <= 0 || incx <= 0) return;
-
for (int i = 0; incx < 0 ? i > n*incx : i < n*incx; i += incx) {
- dx[i] = da * dx[i];
+ x[i] = scalar * x[i];
}
-} /* scal */
+}
+#if defined HAVE_CBLAS_H || defined HAVE_ATLAS_CBLAS_H
+template <>
+inline void scal(const int n, const float scalar, float* x, const int incx) {
+ cblas_sscal(n, scalar, x, incx);
+}
+template <>
+inline void scal(const int n, const double scalar, double* x, const int incx) {
+ cblas_dscal(n, scalar, x, incx);
+}
+
+template <>
+inline void scal(const int n, const Complex64 scalar, Complex64* x, const int incx) {
+ cblas_cscal(n, (const void*)(&scalar), (void*)(x), incx);
+}
+
+template <>
+inline void scal(const int n, const Complex128 scalar, Complex128* x, const int incx) {
+ cblas_zscal(n, (const void*)(&scalar), (void*)(x), incx);
+}
+#endif
+
/*
* Function signature conversion for LAPACK's scal function.
*/
template <typename DType>
-inline void clapack_scal(const int n, const void* da, void* dx, const int incx) {
- // FIXME: See if we can call the clapack version instead of our C++ version.
- scal<DType>(n, *reinterpret_cast<const DType*>(da), reinterpret_cast<DType*>(dx), incx);
+inline void cblas_scal(const int n, const void* scalar, void* x, const int incx) {
+ scal<DType>(n, *reinterpret_cast<const DType*>(scalar), reinterpret_cast<DType*>(x), incx);
}
}} // end of nm::math
#endif