Skip to content

Instantly share code, notes, and snippets.

@zhreshold
Last active March 12, 2019 02:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zhreshold/850fcf0444121422144388a231f81aec to your computer and use it in GitHub Desktop.
Save zhreshold/850fcf0444121422144388a231f81aec to your computer and use it in GitHub Desktop.
YOLOv3 converter from darknet to GluonCV
import numpy as np
import os, sys
import gluoncv as gcv
import mxnet as mx
shapes_ref = []
with open('ref.txt', 'r') as fid:
for line in fid:
pos = line.find('[')
pos2 = line.find(']')
t = line[pos+1:pos2]
s = [int(x) for x in t.split(',')]
shapes_ref.append(s)
net = gcv.model_zoo.get_model('yolo3_416_darknet53_coco', pretrained_base=False)
net.initialize()
x = mx.nd.zeros((1, 3, 416, 416))
net(x)
NUM = 366
keys = {}
orders = ['.*darknet', '.*yolodetectionblockv30', '.*yolooutputv30', '.*yolov30_conv0|.*yolov30_batchnorm0',
'.*yolodetectionblockv31', '.*yolooutputv31', '.*yolov30_conv1|.*yolov30_batchnorm1', '.*yolodetectionblockv32', '.*yolooutputv32',]
count = 0
for select in orders:
for k, v in net.collect_params(select=select).items():
if 'offset' in k or 'anchor' in k:
print('skip:', k)
continue
if 'yolooutput' in select:
if 'weight' in k:
assert count + 1 not in keys, "{} already exists, {}".format(count + 1, k)
keys[count + 1] = (k, v)
elif 'bias' in k:
assert count - 1 not in keys, "{} already exists, {}".format(count - 1, k)
keys[count - 1] = (k, v)
else:
raise RuntimeError('invalid:{}'.format(k))
else:
if 'conv' in k and 'weight' in k:
assert count + 4 not in keys, "{} already exists, {}".format(count + 4, k)
keys[count + 4] = (k, v)
elif 'beta' in k:
assert count -2 not in keys, "{} already exists, {}".format(count -2, k)
keys[count -2] = (k, v)
elif 'gamma' in k:
assert count not in keys, "{} already exists, {}".format(count, k)
keys[count] = (k, v)
else:
assert count-1 not in keys, "{} already exists, {}".format(count-1, k)
keys[count-1] = (k, v)
count += 1
print(len(list(keys)))
ptr = 0
for i in range(400):
try:
print(i, keys[i][0], keys[i][1].shape, tuple(shapes_ref[ptr]))
assert tuple(shapes_ref[ptr]) == keys[i][1].shape, '{}, {}'.format(keys[i][1].shape, tuple(shapes_ref[ptr]))
ptr += 1
except KeyError:
pass
with open('yolov3.weights', 'rb') as fp:
header = np.fromfile(fp, dtype = np.int32, count = 5)
print(header)
weights = np.fromfile(fp, dtype=np.float32)
print(len(weights))
ptr = 0
for i in range(400):
if i not in keys:
continue
shape = keys[i][1].shape
name = keys[i][0]
offset = np.prod(shape)
raw_data = weights[ptr:ptr+offset]
ptr += offset
before = keys[i][1].data().mean().asscalar()
keys[i][1].set_data(mx.nd.array(np.array(raw_data).reshape(shape)))
after = keys[i][1].data().mean().asscalar()
print(name, shape, before, after)
assert len(weights) == ptr
print(ptr)
net.save_parameters('yolo3_416_darknet53_coco-converted.params')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment