示例
Mandelbrot
1from timeit import default_timer as timer
2try:
3 from matplotlib.pylab import imshow, show
4 have_mpl = True
5except ImportError:
6 have_mpl = False
7import numpy as np
8from numba import jit
9
10@jit(nopython=True)
11def mandel(x, y, max_iters):
12 """
13 Given the real and imaginary parts of a complex number,
14 determine if it is a candidate for membership in the Mandelbrot
15 set given a fixed number of iterations.
16 """
17 i = 0
18 c = complex(x,y)
19 z = 0.0j
20 for i in range(max_iters):
21 z = z * z + c
22 if (z.real * z.real + z.imag * z.imag) >= 4:
23 return i
24
25 return 255
26
27@jit(nopython=True)
28def create_fractal(min_x, max_x, min_y, max_y, image, iters):
29 height = image.shape[0]
30 width = image.shape[1]
31
32 pixel_size_x = (max_x - min_x) / width
33 pixel_size_y = (max_y - min_y) / height
34 for x in range(width):
35 real = min_x + x * pixel_size_x
36 for y in range(height):
37 imag = min_y + y * pixel_size_y
38 color = mandel(real, imag, iters)
39 image[y, x] = color
40
41 return image
42
43image = np.zeros((500 * 2, 750 * 2), dtype=np.uint8)
44s = timer()
45create_fractal(-2.0, 1.0, -1.0, 1.0, image, 20)
46e = timer()
47print(e - s)
48if have_mpl:
49 imshow(image)
50 show()
移动平均
1import numpy as np
2
3from numba import guvectorize
4
5@guvectorize(['void(float64[:], intp[:], float64[:])'],
6 '(n),()->(n)')
7def move_mean(a, window_arr, out):
8 window_width = window_arr[0]
9 asum = 0.0
10 count = 0
11 for i in range(window_width):
12 asum += a[i]
13 count += 1
14 out[i] = asum / count
15 for i in range(window_width, len(a)):
16 asum += a[i] - a[i - window_width]
17 out[i] = asum / count
18
19arr = np.arange(20, dtype=np.float64).reshape(2, 10)
20print(arr)
21print(move_mean(arr, 3))
多线程
下面的代码展示了使用 nogil 特性时的潜在性能提升。例如,在一台4核机器上,打印了以下结果:
numpy (1 thread) 145 ms
numba (1 thread) 128 ms
numba (4 threads) 35 ms
备注
如果偏好,可以使用标准的 concurrent.futures 模块,而不是手动生成线程和分派任务。
1import math
2import threading
3from timeit import repeat
4
5import numpy as np
6from numba import jit
7
8nthreads = 4
9size = 10**6
10
11def func_np(a, b):
12 """
13 Control function using Numpy.
14 """
15 return np.exp(2.1 * a + 3.2 * b)
16
17@jit('void(double[:], double[:], double[:])', nopython=True,
18 nogil=True)
19def inner_func_nb(result, a, b):
20 """
21 Function under test.
22 """
23 for i in range(len(result)):
24 result[i] = math.exp(2.1 * a[i] + 3.2 * b[i])
25
26def timefunc(correct, s, func, *args, **kwargs):
27 """
28 Benchmark *func* and print out its runtime.
29 """
30 print(s.ljust(20), end=" ")
31 # Make sure the function is compiled before the benchmark is
32 # started
33 res = func(*args, **kwargs)
34 if correct is not None:
35 assert np.allclose(res, correct), (res, correct)
36 # time it
37 print('{:>5.0f} ms'.format(min(repeat(
38 lambda: func(*args, **kwargs), number=5, repeat=2)) * 1000))
39 return res
40
41def make_singlethread(inner_func):
42 """
43 Run the given function inside a single thread.
44 """
45 def func(*args):
46 length = len(args[0])
47 result = np.empty(length, dtype=np.float64)
48 inner_func(result, *args)
49 return result
50 return func
51
52def make_multithread(inner_func, numthreads):
53 """
54 Run the given function inside *numthreads* threads, splitting
55 its arguments into equal-sized chunks.
56 """
57 def func_mt(*args):
58 length = len(args[0])
59 result = np.empty(length, dtype=np.float64)
60 args = (result,) + args
61 chunklen = (length + numthreads - 1) // numthreads
62 # Create argument tuples for each input chunk
63 chunks = [[arg[i * chunklen:(i + 1) * chunklen] for arg in
64 args] for i in range(numthreads)]
65 # Spawn one thread per chunk
66 threads = [threading.Thread(target=inner_func, args=chunk)
67 for chunk in chunks]
68 for thread in threads:
69 thread.start()
70 for thread in threads:
71 thread.join()
72 return result
73 return func_mt
74
75func_nb = make_singlethread(inner_func_nb)
76func_nb_mt = make_multithread(inner_func_nb, nthreads)
77
78a = np.random.rand(size)
79b = np.random.rand(size)
80
81correct = timefunc(None, "numpy (1 thread)", func_np, a, b)
82timefunc(correct, "numba (1 thread)", func_nb, a, b)
83timefunc(correct, "numba (%d threads)" % nthreads, func_nb_mt, a, b)