Build Custom Dask.Array Function¶
As discussed in the array design document to create a
dask Array object we need the following:
- A dask graph
- A name specifying a set of keys within that graph
- A
chunkstuple giving chunk shape information - A NumPy dtype
Often dask.array functions take other Array objects as inputs along
with parameters, add tasks to a new dask dictionary, create a new chunks
tuple, and then construct and return a new Array object. The hard parts
are invariably creating the right tasks and creating a new chunks tuple.
Careful review of the array design document is suggested.
Example eye¶
Consider this simple example with the eye function.
from dask.base import tokenize
def eye(n, blocksize):
chunks = ((blocksize,) * n // blocksize,
(blocksize,) * n // blocksize)
name = 'eye-' + tokenize(n, blocksize) # unique identifier
dsk = {(name, i, j): (np.eye, blocksize)
if i == j else
(np.zeros, (blocksize, blocksize))
for i in range(n // blocksize)
for j in range(n // blocksize)}
dtype = np.eye(0).dtype # take dtype default from numpy
return Array(dsk, name, chunks, dtype)
This example is particularly simple because it doesn’t take any Array
objects as input.
Example diag¶
Consider the function diag that takes a 1d vector and produces a 2d matrix
with the values of the vector along the diagonal. Consider the case where the
input is a 1d array with chunk sizes (2, 3, 4) in the first dimension like
this:
[x_0, x_1], [x_2, x_3, x_4], [x_5, x_6, x_7, x_8]
We need to create a 2d matrix with chunks equal to ((2, 3, 4), (2, 3, 4))
where the ith block along the diagonal of the output is the result of calling
np.diag on the ith block of the input and all other blocks are zero.
from dask.base import tokenize
def diag(v):
"""Construct a diagonal array, with ``v`` on the diagonal."""
assert v.ndim == 1
chunks = (v.chunks[0], v.chunks[0]) # repeat chunks twice
name = 'diag-' + tokenize(v) # unique identifier
dsk = {(name, i, j): (np.diag, (v.name, i))
if i == j else
(np.zeros, (v.chunks[0][i], v.chunks[0][j]))
for i in range(len(v.chunks[0]))
for j in range(len(v.chunks[0]))}
dsk.update(v.dask) # include dask graph of the input
dtype = v.dtype # output has the same dtype as the input
return Array(dsk, name, chunks, dtype)
>>> x = da.arange(9, chunks=((2, 3, 4),))
>>> x
dask.array<arange-1, shape=(9,), chunks=((2, 3, 4)), dtype=int64>
>>> M = diag(x)
>>> M
dask.array<diag-2, shape=(9, 9), chunks=((2, 3, 4), (2, 3, 4)), dtype=int64>
>>> M.compute()
array([[0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 2, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 3, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 4, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 5, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 6, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 7, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 8]])