.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "build/examples_pose/dive_deep_simple_pose.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_build_examples_pose_dive_deep_simple_pose.py: 4. Dive deep into Training a Simple Pose Model on COCO Keypoints =================================================================== In this tutorial, we show you how to train a pose estimation model [1]_ on the COCO dataset. First let's import some necessary modules. .. GENERATED FROM PYTHON SOURCE LINES 8-26 .. code-block:: default from __future__ import division import time, logging, os, math import numpy as np import mxnet as mx from mxnet import gluon, nd from mxnet import autograd as ag from mxnet.gluon import nn from mxnet.gluon.data.vision import transforms from gluoncv.data import mscoco from gluoncv.model_zoo import get_model from gluoncv.utils import makedirs, LRScheduler from gluoncv.data.transforms.presets.simple_pose import SimplePoseDefaultTrainTransform from gluoncv.utils.metrics import HeatmapAccuracy .. GENERATED FROM PYTHON SOURCE LINES 27-32 Loading the data ---------------- We can load COCO Keypoints dataset with their official API .. GENERATED FROM PYTHON SOURCE LINES 32-37 .. code-block:: default train_dataset = mscoco.keypoints.COCOKeyPoints('~/.mxnet/datasets/coco', splits=('person_keypoints_train2017')) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none loading annotations into memory... Done (t=9.00s) creating index... index created! .. GENERATED FROM PYTHON SOURCE LINES 38-52 The dataset object enables us to retrieve images containing a person, the person's keypoints, and meta-information. Following the original paper, we resize the input to be ``(256, 192)``. For augmentation, we randomly scale, rotate or flip the input. Finally we normalize it with the standard ImageNet statistics. The COCO keypoints dataset contains 17 keypoints for a person. Each keypoint is annotated with three numbers ``(x, y, v)``, where ``x`` and ``y`` mark the coordinates, and ``v`` indicates if the keypoint is visible. For each keypoint, we generate a gaussian kernel centered at the ``(x, y)`` coordinate, and use it as the training label. This means the model predicts a gaussian distribution on a feature map. .. GENERATED FROM PYTHON SOURCE LINES 52-58 .. code-block:: default transform_train = SimplePoseDefaultTrainTransform(num_joints=train_dataset.num_joints, joint_pairs=train_dataset.joint_pairs, image_size=(256, 192), heatmap_size=(64, 48), scale_factor=0.30, rotation_factor=40, random_flip=True) .. GENERATED FROM PYTHON SOURCE LINES 59-62 Now we can define our data loader with the dataset and transformation. We will iterate over ``train_data`` in our training loop. .. GENERATED FROM PYTHON SOURCE LINES 63-70 .. code-block:: default batch_size = 32 train_data = gluon.data.DataLoader( train_dataset.transform(transform_train), batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=0) .. GENERATED FROM PYTHON SOURCE LINES 71-86 Deconvolution Layer ------------------- A deconvolution layer enlarges the feature map size of the input, so that it can be seen as a layer upsamling the input feature map. .. image:: https://raw.githubusercontent.com/vdumoulin/conv_arithmetic/master/gif/no_padding_no_strides_transposed.gif :width: 40% :align: center In the above image, the blue map is the input feature map, and the cyan map is the output. In a ``ResNet`` model, the last feature map shrinks its height and width to be only 1/32 of the input. It may be too small for a heatmap prediction. However if followed by several deconvolution layers, the feature map can have a larger size thus easier to make the prediction. .. GENERATED FROM PYTHON SOURCE LINES 89-99 Model Definition ----------------- A Simple Pose model consists of a main body of a resnet, and several deconvolution layers. Its final layer is a convolution layer predicting one heatmap for each keypoint. Let's take a look at the smallest one from the GluonCV Model Zoo, using ``ResNet18`` as its base model. We load the pre-trained parameters for the ``ResNet18`` layers, and initialize the deconvolution layer and the final convolution layer. .. GENERATED FROM PYTHON SOURCE LINES 99-106 .. code-block:: default context = mx.gpu(0) net = get_model('simple_pose_resnet18_v1b', num_joints=17, pretrained_base=True, ctx=context, pretrained_ctx=context) net.deconv_layers.initialize(ctx=context) net.final_layer.initialize(ctx=context) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Downloading /root/.mxnet/models/resnet18_v1b-2d9d980c.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/resnet18_v1b-2d9d980c.zip... 0%| | 0/42432 [00:00 0: break data = gluon.utils.split_and_load(batch[0], ctx_list=[context], batch_axis=0) label = gluon.utils.split_and_load(batch[1], ctx_list=[context], batch_axis=0) weight = gluon.utils.split_and_load(batch[2], ctx_list=[context], batch_axis=0) with ag.record(): outputs = [net(X) for X in data] loss = [L(yhat, y, w) for yhat, y, w in zip(outputs, label, weight)] for l in loss: l.backward() trainer.step(batch_size) metric.update(label, outputs) break .. GENERATED FROM PYTHON SOURCE LINES 199-212 Due to limitation on the resources, we only train the model for one batch in this tutorial. Please checkout the full :download:`training script <../../../scripts/pose/simple_pose/train_simple_pose.py>` to reproduce our results. References ---------- .. [1] Xiao, Bin, Haiping Wu, and Yichen Wei. \ "Simple baselines for human pose estimation and tracking." \ Proceedings of the European Conference on Computer Vision (ECCV). 2018. .. [2] https://github.com/Microsoft/human-pose-estimation.pytorch/issues/48 .. [3] https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/human_pose_estimation#known-issues .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 4 minutes 16.713 seconds) .. _sphx_glr_download_build_examples_pose_dive_deep_simple_pose.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: dive_deep_simple_pose.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: dive_deep_simple_pose.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_