use super::defines::{AfError, Backend, DType};
use super::dim4::Dim4;
use super::error::HANDLE_ERROR;
use super::util::{af_array, dim_t, void_ptr, HasAfEnum};
use libc::{c_char, c_int, c_longlong, c_uint, c_void};
use std::ffi::CString;
use std::marker::PhantomData;
extern "C" {
fn af_create_array(
out: *mut af_array,
data: *const c_void,
ndims: c_uint,
dims: *const dim_t,
aftype: c_uint,
) -> c_int;
fn af_create_handle(
out: *mut af_array,
ndims: c_uint,
dims: *const dim_t,
aftype: c_uint,
) -> c_int;
fn af_device_array(
out: *mut af_array,
data: *mut c_void,
ndims: c_uint,
dims: *const dim_t,
aftype: c_uint,
) -> c_int;
fn af_get_elements(out: *mut dim_t, arr: af_array) -> c_int;
fn af_get_type(out: *mut c_uint, arr: af_array) -> c_int;
fn af_get_dims(
dim0: *mut c_longlong,
dim1: *mut c_longlong,
dim2: *mut c_longlong,
dim3: *mut c_longlong,
arr: af_array,
) -> c_int;
fn af_get_numdims(result: *mut c_uint, arr: af_array) -> c_int;
fn af_is_empty(result: *mut bool, arr: af_array) -> c_int;
fn af_is_scalar(result: *mut bool, arr: af_array) -> c_int;
fn af_is_row(result: *mut bool, arr: af_array) -> c_int;
fn af_is_column(result: *mut bool, arr: af_array) -> c_int;
fn af_is_vector(result: *mut bool, arr: af_array) -> c_int;
fn af_is_complex(result: *mut bool, arr: af_array) -> c_int;
fn af_is_real(result: *mut bool, arr: af_array) -> c_int;
fn af_is_double(result: *mut bool, arr: af_array) -> c_int;
fn af_is_single(result: *mut bool, arr: af_array) -> c_int;
fn af_is_half(result: *mut bool, arr: af_array) -> c_int;
fn af_is_integer(result: *mut bool, arr: af_array) -> c_int;
fn af_is_bool(result: *mut bool, arr: af_array) -> c_int;
fn af_is_realfloating(result: *mut bool, arr: af_array) -> c_int;
fn af_is_floating(result: *mut bool, arr: af_array) -> c_int;
fn af_is_linear(result: *mut bool, arr: af_array) -> c_int;
fn af_is_owner(result: *mut bool, arr: af_array) -> c_int;
fn af_is_sparse(result: *mut bool, arr: af_array) -> c_int;
fn af_get_data_ptr(data: *mut c_void, arr: af_array) -> c_int;
fn af_eval(arr: af_array) -> c_int;
fn af_eval_multiple(num: c_int, arrays: *const af_array) -> c_int;
fn af_set_manual_eval_flag(flag: c_int) -> c_int;
fn af_get_manual_eval_flag(flag: *mut c_int) -> c_int;
fn af_retain_array(out: *mut af_array, arr: af_array) -> c_int;
fn af_copy_array(out: *mut af_array, arr: af_array) -> c_int;
fn af_release_array(arr: af_array) -> c_int;
fn af_print_array_gen(exp: *const c_char, arr: af_array, precision: c_int) -> c_int;
fn af_cast(out: *mut af_array, arr: af_array, aftype: c_uint) -> c_int;
fn af_get_backend_id(backend: *mut c_uint, input: af_array) -> c_int;
fn af_get_device_id(device: *mut c_int, input: af_array) -> c_int;
fn af_create_strided_array(
arr: *mut af_array,
data: *const c_void,
offset: dim_t,
ndims: c_uint,
dims: *const dim_t,
strides: *const dim_t,
aftype: c_uint,
stype: c_uint,
) -> c_int;
fn af_get_strides(
s0: *mut dim_t,
s1: *mut dim_t,
s2: *mut dim_t,
s3: *mut dim_t,
arr: af_array,
) -> c_int;
fn af_get_offset(offset: *mut dim_t, arr: af_array) -> c_int;
fn af_lock_array(arr: af_array) -> c_int;
fn af_unlock_array(arr: af_array) -> c_int;
fn af_get_device_ptr(ptr: *mut void_ptr, arr: af_array) -> c_int;
fn af_get_allocated_bytes(result: *mut usize, arr: af_array) -> c_int;
}
pub struct Array<T: HasAfEnum> {
handle: af_array,
_marker: PhantomData<T>,
}
unsafe impl<T: HasAfEnum> Send for Array<T> {}
unsafe impl<T: HasAfEnum> Sync for Array<T> {}
macro_rules! is_func {
($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => (
#[doc=$doc_str]
pub fn $fn_name(&self) -> bool {
unsafe {
let mut ret_val: bool = false;
let err_val = $ffi_fn(&mut ret_val as *mut bool, self.handle);
HANDLE_ERROR(AfError::from(err_val));
ret_val
}
}
)
}
impl<T> Array<T>
where
T: HasAfEnum,
{
pub fn new(slice: &[T], dims: Dim4) -> Self {
let aftype = T::get_af_dtype();
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_create_array(
&mut temp as *mut af_array,
slice.as_ptr() as *const c_void,
dims.ndims() as c_uint,
dims.get().as_ptr() as *const c_longlong,
aftype as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn new_strided(slice: &[T], offset: i64, dims: Dim4, strides: Dim4) -> Self {
let aftype = T::get_af_dtype();
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_create_strided_array(
&mut temp as *mut af_array,
slice.as_ptr() as *const c_void,
offset as dim_t,
dims.ndims() as c_uint,
dims.get().as_ptr() as *const c_longlong,
strides.get().as_ptr() as *const c_longlong,
aftype as c_uint,
1 as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn new_empty(dims: Dim4) -> Self {
let aftype = T::get_af_dtype();
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_create_handle(
&mut temp as *mut af_array,
dims.ndims() as c_uint,
dims.get().as_ptr() as *const c_longlong,
aftype as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn new_from_device_ptr(dev_ptr: *mut T, dims: Dim4) -> Self {
let aftype = T::get_af_dtype();
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_device_array(
&mut temp as *mut af_array,
dev_ptr as *mut c_void,
dims.ndims() as c_uint,
dims.get().as_ptr() as *const dim_t,
aftype as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn get_backend(&self) -> Backend {
unsafe {
let mut ret_val: u32 = 0;
let err_val = af_get_backend_id(&mut ret_val as *mut c_uint, self.handle);
HANDLE_ERROR(AfError::from(err_val));
match (err_val, ret_val) {
(0, 1) => Backend::CPU,
(0, 2) => Backend::CUDA,
(0, 3) => Backend::OPENCL,
_ => Backend::DEFAULT,
}
}
}
pub fn get_device_id(&self) -> i32 {
unsafe {
let mut ret_val: i32 = 0;
let err_val = af_get_device_id(&mut ret_val as *mut c_int, self.handle);
HANDLE_ERROR(AfError::from(err_val));
ret_val
}
}
pub fn elements(&self) -> usize {
unsafe {
let mut ret_val: dim_t = 0;
let err_val = af_get_elements(&mut ret_val as *mut dim_t, self.handle);
HANDLE_ERROR(AfError::from(err_val));
ret_val as usize
}
}
pub fn get_type(&self) -> DType {
unsafe {
let mut ret_val: u32 = 0;
let err_val = af_get_type(&mut ret_val as *mut c_uint, self.handle);
HANDLE_ERROR(AfError::from(err_val));
DType::from(ret_val)
}
}
pub fn dims(&self) -> Dim4 {
unsafe {
let mut ret0: i64 = 0;
let mut ret1: i64 = 0;
let mut ret2: i64 = 0;
let mut ret3: i64 = 0;
let err_val = af_get_dims(
&mut ret0 as *mut dim_t,
&mut ret1 as *mut dim_t,
&mut ret2 as *mut dim_t,
&mut ret3 as *mut dim_t,
self.handle,
);
HANDLE_ERROR(AfError::from(err_val));
Dim4::new(&[ret0 as u64, ret1 as u64, ret2 as u64, ret3 as u64])
}
}
pub fn strides(&self) -> Dim4 {
unsafe {
let mut ret0: i64 = 0;
let mut ret1: i64 = 0;
let mut ret2: i64 = 0;
let mut ret3: i64 = 0;
let err_val = af_get_strides(
&mut ret0 as *mut dim_t,
&mut ret1 as *mut dim_t,
&mut ret2 as *mut dim_t,
&mut ret3 as *mut dim_t,
self.handle,
);
HANDLE_ERROR(AfError::from(err_val));
Dim4::new(&[ret0 as u64, ret1 as u64, ret2 as u64, ret3 as u64])
}
}
pub fn numdims(&self) -> u32 {
unsafe {
let mut ret_val: u32 = 0;
let err_val = af_get_numdims(&mut ret_val as *mut c_uint, self.handle);
HANDLE_ERROR(AfError::from(err_val));
ret_val
}
}
pub fn offset(&self) -> i64 {
unsafe {
let mut ret_val: i64 = 0;
let err_val = af_get_offset(&mut ret_val as *mut dim_t, self.handle);
HANDLE_ERROR(AfError::from(err_val));
ret_val
}
}
pub unsafe fn get(&self) -> af_array {
self.handle
}
pub fn set(&mut self, handle: af_array) {
self.handle = handle;
}
pub fn host<O: HasAfEnum>(&self, data: &mut [O]) {
if data.len() != self.elements() {
HANDLE_ERROR(AfError::ERR_SIZE);
}
unsafe {
let err_val = af_get_data_ptr(data.as_mut_ptr() as *mut c_void, self.handle);
HANDLE_ERROR(AfError::from(err_val));
}
}
pub fn eval(&self) {
unsafe {
let err_val = af_eval(self.handle);
HANDLE_ERROR(AfError::from(err_val));
}
}
pub fn copy(&self) -> Self {
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_copy_array(&mut temp as *mut af_array, self.handle);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
is_func!("Check if Array is empty", is_empty, af_is_empty);
is_func!("Check if Array is scalar", is_scalar, af_is_scalar);
is_func!("Check if Array is a row", is_row, af_is_row);
is_func!("Check if Array is a column", is_column, af_is_column);
is_func!("Check if Array is a vector", is_vector, af_is_vector);
is_func!(
"Check if Array is of real (not complex) type",
is_real,
af_is_real
);
is_func!(
"Check if Array is of complex type",
is_complex,
af_is_complex
);
is_func!(
"Check if Array's numerical type is of double precision",
is_double,
af_is_double
);
is_func!(
"Check if Array's numerical type is of single precision",
is_single,
af_is_single
);
is_func!(
"Check if Array's numerical type is of half precision",
is_half,
af_is_half
);
is_func!(
"Check if Array is of integral type",
is_integer,
af_is_integer
);
is_func!("Check if Array is of boolean type", is_bool, af_is_bool);
is_func!(
"Check if Array is floating point real(not complex) data type",
is_realfloating,
af_is_realfloating
);
is_func!(
"Check if Array is floating point type, either real or complex data",
is_floating,
af_is_floating
);
is_func!(
"Check if Array's memory layout is continuous and one dimensional",
is_linear,
af_is_linear
);
is_func!("Check if Array is a sparse matrix", is_sparse, af_is_sparse);
is_func!(
"Check if Array's memory is owned by it and not a view of another Array",
is_owner,
af_is_owner
);
pub fn cast<O: HasAfEnum>(&self) -> Array<O> {
let trgt_type = O::get_af_dtype();
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_cast(&mut temp as *mut af_array, self.handle, trgt_type as c_uint);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn lock(&self) {
unsafe {
let err_val = af_lock_array(self.handle);
HANDLE_ERROR(AfError::from(err_val));
}
}
pub fn unlock(&self) {
unsafe {
let err_val = af_unlock_array(self.handle);
HANDLE_ERROR(AfError::from(err_val));
}
}
pub unsafe fn device_ptr(&self) -> void_ptr {
let mut temp: void_ptr = std::ptr::null_mut();
let err_val = af_get_device_ptr(&mut temp as *mut void_ptr, self.handle);
HANDLE_ERROR(AfError::from(err_val));
temp
}
pub fn get_allocated_bytes(&self) -> usize {
unsafe {
let mut temp: usize = 0;
let err_val = af_get_allocated_bytes(&mut temp as *mut usize, self.handle);
HANDLE_ERROR(AfError::from(err_val));
temp
}
}
}
impl<T: HasAfEnum> Into<Array<T>> for af_array {
fn into(self) -> Array<T> {
Array {
handle: self,
_marker: PhantomData,
}
}
}
impl<T> Clone for Array<T>
where
T: HasAfEnum,
{
fn clone(&self) -> Self {
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let ret_val = af_retain_array(&mut temp as *mut af_array, self.handle);
match ret_val {
0 => temp.into(),
_ => panic!("Weak copy of Array failed with error code: {}", ret_val),
}
}
}
}
impl<T> Drop for Array<T>
where
T: HasAfEnum,
{
fn drop(&mut self) {
unsafe {
let ret_val = af_release_array(self.handle);
match ret_val {
0 => (),
_ => panic!("Array<T> drop failed with error code: {}", ret_val),
}
}
}
}
pub fn print<T: HasAfEnum>(input: &Array<T>) {
let emptystring = CString::new("").unwrap();
unsafe {
let err_val = af_print_array_gen(
emptystring.to_bytes_with_nul().as_ptr() as *const c_char,
input.get(),
4,
);
HANDLE_ERROR(AfError::from(err_val));
}
}
pub fn print_gen<T: HasAfEnum>(msg: String, input: &Array<T>, precision: Option<i32>) {
let emptystring = CString::new(msg.as_bytes()).unwrap();
unsafe {
let err_val = af_print_array_gen(
emptystring.to_bytes_with_nul().as_ptr() as *const c_char,
input.get(),
match precision {
Some(p) => p,
None => 4,
} as c_int,
);
HANDLE_ERROR(AfError::from(err_val));
}
}
pub fn eval_multiple<T: HasAfEnum>(inputs: Vec<&Array<T>>) {
unsafe {
let mut v = Vec::new();
for i in inputs {
v.push(i.get());
}
let err_val = af_eval_multiple(v.len() as c_int, v.as_ptr() as *const af_array);
HANDLE_ERROR(AfError::from(err_val));
}
}
pub fn set_manual_eval(flag: bool) {
unsafe {
let err_val = af_set_manual_eval_flag(flag as c_int);
HANDLE_ERROR(AfError::from(err_val));
}
}
pub fn is_eval_manual() -> bool {
unsafe {
let mut ret_val: i32 = 0;
let err_val = af_get_manual_eval_flag(&mut ret_val as *mut c_int);
HANDLE_ERROR(AfError::from(err_val));
ret_val > 0
}
}
#[cfg(feature = "afserde")]
mod afserde {
use super::{Array, DType, Dim4, HasAfEnum};
use serde::de::{Deserializer, Error, Unexpected};
use serde::ser::Serializer;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
struct ArrayOnHost<T: HasAfEnum + std::fmt::Debug> {
dtype: DType,
shape: Dim4,
data: Vec<T>,
}
impl<T> Serialize for Array<T>
where
T: std::default::Default + std::clone::Clone + Serialize + HasAfEnum + std::fmt::Debug,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut vec = vec![T::default(); self.elements()];
self.host(&mut vec);
let arr_on_host = ArrayOnHost {
dtype: self.get_type(),
shape: self.dims().clone(),
data: vec,
};
arr_on_host.serialize(serializer)
}
}
impl<'de, T> Deserialize<'de> for Array<T>
where
T: Deserialize<'de> + HasAfEnum + std::fmt::Debug,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
match ArrayOnHost::<T>::deserialize(deserializer) {
Ok(arr_on_host) => {
let read_dtype = arr_on_host.dtype;
let expected_dtype = T::get_af_dtype();
if expected_dtype != read_dtype {
let error_msg = format!(
"data type is {:?}, deserialized type is {:?}",
expected_dtype, read_dtype
);
return Err(Error::invalid_value(Unexpected::Enum, &error_msg.as_str()));
}
Ok(Array::<T>::new(
&arr_on_host.data,
arr_on_host.shape.clone(),
))
}
Err(err) => Err(err),
}
}
}
}
#[cfg(test)]
mod tests {
use super::super::array::print;
use super::super::data::constant;
use super::super::device::{info, set_device, sync};
use crate::dim4;
use std::sync::{mpsc, Arc, RwLock};
use std::thread;
#[test]
fn thread_move_array() {
set_device(0);
info();
let mut a = constant(1, dim4!(3, 3));
let handle = thread::spawn(move || {
set_device(0);
println!("\nFrom thread {:?}", thread::current().id());
a += constant(2, dim4!(3, 3));
print(&a);
});
handle.join().unwrap();
}
#[test]
fn thread_borrow_array() {
set_device(0);
info();
let a = constant(1i32, dim4!(3, 3));
let handle = thread::spawn(move || {
set_device(0);
println!("\nFrom thread {:?}", thread::current().id());
print(&a);
});
handle.join().unwrap();
}
#[derive(Debug, Copy, Clone)]
enum Op {
Add,
Sub,
Div,
Mul,
}
#[test]
fn read_from_multiple_threads() {
let ops: Vec<_> = vec![Op::Add, Op::Sub, Op::Div, Op::Mul, Op::Add, Op::Div];
set_device(0);
let a = constant(1.0f32, dim4!(3, 3));
let b = constant(2.0f32, dim4!(3, 3));
let threads: Vec<_> = ops
.into_iter()
.map(|op| {
let x = a.clone();
let y = b.clone();
thread::spawn(move || {
set_device(0);
match op {
Op::Add => {
let _c = x + y;
}
Op::Sub => {
let _c = x - y;
}
Op::Div => {
let _c = x / y;
}
Op::Mul => {
let _c = x * y;
}
}
sync(0);
thread::sleep(std::time::Duration::new(1, 0));
})
})
.collect();
for child in threads {
let _ = child.join();
}
}
#[test]
fn access_using_rwlock() {
let ops: Vec<_> = vec![Op::Add, Op::Sub, Op::Div, Op::Mul, Op::Add, Op::Div];
set_device(0);
let c = constant(0.0f32, dim4!(3, 3));
let a = constant(1.0f32, dim4!(3, 3));
let b = constant(2.0f32, dim4!(3, 3));
let c_lock = Arc::new(RwLock::new(c));
let threads: Vec<_> = ops
.into_iter()
.map(|op| {
let x = a.clone();
let y = b.clone();
let wlock = c_lock.clone();
thread::spawn(move || {
set_device(0);
if let Ok(mut c_guard) = wlock.write() {
match op {
Op::Add => {
*c_guard += x + y;
}
Op::Sub => {
*c_guard += x - y;
}
Op::Div => {
*c_guard += x / y;
}
Op::Mul => {
*c_guard += x * y;
}
}
}
})
})
.collect();
for child in threads {
let _ = child.join();
}
}
#[test]
fn accum_using_channel() {
let ops: Vec<_> = vec![Op::Add, Op::Sub, Op::Div, Op::Mul, Op::Add, Op::Div];
let ops_len: usize = ops.len();
set_device(0);
let mut c = constant(0.0f32, dim4!(3, 3));
let a = constant(1.0f32, dim4!(3, 3));
let b = constant(2.0f32, dim4!(3, 3));
let (tx, rx) = mpsc::channel();
let threads: Vec<_> = ops
.into_iter()
.map(|op| {
let x = a.clone();
let y = b.clone();
let tx_clone = tx.clone();
thread::spawn(move || {
set_device(0);
let c = match op {
Op::Add => x + y,
Op::Sub => x - y,
Op::Div => x / y,
Op::Mul => x * y,
};
tx_clone.send(c).unwrap();
})
})
.collect();
for _i in 0..ops_len {
c += rx.recv().unwrap();
}
for child in threads {
let _ = child.join();
}
}
#[cfg(feature = "afserde")]
mod serde_tests {
use super::super::Array;
use crate::algorithm::sum_all;
use crate::randu;
#[test]
fn array_serde_json() {
let input = randu!(u8; 2, 2);
let serd = match serde_json::to_string(&input) {
Ok(serialized_str) => serialized_str,
Err(e) => e.to_string(),
};
let deserd: Array<u8> = serde_json::from_str(&serd).unwrap();
assert_eq!(sum_all(&(input - deserd)), (0u32, 0u32));
}
#[test]
fn array_serde_bincode() {
let input = randu!(u8; 2, 2);
let encoded = match bincode::serialize(&input) {
Ok(encoded) => encoded,
Err(_) => vec![],
};
let decoded: Array<u8> = bincode::deserialize(&encoded).unwrap();
assert_eq!(sum_all(&(input - decoded)), (0u32, 0u32));
}
}
}