加入收藏 | 设为首页 | 会员中心 | 我要投稿 温州站长网 (https://www.0577zz.com/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 教程 > 正文

潮科技行业入门指南 | 深度学习理论与实战:提高篇(14)——Mask R-CNN代码简介

发布时间:2019-03-26 00:54:27 所属栏目:教程 来源:36氪
导读:编者按:本文节选自《深度学习理论与实战:提高篇 》一书,原文链接http://fancyerii.github.io/2019/03/14/dl-book/ 。作者李理,环信人工智能研发中心vp,有十多年自然语言处理和人工智能研发经验,主持研发过多款智能硬件的问答和对话系统,负责环信中

# 训练集500个图片
dataset_train = ShapesDataset()
dataset_train.load_shapes(500, config.IMAGE_SHAPE[0], config.IMAGE_SHAPE[1])
dataset_train.prepare()

# 验证集50个图片
dataset_val = ShapesDataset()
dataset_val.load_shapes(50, config.IMAGE_SHAPE[0], config.IMAGE_SHAPE[1])
dataset_val.prepare()


image_ids = np.random.choice(dataset_train.image_ids, 4)
for image_id in image_ids:
image = dataset_train.load_image(image_id)
mask, class_ids = dataset_train.load_mask(image_id)
visualize.display_top_masks(image, mask, class_ids, dataset_train.class_names)

随机生成的图片如下图所示,注意,因为每次都是随机生成,因此读者得到的结果可能是不同的。左图是生成的图片,右边是mask。

潮科技行业入门指南 | 深度学习理论与实战:提高篇(14)——Mask R-CNN代码简介

图:随机生成的Shape图片

3、创建模型

model = modellib.MaskRCNN(mode="training", config=config,
model_dir=MODEL_DIR)

因为我们的训练数据不多,因此使用预训练的模型进行Transfer Learning会效果更好。

# 默认使用coco模型来初始化
init_with = "coco" # imagenet, coco, or last

if init_with == "imagenet":
model.load_weights(model.get_imagenet_weights(), by_name=True)
elif init_with == "coco":
# 加载COCO模型的参数,去掉全连接层(mrcnn_bbox_fc),
# logits(mrcnn_class_logits)
# 输出的boudning box(mrcnn_bbox)和Mask(mrcnn_mask)
model.load_weights(COCO_MODEL_PATH, by_name=True,
exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",
"mrcnn_bbox", "mrcnn_mask"])
elif init_with == "last":
# 加载我们最近训练的模型来初始化
model.load_weights(model.find_last(), by_name=True)

4、训练

训练分为两个阶段:

  • heads 只训练上面没有初始化的4层网络的参数,适合训练数据较少(比如本例子)的情况

    all 训练所有的参数

    我们这里值训练heads就够了。

    model.train(dataset_train, dataset_val,
    learning_rate=config.LEARNING_RATE,
    epochs=1,
    layers='heads')

    保存模型参数:

    # 手动保存参数,这通常是不需要的,
    # 因为每次epoch介绍会自动保存,所以这里是注释掉的。
    # model_path = os.path.join(MODEL_DIR, "mask_rcnn_shapes.h5")
    # model.keras_model.save_weights(model_path)

    5、检测

    我们首先需要构造预测的Config并且加载模型参数。

    class InferenceConfig(ShapesConfig):
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1inference_config = InferenceConfig()# 重新构建用于inference的模型 model = modellib.MaskRCNN(mode="inference",
    config=inference_config,
    model_dir=MODEL_DIR)# 加载模型参数,可以手动指定也可以让它自己找最近的模型参数文件 # model_path = os.path.join(ROOT_DIR, ".h5 file name here")model_path = model.find_last()# 加载模型参数 print("Loading weights from ", model_path)model.load_weights(model_path, by_name=True)

    我们随机寻找一个图片来检测:

    # 随机选择验证集的一张图片。
    image_id = random.choice(dataset_val.image_ids)
    original_image, image_meta, gt_class_id, gt_bbox, gt_mask =
    modellib.load_image_gt(dataset_val, inference_config,
    image_id, use_mini_mask=False)

    log("original_image", original_image)
    log("image_meta", image_meta)
    log("gt_class_id", gt_class_id)
    log("gt_bbox", gt_bbox)
    log("gt_mask", gt_mask)

    visualize.display_instances(original_image, gt_bbox, gt_mask, gt_class_id,
    dataset_train.class_names, figsize=(8, 8))

    上面的代码加载一张图片,结果如下图所示,它显示的是真正的(gold/ground-truth) Bounding box和Mask。

    潮科技行业入门指南 | 深度学习理论与实战:提高篇(14)——Mask R-CNN代码简介

    图:随机挑选的测试图片

    接下来我们用模型来预测一下:

    results = model.detect([original_image], verbose=1)

    r = results[0]
    visualize.display_instances(original_image, r['rois'], r['masks'], r['class_ids'],
    dataset_val.class_names, r['scores'], ax=get_ax())

    模型预测的结果如下图所示,可以对比看成模型预测的非常准确。

    潮科技行业入门指南 | 深度学习理论与实战:提高篇(14)——Mask R-CNN代码简介

    图:模型预测的结果

    6、测试

    前面我们只是测试了一个例子,我们需要更加全面的评测。

    image_ids = np.random.choice(dataset_val.image_ids, 10)
    APs = []
    for image_id in image_ids:
    # 加载图片和正确的Bounding box以及mask
    image, image_meta, gt_class_id, gt_bbox, gt_mask =
    modellib.load_image_gt(dataset_val, inference_config,
    image_id, use_mini_mask=False)
    molded_images = np.expand_dims(modellib.mold_image(image, inference_config), 0)
    # 进行检测
    results = model.detect([image], verbose=0)
    r = results[0]
    # 计算AP
    AP, precisions, recalls, overlaps =
    utils.compute_ap(gt_bbox, gt_class_id, gt_mask,
    r["rois"], r["class_ids"], r["scores"], r['masks'])
    APs.append(AP)

    print("mAP: ", np.mean(APs))
    # 输出0.95

    inspect_data.ipynb

    这个notebook演示了Mask R-CNN的数据预处理过程。这个notebook可以用COCO数据集或者我们之前介绍的shape数据集进行演示,为了避免下载大量的COCO数据集,我们这里用shape数据集。

    1、选择数据集

    (编辑:温州站长网)

    【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

热点阅读