# 训练集500个图片 dataset_train = ShapesDataset() dataset_train.load_shapes(500, config.IMAGE_SHAPE[0], config.IMAGE_SHAPE[1]) dataset_train.prepare()
# 验证集50个图片 dataset_val = ShapesDataset() dataset_val.load_shapes(50, config.IMAGE_SHAPE[0], config.IMAGE_SHAPE[1]) dataset_val.prepare()
image_ids = np.random.choice(dataset_train.image_ids, 4) for image_id in image_ids: image = dataset_train.load_image(image_id) mask, class_ids = dataset_train.load_mask(image_id) visualize.display_top_masks(image, mask, class_ids, dataset_train.class_names)
随机生成的图片如下图所示,注意,因为每次都是随机生成,因此读者得到的结果可能是不同的。左图是生成的图片,右边是mask。

图:随机生成的Shape图片
3、创建模型
model = modellib.MaskRCNN(mode="training", config=config, model_dir=MODEL_DIR)
因为我们的训练数据不多,因此使用预训练的模型进行Transfer Learning会效果更好。
# 默认使用coco模型来初始化 init_with = "coco" # imagenet, coco, or last
if init_with == "imagenet": model.load_weights(model.get_imagenet_weights(), by_name=True) elif init_with == "coco": # 加载COCO模型的参数,去掉全连接层(mrcnn_bbox_fc), # logits(mrcnn_class_logits) # 输出的boudning box(mrcnn_bbox)和Mask(mrcnn_mask) model.load_weights(COCO_MODEL_PATH, by_name=True, exclude=["mrcnn_class_logits", "mrcnn_bbox_fc", "mrcnn_bbox", "mrcnn_mask"]) elif init_with == "last": # 加载我们最近训练的模型来初始化 model.load_weights(model.find_last(), by_name=True)
4、训练
训练分为两个阶段:
|