python多線程方法詳解
處理多個數據和多文件時,使用for循環的速度非常慢,此時需要用多線程來加速運行進度,常用的模塊為multiprocess和joblib,下面對兩種包我常用的方法進行說明。
1、模塊安裝
pip install multiprocessing pip install joblib
2、以分塊計算NDVI為例
首先導入需要的包
import numpy as np from osgeo import gdal import time from multiprocessing import cpu_count from multiprocessing import Pool from joblib import Parallel, delayed
定義GdalUtil類,以讀取遙感數據
class GdalUtil: def __init__(self): pass @staticmethod def read_file(raster_file, read_band=None): """讀取柵格數據""" # 註冊柵格驅動 gdal.AllRegister() gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES') # 打開輸入圖像 dataset = gdal.Open(raster_file, gdal.GA_ReadOnly) if dataset == None: print('打開圖像{0} 失敗.\n', raster_file) # 列 raster_width = dataset.RasterXSize # 行 raster_height = dataset.RasterYSize # 讀取數據 if read_band == None: data_array = dataset.ReadAsArray(0, 0, raster_width, raster_height) else: band = dataset.GetRasterBand(read_band) data_array = band.ReadAsArray(0, 0, raster_width, raster_height) return data_array @staticmethod def read_block_data(dataset, band_num, cols_read, rows_read, start_col=0, start_row=0): band = dataset.GetRasterBand(band_num) res_data = band.ReadAsArray(start_col, start_row, cols_read, rows_read) return res_data @staticmethod def get_raster_band(raster_path): # 註冊柵格驅動 gdal.AllRegister() gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES') # 打開輸入圖像 dataset = gdal.Open(raster_path, gdal.GA_ReadOnly) if dataset == None: print('打開圖像{0} 失敗.\n', raster_path) raster_band = dataset.RasterCount return raster_band @staticmethod def get_file_size(raster_path): """獲取柵格仿射變換參數""" # 註冊柵格驅動 gdal.AllRegister() gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES') # 打開輸入圖像 dataset = gdal.Open(raster_path, gdal.GA_ReadOnly) if dataset == None: print('打開圖像{0} 失敗.\n', raster_path) # 列 raster_width = dataset.RasterXSize # 行 raster_height = dataset.RasterYSize return raster_width, raster_height @staticmethod def get_file_geotransform(raster_path): """獲取柵格仿射變換參數""" # 註冊柵格驅動 gdal.AllRegister() gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES') # 打開輸入圖像 dataset = gdal.Open(raster_path, gdal.GA_ReadOnly) if dataset == None: print('打開圖像{0} 失敗.\n', raster_path) # 獲取輸入圖像仿射變換參數 input_geotransform = dataset.GetGeoTransform() return input_geotransform @staticmethod def get_file_proj(raster_path): """獲取柵格圖像空間參考""" # 註冊柵格驅動 gdal.AllRegister() gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES') # 打開輸入圖像 dataset = gdal.Open(raster_path, gdal.GA_ReadOnly) if dataset == None: print('打開圖像{0} 失敗.\n', raster_path) # 獲取輸入圖像空間參考 input_project = dataset.GetProjection() return input_project @staticmethod def write_file(dataset, geotransform, project, output_path, out_format='GTiff', eType=gdal.GDT_Float32): """寫入柵格""" if np.ndim(dataset) == 3: out_band, out_rows, out_cols = dataset.shape else: out_band = 1 out_rows, out_cols = dataset.shape # 創建指定輸出格式的驅動 out_driver = gdal.GetDriverByName(out_format) if out_driver == None: print('格式%s 不支持Creat()方法.\n', out_format) return out_dataset = out_driver.Create(output_path, xsize=out_cols, ysize=out_rows, bands=out_band, eType=eType) # 設置輸出圖像的仿射參數 out_dataset.SetGeoTransform(geotransform) # 設置輸出圖像的投影參數 out_dataset.SetProjection(project) # 寫出數據 if out_band == 1: out_dataset.GetRasterBand(1).WriteArray(dataset) else: for i in range(out_band): out_dataset.GetRasterBand(i + 1).WriteArray(dataset[i]) del out_dataset
定義計算NDVI的函數
def cal_ndvi(multi): ''' 計算高分NDVI :param multi:格式為列表,依次包含[遙感文件路徑,開始行號,開始列號,待讀的行數,待讀的列數] :return: NDVI數組 ''' input_file, start_col, start_row, cols_step, rows_step = multi dataset = gdal.Open(input_file, gdal.GA_ReadOnly) nir_data = GdalUtil.read_block_data(dataset, 4, cols_step, rows_step, start_col=start_col, start_row=start_row) red_data = GdalUtil.read_block_data(dataset, 3, cols_step, rows_step, start_col=start_col, start_row=start_row) ndvi = (nir_data - red_data) / (nir_data + red_data) ndvi[(ndvi > 1.5) | (ndvi < -1)] = 0 return ndvi
定義主函數
if __name__ == "__main__": input_file = r'D:\originalData\GF1\namucuo2021.tif' output_file = r'D:\originalData\GF1\namucuo2021_ndvi.tif' method = 'joblib' # method = 'multiprocessing' # 獲取文件主要信息 raster_cols, raster_rows = GdalUtil.get_file_size(input_file) geotransform = GdalUtil.get_file_geotransform(input_file) project = GdalUtil.get_file_proj(input_file) # 定義分塊大小 rows_block_size = 50 cols_block_size = 50 multi = [] for j in range(0, raster_rows, rows_block_size): for i in range(0, raster_cols, cols_block_size): if j + rows_block_size < raster_rows: rows_step = rows_block_size else: rows_step = raster_rows - j # 數據橫向步長 if i + cols_block_size < raster_cols: cols_step = cols_block_size else: cols_step = raster_cols - i temp_multi = [input_file, i, j, cols_step, rows_step] multi.append(temp_multi) t1 = time.time() if method == 'multiprocessing': # multiprocessing方法 pool = Pool(processes=cpu_count()-1) # 註意map函數中傳入的參數應該是可迭代對象,如list;返回值為list res = pool.map(cal_ndvi, multi) pool.close() pool.join() else: # joblib方法 res = Parallel(n_jobs=-1)(delayed(cal_ndvi)(input_list) for input_list in multi) t2 = time.time() print("Total time:" + (t2 - t1).__str__()) # 將multiprocessing中的結果提取出來,放回對應的矩陣位置中 out_data = np.zeros([raster_rows, raster_cols], dtype='float') for result, input_multi in zip(res, multi): start_col = input_multi[1] start_row = input_multi[2] cols_step = input_multi[3] rows_step = input_multi[4] out_data[start_row:start_row + rows_step, start_col:start_col + cols_step] = result GdalUtil.write_file(out_data, geotransform, project, output_file)
雙重for循環時,兩層for循環都使用multiprocessing時會報錯,這時可以外層for循環使用joblib方法,內層for循環改為multiprocessing方法,不會報錯
到此這篇關於python多線程方法詳解的文章就介紹到這瞭,更多相關python多線程內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- Python計算多幅圖像柵格值的平均值
- 使用python進行nc轉tif的3種情況解決
- 基於Python實現nc批量轉tif格式
- python 使用GDAL實現柵格tif轉矢量shp的方式小結
- R語言中Fisher判別的使用方法