#######################################################
# Copyright (c) 2015, ArrayFire
# All rights reserved.
#
# This file is distributed under 3-clause BSD license.
# The complete license agreement can be obtained at:
# http://arrayfire.com/licenses/BSD-3-Clause
########################################################
"""
Math functions (sin, sqrt, exp, etc).
"""
from .library import *
from .array import *
from .bcast import _bcast_var
from .util import _is_number
def _arith_binary_func(lhs, rhs, c_func):
out = Array()
is_left_array = isinstance(lhs, Array)
is_right_array = isinstance(rhs, Array)
if not (is_left_array or is_right_array):
raise TypeError("Atleast one input needs to be of type arrayfire.array")
elif (is_left_array and is_right_array):
safe_call(c_func(c_pointer(out.arr), lhs.arr, rhs.arr, _bcast_var.get()))
elif (_is_number(rhs)):
ldims = dim4_to_tuple(lhs.dims())
rty = implicit_dtype(rhs, lhs.type())
other = Array()
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty)
safe_call(c_func(c_pointer(out.arr), lhs.arr, other.arr, _bcast_var.get()))
else:
rdims = dim4_to_tuple(rhs.dims())
lty = implicit_dtype(lhs, rhs.type())
other = Array()
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], lty)
safe_call(c_func(c_pointer(out.arr), other.arr, rhs.arr, _bcast_var.get()))
return out
def _arith_unary_func(a, c_func):
out = Array()
safe_call(c_func(c_pointer(out.arr), a.arr))
return out
[docs]def cast(a, dtype):
"""
Cast an array to a specified type
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
dtype: af.Dtype
Must be one of the following:
- Dtype.f32 for float
- Dtype.f64 for double
- Dtype.b8 for bool
- Dtype.u8 for unsigned char
- Dtype.s32 for signed 32 bit integer
- Dtype.u32 for unsigned 32 bit integer
- Dtype.s64 for signed 64 bit integer
- Dtype.u64 for unsigned 64 bit integer
- Dtype.c32 for 32 bit complex number
- Dtype.c64 for 64 bit complex number
Returns
--------
out : af.Array
array containing the values from `a` after converting to `dtype`.
"""
out=Array()
safe_call(backend.get().af_cast(c_pointer(out.arr), a.arr, dtype.value))
return out
[docs]def minof(lhs, rhs):
"""
Find the minimum value of two inputs at each location.
Parameters
----------
lhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
rhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
Returns
--------
out : af.Array
array containing the minimum value at each location of the inputs.
Note
-------
- Atleast one of `lhs` and `rhs` needs to be af.Array.
- If `lhs` and `rhs` are both af.Array, they must be of same size.
"""
return _arith_binary_func(lhs, rhs, backend.get().af_minof)
[docs]def maxof(lhs, rhs):
"""
Find the maximum value of two inputs at each location.
Parameters
----------
lhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
rhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
Returns
--------
out : af.Array
array containing the maximum value at each location of the inputs.
Note
-------
- Atleast one of `lhs` and `rhs` needs to be af.Array.
- If `lhs` and `rhs` are both af.Array, they must be of same size.
"""
return _arith_binary_func(lhs, rhs, backend.get().af_maxof)
[docs]def clamp(val, low, high):
"""
Clamp the input value between low and high
Parameters
----------
val : af.Array
Multi dimensional arrayfire array to be clamped.
low : af.Array or scalar
Multi dimensional arrayfire array or a scalar number denoting the lower value(s).
high : af.Array or scalar
Multi dimensional arrayfire array or a scalar number denoting the higher value(s).
"""
out = Array()
is_low_array = isinstance(low, Array)
is_high_array = isinstance(high, Array)
vdims = dim4_to_tuple(val.dims())
vty = val.type()
if not is_low_array:
low_arr = constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty)
else:
low_arr = low.arr
if not is_high_array:
high_arr = constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty)
else:
high_arr = high.arr
safe_call(backend.get().af_clamp(c_pointer(out.arr), val.arr, low_arr, high_arr, _bcast_var.get()))
return out
[docs]def mod(lhs, rhs):
"""
Find the modulus.
Parameters
----------
lhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
rhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
Returns
--------
out : af.Array
Contains the moduli after dividing each value of lhs` with those in `rhs`.
Note
-------
- Atleast one of `lhs` and `rhs` needs to be af.Array.
- If `lhs` and `rhs` are both af.Array, they must be of same size.
"""
return _arith_binary_func(lhs, rhs, backend.get().af_mod)
[docs]def rem(lhs, rhs):
"""
Find the remainder.
Parameters
----------
lhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
rhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
Returns
--------
out : af.Array
Contains the remainders after dividing each value of lhs` with those in `rhs`.
Note
-------
- Atleast one of `lhs` and `rhs` needs to be af.Array.
- If `lhs` and `rhs` are both af.Array, they must be of same size.
"""
return _arith_binary_func(lhs, rhs, backend.get().af_rem)
[docs]def abs(a):
"""
Find the absolute values.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
Contains the absolute values of the inputs.
"""
return _arith_unary_func(a, backend.get().af_abs)
[docs]def arg(a):
"""
Find the theta value of the inputs in polar co-ordinates.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
Contains the theta values.
"""
return _arith_unary_func(a, backend.get().af_arg)
[docs]def sign(a):
"""
Find the sign of the inputs.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing 1 for negative values, 0 otherwise.
"""
return _arith_unary_func(a, backend.get().af_sign)
[docs]def round(a):
"""
Round the values to nearest integer.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the values rounded to nearest integer.
"""
return _arith_unary_func(a, backend.get().af_round)
[docs]def trunc(a):
"""
Round the values towards zero.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the truncated values.
"""
return _arith_unary_func(a, backend.get().af_trunc)
[docs]def floor(a):
"""
Round the values towards a smaller integer.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the floored values.
"""
return _arith_unary_func(a, backend.get().af_floor)
[docs]def ceil(a):
"""
Round the values towards a bigger integer.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the ceiled values.
"""
return _arith_unary_func(a, backend.get().af_ceil)
[docs]def hypot(lhs, rhs):
"""
Find the value of the hypotunese.
Parameters
----------
lhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
rhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
Returns
--------
out : af.Array
Contains the value of `sqrt(lhs**2, rhs**2)`.
Note
-------
- Atleast one of `lhs` and `rhs` needs to be af.Array.
- If `lhs` and `rhs` are both af.Array, they must be of same size.
"""
return _arith_binary_func(lhs, rhs, backend.get().af_hypot)
[docs]def sin(a):
"""
Sine of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the sine of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_sin)
[docs]def cos(a):
"""
Cosine of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the cosine of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_cos)
[docs]def tan(a):
"""
Tangent of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the tangent of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_tan)
[docs]def asin(a):
"""
Arc Sine of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the arc sine of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_asin)
[docs]def acos(a):
"""
Arc Cosine of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the arc cosine of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_acos)
[docs]def atan(a):
"""
Arc Tangent of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the arc tangent of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_atan)
[docs]def atan2(lhs, rhs):
"""
Find the arc tan using two values.
Parameters
----------
lhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
rhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
Returns
--------
out : af.Array
Contains the value arc tan values where:
- `lhs` contains the sine values.
- `rhs` contains the cosine values.
Note
-------
- Atleast one of `lhs` and `rhs` needs to be af.Array.
- If `lhs` and `rhs` are both af.Array, they must be of same size.
"""
return _arith_binary_func(lhs, rhs, backend.get().af_atan2)
[docs]def cplx(lhs, rhs=None):
"""
Create a complex array from real inputs.
Parameters
----------
lhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
rhs : optional: af.Array or scalar. default: None.
Multi dimensional arrayfire array or a scalar number.
Returns
--------
out : af.Array
Contains complex values whose
- real values contain values from `lhs`
- imaginary values contain values from `rhs` (0 if `rhs` is None)
Note
-------
- Atleast one of `lhs` and `rhs` needs to be af.Array.
- If `lhs` and `rhs` are both af.Array, they must be of same size.
"""
if rhs is None:
return _arith_unary_func(lhs, backend.get().af_cplx)
else:
return _arith_binary_func(lhs, rhs, backend.get().af_cplx2)
[docs]def real(a):
"""
Find the real values of the input.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the real values from `a`.
"""
return _arith_unary_func(a, backend.get().af_real)
[docs]def imag(a):
"""
Find the imaginary values of the input.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the imaginary values from `a`.
"""
return _arith_unary_func(a, backend.get().af_imag)
[docs]def conjg(a):
"""
Find the complex conjugate values of the input.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing copmplex conjugate values from `a`.
"""
return _arith_unary_func(a, backend.get().af_conjg)
[docs]def sinh(a):
"""
Hyperbolic Sine of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the hyperbolic sine of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_sinh)
[docs]def cosh(a):
"""
Hyperbolic Cosine of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the hyperbolic cosine of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_cosh)
[docs]def tanh(a):
"""
Hyperbolic Tangent of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the hyperbolic tangent of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_tanh)
[docs]def asinh(a):
"""
Arc Hyperbolic Sine of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the arc hyperbolic sine of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_asinh)
[docs]def acosh(a):
"""
Arc Hyperbolic Cosine of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the arc hyperbolic cosine of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_acosh)
[docs]def atanh(a):
"""
Arc Hyperbolic Tangent of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the arc hyperbolic tangent of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_atanh)
[docs]def root(lhs, rhs):
"""
Find the root values of two inputs at each location.
Parameters
----------
lhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
rhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
Returns
--------
out : af.Array
array containing the value of `lhs ** (1/rhs)`
Note
-------
- Atleast one of `lhs` and `rhs` needs to be af.Array.
- If `lhs` and `rhs` are both af.Array, they must be of same size.
"""
return _arith_binary_func(lhs, rhs, backend.get().af_root)
[docs]def pow(lhs, rhs):
"""
Find the power of two inputs at each location.
Parameters
----------
lhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
rhs : af.Array or scalar
Multi dimensional arrayfire array or a scalar number.
Returns
--------
out : af.Array
array containing the value of `lhs ** (rhs)`
Note
-------
- Atleast one of `lhs` and `rhs` needs to be af.Array.
- If `lhs` and `rhs` are both af.Array, they must be of same size.
"""
return _arith_binary_func(lhs, rhs, backend.get().af_pow)
[docs]def pow2(a):
"""
Raise 2 to the power of each element in input.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array where each element is 2 raised to power of the corresponding value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_pow2)
[docs]def sigmoid(a):
"""
Raise 2 to the power of each element in input.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array where each element is outout of a sigmoid function for the corresponding value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_sigmoid)
[docs]def exp(a):
"""
Exponential of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the exponential of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_exp)
[docs]def expm1(a):
"""
Exponential of each element in the array minus 1.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the exponential of each value from `a`.
Note
-------
- `a` must not be complex.
- This function provides a more stable result for small values of `a`.
"""
return _arith_unary_func(a, backend.get().af_expm1)
[docs]def erf(a):
"""
Error function of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the error function of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_erf)
[docs]def erfc(a):
"""
Complementary error function of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the complementary error function of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_erfc)
[docs]def log(a):
"""
Natural logarithm of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the natural logarithm of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_log)
[docs]def log1p(a):
"""
Logarithm of each element in the array plus 1.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the the values of `log(a) + 1`
Note
-------
- `a` must not be complex.
- This function provides a more stable result for small values of `a`.
"""
return _arith_unary_func(a, backend.get().af_log1p)
[docs]def log10(a):
"""
Logarithm base 10 of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the logarithm base 10 of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_log10)
[docs]def log2(a):
"""
Logarithm base 2 of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the logarithm base 2 of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_log2)
[docs]def sqrt(a):
"""
Square root of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the square root of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_sqrt)
[docs]def rsqrt(a):
"""
Reciprocal or inverse square root of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the inverse square root of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_rsqrt)
[docs]def cbrt(a):
"""
Cube root of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the cube root of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_cbrt)
[docs]def factorial(a):
"""
factorial of each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the factorial of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_factorial)
[docs]def tgamma(a):
"""
Performs the gamma function for each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the output of gamma function of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_tgamma)
[docs]def lgamma(a):
"""
Performs the logarithm of gamma function for each element in the array.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the output of logarithm of gamma function of each value from `a`.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_lgamma)
[docs]def iszero(a):
"""
Check if each element of the input is zero.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the output after checking if each value of `a` is 0.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_iszero)
[docs]def isinf(a):
"""
Check if each element of the input is infinity.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the output after checking if each value of `a` is inifnite.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_isinf)
[docs]def isnan(a):
"""
Check if each element of the input is NaN.
Parameters
----------
a : af.Array
Multi dimensional arrayfire array.
Returns
--------
out : af.Array
array containing the output after checking if each value of `a` is NaN.
Note
-------
`a` must not be complex.
"""
return _arith_unary_func(a, backend.get().af_isnan)