Skip to content

Instantly share code, notes, and snippets.

@zhreshold
Created May 9, 2020 00:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zhreshold/24d1c695b4895ea25674555bb25504d8 to your computer and use it in GitHub Desktop.
Save zhreshold/24d1c695b4895ea25674555bb25504d8 to your computer and use it in GitHub Desktop.
GluonCV object detection Batch Input
import gluoncv as gcv
import mxnet as mx
from matplotlib import pyplot as plt
image_list = ['cat1.jpg', 'dog.jpg']
class ImageListData(mx.gluon.data.Dataset):
def __init__(self, im_list):
self.im_list = im_list
def __len__(self):
return len(self.im_list)
def __getitem__(self, idx):
return mx.image.imread(self.im_list[idx], 1)
class Transform:
def __init__(self, width, height, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.w = width
self.h = height
self.mean = mean
self.std = std
def __call__(self, src):
h, w, _ = src.shape
img = mx.image.imresize(src, self.w, self.h, interp=1)
scale = mx.nd.array([[w / self.w, h / self.h] * 2])
img = mx.nd.image.to_tensor(img)
img = mx.nd.image.normalize(img, mean=self.mean, std=self.std)
return img, scale
batch_size = 2
data = ImageListData(image_list)
loader = mx.gluon.data.DataLoader(data.transform_first(Transform(width=512, height=512)), batch_size=batch_size)
net = gcv.model_zoo.get_model('yolo3_darknet53_voc', pretrained=True)
count = 0
for x, scale in loader:
cids, scores, bboxes = net(x)
bboxes = bboxes * scale # restore the original aspect ratio
for i in range(batch_size):
gcv.utils.viz.plot_bbox(data[count+i], bboxes[i], scores[i], cids[i], class_names=net.classes)
plt.show()
count += batch_size
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment