.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "build/examples_detection/finetune_detection.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_build_examples_detection_finetune_detection.py: 08. Finetune a pretrained detection model ============================================ Fine-tuning is commonly used approach to transfer previously trained model to a new dataset. It is especially useful if the targeting new dataset is relatively small. Finetuning from pre-trained models can help reduce the risk of overfitting. Finetuned model may also generalizes better if the previously used dataset is in the similar domain of the new dataset. This tutorial opens up a good approach for fine-tuning object detection models provided by GluonCV. More Specifically, we show how to use a customized Pikachu dataset and illustrate the finetuning fundamentals step by step. You will be familiarize the steps and modify it to fit your own object detection projects. .. GENERATED FROM PYTHON SOURCE LINES 16-24 .. code-block:: default import time from matplotlib import pyplot as plt import numpy as np import mxnet as mx from mxnet import autograd, gluon import gluoncv as gcv from gluoncv.utils import download, viz .. GENERATED FROM PYTHON SOURCE LINES 25-29 Pikachu Dataset ---------------- First we will start with a nice Pikachu dataset generated by rendering 3D models on random real-world scenes. You can refer to :ref:`sphx_glr_build_examples_datasets_detection_custom.py` for tutorial of how to create your own datasets. .. GENERATED FROM PYTHON SOURCE LINES 29-34 .. code-block:: default url = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/pikachu/train.rec' idx_url = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/pikachu/train.idx' download(url, path='pikachu_train.rec', overwrite=False) download(idx_url, path='pikachu_train.idx', overwrite=False) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Downloading pikachu_train.rec from https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/pikachu/train.rec... 0%| | 0/85604 [00:00` .. GENERATED FROM PYTHON SOURCE LINES 80-95 .. code-block:: default def get_dataloader(net, train_dataset, data_shape, batch_size, num_workers): from gluoncv.data.batchify import Tuple, Stack, Pad from gluoncv.data.transforms.presets.ssd import SSDDefaultTrainTransform width, height = data_shape, data_shape # use fake data to generate fixed anchors for target generation with autograd.train_mode(): _, _, anchors = net(mx.nd.zeros((1, 3, height, width))) batchify_fn = Tuple(Stack(), Stack(), Stack()) # stack image, cls_targets, box_targets train_loader = gluon.data.DataLoader( train_dataset.transform(SSDDefaultTrainTransform(width, height, anchors)), batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers) return train_loader train_data = get_dataloader(net, dataset, 512, 16, 0) .. GENERATED FROM PYTHON SOURCE LINES 96-97 Try use GPU for training .. GENERATED FROM PYTHON SOURCE LINES 97-103 .. code-block:: default try: a = mx.nd.zeros((1,), ctx=mx.gpu(0)) ctx = [mx.gpu(0)] except: ctx = [mx.cpu()] .. GENERATED FROM PYTHON SOURCE LINES 104-105 Start training(finetuning) .. GENERATED FROM PYTHON SOURCE LINES 105-147 .. code-block:: default net.collect_params().reset_ctx(ctx) trainer = gluon.Trainer( net.collect_params(), 'sgd', {'learning_rate': 0.001, 'wd': 0.0005, 'momentum': 0.9}) mbox_loss = gcv.loss.SSDMultiBoxLoss() ce_metric = mx.metric.Loss('CrossEntropy') smoothl1_metric = mx.metric.Loss('SmoothL1') for epoch in range(0, 2): ce_metric.reset() smoothl1_metric.reset() tic = time.time() btic = time.time() net.hybridize(static_alloc=True, static_shape=True) for i, batch in enumerate(train_data): batch_size = batch[0].shape[0] data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) cls_targets = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) box_targets = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0) with autograd.record(): cls_preds = [] box_preds = [] for x in data: cls_pred, box_pred, _ = net(x) cls_preds.append(cls_pred) box_preds.append(box_pred) sum_loss, cls_loss, box_loss = mbox_loss( cls_preds, box_preds, cls_targets, box_targets) autograd.backward(sum_loss) # since we have already normalized the loss, we don't want to normalize # by batch-size anymore trainer.step(1) ce_metric.update(0, [l * batch_size for l in cls_loss]) smoothl1_metric.update(0, [l * batch_size for l in box_loss]) name1, loss1 = ce_metric.get() name2, loss2 = smoothl1_metric.get() if i % 20 == 0: print('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'.format( epoch, i, batch_size/(time.time()-btic), name1, loss1, name2, loss2)) btic = time.time() .. rst-class:: sphx-glr-script-out Out: .. code-block:: none [Epoch 0][Batch 0], Speed: 8.521 samples/sec, CrossEntropy=11.958, SmoothL1=1.986 [Epoch 0][Batch 20], Speed: 19.758 samples/sec, CrossEntropy=4.358, SmoothL1=1.233 [Epoch 0][Batch 40], Speed: 21.477 samples/sec, CrossEntropy=3.299, SmoothL1=0.968 [Epoch 1][Batch 0], Speed: 20.556 samples/sec, CrossEntropy=1.529, SmoothL1=0.308 [Epoch 1][Batch 20], Speed: 19.868 samples/sec, CrossEntropy=1.613, SmoothL1=0.454 [Epoch 1][Batch 40], Speed: 19.864 samples/sec, CrossEntropy=1.577, SmoothL1=0.453 .. GENERATED FROM PYTHON SOURCE LINES 148-149 Save finetuned weights to disk .. GENERATED FROM PYTHON SOURCE LINES 149-151 .. code-block:: default net.save_parameters('ssd_512_mobilenet1.0_pikachu.params') .. GENERATED FROM PYTHON SOURCE LINES 152-155 Predict with finetuned model ---------------------------- We can test the performance using finetuned weights .. GENERATED FROM PYTHON SOURCE LINES 155-164 .. code-block:: default test_url = 'https://raw.githubusercontent.com/zackchase/mxnet-the-straight-dope/master/img/pikachu.jpg' download(test_url, 'pikachu_test.jpg') net = gcv.model_zoo.get_model('ssd_512_mobilenet1.0_custom', classes=classes, pretrained_base=False) net.load_parameters('ssd_512_mobilenet1.0_pikachu.params') x, image = gcv.data.transforms.presets.ssd.load_test('pikachu_test.jpg', 512) cid, score, bbox = net(x) ax = viz.plot_bbox(image, bbox[0], score[0], cid[0], class_names=classes) plt.show() .. image-sg:: /build/examples_detection/images/sphx_glr_finetune_detection_002.png :alt: finetune detection :srcset: /build/examples_detection/images/sphx_glr_finetune_detection_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Downloading pikachu_test.jpg from https://raw.githubusercontent.com/zackchase/mxnet-the-straight-dope/master/img/pikachu.jpg... 0%| | 0/88 [00:00` :download:`Download train_yolo.py<../../../scripts/detection/yolo/train_yolo.py>` .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 1 minutes 33.133 seconds) .. _sphx_glr_download_build_examples_detection_finetune_detection.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: finetune_detection.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: finetune_detection.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_