Source code for arrayfire.bcast

#######################################################
# Copyright (c) 2015, ArrayFire
# All rights reserved.
#
# This file is distributed under 3-clause BSD license.
# The complete license agreement can be obtained at:
# http://arrayfire.com/licenses/BSD-3-Clause
########################################################

"""
Function to perform broadcasting operations.
"""

class _bcast(object):
    _flag = False
    def get(self):
        return _bcast._flag

    def set(self, flag):
        _bcast._flag = flag

    def toggle(self):
        _bcast._flag ^= True

_bcast_var = _bcast()

[docs]def broadcast(func, *args): """ Function to perform broadcast operations. This function can be used directly or as an annotation in the following manner. Example ------- Using broadcast as an annotation >>> import arrayfire as af >>> @af.broadcast ... def add(a, b): ... return a + b ... >>> a = af.randu(2,3) >>> b = af.randu(2,1) # b is a different size >>> # Trying to add arrays of different sizes raises an exceptions >>> c = add(a, b) # This call does not raise an exception because of the annotation >>> af.display(a) [2 3 1 1] 0.4107 0.9518 0.4198 0.8224 0.1794 0.0081 >>> af.display(b) [2 1 1 1] 0.7269 0.7104 >>> af.display(c) [2 3 1 1] 1.1377 1.6787 1.1467 1.5328 0.8898 0.7185 Using broadcast as function >>> import arrayfire as af >>> add = lambda a,b: a + b >>> a = af.randu(2,3) >>> b = af.randu(2,1) # b is a different size >>> # Trying to add arrays of different sizes raises an exceptions >>> c = af.broadcast(add, a, b) # This call does not raise an exception >>> af.display(a) [2 3 1 1] 0.4107 0.9518 0.4198 0.8224 0.1794 0.0081 >>> af.display(b) [2 1 1 1] 0.7269 0.7104 >>> af.display(c) [2 3 1 1] 1.1377 1.6787 1.1467 1.5328 0.8898 0.7185 """ def wrapper(*func_args): _bcast_var.toggle() res = func(*func_args) _bcast_var.toggle() return res if len(args) == 0: return wrapper else: return wrapper(*args)