lib/nmatrix/blas.rb in nmatrix-0.2.0 vs lib/nmatrix/blas.rb in nmatrix-0.2.1
- old
+ new
@@ -69,11 +69,11 @@
# - +ArgumentError+ -> The dtype of the matrices must be equal.
#
def gemm(a, b, c = nil, alpha = 1.0, beta = 0.0, transpose_a = false, transpose_b = false, m = nil, n = nil, k = nil, lda = nil, ldb = nil, ldc = nil)
raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless a.is_a?(NMatrix) and b.is_a?(NMatrix) and a.stype == :dense and b.stype == :dense
raise(ArgumentError, 'Expected nil or dense NMatrix as third argument.') unless c.nil? or (c.is_a?(NMatrix) and c.stype == :dense)
- raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == b.dtype and (c ? a.dtype == c.dtype : true)
+ raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == b.dtype and (c ? a.dtype == c.dtype : true)
# First, set m, n, and k, which depend on whether we're taking the
# transpose of a and b.
if c
m ||= c.shape[0]
@@ -91,11 +91,11 @@
m ||= a.shape[0]
k ||= a.shape[1]
end
n ||= transpose_b ? b.shape[0] : b.shape[1]
- c = NMatrix.new([m, n], dtype: a.dtype)
+ c = NMatrix.new([m, n], dtype: a.dtype)
end
# I think these are independent of whether or not a transpose occurs.
lda ||= a.shape[1]
ldb ||= b.shape[1]
@@ -141,11 +141,11 @@
# - ++ ->
#
def gemv(a, x, y = nil, alpha = 1.0, beta = 0.0, transpose_a = false, m = nil, n = nil, lda = nil, incx = nil, incy = nil)
raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless a.is_a?(NMatrix) and x.is_a?(NMatrix) and a.stype == :dense and x.stype == :dense
raise(ArgumentError, 'Expected nil or dense NMatrix as third argument.') unless y.nil? or (y.is_a?(NMatrix) and y.stype == :dense)
- raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == x.dtype and (y ? a.dtype == y.dtype : true)
+ raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == x.dtype and (y ? a.dtype == y.dtype : true)
m ||= transpose_a == :transpose ? a.shape[1] : a.shape[0]
n ||= transpose_a == :transpose ? a.shape[0] : a.shape[1]
raise(ArgumentError, "dimensions don't match") unless x.shape[0] == n && x.shape[1] == 1
@@ -153,13 +153,13 @@
raise(ArgumentError, "dimensions don't match") unless y.shape[0] == m && y.shape[1] == 1
else
y = NMatrix.new([m,1], dtype: a.dtype)
end
- lda ||= a.shape[1]
- incx ||= 1
- incy ||= 1
+ lda ||= a.shape[1]
+ incx ||= 1
+ incy ||= 1
::NMatrix::BLAS.cblas_gemv(transpose_a, m, n, alpha, a, lda, x, incx, beta, y, incy)
return y
end
@@ -186,10 +186,10 @@
# - +ArgumentError+ -> NMatrix dtype mismatch.
# - +ArgumentError+ -> Need to supply n for non-standard incx, incy values.
#
def rot(x, y, c, s, incx = 1, incy = 1, n = nil, in_place=false)
raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless x.is_a?(NMatrix) and y.is_a?(NMatrix) and x.stype == :dense and y.stype == :dense
- raise(ArgumentError, 'NMatrix dtype mismatch.') unless x.dtype == y.dtype
+ raise(ArgumentError, 'NMatrix dtype mismatch.') unless x.dtype == y.dtype
raise(ArgumentError, 'Need to supply n for non-standard incx, incy values') if n.nil? && incx != 1 && incx != -1 && incy != 1 && incy != -1
n ||= [x.size/incx.abs, y.size/incy.abs].min
if in_place