lib/nmatrix/blas.rb in nmatrix-0.1.0 vs lib/nmatrix/blas.rb in nmatrix-0.2.0
- old
+ new
@@ -26,11 +26,19 @@
# This file contains the safer accessors for the BLAS functions
# supported by NMatrix.
#++
module NMatrix::BLAS
+
+ #Add functions from C extension to main BLAS module
class << self
+ NMatrix::Internal::BLAS.singleton_methods.each do |m|
+ define_method m, NMatrix::Internal::BLAS.method(m).to_proc
+ end
+ end
+
+ class << self
#
# call-seq:
# gemm(a, b) -> NMatrix
# gemm(a, b, c) -> NMatrix
# gemm(a, b, c, alpha, beta) -> NMatrix
@@ -135,25 +143,24 @@
def gemv(a, x, y = nil, alpha = 1.0, beta = 0.0, transpose_a = false, m = nil, n = nil, lda = nil, incx = nil, incy = nil)
raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless a.is_a?(NMatrix) and x.is_a?(NMatrix) and a.stype == :dense and x.stype == :dense
raise(ArgumentError, 'Expected nil or dense NMatrix as third argument.') unless y.nil? or (y.is_a?(NMatrix) and y.stype == :dense)
raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == x.dtype and (y ? a.dtype == y.dtype : true)
- m ||= transpose_a ? a.shape[1] : a.shape[0]
- n ||= transpose_a ? a.shape[0] : a.shape[1]
+ m ||= transpose_a == :transpose ? a.shape[1] : a.shape[0]
+ n ||= transpose_a == :transpose ? a.shape[0] : a.shape[1]
+ raise(ArgumentError, "dimensions don't match") unless x.shape[0] == n && x.shape[1] == 1
+ if y
+ raise(ArgumentError, "dimensions don't match") unless y.shape[0] == m && y.shape[1] == 1
+ else
+ y = NMatrix.new([m,1], dtype: a.dtype)
+ end
+
lda ||= a.shape[1]
incx ||= 1
incy ||= 1
- # NM_COMPLEX64 and NM_COMPLEX128 both require complex alpha and beta.
- if a.dtype == :complex64 or a.dtype == :complex128
- alpha = Complex(1.0, 0.0) if alpha == 1.0
- beta = Complex(0.0, 0.0) if beta == 0.0
- end
-
- y ||= NMatrix.new([m, n], dtype: a.dtype)
-
::NMatrix::BLAS.cblas_gemv(transpose_a, m, n, alpha, a, lda, x, incx, beta, y, incy)
return y
end
@@ -216,11 +223,11 @@
# call-seq:
# rotg(ab) -> [Numeric, Numeric]
#
# Apply givens plane rotation to the coordinates (a,b), returning the cosine and sine of the angle theta.
#
- # Since the givens rotation includes a square root, integers and rationals are disallowed.
+ # Since the givens rotation includes a square root, integers are disallowed.
#
# * *Arguments* :
# - +ab+ -> NMatrix with two elements
# * *Returns* :
# - Array with the results, in the format [cos(theta), sin(theta)]
@@ -276,8 +283,23 @@
def nrm2(x, incx = 1, n = nil)
n ||= x.size / incx
raise(ArgumentError, "Expected dense NMatrix for arg 0") unless x.is_a?(NMatrix)
raise(RangeError, "n out of range") if n*incx > x.size || n*incx <= 0 || n <= 0
::NMatrix::BLAS.cblas_nrm2(n, x, incx)
+ end
+
+ # The following are functions that used to be implemented in C, but
+ # now require nmatrix-atlas or nmatrix-lapcke to run properly, so we can just
+ # implemented their stubs in Ruby.
+ def cblas_trmm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb)
+ raise(NotImplementedError,"cblas_trmm requires either the nmatrix-lapacke or nmatrix-atlas gem")
+ end
+
+ def cblas_syrk(order, uplo, trans, n, k, alpha, a, lda, beta, c, ldc)
+ raise(NotImplementedError,"cblas_syrk requires either the nmatrix-lapacke or nmatrix-atlas gem")
+ end
+
+ def cblas_herk(order, uplo, trans, n, k, alpha, a, lda, beta, c, ldc)
+ raise(NotImplementedError,"cblas_herk requires either the nmatrix-lapacke or nmatrix-atlas gem")
end
end
end