lib/nmatrix/blas.rb in nmatrix-0.0.3 vs lib/nmatrix/blas.rb in nmatrix-0.0.4
- old
+ new
@@ -1,70 +1,196 @@
-module NMatrix::BLAS
+#--
+# = NMatrix
+#
+# A linear algebra library for scientific computation in Ruby.
+# NMatrix is part of SciRuby.
+#
+# NMatrix was originally inspired by and derived from NArray, by
+# Masahiro Tanaka: http://narray.rubyforge.org
+#
+# == Copyright Information
+#
+# SciRuby is Copyright (c) 2010 - 2013, Ruby Science Foundation
+# NMatrix is Copyright (c) 2013, Ruby Science Foundation
+#
+# Please see LICENSE.txt for additional copyright notices.
+#
+# == Contributing
+#
+# By contributing source code to SciRuby, you agree to be bound by
+# our Contributor Agreement:
+#
+# * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
+#
+# == blas.rb
+#
+# This file contains the safer accessors for the BLAS functions
+# supported by NMatrix.
+#++
+module NMatrix::BLAS
class << self
-
+ #
+ # call-seq:
+ # gemm(a, b) -> NMatrix
+ # gemm(a, b, c) -> NMatrix
+ # gemm(a, b, c, alpha, beta) -> NMatrix
+ #
+ # Updates the value of C via the matrix multiplication
+ # C = (alpha * A * B) + (beta * C)
+ # where +alpha+ and +beta+ are scalar values.
+ #
+ # * *Arguments* :
+ # - +a+ -> Matrix A.
+ # - +b+ -> Matrix B.
+ # - +c+ -> Matrix C.
+ # - +alpha+ -> A scalar value that multiplies A * B.
+ # - +beta+ -> A scalar value that multiplies C.
+ # - +transpose_a+ ->
+ # - +transpose_b+ ->
+ # - +m+ ->
+ # - +n+ ->
+ # - +k+ ->
+ # - +lda+ ->
+ # - +ldb+ ->
+ # - +ldc+ ->
+ # * *Returns* :
+ # - A NMatrix equal to (alpha * A * B) + (beta * C).
+ # * *Raises* :
+ # - +ArgumentError+ -> +a+ and +b+ must be dense matrices.
+ # - +ArgumentError+ -> +c+ must be +nil+ or a dense matrix.
+ # - +ArgumentError+ -> The dtype of the matrices must be equal.
+ #
def gemm(a, b, c = nil, alpha = 1.0, beta = 0.0, transpose_a = false, transpose_b = false, m = nil, n = nil, k = nil, lda = nil, ldb = nil, ldc = nil)
- raise ArgumentError, 'Expected dense NMatrices as first two arguments.' unless a.is_a?(NMatrix) and b.is_a?(NMatrix) and a.stype == :dense and b.stype == :dense
- raise ArgumentError, 'Expected nil or dense NMatrix as third argument.' unless c.nil? or (c.is_a?(NMatrix) and c.stype == :dense)
- raise ArgumentError, 'NMatrix dtype mismatch.' unless a.dtype == b.dtype and (c ? a.dtype == c.dtype : true)
+ raise ArgumentError, 'Expected dense NMatrices as first two arguments.' unless a.is_a?(NMatrix) and b.is_a?(NMatrix) and a.stype == :dense and b.stype == :dense
+ raise ArgumentError, 'Expected nil or dense NMatrix as third argument.' unless c.nil? or (c.is_a?(NMatrix) and c.stype == :dense)
+ raise ArgumentError, 'NMatrix dtype mismatch.' unless a.dtype == b.dtype and (c ? a.dtype == c.dtype : true)
- # First, set m, n, and k, which depend on whether we're taking the
- # transpose of a and b.
- if c
- m ||= c.shape[0]
- n ||= c.shape[1]
- k ||= transpose_a ? a.shape[0] : a.shape[1]
+ # First, set m, n, and k, which depend on whether we're taking the
+ # transpose of a and b.
+ if c
+ m ||= c.shape[0]
+ n ||= c.shape[1]
+ k ||= transpose_a ? a.shape[0] : a.shape[1]
- else
- if transpose_a
- # Either :transpose or :complex_conjugate.
- m ||= a.shape[1]
- k ||= a.shape[0]
+ else
+ if transpose_a
+ # Either :transpose or :complex_conjugate.
+ m ||= a.shape[1]
+ k ||= a.shape[0]
- else
- # No transpose.
- m ||= a.shape[0]
- k ||= a.shape[1]
- end
+ else
+ # No transpose.
+ m ||= a.shape[0]
+ k ||= a.shape[1]
+ end
- n ||= transpose_b ? b.shape[0] : b.shape[1]
- c = NMatrix.new([m, n], a.dtype)
- end
+ n ||= transpose_b ? b.shape[0] : b.shape[1]
+ c = NMatrix.new([m, n], a.dtype)
+ end
- # I think these are independent of whether or not a transpose occurs.
- lda ||= a.shape[1]
- ldb ||= b.shape[1]
- ldc ||= c.shape[1]
+ # I think these are independent of whether or not a transpose occurs.
+ lda ||= a.shape[1]
+ ldb ||= b.shape[1]
+ ldc ||= c.shape[1]
- # NM_COMPLEX64 and NM_COMPLEX128 both require complex alpha and beta.
- if a.dtype == :complex64 or a.dtype == :complex128
- alpha = Complex(1.0, 0.0) if alpha == 1.0
- beta = Complex(0.0, 0.0) if beta == 0.0
- end
+ # NM_COMPLEX64 and NM_COMPLEX128 both require complex alpha and beta.
+ if a.dtype == :complex64 or a.dtype == :complex128
+ alpha = Complex(1.0, 0.0) if alpha == 1.0
+ beta = Complex(0.0, 0.0) if beta == 0.0
+ end
- # For argument descriptions, see: http://www.netlib.org/blas/dgemm.f
- ::NMatrix::BLAS.cblas_gemm(:row, transpose_a, transpose_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
+ # For argument descriptions, see: http://www.netlib.org/blas/dgemm.f
+ ::NMatrix::BLAS.cblas_gemm(:row, transpose_a, transpose_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
- return c
+ return c
end
+ #
+ # call-seq:
+ # gemv(a, x) -> NVector
+ # gemv(a, x, y) -> NVector
+ # gemv(a, x, y, alpha, beta) -> NVector
+ #
+ # Implements matrix-vector product via
+ # y = (alpha * A * x) + (beta * y)
+ # where +alpha+ and +beta+ are scalar values.
+ #
+ # * *Arguments* :
+ # - +a+ -> Matrix A.
+ # - +x+ -> Vector x.
+ # - +y+ -> Vector y.
+ # - +alpha+ -> A scalar value that multiplies A * x.
+ # - +beta+ -> A scalar value that multiplies y.
+ # - +transpose_a+ ->
+ # - +m+ ->
+ # - +n+ ->
+ # - +lda+ ->
+ # - +incx+ ->
+ # - +incy+ ->
+ # * *Returns* :
+ # -
+ # * *Raises* :
+ # - ++ ->
+ #
def gemv(a, x, y = nil, alpha = 1.0, beta = 0.0, transpose_a = false, m = nil, n = nil, lda = nil, incx = nil, incy = nil)
- m ||= transpose_a ? a.shape[1] : a.shape[0]
- n ||= transpose_a ? a.shape[0] : a.shape[1]
+ raise ArgumentError, 'Expected dense NMatrices as first two arguments.' unless a.is_a?(NMatrix) and x.is_a?(NMatrix) and a.stype == :dense and x.stype == :dense
+ raise ArgumentError, 'Expected nil or dense NMatrix as third argument.' unless y.nil? or (y.is_a?(NMatrix) and y.stype == :dense)
+ raise ArgumentError, 'NMatrix dtype mismatch.' unless a.dtype == x.dtype and (y ? a.dtype == y.dtype : true)
- lda ||= a.shape[1]
- incx ||= 1
- incy ||= 1
+ m ||= transpose_a ? a.shape[1] : a.shape[0]
+ n ||= transpose_a ? a.shape[0] : a.shape[1]
- # NM_COMPLEX64 and NM_COMPLEX128 both require complex alpha and beta.
- if a.dtype == :complex64 or a.dtype == :complex128
- alpha = Complex(1.0, 0.0) if alpha == 1.0
- beta = Complex(0.0, 0.0) if beta == 0.0
- end
+ lda ||= a.shape[1]
+ incx ||= 1
+ incy ||= 1
- ::NMatrix::BLAS.cblas_gemv(transpose_a, m, n, alpha, a, lda, x, incx, beta, y, incy)
+ # NM_COMPLEX64 and NM_COMPLEX128 both require complex alpha and beta.
+ if a.dtype == :complex64 or a.dtype == :complex128
+ alpha = Complex(1.0, 0.0) if alpha == 1.0
+ beta = Complex(0.0, 0.0) if beta == 0.0
+ end
- return y
+ y ||= NMatrix.new([m, n], a.dtype)
+
+ ::NMatrix::BLAS.cblas_gemv(transpose_a, m, n, alpha, a, lda, x, incx, beta, y, incy)
+
+ return y
end
+ #
+ # call-seq:
+ # rot(x, y, c, s)
+ #
+ # Apply plane rotation.
+ #
+ # * *Arguments* :
+ # - +x+ ->
+ # - +y+ ->
+ # - +s+ ->
+ # - +c+ ->
+ # - +incx+ ->
+ # - +incy+ ->
+ # - +n+ ->
+ # * *Returns* :
+ # - Array with the results, in the format [xx, yy]
+ # * *Raises* :
+ # - +ArgumentError+ -> Expected dense NMatrices as first two arguments.
+ # - +ArgumentError+ -> Nmatrix dtype mismatch.
+ # - +ArgumentError+ -> Need to supply n for non-standard incx, incy values.
+ #
+ def rot(x, y, c, s, incx = 1, incy = 1, n = nil)
+ raise ArgumentError, 'Expected dense NMatrices as first two arguments.' unless x.is_a?(NMatrix) and y.is_a?(NMatrix) and x.stype == :dense and y.stype == :dense
+ raise ArgumentError, 'NMatrix dtype mismatch.' unless x.dtype == y.dtype
+ raise ArgumentError, 'Need to supply n for non-standard incx, incy values' if n.nil? && incx != 1 && incx != -1 && incy != 1 && incy != -1
+
+ n ||= x.size > y.size ? y.size : x.size
+
+ xx = x.clone
+ yy = y.clone
+
+ ::NMatrix::BLAS.cblas_rot(n, xx, incx, yy, incy, c, s)
+
+ return [xx,yy]
+ end
end
end