Source code for arrayfire.util

#######################################################
# Copyright (c) 2015, ArrayFire
# All rights reserved.
#
# This file is distributed under 3-clause BSD license.
# The complete license agreement can be obtained at:
# http://arrayfire.com/licenses/BSD-3-Clause
########################################################

"""
Utility functions to help with Array metadata.
"""

from .library import *
import numbers

[docs]def dim4(d0=1, d1=1, d2=1, d3=1): c_dim4 = c_dim_t * 4 out = c_dim4(1, 1, 1, 1) for i, dim in enumerate((d0, d1, d2, d3)): if (dim is not None): out[i] = c_dim_t(dim) return out
def _is_number(a): return isinstance(a, numbers.Number)
[docs]def number_dtype(a): if isinstance(a, bool): return Dtype.b8 if isinstance(a, int): return Dtype.s64 elif isinstance(a, float): return Dtype.f64 elif isinstance(a, complex): return Dtype.c64 else: return to_dtype[a.dtype.char]
[docs]def implicit_dtype(number, a_dtype): n_dtype = number_dtype(number) n_value = n_dtype.value f64v = Dtype.f64.value f32v = Dtype.f32.value c32v = Dtype.c32.value c64v = Dtype.c64.value if n_value == f64v and (a_dtype == f32v or a_dtype == c32v): return Dtype.f32 if n_value == c64v and (a_dtype == f32v or a_dtype == c32v): return Dtype.c32 return n_dtype
[docs]def dim4_to_tuple(dims, default=1): assert(isinstance(dims, tuple)) if (default is not None): assert(_is_number(default)) out = [default]*4 for i, dim in enumerate(dims): out[i] = dim return tuple(out)
[docs]def to_str(c_str): return str(c_str.value.decode('utf-8'))
[docs]def safe_call(af_error): if (af_error != ERR.NONE.value): err_str = c_char_ptr_t(0) err_len = c_dim_t(0) backend.get().af_get_last_error(c_pointer(err_str), c_pointer(err_len)) raise RuntimeError(to_str(err_str))
[docs]def get_version(): """ Function to get the version of arrayfire. """ major=c_int_t(0) minor=c_int_t(0) patch=c_int_t(0) safe_call(backend.get().af_get_version(c_pointer(major), c_pointer(minor), c_pointer(patch))) return major.value,minor.value,patch.value
[docs]def get_reversion(): """ Function to get the revision hash of the library. """ return to_str(backend.get().af_get_revision())
to_dtype = {'f' : Dtype.f32, 'd' : Dtype.f64, 'b' : Dtype.b8, 'B' : Dtype.u8, 'h' : Dtype.s16, 'H' : Dtype.u16, 'i' : Dtype.s32, 'I' : Dtype.u32, 'l' : Dtype.s64, 'L' : Dtype.u64, 'F' : Dtype.c32, 'D' : Dtype.c64, 'hf': Dtype.f16} to_typecode = {Dtype.f32.value : 'f', Dtype.f64.value : 'd', Dtype.b8.value : 'b', Dtype.u8.value : 'B', Dtype.s16.value : 'h', Dtype.u16.value : 'H', Dtype.s32.value : 'i', Dtype.u32.value : 'I', Dtype.s64.value : 'l', Dtype.u64.value : 'L', Dtype.c32.value : 'F', Dtype.c64.value : 'D', Dtype.f16.value : 'hf'} to_c_type = {Dtype.f32.value : c_float_t, Dtype.f64.value : c_double_t, Dtype.b8.value : c_char_t, Dtype.u8.value : c_uchar_t, Dtype.s16.value : c_short_t, Dtype.u16.value : c_ushort_t, Dtype.s32.value : c_int_t, Dtype.u32.value : c_uint_t, Dtype.s64.value : c_longlong_t, Dtype.u64.value : c_ulonglong_t, Dtype.c32.value : c_float_t * 2, Dtype.c64.value : c_double_t * 2, Dtype.f16.value : c_ushort_t} to_typename = {Dtype.f32.value : 'float', Dtype.f64.value : 'double', Dtype.b8.value : 'bool', Dtype.u8.value : 'unsigned char', Dtype.s16.value : 'short int', Dtype.u16.value : 'unsigned short int', Dtype.s32.value : 'int', Dtype.u32.value : 'unsigned int', Dtype.s64.value : 'long int', Dtype.u64.value : 'unsigned long int', Dtype.c32.value : 'float complex', Dtype.c64.value : 'double complex', Dtype.f16.value : 'half'}