J'essaie de créer avec numba une fonction qui renvoie un tableau numpy évalué sur un autre tableau.Je posterai un code simple sans njit:

import numpy as np
import numba as nb

def prueba(arr, eva):
    mask = []
    for i in range(len(arr)):
        mask.append(arr[i])
    return eva[mask]

Cela fonctionne correctement, comme prévu:

>>> prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))
array([6, 7, 8])

Néanmoins, lorsque j'essaye de le compiler avec numba en mode nopython (@njit), cela génère une erreur

@nb.njit
def prueba(arr, eva):
    mask = []
    for i in range(len(arr)):
        mask.append(arr[i])
    return eva[mask]

>>> prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
<ipython-input-9-111474f08921> in <module>
----> 1 prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))

~/.local/lib/python3.7/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
    399                 e.patch_message(msg)
    400 
--> 401             error_rewrite(e, 'typing')
    402         except errors.UnsupportedError as e:
    403             # Something unsupported is present in the user code, add help info

~/.local/lib/python3.7/site-packages/numba/dispatcher.py in error_rewrite(e, issue_type)
    342                 raise e
    343             else:
--> 344                 reraise(type(e), e, None)
    345 
    346         argtypes = []

~/.local/lib/python3.7/site-packages/numba/six.py in reraise(tp, value, tb)
    666             value = tp()
    667         if value.__traceback__ is not tb:
--> 668             raise value.with_traceback(tb)
    669         raise value
    670 

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(int64, 1d, C), list(int64))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
In definition 2:
    All templates rejected with literals.
In definition 3:
    All templates rejected without literals.
In definition 4:
    All templates rejected with literals.
In definition 5:
    All templates rejected without literals.
In definition 6:
    All templates rejected with literals.
In definition 7:
    All templates rejected without literals.
In definition 8:
    All templates rejected with literals.
In definition 9:
    All templates rejected without literals.
In definition 10:
    All templates rejected with literals.
In definition 11:
    All templates rejected without literals.
In definition 12:
    TypeError: unsupported array index type list(int64) in [list(int64)]
    raised from /home/donielix/.local/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
In definition 13:
    TypeError: unsupported array index type list(int64) in [list(int64)]
    raised from /home/donielix/.local/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of intrinsic-call at <ipython-input-8-1b5c9f1a65d5> (6)
[2] During: typing of static-get-item at <ipython-input-8-1b5c9f1a65d5> (6)

File "<ipython-input-8-1b5c9f1a65d5>", line 6:
def prueba(arr, eva):
    <source elided>
        mask.append(arr[i])
    return eva[mask]
    ^

Ma question est donc la suivante: pourquoi ce code simple donne-t-il une erreur inattendue? Et comment dois-je contourner ce problème?

0
Dani 20 avril 2020 à 19:28

2 réponses

Meilleure réponse

Directement à partir de la documentation:

Un sous-ensemble d'indexation avancée est également pris en charge: un seul index est autorisé, et il doit s'agir d'un tableau unidimensionnel (il peut être combiné avec un nombre arbitraire d'indices de base). https://numba.pydata.org/numba- doc / dev / reference / numpysupported.html # array-access

Par conséquent, pour que votre code fonctionne, vous devez convertir mask en numpy array:

@nb.njit
def prueba(arr, eva):
    mask = []
    for i in range(len(arr)):
        mask.append(arr[i])
    mask_as_array = np.array(mask)
    return eva[mask_as_array]

prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))
1
Ralvi Isufaj 20 avril 2020 à 19:51

Votre indexation à l'aide de numpy:

In [181]: a, b = np.array([1,2,3]), np.array([5,6,7,8,9,10])                                           
In [182]: b[a]                                                                                         
Out[182]: array([6, 7, 8])
In [183]: def foo(arr, eva): 
     ...:     return eva[arr] 
     ...:                                                                                              
In [184]: foo(a,b)                                                                                     
Out[184]: array([6, 7, 8])
In [186]: timeit foo(a,b)                                                                              
350 ns ± 9.98 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

Essayer de le répliquer (et peut-être l'accélérer) avec numba:

In [185]: import numba                                                                                 

In [187]: @numba.njit 
     ...: def foo1(arr,eva): 
     ...:     return eva[arr] 
     ...:                                                                                              
In [188]: foo1(a,b)                                                                                    
Out[188]: array([6, 7, 8])
In [189]: timeit foo1(a,b)                                                                             
968 ns ± 19.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [190]: @numba.njit 
     ...: def foo2(arr,eva): 
     ...:     res = np.empty(len(arr), eva.dtype) 
     ...:     for i in range(len(arr)): 
     ...:         res[i] = b[a[i]] 
     ...:     return res 

In [191]: foo2(a,b)                                                                                    
Out[191]: array([6, 7, 8])
In [192]: timeit foo2(a,b)                                                                             
941 ns ± 7.91 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [193]: @numba.njit 
     ...: def foo2(arr,eva): 
     ...:     res = np.empty(len(arr), eva.dtype) 
     ...:     for i,v in enumerate(a): 
     ...:         res[i] = b[v] 
     ...:     return res 

In [194]: foo2(a,b)                                                                                    
Out[194]: array([6, 7, 8])
In [195]: timeit foo2(a,b)                                                                             
941 ns ± 17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

Il ne sert à rien d'essayer de remplacer une fonctionnalité numpy de base par numba.

Quelqu'un avec plus d'expérience numba pourrait améliorer cela.

Éditer

Comme je l'ai observé au départ, numba n'aime pas indexer avec une liste. La conversion de la liste en tableau fonctionne:

In [196]: @numba.njit 
     ...: def prueba(arr, eva): 
     ...:     mask = [] 
     ...:     for i in range(len(arr)): 
     ...:         mask.append(arr[i]) 
     ...:     mask = np.array(mask) 
     ...:     return eva[mask] 
     ...:                                                                                              
In [197]: prueba(a,b)                                                                                  
Out[197]: array([6, 7, 8])
In [198]: timeit prueba(a,b)                                                                           
1.5 µs ± 4.79 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
0
hpaulj 20 avril 2020 à 19:52