最近写ISP算法的时候,发现Python运行循环特别的慢,终于在网上找到了可以加速Python循环的方法。速度实测下来提升有一百倍以上。

numba简介

Numba 是 Python 的实时编译器,它特别适合使用 NumPy 数组和函数以及循环的代码。

Numba 会读取修饰函数的 Python 字节码,并结合有关函数输入参数类型的信息。它分析和优化您的代码,最后使用 LLVM 编译器库生成函数的机器代码版本,根据您的 CPU 功能进行定制。然后每次调用函数时都会使用此编译版本。

Numba 必须编译函数在执行函数的机器代码版本之前给出的参数类型,这需要时间。但是,一旦编译完成,Numba 将缓存函数的机器代码版本,用于提供的特定类型的参数。如果再次调用具有相同类型的 ,它可以重用缓存的版本,而不必再次编译。(第一次执行会比较慢)

使用方法

在函数前面加上 @jit
编译选项:

  • nopython:Numba有两种编译模式:非python模式和对象模式。前者生成更快的代码,但有限制。@jit(nopython=True)
  • cache:为了避免每次调用 Python 程序时编译时间,可以指示 Numba 将函数编译的结果写入基于文件的缓存中。@jit(cache=True)
  • parallel:为已知具有并行语义的函数中的这些操作启用自动并行化(和相关优化)@jit(nopython=True, parallel=True)
from numba import jit
import numpy as np

x = np.arange(100).reshape(10, 10)

@jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit
def go_fast(a): # Function is compiled to machine code when called the first time
trace = 0.0
for i in range(a.shape[0]): # Numba likes loops
trace += np.tanh(a[i, i]) # Numba likes NumPy functions
return a + trace # Numba likes NumPy broadcasting

print(go_fast(x))

举例

比如说坏点校正的算法,会对每个像素点都进行一个判断,关键的循环部位利用numba进行加速。

from numba import jit
def bad_pixel_correction(raw: RawImageInfo, params: RawImageParams):
"""
function: bad_pixel_correction
correct for the bad (dead, stuck, or hot) pixels
input: raw:RawImageInfo() params:RawImageParams()
卷积核neighborhood_size * neighborhood_size,当这个值大于卷积核内最大的值或者小于最小的值,会将这个值替代掉
这个算法应该会损失不少分辨率
"""
neighborhood_size = params.get_size_for_bad_pixel_correction()
if ((neighborhood_size % 2) == 0):
print("neighborhood_size shoud be odd number, recommended value 3")
return raw

raw_data = raw.get_raw_data()
raw_channel_data = list()

if (raw.get_color_space() == "raw"):
ret_img = RawImageInfo()
ret_img.create_image('after bad pixel correction', raw_data.shape)
# Separate out the quarter resolution images
D = split_raw_data(raw_data)

# number of pixels to be padded at the borders
#no_of_pixel_pad = math.floor(neighborhood_size / 2.)
no_of_pixel_pad = neighborhood_size // 2

for idx in range(0, len(D)): # perform same operation for each quarter

# display progress
print("bad pixel correction: Quarter " + str(idx+1) + " of 4")

img = D[idx]
width, height = img.shape[1], img.shape[0]

# pad pixels at the borders, 扩充边缘
img = np.pad(img,
(no_of_pixel_pad, no_of_pixel_pad),
'reflect') # reflect would not repeat the border value

raw_channel_data.append(bad_pixel_correction_subfunc(img, no_of_pixel_pad,width,height))

# Regrouping the data
ret_img.data[::2, ::2] = raw_channel_data[0]
ret_img.data[::2, 1::2] = raw_channel_data[1]
ret_img.data[1::2, ::2] = raw_channel_data[2]
ret_img.data[1::2, 1::2] = raw_channel_data[3]
return ret_img
else:
params.set_error_str("bad pixel correction need RAW data")
return None

@jit(nopython=True)
def bad_pixel_correction_subfunc(img, no_of_pixel_pad, width, height):
for i in range(no_of_pixel_pad, height + no_of_pixel_pad):
for j in range(no_of_pixel_pad, width + no_of_pixel_pad):
# save the middle pixel value
mid_pixel_val = img[i, j]
# extract the neighborhood
neighborhood = img[i - no_of_pixel_pad: i + no_of_pixel_pad+1,
j - no_of_pixel_pad: j + no_of_pixel_pad+1]

# set the center pixels value same as the left pixel
# Does not matter replace with right or left pixel
# is used to replace the center pixels value
neighborhood[no_of_pixel_pad,
no_of_pixel_pad] = neighborhood[no_of_pixel_pad, no_of_pixel_pad-1]

min_neighborhood = np.min(neighborhood)
max_neighborhood = np.max(neighborhood)

if (mid_pixel_val < min_neighborhood):
img[i, j] = min_neighborhood
elif (mid_pixel_val > max_neighborhood):
img[i, j] = max_neighborhood
else:
img[i, j] = mid_pixel_val

# Put the corrected image to the dictionary
return img[no_of_pixel_pad: height + no_of_pixel_pad, no_of_pixel_pad: width + no_of_pixel_pad]