Python深度學習之Unet 語義分割模型(Keras)
前言
最近由於在尋找方向上迷失自我,準備瞭解更多的計算機視覺任務重的模型。看到語義分割任務重Unet一個有意思的模型,我準備來復現一下它。
一、什麼是語義分割
語義分割任務,如下圖所示:
簡而言之,語義分割任務就是將圖片中的不同類別,用不同的顏色標記出來,每一個類別使用一種顏色。常用於醫學圖像,衛星圖像任務。
那如何做到將像素點上色呢?
其實語義分割的輸出和圖像分類網絡類似,圖像分類類別數是一個一維的one hot 矩陣。例如:三分類的[0,1,0]。
語義分割任務最後的輸出特征圖 是一個三維結構,大小與原圖類似,通道數就是類別數。 如下圖(圖片來源於知乎)所示:
其中通道數是類別數,每個通道所標記的像素點,是該類別在圖像中的位置,最後通過argmax 取每個通道有用像素 合成一張圖像,用不同顏色表示其類別位置。 語義分割任務其實也是分類任務中的一種,他不過是對每一個像素點進行細分,找到每一個像素點所述的類別。 這就是語義分割任務啦~
下面我們來復現 unet 模型
二、Unet
1.基本原理
什麼是Unet,它的網絡結構如下圖所示:
整個網絡是一個“U” 的形狀,Unet 網絡可以分成兩部分,上圖紅色方框中是特征提取部分,和其他卷積神經網絡一樣,都是通過堆疊卷積提取圖像特征,通過池化來壓縮特征圖。藍色方框中為圖像還原部分(這樣稱它可能不太專業,大傢理解就好),通過上采樣和卷積來來將壓縮的圖像進行還原。特征提取部分可以使用優秀的網絡,例如:Resnet50,VGG等。
註意:由於 Resnet50和VGG 網絡太大。本文將使用Mobilenet 作為主幹特征提取網絡。為瞭方便理解Unet,本文將使用自己搭建的一個mini_unet 去幫祝大傢理解。為瞭方便計算,復現過程會把壓縮後的特征圖上采樣和輸入的特征圖一樣大小。
代碼github地址: 一直上不去
先上傳到碼雲: https://gitee.com/Boss-Jian/unet
2.mini_unet
mini_unet 是搭建來幫助大傢理解語義分割的網絡流程,並不能作為一個優秀的模型完成語義分割任務,來看一下代碼的實現:
from keras.layers import Input,Conv2D,Dropout,MaxPooling2D,Concatenate,UpSampling2D from numpy import pad from keras.models import Model def unet_mini(n_classes=21,input_shape=(224,224,3)): img_input = Input(shape=input_shape) #------------------------------------------------------ # #encoder 部分 #224,224,3 - > 112,112,32 conv1 = Conv2D(32,(3,3),activation='relu',padding='same')(img_input) conv1 = Dropout(0.2)(conv1) conv1 = Conv2D(32,(3,3),activation='relu',padding='same')(conv1) pool1 = MaxPooling2D((2,2),strides=2)(conv1) #112,112,32 -> 56,56,64 conv2 = Conv2D(64,(3,3),activation='relu',padding='same')(pool1) conv2 = Dropout(0.2)(conv2) conv2 = Conv2D(64,(3,3),activation='relu',padding='same')(conv2) pool2 = MaxPooling2D((2,2),strides=2)(conv2) #56,56,64 -> 56,56,128 conv3 = Conv2D(128,(3,3),activation='relu',padding='same')(pool2) conv3 = Dropout(0.2)(conv3) conv3 = Conv2D(128,(3,3),activation='relu',padding='same')(conv3) #------------------------------------------------- # decoder 部分 #56,56,128 -> 112,112,64 up1 = UpSampling2D(2)(conv3) #112,112,64 -> 112,112,64+128 up1 = Concatenate(axis=-1)([up1,conv2]) # #112,112,192 -> 112,112,64 conv4 = Conv2D(64,(3,3),activation='relu',padding='same')(up1) conv4 = Dropout(0.2)(conv4) conv4 = Conv2D(64,(3,3),activation='relu',padding='same')(conv4) #112,112,64 - >224,224,64 up2 = UpSampling2D(2)(conv4) #224,224,64 -> 224,224,64+32 up2 = Concatenate(axis=-1)([up2,conv1]) # 224,224,96 -> 224,224,32 conv5 = Conv2D(32,(3,3),activation='relu',padding='same')(up2) conv5 = Dropout(0.2)(conv5) conv5 = Conv2D(32,(3,3),activation='relu',padding='same')(conv5) o = Conv2D(n_classes,1,padding='same')(conv5) return Model(img_input,o,name="unet_mini") if __name__=="__main__": model = unet_mini() model.summary()
mini_unet 通過encoder 部分將 224x224x3的圖像 變成 112x112x64 的特征圖,再通過 上采樣方法將特征圖放大到 224x224x32。最後通過卷積:
o = Conv2D(n_classes,1,padding='same')(conv5)
將特征圖的通道數調節成和類別數一樣。
3. Mobilenet_unet
Mobilenet_unet 是使用Mobinet 作為主幹特征提取網絡,並且加載預訓練權重來提升特征提取的能力。decoder 的還原部分和上面一致,下面是Mobilenet_unet 的網絡結構:
from keras.models import * from keras.layers import * import keras.backend as K import keras from tensorflow.python.keras.backend import shape IMAGE_ORDERING = "channels_last"# channel last def relu6(x): return K.relu(x, max_value=6) def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)): channel_axis = 1 if IMAGE_ORDERING == 'channels_first' else -1 filters = int(filters * alpha) x = ZeroPadding2D(padding=(1, 1), name='conv1_pad', data_format=IMAGE_ORDERING)(inputs) x = Conv2D(filters, kernel, data_format=IMAGE_ORDERING, padding='valid', use_bias=False, strides=strides, name='conv1')(x) x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x) return Activation(relu6, name='conv1_relu')(x) def _depthwise_conv_block(inputs, pointwise_conv_filters, alpha, depth_multiplier=1, strides=(1, 1), block_id=1): channel_axis = 1 if IMAGE_ORDERING == 'channels_first' else -1 pointwise_conv_filters = int(pointwise_conv_filters * alpha) x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING, name='conv_pad_%d' % block_id)(inputs) x = DepthwiseConv2D((3, 3), data_format=IMAGE_ORDERING, padding='valid', depth_multiplier=depth_multiplier, strides=strides, use_bias=False, name='conv_dw_%d' % block_id)(x) x = BatchNormalization( axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x) x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x) x = Conv2D(pointwise_conv_filters, (1, 1), data_format=IMAGE_ORDERING, padding='same', use_bias=False, strides=(1, 1), name='conv_pw_%d' % block_id)(x) x = BatchNormalization(axis=channel_axis, name='conv_pw_%d_bn' % block_id)(x) return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x) def get_mobilnet_eocoder(input_shape=(224,224,3),weights_path=""): # 必須是32 的倍數 assert input_shape[0] % 32 == 0 assert input_shape[1] % 32 == 0 alpha = 1.0 depth_multiplier = 1 img_input = Input(shape=input_shape) #(None, 224, 224, 3) ->(None, 112, 112, 64) x = _conv_block(img_input, 32, alpha, strides=(2, 2)) x = _depthwise_conv_block(x, 64, alpha, depth_multiplier, block_id=1) f1 = x #(None, 112, 112, 64) -> (None, 56, 56, 128) x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, strides=(2, 2), block_id=2) x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, block_id=3) f2 = x #(None, 56, 56, 128) -> (None, 28, 28, 256) x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, strides=(2, 2), block_id=4) x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, block_id=5) f3 = x # (None, 28, 28, 256) -> (None, 14, 14, 512) x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, strides=(2, 2), block_id=6) x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=7) x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=8) x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=9) x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=10) x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=11) f4 = x # (None, 14, 14, 512) -> (None, 7, 7, 1024) x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, strides=(2, 2), block_id=12) x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, block_id=13) f5 = x # 加載預訓練模型 if weights_path!="": Model(img_input, x).load_weights(weights_path, by_name=True, skip_mismatch=True) # f1: (None, 112, 112, 64) # f2: (None, 56, 56, 128) # f3: (None, 28, 28, 256) # f4: (None, 14, 14, 512) # f5: (None, 7, 7, 1024) return img_input, [f1, f2, f3, f4, f5] def mobilenet_unet(num_classes=2,input_shape=(224,224,3)): #encoder img_input,levels = get_mobilnet_eocoder(input_shape=input_shape,weights_path="model_data\mobilenet_1_0_224_tf_no_top.h5") [f1, f2, f3, f4, f5] = levels # f1: (None, 112, 112, 64) # f2: (None, 56, 56, 128) # f3: (None, 28, 28, 256) # f4: (None, 14, 14, 512) # f5: (None, 7, 7, 1024) #decoder #(None, 14, 14, 512) - > (None, 14, 14, 512) o = f4 o = ZeroPadding2D()(o) o = Conv2D(512, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING)(o) o = BatchNormalization()(o) #(None, 14, 14, 512) ->(None,28,28,256) o = UpSampling2D(2)(o) o = Concatenate(axis=-1)([o,f3]) o = ZeroPadding2D()(o) o = Conv2D(256, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING)(o) o = BatchNormalization()(o) # None,28,28,256)->(None,56,56,128) o = UpSampling2D(2)(o) o = Concatenate(axis=-1)([o,f2]) o = ZeroPadding2D()(o) o = Conv2D(128, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING)(o) o = BatchNormalization()(o) #(None,56,56,128) ->(None,112,112,64) o = UpSampling2D(2)(o) o = Concatenate(axis=-1)([o,f1]) o = ZeroPadding2D()(o) o = Conv2D(128, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING)(o) o = BatchNormalization()(o) #(None,112,112,64) -> (None,112,112,num_classes) # 再上采樣 讓輸入和出處圖片大小一致 o = UpSampling2D(2)(o) o = ZeroPadding2D()(o) o = Conv2D(64, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING)(o) o = BatchNormalization()(o) o = Conv2D(num_classes, (3, 3), padding='same', data_format=IMAGE_ORDERING)(o) return Model(img_input,o) if __name__=="__main__": mobilenet_unet(input_shape=(512,512,3)).summary()
特征圖的大小變化,以及代碼含義都已經註釋在代碼裡瞭。大傢仔細閱讀吧
4.數據加載部分
import math import os from random import shuffle import cv2 import keras import numpy as np from PIL import Image #------------------------------- # 將圖片轉換為 rgb #------------------------------ def cvtColor(image): if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: return image else: image = image.convert('RGB') return image #------------------------------- # 圖片歸一化 0~1 #------------------------------ def preprocess_input(image): image = image / 127.5 - 1 return image #--------------------------------------------------- # 對輸入圖像進行resize #--------------------------------------------------- def resize_image(image, size): iw, ih = image.size w, h = size scale = min(w/iw, h/ih) nw = int(iw*scale) nh = int(ih*scale) image = image.resize((nw,nh), Image.BICUBIC) new_image = Image.new('RGB', size, (128,128,128)) new_image.paste(image, ((w-nw)//2, (h-nh)//2)) return new_image, nw, nh class UnetDataset(keras.utils.Sequence): def __init__(self, annotation_lines, input_shape, batch_size, num_classes, train, dataset_path): self.annotation_lines = annotation_lines self.length = len(self.annotation_lines) self.input_shape = input_shape self.batch_size = batch_size self.num_classes = num_classes self.train = train self.dataset_path = dataset_path def __len__(self): return math.ceil(len(self.annotation_lines) / float(self.batch_size)) def __getitem__(self, index): #圖片和標簽、 images = [] targets = [] # 讀取一個batchsize for i in range(index*self.batch_size,(index+1)*self.batch_size): #判斷 i 越界情況 i = i%self.length name = self.annotation_lines[i].split()[0] # 從路徑中讀取圖像 jpg 表示圖片,png 表示標簽 jpg = Image.open(os.path.join(os.path.join(self.dataset_path,'Images'),name+'.png')) png = Image.open(os.path.join(os.path.join(self.dataset_path,'Labels'),name+'.png')) #------------------- # 數據增強 和 歸一化 #------------------- jpg,png = self.get_random_data(jpg,png,self.input_shape,random=self.train) jpg = preprocess_input(np.array(jpg,np.float64)) png = np.array(png) #----------------------------------- # 醫學圖像中 描繪出的是細胞邊緣 # 將小於 127.5的像素點 作為目標 像素點 #------------------------------------ seg_labels = np.zeros_like(png) seg_labels[png<=127.5] = 1 #-------------------------------- # 轉化為 one hot 標簽 # ------------------------- seg_labels = np.eye(self.num_classes + 1)[seg_labels.reshape([-1])] seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1)) images.append(jpg) targets.append(seg_labels) images = np.array(images) targets = np.array(targets) return images, targets def rand(self, a=0, b=1): return np.random.rand() * (b - a) + a def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True): image = cvtColor(image) label = Image.fromarray(np.array(label)) h, w = input_shape if not random: iw, ih = image.size scale = min(w/iw, h/ih) nw = int(iw*scale) nh = int(ih*scale) image = image.resize((nw,nh), Image.BICUBIC) new_image = Image.new('RGB', [w, h], (128,128,128)) new_image.paste(image, ((w-nw)//2, (h-nh)//2)) label = label.resize((nw,nh), Image.NEAREST) new_label = Image.new('L', [w, h], (0)) new_label.paste(label, ((w-nw)//2, (h-nh)//2)) return new_image, new_label # resize image rand_jit1 = self.rand(1-jitter,1+jitter) rand_jit2 = self.rand(1-jitter,1+jitter) new_ar = w/h * rand_jit1/rand_jit2 scale = self.rand(0.25, 2) if new_ar < 1: nh = int(scale*h) nw = int(nh*new_ar) else: nw = int(scale*w) nh = int(nw/new_ar) image = image.resize((nw,nh), Image.BICUBIC) label = label.resize((nw,nh), Image.NEAREST) flip = self.rand()<.5 if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT) label = label.transpose(Image.FLIP_LEFT_RIGHT) # place image dx = int(self.rand(0, w-nw)) dy = int(self.rand(0, h-nh)) new_image = Image.new('RGB', (w,h), (128,128,128)) new_label = Image.new('L', (w,h), (0)) new_image.paste(image, (dx, dy)) new_label.paste(label, (dx, dy)) image = new_image label = new_label # distort image hue = self.rand(-hue, hue) sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat) val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val) x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV) x[..., 0] += hue*360 x[..., 0][x[..., 0]>1] -= 1 x[..., 0][x[..., 0]<0] += 1 x[..., 1] *= sat x[..., 2] *= val x[x[:,:, 0]>360, 0] = 360 x[:, :, 1:][x[:, :, 1:]>1] = 1 x[x<0] = 0 image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255 return image_data,label def on_epoch_begin(self): shuffle(self.annotation_lines)
訓練過程代碼:
import numpy as np from tensorflow.python.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard from keras.optimizers import Adam import os from unet_mini import unet_mini from mobilnet_unet import mobilenet_unet from callbacks import ExponentDecayScheduler,LossHistory from keras import backend as K from keras import backend from data_loader import UnetDataset #-------------------------------------- # 交叉熵損失函數 cls_weights 類別的權重 #------------------------------------- def CE(cls_weights): cls_weights = np.reshape(cls_weights, [1, 1, 1, -1]) def _CE(y_true, y_pred): y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon()) CE_loss = - y_true[...,:-1] * K.log(y_pred) * cls_weights CE_loss = K.mean(K.sum(CE_loss, axis = -1)) # dice_loss = tf.Print(CE_loss, [CE_loss]) return CE_loss return _CE def f_score(beta=1, smooth = 1e-5, threhold = 0.5): def _f_score(y_true, y_pred): y_pred = backend.greater(y_pred, threhold) y_pred = backend.cast(y_pred, backend.floatx()) tp = backend.sum(y_true[...,:-1] * y_pred, axis=[0,1,2]) fp = backend.sum(y_pred , axis=[0,1,2]) - tp fn = backend.sum(y_true[...,:-1], axis=[0,1,2]) - tp score = ((1 + beta ** 2) * tp + smooth) \ / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) return score return _f_score def train(): #------------------------- # 細胞圖像 分為細胞壁 和其他 # 初始化 參數 #------------------------- num_classes = 2 input_shape = (512,512,3) # 從第幾個epoch 繼續訓練 batch_size = 4 learn_rate = 1e-4 start_epoch = 0 end_epoch = 100 num_workers = 4 dataset_path = 'Medical_Datasets' model = mobilenet_unet(num_classes,input_shape=input_shape) model.summary() # 讀取數據圖片的路勁 with open(os.path.join(dataset_path, "ImageSets/Segmentation/train.txt"),"r") as f: train_lines = f.readlines() logging = TensorBoard(log_dir = 'logs/') checkpoint = ModelCheckpoint('logs/ep{epoch:03d}-loss{loss:.3f}.h5', monitor = 'loss', save_weights_only = True, save_best_only = False, period = 1) reduce_lr = ExponentDecayScheduler(decay_rate = 0.96, verbose = 1) early_stopping = EarlyStopping(monitor='loss', min_delta=0, patience=10, verbose=1) loss_history = LossHistory('logs/', val_loss_flag = False) epoch_step = len(train_lines) // batch_size cls_weights = np.ones([num_classes], np.float32) loss = CE(cls_weights) model.compile(loss = loss, optimizer = Adam(lr=learn_rate), metrics = [f_score()]) train_dataloader = UnetDataset(train_lines, input_shape[:2], batch_size, num_classes, True, dataset_path) print('Train on {} samples, with batch size {}.'.format(len(train_lines), batch_size)) model.fit_generator( generator = train_dataloader, steps_per_epoch = epoch_step, epochs = end_epoch, initial_epoch = start_epoch, # use_multiprocessing = True if num_workers > 1 else False, workers = num_workers, callbacks = [logging, checkpoint, early_stopping,reduce_lr,loss_history] ) if __name__=="__main__": train()
最後的預測結果:
完整的代大傢感興趣可以去github下載下來再看,代碼比較多,全部貼出來博客顯得太長瞭。
這就是簡單的語義分割任務啦。
參考
https://github.com/bubbliiiing/unet-keras
https://github.com/divamgupta/image-segmentation-keras
以上就是Python深度學習之Unet 語義分割模型(Keras)的詳細內容,更多關於Python Unet 語義分割模型的資料請關註WalkonNet其它相關文章!
推薦閱讀:
- 關於keras中卷積層Conv2D的學習記錄
- Pytorch深度學習之實現病蟲害圖像分類
- tensorflow2.0實現復雜神經網絡(多輸入多輸出nn,Resnet)
- Python深度學習之實現卷積神經網絡
- 圖片去摩爾紋簡述實現python代碼示例