Source code for arrayfire.blas

#######################################################
# Copyright (c) 2015, ArrayFire
# All rights reserved.
#
# This file is distributed under 3-clause BSD license.
# The complete license agreement can be obtained at:
# http://arrayfire.com/licenses/BSD-3-Clause
########################################################

"""
BLAS functions (matmul, dot, etc)
"""

from .library import *
from .array import *

[docs]def matmul(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE): """ Generalized matrix multiplication for two matrices. Parameters ---------- lhs : af.Array A 2 dimensional, real or complex arrayfire array. rhs : af.Array A 2 dimensional, real or complex arrayfire array. lhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE. Can be one of - af.MATPROP.NONE - If no op should be done on `lhs`. - af.MATPROP.TRANS - If `lhs` has to be transposed before multiplying. - af.MATPROP.CTRANS - If `lhs` has to be hermitian transposed before multiplying. rhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE. Can be one of - af.MATPROP.NONE - If no op should be done on `rhs`. - af.MATPROP.TRANS - If `rhs` has to be transposed before multiplying. - af.MATPROP.CTRANS - If `rhs` has to be hermitian transposed before multiplying. Returns ------- out : af.Array Output of the matrix multiplication on `lhs` and `rhs`. Note ----- - The data types of `lhs` and `rhs` should be the same. - Batches are not supported. """ out = Array() safe_call(backend.get().af_matmul(c_pointer(out.arr), lhs.arr, rhs.arr, lhs_opts.value, rhs_opts.value)) return out
[docs]def matmulTN(lhs, rhs): """ Matrix multiplication after transposing the first matrix. Parameters ---------- lhs : af.Array A 2 dimensional, real or complex arrayfire array. rhs : af.Array A 2 dimensional, real or complex arrayfire array. Returns ------- out : af.Array Output of the matrix multiplication on `transpose(lhs)` and `rhs`. Note ----- - The data types of `lhs` and `rhs` should be the same. - Batches are not supported. """ out = Array() safe_call(backend.get().af_matmul(c_pointer(out.arr), lhs.arr, rhs.arr, MATPROP.TRANS.value, MATPROP.NONE.value)) return out
[docs]def matmulNT(lhs, rhs): """ Matrix multiplication after transposing the second matrix. Parameters ---------- lhs : af.Array A 2 dimensional, real or complex arrayfire array. rhs : af.Array A 2 dimensional, real or complex arrayfire array. Returns ------- out : af.Array Output of the matrix multiplication on `lhs` and `transpose(rhs)`. Note ----- - The data types of `lhs` and `rhs` should be the same. - Batches are not supported. """ out = Array() safe_call(backend.get().af_matmul(c_pointer(out.arr), lhs.arr, rhs.arr, MATPROP.NONE.value, MATPROP.TRANS.value)) return out
[docs]def matmulTT(lhs, rhs): """ Matrix multiplication after transposing both inputs. Parameters ---------- lhs : af.Array A 2 dimensional, real or complex arrayfire array. rhs : af.Array A 2 dimensional, real or complex arrayfire array. Returns ------- out : af.Array Output of the matrix multiplication on `transpose(lhs)` and `transpose(rhs)`. Note ----- - The data types of `lhs` and `rhs` should be the same. - Batches are not supported. """ out = Array() safe_call(backend.get().af_matmul(c_pointer(out.arr), lhs.arr, rhs.arr, MATPROP.TRANS.value, MATPROP.TRANS.value)) return out
[docs]def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE, return_scalar = False): """ Dot product of two input vectors. Parameters ---------- lhs : af.Array A 1 dimensional, real or complex arrayfire array. rhs : af.Array A 1 dimensional, real or complex arrayfire array. lhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE. Can be one of - af.MATPROP.NONE - If no op should be done on `lhs`. - No other options are currently supported. rhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE. Can be one of - af.MATPROP.NONE - If no op should be done on `rhs`. - No other options are currently supported. return_scalar: optional: bool. default: False. - When set to true, the input arrays are flattened and the output is a scalar Returns ------- out : af.Array or scalar Output of dot product of `lhs` and `rhs`. Note ----- - The data types of `lhs` and `rhs` should be the same. - Batches are not supported. """ if return_scalar: real = c_double_t(0) imag = c_double_t(0) safe_call(backend.get().af_dot_all(c_pointer(real), c_pointer(imag), lhs.arr, rhs.arr, lhs_opts.value, rhs_opts.value)) real = real.value imag = imag.value return real if imag == 0 else real + imag * 1j else: out = Array() safe_call(backend.get().af_dot(c_pointer(out.arr), lhs.arr, rhs.arr, lhs_opts.value, rhs_opts.value)) return out
[docs]def gemm(lhs, rhs, alpha=1.0, beta=0.0, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE, C=None): """ BLAS general matrix multiply (GEMM) of two af_array objects. This provides a general interface to the BLAS level 3 general matrix multiply (GEMM), which is generally defined as: C = alpha * opA(A) opB(B) + beta * C where alpha and beta are both scalars; A and B are the matrix multiply operands; and opA and opB are noop (if AF_MAT_NONE) or transpose (if AF_MAT_TRANS) operations on A or B before the actual GEMM operation. Batched GEMM is supported if at least either A or B have more than two dimensions (see af::matmul for more details on broadcasting). However, only one alpha and one beta can be used for all of the batched matrix operands. Parameters ---------- lhs : af.Array A 2 dimensional, real or complex arrayfire array. rhs : af.Array A 2 dimensional, real or complex arrayfire array. alpha : scalar beta : scalar lhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE. Can be one of - af.MATPROP.NONE - If no op should be done on `lhs`. - af.MATPROP.TRANS - If `lhs` has to be transposed before multiplying. - af.MATPROP.CTRANS - If `lhs` has to be hermitian transposed before multiplying. rhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE. Can be one of - af.MATPROP.NONE - If no op should be done on `rhs`. - af.MATPROP.TRANS - If `rhs` has to be transposed before multiplying. - af.MATPROP.CTRANS - If `rhs` has to be hermitian transposed before multiplying. Returns ------- out : af.Array Output of the matrix multiplication on `lhs` and `rhs`. Note ----- - The data types of `lhs` and `rhs` should be the same. - Batches are not supported. """ if C is None: out = Array() else: out = C ltype = lhs.dtype() if ltype == Dtype.f32: aptr = c_cast(c_pointer(c_float_t(alpha)),c_void_ptr_t) bptr = c_cast(c_pointer(c_float_t(beta)), c_void_ptr_t) elif ltype == Dtype.c32: if isinstance(alpha, af_cfloat_t): aptr = c_cast(c_pointer(alpha), c_void_ptr_t) elif isinstance(alpha, tuple): aptr = c_cast(c_pointer(af_cfloat_t(alpha[0], alpha[1])), c_void_ptr_t) else: aptr = c_cast(c_pointer(af_cfloat_t(alpha)), c_void_ptr_t) if isinstance(beta, af_cfloat_t): bptr = c_cast(c_pointer(beta), c_void_ptr_t) elif isinstance(beta, tuple): bptr = c_cast(c_pointer(af_cfloat_t(beta[0], beta[1])), c_void_ptr_t) else: bptr = c_cast(c_pointer(af_cfloat_t(beta)), c_void_ptr_t) elif ltype == Dtype.f64: aptr = c_cast(c_pointer(c_double_t(alpha)),c_void_ptr_t) bptr = c_cast(c_pointer(c_double_t(beta)), c_void_ptr_t) elif ltype == Dtype.c64: if isinstance(alpha, af_cdouble_t): aptr = c_cast(c_pointer(alpha), c_void_ptr_t) elif isinstance(alpha, tuple): aptr = c_cast(c_pointer(af_cdouble_t(alpha[0], alpha[1])), c_void_ptr_t) else: aptr = c_cast(c_pointer(af_cdouble_t(alpha)), c_void_ptr_t) if isinstance(beta, af_cdouble_t): bptr = c_cast(c_pointer(beta), c_void_ptr_t) elif isinstance(beta, tuple): bptr = c_cast(c_pointer(af_cdouble_t(beta[0], beta[1])), c_void_ptr_t) else: bptr = c_cast(c_pointer(af_cdouble_t(beta)), c_void_ptr_t) elif ltype == Dtype.f16: raise TypeError("fp16 currently unsupported gemm() input type") else: raise TypeError("unsupported input type") safe_call(backend.get().af_gemm(c_pointer(out.arr), lhs_opts.value, rhs_opts.value, aptr, lhs.arr, rhs.arr, bptr)) return out