use super::core::{
af_array, dim_t, AfError, Array, FloatingPoint, HasAfEnum, SparseFormat, HANDLE_ERROR,
};
use libc::{c_int, c_uint, c_void};
extern "C" {
fn af_create_sparse_array(
out: *mut af_array,
nRows: dim_t,
nCols: dim_t,
vals: af_array,
rowIdx: af_array,
colIdx: af_array,
stype: c_uint,
) -> c_int;
fn af_create_sparse_array_from_ptr(
out: *mut af_array,
nRows: dim_t,
nCols: dim_t,
nNZ: dim_t,
values: *const c_void,
rowIdx: *const c_int,
colIdx: *const c_int,
aftype: c_uint,
stype: c_uint,
src: c_uint,
) -> c_int;
fn af_create_sparse_array_from_dense(
out: *mut af_array,
dense: af_array,
stype: c_uint,
) -> c_int;
fn af_sparse_convert_to(out: *mut af_array, input: af_array, dstStrge: c_uint) -> c_int;
fn af_sparse_to_dense(out: *mut af_array, sparse: af_array) -> c_int;
fn af_sparse_get_info(
vals: *mut af_array,
rIdx: *mut af_array,
cIdx: *mut af_array,
stype: *mut c_uint,
input: af_array,
) -> c_int;
fn af_sparse_get_values(out: *mut af_array, input: af_array) -> c_int;
fn af_sparse_get_row_idx(out: *mut af_array, input: af_array) -> c_int;
fn af_sparse_get_col_idx(out: *mut af_array, input: af_array) -> c_int;
fn af_sparse_get_nnz(out: *mut dim_t, input: af_array) -> c_int;
fn af_sparse_get_storage(out: *mut c_uint, input: af_array) -> c_int;
}
pub fn sparse<T>(
rows: u64,
cols: u64,
values: &Array<T>,
row_indices: &Array<i32>,
col_indices: &Array<i32>,
format: SparseFormat,
) -> Array<T>
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_create_sparse_array(
&mut temp as *mut af_array,
rows as dim_t,
cols as dim_t,
values.get(),
row_indices.get(),
col_indices.get(),
format as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn sparse_from_host<T>(
rows: u64,
cols: u64,
nzz: u64,
values: &[T],
row_indices: &[i32],
col_indices: &[i32],
format: SparseFormat,
) -> Array<T>
where
T: HasAfEnum + FloatingPoint,
{
let aftype = T::get_af_dtype();
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_create_sparse_array_from_ptr(
&mut temp as *mut af_array,
rows as dim_t,
cols as dim_t,
nzz as dim_t,
values.as_ptr() as *const c_void,
row_indices.as_ptr() as *const c_int,
col_indices.as_ptr() as *const c_int,
aftype as c_uint,
format as c_uint,
1,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn sparse_from_dense<T>(dense: &Array<T>, format: SparseFormat) -> Array<T>
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_create_sparse_array_from_dense(
&mut temp as *mut af_array,
dense.get(),
format as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn sparse_convert_to<T>(input: &Array<T>, format: SparseFormat) -> Array<T>
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val =
af_sparse_convert_to(&mut temp as *mut af_array, input.get(), format as c_uint);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn sparse_to_dense<T>(input: &Array<T>) -> Array<T>
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_sparse_to_dense(&mut temp as *mut af_array, input.get());
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn sparse_get_info<T>(input: &Array<T>) -> (Array<T>, Array<i32>, Array<i32>, SparseFormat)
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut val: af_array = std::ptr::null_mut();
let mut row: af_array = std::ptr::null_mut();
let mut col: af_array = std::ptr::null_mut();
let mut stype: u32 = 0;
let err_val = af_sparse_get_info(
&mut val as *mut af_array,
&mut row as *mut af_array,
&mut col as *mut af_array,
&mut stype as *mut c_uint,
input.get(),
);
HANDLE_ERROR(AfError::from(err_val));
(
val.into(),
row.into(),
col.into(),
SparseFormat::from(stype),
)
}
}
pub fn sparse_get_values<T>(input: &Array<T>) -> Array<T>
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut val: af_array = std::ptr::null_mut();
let err_val = af_sparse_get_values(&mut val as *mut af_array, input.get());
HANDLE_ERROR(AfError::from(err_val));
val.into()
}
}
pub fn sparse_get_row_indices<T>(input: &Array<T>) -> Array<i32>
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut val: af_array = std::ptr::null_mut();
let err_val = af_sparse_get_row_idx(&mut val as *mut af_array, input.get());
HANDLE_ERROR(AfError::from(err_val));
val.into()
}
}
pub fn sparse_get_col_indices<T>(input: &Array<T>) -> Array<i32>
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut val: af_array = std::ptr::null_mut();
let err_val = af_sparse_get_col_idx(&mut val as *mut af_array, input.get());
HANDLE_ERROR(AfError::from(err_val));
val.into()
}
}
pub fn sparse_get_nnz<T: HasAfEnum>(input: &Array<T>) -> i64 {
let mut count: i64 = 0;
unsafe {
let err_val = af_sparse_get_nnz(&mut count as *mut dim_t, input.get());
HANDLE_ERROR(AfError::from(err_val));
}
count
}
pub fn sparse_get_format<T: HasAfEnum>(input: &Array<T>) -> SparseFormat {
let mut stype: u32 = 0;
unsafe {
let err_val = af_sparse_get_storage(&mut stype as *mut c_uint, input.get());
HANDLE_ERROR(AfError::from(err_val));
}
SparseFormat::from(stype)
}