Source code for arrayfire.arith

#######################################################
# Copyright (c) 2015, ArrayFire
# All rights reserved.
#
# This file is distributed under 3-clause BSD license.
# The complete license agreement can be obtained at:
# http://arrayfire.com/licenses/BSD-3-Clause
########################################################

"""
Math functions (sin, sqrt, exp, etc).
"""

from .library import *
from .array import *
from .bcast import _bcast_var
from .util import _is_number

def _arith_binary_func(lhs, rhs, c_func):
    out = Array()

    is_left_array = isinstance(lhs, Array)
    is_right_array = isinstance(rhs, Array)

    if not (is_left_array or is_right_array):
        raise TypeError("Atleast one input needs to be of type arrayfire.array")

    elif (is_left_array and is_right_array):
        safe_call(c_func(c_pointer(out.arr), lhs.arr, rhs.arr, _bcast_var.get()))

    elif (_is_number(rhs)):
        ldims = dim4_to_tuple(lhs.dims())
        rty = implicit_dtype(rhs, lhs.type())
        other = Array()
        other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty)
        safe_call(c_func(c_pointer(out.arr), lhs.arr, other.arr, _bcast_var.get()))

    else:
        rdims = dim4_to_tuple(rhs.dims())
        lty = implicit_dtype(lhs, rhs.type())
        other = Array()
        other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], lty)
        safe_call(c_func(c_pointer(out.arr), other.arr, rhs.arr, _bcast_var.get()))

    return out

def _arith_unary_func(a, c_func):
    out = Array()
    safe_call(c_func(c_pointer(out.arr), a.arr))
    return out

[docs]def cast(a, dtype): """ Cast an array to a specified type Parameters ---------- a : af.Array Multi dimensional arrayfire array. dtype: af.Dtype Must be one of the following: - Dtype.f32 for float - Dtype.f64 for double - Dtype.b8 for bool - Dtype.u8 for unsigned char - Dtype.s32 for signed 32 bit integer - Dtype.u32 for unsigned 32 bit integer - Dtype.s64 for signed 64 bit integer - Dtype.u64 for unsigned 64 bit integer - Dtype.c32 for 32 bit complex number - Dtype.c64 for 64 bit complex number Returns -------- out : af.Array array containing the values from `a` after converting to `dtype`. """ out=Array() safe_call(backend.get().af_cast(c_pointer(out.arr), a.arr, dtype.value)) return out
[docs]def minof(lhs, rhs): """ Find the minimum value of two inputs at each location. Parameters ---------- lhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. rhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. Returns -------- out : af.Array array containing the minimum value at each location of the inputs. Note ------- - Atleast one of `lhs` and `rhs` needs to be af.Array. - If `lhs` and `rhs` are both af.Array, they must be of same size. """ return _arith_binary_func(lhs, rhs, backend.get().af_minof)
[docs]def maxof(lhs, rhs): """ Find the maximum value of two inputs at each location. Parameters ---------- lhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. rhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. Returns -------- out : af.Array array containing the maximum value at each location of the inputs. Note ------- - Atleast one of `lhs` and `rhs` needs to be af.Array. - If `lhs` and `rhs` are both af.Array, they must be of same size. """ return _arith_binary_func(lhs, rhs, backend.get().af_maxof)
[docs]def clamp(val, low, high): """ Clamp the input value between low and high Parameters ---------- val : af.Array Multi dimensional arrayfire array to be clamped. low : af.Array or scalar Multi dimensional arrayfire array or a scalar number denoting the lower value(s). high : af.Array or scalar Multi dimensional arrayfire array or a scalar number denoting the higher value(s). """ out = Array() is_low_array = isinstance(low, Array) is_high_array = isinstance(high, Array) vdims = dim4_to_tuple(val.dims()) vty = val.type() if not is_low_array: low_arr = constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty) else: low_arr = low.arr if not is_high_array: high_arr = constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty) else: high_arr = high.arr safe_call(backend.get().af_clamp(c_pointer(out.arr), val.arr, low_arr, high_arr, _bcast_var.get())) return out
[docs]def mod(lhs, rhs): """ Find the modulus. Parameters ---------- lhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. rhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. Returns -------- out : af.Array Contains the moduli after dividing each value of lhs` with those in `rhs`. Note ------- - Atleast one of `lhs` and `rhs` needs to be af.Array. - If `lhs` and `rhs` are both af.Array, they must be of same size. """ return _arith_binary_func(lhs, rhs, backend.get().af_mod)
[docs]def rem(lhs, rhs): """ Find the remainder. Parameters ---------- lhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. rhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. Returns -------- out : af.Array Contains the remainders after dividing each value of lhs` with those in `rhs`. Note ------- - Atleast one of `lhs` and `rhs` needs to be af.Array. - If `lhs` and `rhs` are both af.Array, they must be of same size. """ return _arith_binary_func(lhs, rhs, backend.get().af_rem)
[docs]def abs(a): """ Find the absolute values. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array Contains the absolute values of the inputs. """ return _arith_unary_func(a, backend.get().af_abs)
[docs]def arg(a): """ Find the theta value of the inputs in polar co-ordinates. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array Contains the theta values. """ return _arith_unary_func(a, backend.get().af_arg)
[docs]def sign(a): """ Find the sign of the inputs. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing 1 for negative values, 0 otherwise. """ return _arith_unary_func(a, backend.get().af_sign)
[docs]def round(a): """ Round the values to nearest integer. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the values rounded to nearest integer. """ return _arith_unary_func(a, backend.get().af_round)
[docs]def trunc(a): """ Round the values towards zero. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the truncated values. """ return _arith_unary_func(a, backend.get().af_trunc)
[docs]def floor(a): """ Round the values towards a smaller integer. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the floored values. """ return _arith_unary_func(a, backend.get().af_floor)
[docs]def ceil(a): """ Round the values towards a bigger integer. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the ceiled values. """ return _arith_unary_func(a, backend.get().af_ceil)
[docs]def hypot(lhs, rhs): """ Find the value of the hypotunese. Parameters ---------- lhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. rhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. Returns -------- out : af.Array Contains the value of `sqrt(lhs**2, rhs**2)`. Note ------- - Atleast one of `lhs` and `rhs` needs to be af.Array. - If `lhs` and `rhs` are both af.Array, they must be of same size. """ return _arith_binary_func(lhs, rhs, backend.get().af_hypot)
[docs]def sin(a): """ Sine of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the sine of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_sin)
[docs]def cos(a): """ Cosine of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the cosine of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_cos)
[docs]def tan(a): """ Tangent of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the tangent of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_tan)
[docs]def asin(a): """ Arc Sine of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the arc sine of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_asin)
[docs]def acos(a): """ Arc Cosine of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the arc cosine of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_acos)
[docs]def atan(a): """ Arc Tangent of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the arc tangent of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_atan)
[docs]def atan2(lhs, rhs): """ Find the arc tan using two values. Parameters ---------- lhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. rhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. Returns -------- out : af.Array Contains the value arc tan values where: - `lhs` contains the sine values. - `rhs` contains the cosine values. Note ------- - Atleast one of `lhs` and `rhs` needs to be af.Array. - If `lhs` and `rhs` are both af.Array, they must be of same size. """ return _arith_binary_func(lhs, rhs, backend.get().af_atan2)
[docs]def cplx(lhs, rhs=None): """ Create a complex array from real inputs. Parameters ---------- lhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. rhs : optional: af.Array or scalar. default: None. Multi dimensional arrayfire array or a scalar number. Returns -------- out : af.Array Contains complex values whose - real values contain values from `lhs` - imaginary values contain values from `rhs` (0 if `rhs` is None) Note ------- - Atleast one of `lhs` and `rhs` needs to be af.Array. - If `lhs` and `rhs` are both af.Array, they must be of same size. """ if rhs is None: return _arith_unary_func(lhs, backend.get().af_cplx) else: return _arith_binary_func(lhs, rhs, backend.get().af_cplx2)
[docs]def real(a): """ Find the real values of the input. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the real values from `a`. """ return _arith_unary_func(a, backend.get().af_real)
[docs]def imag(a): """ Find the imaginary values of the input. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the imaginary values from `a`. """ return _arith_unary_func(a, backend.get().af_imag)
[docs]def conjg(a): """ Find the complex conjugate values of the input. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing copmplex conjugate values from `a`. """ return _arith_unary_func(a, backend.get().af_conjg)
[docs]def sinh(a): """ Hyperbolic Sine of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the hyperbolic sine of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_sinh)
[docs]def cosh(a): """ Hyperbolic Cosine of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the hyperbolic cosine of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_cosh)
[docs]def tanh(a): """ Hyperbolic Tangent of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the hyperbolic tangent of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_tanh)
[docs]def asinh(a): """ Arc Hyperbolic Sine of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the arc hyperbolic sine of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_asinh)
[docs]def acosh(a): """ Arc Hyperbolic Cosine of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the arc hyperbolic cosine of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_acosh)
[docs]def atanh(a): """ Arc Hyperbolic Tangent of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the arc hyperbolic tangent of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_atanh)
[docs]def root(lhs, rhs): """ Find the root values of two inputs at each location. Parameters ---------- lhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. rhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. Returns -------- out : af.Array array containing the value of `lhs ** (1/rhs)` Note ------- - Atleast one of `lhs` and `rhs` needs to be af.Array. - If `lhs` and `rhs` are both af.Array, they must be of same size. """ return _arith_binary_func(lhs, rhs, backend.get().af_root)
[docs]def pow(lhs, rhs): """ Find the power of two inputs at each location. Parameters ---------- lhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. rhs : af.Array or scalar Multi dimensional arrayfire array or a scalar number. Returns -------- out : af.Array array containing the value of `lhs ** (rhs)` Note ------- - Atleast one of `lhs` and `rhs` needs to be af.Array. - If `lhs` and `rhs` are both af.Array, they must be of same size. """ return _arith_binary_func(lhs, rhs, backend.get().af_pow)
[docs]def pow2(a): """ Raise 2 to the power of each element in input. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array where each element is 2 raised to power of the corresponding value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_pow2)
[docs]def sigmoid(a): """ Raise 2 to the power of each element in input. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array where each element is outout of a sigmoid function for the corresponding value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_sigmoid)
[docs]def exp(a): """ Exponential of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the exponential of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_exp)
[docs]def expm1(a): """ Exponential of each element in the array minus 1. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the exponential of each value from `a`. Note ------- - `a` must not be complex. - This function provides a more stable result for small values of `a`. """ return _arith_unary_func(a, backend.get().af_expm1)
[docs]def erf(a): """ Error function of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the error function of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_erf)
[docs]def erfc(a): """ Complementary error function of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the complementary error function of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_erfc)
[docs]def log(a): """ Natural logarithm of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the natural logarithm of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_log)
[docs]def log1p(a): """ Logarithm of each element in the array plus 1. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the the values of `log(a) + 1` Note ------- - `a` must not be complex. - This function provides a more stable result for small values of `a`. """ return _arith_unary_func(a, backend.get().af_log1p)
[docs]def log10(a): """ Logarithm base 10 of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the logarithm base 10 of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_log10)
[docs]def log2(a): """ Logarithm base 2 of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the logarithm base 2 of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_log2)
[docs]def sqrt(a): """ Square root of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the square root of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_sqrt)
[docs]def rsqrt(a): """ Reciprocal or inverse square root of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the inverse square root of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_rsqrt)
[docs]def cbrt(a): """ Cube root of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the cube root of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_cbrt)
[docs]def factorial(a): """ factorial of each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the factorial of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_factorial)
[docs]def tgamma(a): """ Performs the gamma function for each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the output of gamma function of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_tgamma)
[docs]def lgamma(a): """ Performs the logarithm of gamma function for each element in the array. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the output of logarithm of gamma function of each value from `a`. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_lgamma)
[docs]def iszero(a): """ Check if each element of the input is zero. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the output after checking if each value of `a` is 0. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_iszero)
[docs]def isinf(a): """ Check if each element of the input is infinity. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the output after checking if each value of `a` is inifnite. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_isinf)
[docs]def isnan(a): """ Check if each element of the input is NaN. Parameters ---------- a : af.Array Multi dimensional arrayfire array. Returns -------- out : af.Array array containing the output after checking if each value of `a` is NaN. Note ------- `a` must not be complex. """ return _arith_unary_func(a, backend.get().af_isnan)