第五章–卷積神經網路基礎–八股法搭建卷積神經網路
本講目標:
??介紹神經網路基本概念,用八股法實作卷積神經網路(以cifar10為例,本節建立的框架作為后續網路的baseline,在baseline中修改實作其他網路),參考視頻,
卷積神經網路基礎
- 0.回顧全連接神經網路
- 1.卷積計算程序
- 1.1-卷積概念
- 1.2-卷積核的表示
- 1.3-單通道影像卷積計算
- 1.4-RGB通道影像卷積計算
- 2.感受野
- 2.1-感受野(Receptive Field)計算
- 3.全零填充
- 3.1-全零填充計算
- 3.2-全零填充TF描述
- 4.TF描述卷積層
- 5.批標準化(Batch Normalization,BN)
- 6.池化層
- 7.舍棄層(Dropout)
- 8.搭建卷積神經網路(CBAPB)
- 8.1-CBAPB組成
- 8.2-CBAPB示例
- 9.八股法搭建完整卷積神經網路
- 9.1-完整代碼
- 9.2-輸出結果
- 10.總結八股創建神經網路
- 10.1 import
- 10.2 train,test
- 10.3搭建神經網路結構----Sequential or class
- 10.4model.compile
- 10.5 model.fit
- 10.6 model.summary
0.回顧全連接神經網路
??每個神經元與前后相鄰層的每一個神經元都有連接關系,輸入是特征,輸出為預測的結果,全連接網路的引數個數為:
第一層引數:784x128+128
第二層引數:128x10+10
總共101770個引數
??但在實際應用中,圖片大多數高解析度的多通道圖片,如下圖所示:
??如此直接輸入到全連接網路,會使得待優化的引數過多,容易導致模型的過擬合,
??為了解決引數量過大而導致模型過擬合的問題,一般不會將原始影像直接輸入,而是對影像進行特征提取,再將提取到的特征輸入到全連接網路,如下圖所示,是將汽車圖片經過多次特征提取后再送入全連接網路,
1.卷積計算程序
1.1-卷積概念
??卷積計算可認為是一種有效提取影像特征的方法,一般會用一個正方形的卷積核,按指定步長,在輸入特征圖上滑動,遍歷輸入特征圖中的每個像素點,
??每一個步長,卷積核會與輸入特征圖出現重合區域,重合區域對應元素相乘、求和再加上偏置項得到輸出特征的一個像素點,詳細計算如下圖所示:
??輸入特征圖的深度(channel),決定了當前層卷積核的深度),當前層卷積核的個數,決定了當前層輸出特征圖的深度,
1.2-卷積核的表示
左圖是單通道的3x3卷積,共10個引數;
中間是三通道的3x3卷積,共28個引數;
右圖是三通道的5x5卷積,共76個引數,
1.3-單通道影像卷積計算
1.4-RGB通道影像卷積計算
2.感受野
2.1-感受野(Receptive Field)計算
??感受野是指卷積神經網路個輸出特征圖中的每個像素點,在原始輸入圖片上映射區域的大小,
??感受野的相關概念及大小選型可以參考這篇文章,
3.全零填充
3.1-全零填充計算
??為了保持輸出影像尺寸與輸入影像一致,經常會在輸入影像周圍進行全零填充,如下圖所示,在5x5的輸入影像周圍填0,則輸出的尺寸仍為5x5,
3.2-全零填充TF描述
??tensorflow描述全零填充用引數padding=’SAME’或padding=’VALID’表示,
4.TF描述卷積層
tf.keras.layers.Conv2D (
filters = 卷積核個數,
kernel_size = 卷積核尺寸,
strides = 滑動步長,
padding = “same” or “valid”, #使用全零填充是“same”,不使用是“valid”(默認)
activation = “ relu ” or “ sigmoid ” or “ tanh ” or “ softmax”等 , # 如有BN 此處不寫
input_shape = (高, 寬 , 通道數) #輸入特征圖維度,可省略
)
??卷積層的表示如下:
5.批標準化(Batch Normalization,BN)
??標準化:使資料符合0均值1為標準差的分布
??批標準化:對一小批資料(batch),做標準化處理
6.池化層
?池化用于減少特征數量;
?最大值池化可提取圖片紋理;
tf.keras.layers.MaxPool2D(
pool_size= 池化核尺寸,
strides= 池化步長,#默認為pool_size
padding=‘valid’or‘same’ #(默認)“valid”
)
?均值池化可保留背景特征;
tf.keras.layers.AveragePooling2D(
pool_size= 池化核尺寸,
strides= 池化步長,#默認為pool_size
padding=‘valid’or‘same’ # (默認)“valid”)
)
7.舍棄層(Dropout)
??神經網路訓練時,將一部分神經元按照一定概率從神經網路中暫時舍棄,
??神經網路使用時,被舍棄的神經元恢復連接,
8.搭建卷積神經網路(CBAPB)
8.1-CBAPB組成
??利用上述知識,就可以構建出基本的卷積神經網路(CNN)了,其核心思路為在 CNN中利用卷積核(kernel)提取特征后,送入全連接網路,
??CNN 模型的主要模塊:一般包括上述的卷積層(Conv2D)、BN 層、激活函式(Activation)、池化層(Pooling)、失活(舍棄)層(Dropout)以及全連接層,
??故取特征提取部分的各個模塊的首字母,組成CBPAB,
??牢記以下五個部分組成卷積:
卷積就是特征提取器:CBAPB
Conv2D、BatchNormalization、Activation、Pooling、Dropout
卷積就是特征提取器:CBAPB
Conv2D、BatchNormalization、Activation、Pooling、Dropout
卷積就是特征提取器:CBAPB
Conv2D、BatchNormalization、Activation、Pooling、Dropout
卷積就是特征提取器:CBAPB
Conv2D、BatchNormalization、Activation、Pooling、Dropout
卷積就是特征提取器:CBAPB
Conv2D、BatchNormalization、Activation、Pooling、Dropout
卷積就是特征提取器:CBAPB
Conv2D、BatchNormalization、Activation、Pooling、Dropout
卷積就是特征提取器:CBAPB
Conv2D、BatchNormalization、Activation、Pooling、Dropout
卷積就是特征提取器:CBAPB
Conv2D、BatchNormalization、Activation、Pooling、Dropout
卷積就是特征提取器:CBAPB
Conv2D、BatchNormalization、Activation、Pooling、Dropout
8.2-CBAPB示例
?搭建如下網路,記住CBAPB
?卷積就是特征提取器,就是CBAPB,
卷積就是特征提取器:CBAPB
9.八股法搭建完整卷積神經網路
9.1-完整代碼
#六步法第一步->import匯入需要的包
import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D,BatchNormalization,Activation,MaxPool2D,Dropout,Flatten,Dense
from tensorflow.keras import Model
np.set_printoptions(threshold=np.inf)
#六步法第二步->輸入train,test
cifar10=tf.keras.datasets.cifar10
(x_train,y_train),(x_test,y_test)=cifar10.load_data()
x_train,x_test=x_train/255.,x_test/255.
#六步法第三步->Model搭建網路結構
class Baseline(Model):
def __init__(self):
super(Baseline,self).__init__()
#卷積就是CBAPD
#C
self.c1=Conv2D(filters=6,kernel_size=(5,5),padding='same')#卷積層
#B
self.b1=BatchNormalization()#BN層
#A
self.a1=Activation('relu')#激活層
#P
self.p1=MaxPool2D(pool_size=(2,2),strides=2,padding='same')#池化層
#D
self.d1=Dropout(0.2)#dropout層
self.flatten=Flatten()
self.f1=Dense(128,activation='relu')
self.d2=Dropout(0.2)
self.f2=Dense(10,activation='softmax')
def call(self,x):
x=self.c1(x)
x=self.b1(x)
x=self.a1(x)
x=self.p1(x)
x=self.d1(x)
x=self.flatten(x)
x=self.f1(x)
x=self.d2(x)
y=self.f2(x)
return y
model=Baseline()
#六步法第四步->compile配置模型
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
#加載已經有的模型,如果沒有則不加載
checkpoint_save_path='./checkpoint/Baseline.ckpt'
if os.path.exists(checkpoint_save_path+'.index'):
print('-------load the model-------')
model.load_weights(checkpoint_save_path)
#撰寫自動保存模型的回呼函式
cp_callback=tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True
)
#六步法第五步->fit訓練模型
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
callbacks=[cp_callback])
#六步法第五步->summary查看網路模型
model.summary()
# print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
file.close()
############################################### show ###############################################
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1,2,1)
plt.plot(acc,label='Training Accuracy')
plt.plot(val_acc,label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
9.2-輸出結果
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 3s 0us/step
170508288/170498071 [==============================] - 3s 0us/step
Epoch 1/5
1563/1563 [==============================] - 19s 6ms/step - loss: 1.6454 - sparse_categorical_accuracy: 0.4035 - val_loss: 1.4076 - val_sparse_categorical_accuracy: 0.5007
Epoch 2/5
1563/1563 [==============================] - 9s 6ms/step - loss: 1.3999 - sparse_categorical_accuracy: 0.4932 - val_loss: 1.2718 - val_sparse_categorical_accuracy: 0.5448
Epoch 3/5
1563/1563 [==============================] - 9s 6ms/step - loss: 1.3375 - sparse_categorical_accuracy: 0.5205 - val_loss: 1.2459 - val_sparse_categorical_accuracy: 0.5546
Epoch 4/5
1563/1563 [==============================] - 9s 6ms/step - loss: 1.2863 - sparse_categorical_accuracy: 0.5368 - val_loss: 1.3322 - val_sparse_categorical_accuracy: 0.5338
Epoch 5/5
1563/1563 [==============================] - 9s 5ms/step - loss: 1.2570 - sparse_categorical_accuracy: 0.5509 - val_loss: 1.1740 - val_sparse_categorical_accuracy: 0.5871
Model: "baseline"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) multiple 456
batch_normalization (BatchN multiple 24
ormalization)
activation (Activation) multiple 0
max_pooling2d (MaxPooling2D multiple 0
)
dropout (Dropout) multiple 0
flatten (Flatten) multiple 0
dense (Dense) multiple 196736
dropout_1 (Dropout) multiple 0
dense_1 (Dense) multiple 1290
=================================================================
Total params: 198,506
Trainable params: 198,494
Non-trainable params: 12
________________________________________________________________
10.總結八股創建神經網路
10.1 import
??引入tensorflow、keras、numpy、matplotlib等庫,可以同時引入keras中的layers、models等庫方便呼叫內部的API,
10.2 train,test
??讀取資料集,可以來源于框架本身帶的資料集,如mnist,fashion,鳶尾花資料集等,或者實際應用中需要自己制作資料集,
10.3搭建神經網路結構----Sequential or class
??當網路結構比較簡單時,可以利用keras中的tf.keras.Sequential 來搭建順序網路模型,
但當網路不是簡單的順序模型時(如殘差網路),則需要用class來定義自己的網路結構,
10.4model.compile
??對搭建好的網路進行編譯,需要指定優化器(Adam、sgd、RMSdrop)、損失函式(交叉熵、均方差)以及需要記錄的準確率和損失值(acc/loss)
10.5 model.fit
??指定訓練資料、驗證資料、迭代輪數、批量大小等等,由于神經網路的引數量和計算量一般都比較大,訓練所需的時間也會比較長,所以會在這里加入斷點續訓以及模型引數的保存等等,使得訓練更加方便,同時防止程式意外停止導致資料丟失的情況,
10.6 model.summary
??將神經網路的模型具體資訊列印出來,包括網路結構、網路各層引數等,便于對網路進行瀏覽和檢查,
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/438653.html
標籤:AI
上一篇:資料分析工具Pandas