diff --git a/src/Main.py b/src/Main.py index 1fe61c5..effe030 100644 --- a/src/Main.py +++ b/src/Main.py @@ -6,10 +6,10 @@ from sklearn.model_selection import train_test_split from sklearn.utils import shuffle -from src.utils import get_data +from src.utils import get_data, image_h_w from src.UNetKeras import UNetKeras import tensorflow as tf - +BATCH_SIZE = 1 if __name__ == "__main__": print("========= Get data =========") X, y = get_data() @@ -19,27 +19,27 @@ del X, y print("========= Build model =========") - model = UNetKeras() + model = UNetKeras(height=image_h_w, width=image_h_w) model.compile() print("========= Start train model =========") ModelCheckpoint = tf.keras.callbacks.ModelCheckpoint("model/val_best_model.h5", monitor="val_loss", verbose=1, save_best_only=True) - model.fit(X_train, y_train, batch_size=1, epochs=10, validation_split=0.01, callbacks=[ModelCheckpoint]) + model.fit(X_train, y_train, batch_size=BATCH_SIZE, epochs=10, validation_split=0.01, callbacks=[ModelCheckpoint]) print("========= Save last time model =========") model.model.save("model/model_final_time.h5") print("========= Model evaluate Start =========") print("========= Test the last saved model =========") - model.model.evaluate(X_test, y_test, batch_size=32) - pred = model.predict(X_test) + model.model.evaluate(X_test, y_test, batch_size=BATCH_SIZE) + pred = model.predict(X_test,batch_size=BATCH_SIZE) import matplotlib.pyplot as plt import numpy as np import os - images_test = np.reshape(np.argmax(y_test, axis=-1), (-1, 512, 512)) - images_pred = np.reshape(pred, (-1, 512, 512)) + images_test = np.reshape(np.argmax(y_test, axis=-1), (-1, image_h_w, image_h_w)) + images_pred = np.reshape(pred, (-1, image_h_w, image_h_w)) predict_file_path = ".predict_images_last_saved_model/" if not os.path.exists(predict_file_path): os.mkdir(predict_file_path) @@ -59,9 +59,9 @@ os.mkdir(predict_file_path) val_best_model = tf.keras.models.load_model("model/val_best_model.h5") - val_best_model.evaluate(X_test, y_test, batch_size=32) - pred = val_best_model.predict(X_test) - images_pred = np.reshape(np.argmax(pred, axis=-1), (-1, 512, 512)) + val_best_model.evaluate(X_test, y_test, batch_size=BATCH_SIZE) + pred = val_best_model.predict(X_test,batch_size=BATCH_SIZE) + images_pred = np.reshape(np.argmax(pred, axis=-1), (-1, image_h_w, image_h_w)) for i in range(len(images_pred)): # plt.subplot(1, 2, 1) diff --git a/src/utils.py b/src/utils.py index 1d90b4b..8d1234c 100644 --- a/src/utils.py +++ b/src/utils.py @@ -8,6 +8,9 @@ import numpy as np import h5py from scipy import ndimage +import scipy + +image_h_w = 256 # 加载数据 @@ -24,7 +27,8 @@ def _load(): for i in range(1, 2201): now_file_path = "../data/Image/IM" + str(i) + ".png" image = np.array(ndimage.imread(now_file_path, flatten=False)) - images.append(image) # images shape=(m,64,64,3) + image = scipy.misc.imresize(image, size=(image_h_w, image_h_w)) + images.append(image) images = np.array(images, copy=True) file = h5py.File('../data/images.h5', 'w') # 创建HDF5文件 file.create_dataset('images', data=images) # 写入 @@ -50,7 +54,7 @@ def _load(): labels = flie.get("labels") labels = np.array(labels, dtype=np.float32) - images = images/255. + images = images / 255. train_image = np.expand_dims(images, -1) print(train_image.shape) train_label = np.expand_dims(labels, -1) @@ -89,4 +93,3 @@ def encode_one_hot(x, classes_num=3): else: raise IndexError("The last dimension is not 1") return x_tiled -