.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "build/examples_segmentation/voc_sota.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_build_examples_segmentation_voc_sota.py: 6. Reproducing SoTA on Pascal VOC Dataset ========================================= This is a semantic segmentation tutorial for reproducing state-of-the-art results on Pascal VOC dataset using Gluon CV toolkit. Start Training Now ~~~~~~~~~~~~~~~~~~ .. hint:: Feel free to skip the tutorial because the training script is self-complete and ready to launch. :download:`Download Full Python Script: train.py<../../../scripts/segmentation/train.py>` Example training command for training DeepLabV3:: # First finetuning COCO dataset pretrained model on the augmented set # If you would like to train from scratch on COCO, please see deeplab_resnet101_coco.sh CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset pascal_aug --model-zoo deeplab_resnet101_coco --aux --lr 0.001 --syncbn --ngpus 4 --checkname res101 # Finetuning on original set CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset pascal_voc --model deeplab --aux --backbone resnet101 --lr 0.0001 --syncbn --ngpus 4 --checkname res101 --resume runs/pascal_aug/deeplab/res101/checkpoint.params For more training command options, please run ``python train.py -h`` Please checkout the `model_zoo <../model_zoo/index.html#semantic-segmentation>`_ for training commands of reproducing the pretrained model. .. GENERATED FROM PYTHON SOURCE LINES 28-34 .. code-block:: default import numpy as np import mxnet as mx from mxnet import gluon, autograd import gluoncv .. GENERATED FROM PYTHON SOURCE LINES 35-93 Evils in the Training Details ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ State-of-the-art results [Chen17]_ [Zhao17]_ on Pascal VOC dataset are typically difficult to reproduce due to the sophisticated training details. In this tutorial we walk through our state-of-the-art implementation step-by-step. DeepLabV3 Implementation ------------------------ We implemented state-of-the-art semantic segmentation model of DeepLabV3 in Gluon-CV. Atrous Spatial Pyramid Pooling (ASPP) is the key part of DeepLabV3 model, which is built on top of FCN. It combines multiple scale features with different receptive field sizes, by using different atrous rate of dilated convolution and incorporating a global pooling branch with a global receptive field. The ASPP module is defined as:: class _ASPP(nn.HybridBlock): def __init__(self, in_channels, atrous_rates, norm_layer, norm_kwargs): super(_ASPP, self).__init__() out_channels = 256 b0 = nn.HybridSequential() with b0.name_scope(): b0.add(nn.Conv2D(in_channels=in_channels, channels=out_channels, kernel_size=1, use_bias=False)) b0.add(norm_layer(in_channels=out_channels, **norm_kwargs)) b0.add(nn.Activation("relu")) rate1, rate2, rate3 = tuple(atrous_rates) b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer, norm_kwargs) b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer, norm_kwargs) b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer, norm_kwargs) b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer, norm_kwargs=norm_kwargs) self.concurent = gluon.contrib.nn.HybridConcurrent(axis=1) with self.concurent.name_scope(): self.concurent.add(b0) self.concurent.add(b1) self.concurent.add(b2) self.concurent.add(b3) self.concurent.add(b4) self.project = nn.HybridSequential() with self.project.name_scope(): self.project.add(nn.Conv2D(in_channels=5*out_channels, channels=out_channels, kernel_size=1, use_bias=False)) self.project.add(norm_layer(in_channels=out_channels, **norm_kwargs)) self.project.add(nn.Activation("relu")) self.project.add(nn.Dropout(0.5)) def hybrid_forward(self, F, x): return self.project(self.concurent(x)) DeepLabV3 model is provided in :class:`gluoncv.model_zoo.DeepLabV3`. To get DeepLabV3 model using ResNet50 base network for VOC dataset: .. GENERATED FROM PYTHON SOURCE LINES 93-96 .. code-block:: default model = gluoncv.model_zoo.get_deeplab (dataset='pascal_voc', backbone='resnet50', pretrained=False) print(model) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none DeepLabV3( (conv1): HybridSequential( (0): Conv2D(3 -> 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (2): Activation(relu) (3): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (5): Activation(relu) (6): Conv2D(64 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu): Activation(relu) (maxpool): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(1, 1), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW) (layer1): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(128 -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu1): Activation(relu) (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu2): Activation(relu) (conv3): Conv2D(64 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(128 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) ) ) (1): BottleneckV1b( (conv1): Conv2D(256 -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu1): Activation(relu) (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu2): Activation(relu) (conv3): Conv2D(64 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(256 -> 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu1): Activation(relu) (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64) (relu2): Activation(relu) (conv3): Conv2D(64 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu3): Activation(relu) ) ) (layer2): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(256 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(256 -> 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) ) ) (1): BottleneckV1b( (conv1): Conv2D(512 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(512 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu3): Activation(relu) ) (3): BottleneckV1b( (conv1): Conv2D(512 -> 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu1): Activation(relu) (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128) (relu2): Activation(relu) (conv3): Conv2D(128 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu3): Activation(relu) ) ) (layer3): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(512 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(512 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) ) ) (1): BottleneckV1b( (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) ) (3): BottleneckV1b( (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) ) (4): BottleneckV1b( (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) ) (5): BottleneckV1b( (conv1): Conv2D(1024 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu1): Activation(relu) (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (relu2): Activation(relu) (conv3): Conv2D(256 -> 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=1024) (relu3): Activation(relu) ) ) (layer4): HybridSequential( (0): BottleneckV1b( (conv1): Conv2D(1024 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu1): Activation(relu) (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu2): Activation(relu) (conv3): Conv2D(512 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048) (relu3): Activation(relu) (downsample): HybridSequential( (0): Conv2D(1024 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048) ) ) (1): BottleneckV1b( (conv1): Conv2D(2048 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu1): Activation(relu) (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu2): Activation(relu) (conv3): Conv2D(512 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048) (relu3): Activation(relu) ) (2): BottleneckV1b( (conv1): Conv2D(2048 -> 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu1): Activation(relu) (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False) (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512) (relu2): Activation(relu) (conv3): Conv2D(512 -> 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=2048) (relu3): Activation(relu) ) ) (head): _DeepLabHead( (aspp): _ASPP( (concurent): HybridConcurrent( (0): HybridSequential( (0): Conv2D(2048 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (2): Activation(relu) ) (1): HybridSequential( (0): Conv2D(2048 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(12, 12), dilation=(12, 12), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (2): Activation(relu) ) (2): HybridSequential( (0): Conv2D(2048 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(24, 24), dilation=(24, 24), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (2): Activation(relu) ) (3): HybridSequential( (0): Conv2D(2048 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(36, 36), dilation=(36, 36), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (2): Activation(relu) ) (4): _AsppPooling( (gap): HybridSequential( (0): GlobalAvgPool2D(size=(1, 1), stride=(1, 1), padding=(0, 0), ceil_mode=True, global_pool=True, pool_type=avg, layout=NCHW) (1): Conv2D(2048 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (3): Activation(relu) ) ) ) (project): HybridSequential( (0): Conv2D(1280 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (2): Activation(relu) (3): Dropout(p = 0.5, axes=()) ) ) (block): HybridSequential( (0): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (2): Activation(relu) (3): Dropout(p = 0.1, axes=()) (4): Conv2D(256 -> 21, kernel_size=(1, 1), stride=(1, 1)) ) ) (auxlayer): _FCNHead( (block): HybridSequential( (0): Conv2D(1024 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256) (2): Activation(relu) (3): Dropout(p = 0.1, axes=()) (4): Conv2D(256 -> 21, kernel_size=(1, 1), stride=(1, 1)) ) ) ) .. GENERATED FROM PYTHON SOURCE LINES 97-107 COCO Pretraining ---------------- COCO dataset is an large instance segmentation dataset with 80 categories, which has 127K training images. From the training set of MS-COCO dataset, we select with images containing the 20 classes shared with PASCAL dataset with more than 1,000 labeled pixels, resulting 92.5K images. All the other classes are marked as background. You can simply get this dataset using the following command: .. GENERATED FROM PYTHON SOURCE LINES 107-127 .. code-block:: default # image transform for color normalization from mxnet.gluon.data.vision import transforms input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # get the dataset trainset = gluoncv.data.COCOSegmentation(split='train', transform=input_transform) print('Training images:', len(trainset)) # set batch_size = 2 for toy example batch_size = 2 # Create Training Loader train_data = gluon.data.DataLoader( trainset, batch_size, shuffle=True, last_batch='rollover', num_workers=0) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none train set loading annotations into memory... Done (t=13.29s) creating index... index created! Training images: 92516 .. GENERATED FROM PYTHON SOURCE LINES 128-130 Plot an Example of generated images: .. GENERATED FROM PYTHON SOURCE LINES 130-159 .. code-block:: default # Random pick one example for visualization: import random from datetime import datetime random.seed(datetime.now()) idx = random.randint(0, len(trainset)) img, mask = trainset[idx] from gluoncv.utils.viz import get_color_pallete, DeNormalize # get color pallete for visualize mask mask = get_color_pallete(mask.asnumpy(), dataset='coco') mask.save('mask.png') # denormalize the image img = DeNormalize([.485, .456, .406], [.229, .224, .225])(img) img = np.transpose((img.asnumpy()*255).astype(np.uint8), (1, 2, 0)) from matplotlib import pyplot as plt import matplotlib.image as mpimg # subplot 1 for img fig = plt.figure() fig.add_subplot(1,2,1) plt.imshow(img) # subplot 2 for the mask mmask = mpimg.imread('mask.png') fig.add_subplot(1,2,2) plt.imshow(mmask) # display plt.show() .. image-sg:: /build/examples_segmentation/images/sphx_glr_voc_sota_001.png :alt: voc sota :srcset: /build/examples_segmentation/images/sphx_glr_voc_sota_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 160-186 Direct launch command of the COCO pretraining:: CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset coco --model deeplab --aux --backbone resnet101 --lr 0.01 --syncbn --ngpus 4 --checkname res101 --epochs 30 You can also skip the COCO pretraining by getting the pretrained model:: from gluoncv import model_zoo model_zoo.get_model('deeplab_resnet101_coco', pretrained=True) Pascal VOC and the Augmented Set -------------------------------- Pascal VOC dataset [Everingham10]_ has 2,913 images in training and validation sets. The augmented set [Hariharan15]_ has 10,582 and 1449 training and validation images. We first fine-tune the COCO pretrained model on Pascal Augmentation dataset, then fine-tune again on Pascal VOC dataset to get the best performance. Learning Rates -------------- We use different learning rates for pretrained base network and the DeepLab head without pretrained weights. We enlarge the learning rate of the head by 10 times. A poly-like cosine learning rate scheduling strategy is used. The learning rate is given by :math:`lr = base\_lr \times (1-iters/niters)^{power}`. Please check https://gluon-cv.mxnet.io/api/utils.html#gluoncv.utils.LRScheduler for more details. .. GENERATED FROM PYTHON SOURCE LINES 186-189 .. code-block:: default lr_scheduler = gluoncv.utils.LRScheduler(mode='poly', base_lr=0.01, nepochs=30, iters_per_epoch=len(train_data), power=0.9) .. GENERATED FROM PYTHON SOURCE LINES 190-194 We first use the base learning rate of 0.01 to pretrain on MS-COCO dataset, then we divide the base learning rate by 10 times and 100 times respectively when fine-tuning on Pascal Augmented dataset and Pascal VOC original dataset. .. GENERATED FROM PYTHON SOURCE LINES 196-214 You can `Start Training Now`_. References ---------- .. [Chen17] Chen, Liang-Chieh, et al. "Rethinking atrous convolution for semantic image segmentation." \ arXiv preprint arXiv:1706.05587 (2017). .. [Zhao17] Zhao, Hengshuang, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia. \ "Pyramid scene parsing network." IEEE Conf. on Computer Vision and Pattern Recognition (CVPR). 2017. .. [Everingham10] Everingham, Mark, Luc Van Gool, Christopher KI Williams, John Winn, \ and Andrew Zisserman. "The pascal visual object classes (voc) challenge." \ International journal of computer vision 88, no. 2 (2010): 303-338. .. [Hariharan15] Hariharan, Bharath, Pablo Arbeláez, Ross Girshick, and Jitendra Malik. \ "Hypercolumns for object segmentation and fine-grained localization." In Proceedings of \ the IEEE conference on computer vision and pattern recognition, pp. 447-456. 2015. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 23.867 seconds) .. _sphx_glr_download_build_examples_segmentation_voc_sota.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: voc_sota.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: voc_sota.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_