加入收藏 | 设为首页 | 会员中心 | 我要投稿 温州站长网 (https://www.0577zz.com/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 教程 > 正文

潮科技行业入门指南 | 深度学习理论与实战:提高篇(14)——Mask R-CNN代码简介

发布时间:2019-03-26 00:54:27 所属栏目:教程 来源:36氪
导读:编者按:本文节选自《深度学习理论与实战:提高篇 》一书,原文链接http://fancyerii.github.io/2019/03/14/dl-book/ 。作者李理,环信人工智能研发中心vp,有十多年自然语言处理和人工智能研发经验,主持研发过多款智能硬件的问答和对话系统,负责环信中

编者按:本文节选自《深度学习理论与实战:提高篇 》一书,原文链接http://fancyerii.github.io/2019/03/14/dl-book/ 。作者李理,环信人工智能研发中心vp,有十多年自然语言处理和人工智能研发经验,主持研发过多款智能硬件的问答和对话系统,负责环信中文语义分析开放平台和环信智能机器人的设计与研发。

以下为正文。

潮科技行业入门指南 | 深度学习理论与实战:提高篇(14)——Mask R-CNN代码简介

目录

  • 安装

    demo.ipynb

    • 运行

      关键代码

      train_shapes.ipynb

      • 配置

        Dataset

        创建模型

        训练

        检测

        测试

        inspect_data.ipynb

        • 选择数据集

          加载Dataset

          显示样本

          Bounding Box

          Mini Masks

          Anchor

          训练数据生成器

          Facebook(Mask R-CNN的作者He Kaiming等人目前在Facebook)的实现在这里。但是这是用Caffe2实现的,本书没有介绍这个框架,因此我们介绍Tensorflow和Keras的版本实现的版本。但是建议有兴趣的读者也可以尝试一下Facebook提供的代码。

          安装(((0)))

          demo.ipynb

          1、运行

          jupyter notebook
          打开文件samples/demo.ipynb,运行所有的Cell

          2、关键代码

          这里是使用预训练的模型,会自动上网下载,所以第一次运行会比较慢。这是下载模型参数的代码:

          COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
          # Download COCO trained weights from Releases if needed
          if not os.path.exists(COCO_MODEL_PATH):
          utils.download_trained_weights(COCO_MODEL_PATH)

          创建模型和加载参数:

          # 创建MaskRCNN对象,模式是inferencemodel = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)

          # 加载模型参数 model.load_weights(COCO_MODEL_PATH, by_name=True)

          读取图片并且进行分割:

          # 随机加载一张图片
          file_names = next(os.walk(IMAGE_DIR))[2]
          image = skimage.io.imread(os.path.join(IMAGE_DIR, random.choice(file_names)))

          # 进行目标检测和分割
          results = model.detect([image], verbose=1)

          # 显示结果
          r = results[0]
          visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'],
          class_names, r['scores'])

          检测结果r包括rois(RoI)、masks(对应RoI的每个像素是否属于目标物体)、scores(得分)和class_ids(类别)。

          下图是运行的效果,我们可以看到它检测出来4个目标物体,并且精确到像素级的分割处理物体和背景。

          潮科技行业入门指南 | 深度学习理论与实战:提高篇(14)——Mask R-CNN代码简介

          图:Mask RCNN检测效果

          train_shapes.ipynb

          除了可以使用训练好的模型,我们也可以用自己的数据进行训练,为了演示,这里使用了一个很小的shape数据集。这个数据集是on-the-fly的用代码生成的一些三角形、正方形、圆形,因此不需要下载数据。

          1、配置

          代码提供了基础的类Config,我们只需要继承并稍作修改:

          class ShapesConfig(Config):
          """用于训练shape数据集的配置
          继承子基本的Config类,然后override了一些配置项。
          """
          # 起个好记的名字
          NAME = "shapes"

          # 使用一个GPU训练,每个GPU上8个图片。因此batch大小是8 (GPUs * images/GPU).
          GPU_COUNT = 1
          IMAGES_PER_GPU = 8

          # 分类数(需要包括背景类)
          NUM_CLASSES = 1 + 3 # background + 3 shapes

          # 图片为固定的128x128
          IMAGE_MIN_DIM = 128
          IMAGE_MAX_DIM = 128

          # 因为图片比较小,所以RPN anchor也是比较小的
          RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128) # anchor side in pixels

          # 每张图片建议的RoI数量,对于这个小图片的例子可以取比较小的值。
          TRAIN_ROIS_PER_IMAGE = 32

          # 每个epoch的数据量
          STEPS_PER_EPOCH = 100

          # 每5步验证一下。
          VALIDATION_STEPS = 5

          config = ShapesConfig()
          config.display()

          2、Dataset

          对于我们自己的数据集,我们需要继承utils.Dataset类,并且重写如下方法:

          • load_image

            load_mask

            image_reference

            在重写这3个方法之前我们首先来看load_shapes,这个函数on-the-fly的生成数据。

            class ShapesDataset(utils.Dataset):
            """随机生成shape数据。包括三角形,正方形和圆形,以及它的位置。
            这是on-th-fly的生成数据,因此不需要访问文件。
            """

            def load_shapes(self, count, height, width):
            """生成图片
            count: 返回的图片数量
            height, width: 生成图片的height和width
            """
            # 类别
            self.add_class("shapes", 1, "square")
            self.add_class("shapes", 2, "circle")
            self.add_class("shapes", 3, "triangle")

            # 注意:这里只是生成图片的specifications(说明书),
            # 具体包括性质、颜色、大小和位置等信息。
            # 真正的图片是在load_image()函数里根据这些specifications
            # 来on-th-fly的生成。
            for i in range(count):
            bg_color, shapes = self.random_image(height, width)
            self.add_image("shapes", image_id=i, path=None,
            width=width, height=height,
            bg_color=bg_color, shapes=shapes)

            其中add_image是在基类中定义:

            def add_image(self, source, image_id, path, **kwargs):
            image_info = {
            "id": image_id,
            "source": source,
            "path": path,
            }
            image_info.update(kwargs)
            self.image_info.append(image_info)

            它有3个命名参数source、image_id和path。source是标识图片的来源,我们这里都是固定的字符串”shapes”;image_id是图片的id,我们这里用生成的序号i,而path一般标识图片的路径,我们这里是None。其余的参数就原封不动的保存下来。

            random_image函数随机的生成图片的位置,请读者仔细阅读代码注释。

            (编辑:温州站长网)

            【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

热点阅读