nk20-scripts/management/commands/_import_utils.py

115 lines
4.1 KiB
Python

# Copyright (C) 2018-2020 by BDE ENS Paris-Saclay
# SPDX-License-Identifier: GPL-3.0-or-later
import json
import time
from collections import defaultdict
from django.apps import apps
from django.core.management.base import BaseCommand
from django.db import transaction
from polymorphic.models import PolymorphicModel
def timed(method):
""""
A simple decorator to measure time elapsed in class function (hence the args[0])
"""
def _timed(*args, **kw):
ts = time.time()
result = method(*args, **kw)
te = time.time()
args[0].print_success(f"\n {method.__name__} executed ({te-ts:.2f}s)")
return result
return _timed
class ImportCommand(BaseCommand):
"""
Generic command for import of NK15 database
"""
def __init__(self, *args, **kwargs):
super().__init__(args, kwargs)
self.MAP_IDBDE = dict()
def print_success(self, to_print):
return self.stdout.write(self.style.SUCCESS(to_print))
def print_error(self, to_print):
return self.stdout.write(self.style.ERROR(to_print))
def update_line(self, n, total, content):
n = str(n)
total = str(total)
n.rjust(len(total))
print(f"\r ({n}/{total}) {content:16.16}", end="")
def create_parser(self, prog_name, subcommand, **kwargs):
parser = super().create_parser(prog_name, subcommand, **kwargs)
parser.add_argument('--nk15db', action='store', default='nk15', help='NK15 database name')
parser.add_argument('--nk15user', action='store', default='nk15_user', help='NK15 database owner')
parser.add_argument('-s', '--save', default='map.json', action='store', help="save mapping of idbde")
parser.add_argument('-m', '--map', default='map.json', action='store', help="import mapping of idbde")
parser.add_argument('-c', '--chunk', type=int, default=100, help="chunk size for bulk_create")
return parser
def save_map(self, filename):
with open(filename, 'w') as fp:
json.dump(self.MAP_IDBDE, fp, sort_keys=True, indent=2)
def load_map(self, filename):
with open(filename, 'r') as fp:
self.MAP_IDBDE = json.load(fp, object_hook=lambda d: {int(k): int(v) for k, v in d.items()})
class BulkCreateManager(object):
"""
This helper class keeps track of ORM objects to be created for multiple
model classes, and automatically creates those objects with `bulk_create`
when the number of objects accumulated for a given model class exceeds
`chunk_size`.
Upon completion of the loop that's `add()`ing objects, the developer must
call `done()` to ensure the final set of objects is created for all models.
"""
def __init__(self, chunk_size=100):
self._create_queues = defaultdict(list)
self.chunk_size = chunk_size
def _commit(self, model_class):
model_key = model_class._meta.label
# check for mutli-table inheritance it happens
# if model_class is a grand-child of PolymorphicModel
if model_class.__base__.__base__ is PolymorphicModel:
self._commit(model_class.__base__)
with transaction.atomic():
for obj in self._create_queues[model_key]:
obj.save_base(raw=True)
else:
model_class.objects.bulk_create(self._create_queues[model_key])
self._create_queues[model_key] = []
def add(self, *args):
"""
Add an object to the queue to be created, and call bulk_create if we
have enough objs.
"""
for obj in args:
model_class = type(obj)
model_key = model_class._meta.label
self._create_queues[model_key].append(obj)
if len(self._create_queues[model_key]) >= self.chunk_size:
self._commit(model_class)
def done(self):
"""
Always call this upon completion to make sure the final partial chunk
is saved.
"""
for model_name, objs in self._create_queues.items():
if len(objs) > 0:
self._commit(apps.get_model(model_name))