mirror of
https://gitlab.crans.org/bde/nk20-scripts
synced 2024-12-02 13:16:56 +00:00
115 lines
4.1 KiB
Python
115 lines
4.1 KiB
Python
# Copyright (C) 2018-2020 by BDE ENS Paris-Saclay
|
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
|
|
import json
|
|
import time
|
|
from collections import defaultdict
|
|
|
|
from django.core.management.base import BaseCommand
|
|
from django.apps import apps
|
|
from django.db import transaction
|
|
|
|
from polymorphic.models import PolymorphicModel
|
|
|
|
|
|
def timed(method):
|
|
""""
|
|
A simple decorator to measure time elapsed in class function (hence the args[0])
|
|
"""
|
|
def _timed(*args, **kw):
|
|
ts = time.time()
|
|
result = method(*args, **kw)
|
|
te = time.time()
|
|
args[0].print_success(f"\n {method.__name__} executed ({te-ts:.2f}s)")
|
|
return result
|
|
|
|
return _timed
|
|
|
|
|
|
class ImportCommand(BaseCommand):
|
|
"""
|
|
Generic command for import of NK15 database
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(args, kwargs)
|
|
self.MAP_IDBDE = dict()
|
|
|
|
def print_success(self, to_print):
|
|
return self.stdout.write(self.style.SUCCESS(to_print))
|
|
|
|
def print_error(self, to_print):
|
|
return self.stdout.write(self.style.ERROR(to_print))
|
|
|
|
def update_line(self, n, total, content):
|
|
n = str(n)
|
|
total = str(total)
|
|
n.rjust(len(total))
|
|
print(f"\r ({n}/{total}) {content:16.16}", end="")
|
|
|
|
def create_parser(self, prog_name, subcommand, **kwargs):
|
|
parser = super().create_parser(prog_name, subcommand, **kwargs)
|
|
parser.add_argument('--nk15db', action='store', default='nk15', help='NK15 database name')
|
|
parser.add_argument('--nk15user', action='store', default='nk15_user', help='NK15 database owner')
|
|
parser.add_argument('-s', '--save', default='map.json', action='store', help="save mapping of idbde")
|
|
parser.add_argument('-m', '--map', default='map.json', action='store', help="import mapping of idbde")
|
|
parser.add_argument('-c', '--chunk', type=int, default=100, help="chunk size for bulk_create")
|
|
return parser
|
|
|
|
def save_map(self, filename):
|
|
with open(filename, 'w') as fp:
|
|
json.dump(self.MAP_IDBDE, fp, sort_keys=True, indent=2)
|
|
|
|
def load_map(self, filename):
|
|
with open(filename, 'r') as fp:
|
|
self.MAP_IDBDE = json.load(fp, object_hook=lambda d: {int(k): int(v) for k, v in d.items()})
|
|
|
|
|
|
class BulkCreateManager(object):
|
|
"""
|
|
This helper class keeps track of ORM objects to be created for multiple
|
|
model classes, and automatically creates those objects with `bulk_create`
|
|
when the number of objects accumulated for a given model class exceeds
|
|
`chunk_size`.
|
|
Upon completion of the loop that's `add()`ing objects, the developer must
|
|
call `done()` to ensure the final set of objects is created for all models.
|
|
"""
|
|
|
|
def __init__(self, chunk_size=100):
|
|
self._create_queues = defaultdict(list)
|
|
self.chunk_size = chunk_size
|
|
|
|
def _commit(self, model_class):
|
|
model_key = model_class._meta.label
|
|
# check for mutli-table inheritance it happens
|
|
# if model_class is a grand-child of PolymorphicModel
|
|
if model_class.__base__.__base__ is PolymorphicModel:
|
|
self._commit(model_class.__base__)
|
|
with transaction.atomic():
|
|
for obj in self._create_queues[model_key]:
|
|
obj.save_base(raw=True)
|
|
else:
|
|
model_class.objects.bulk_create(self._create_queues[model_key])
|
|
self._create_queues[model_key] = []
|
|
|
|
def add(self, *args):
|
|
"""
|
|
Add an object to the queue to be created, and call bulk_create if we
|
|
have enough objs.
|
|
"""
|
|
for obj in args:
|
|
model_class = type(obj)
|
|
model_key = model_class._meta.label
|
|
self._create_queues[model_key].append(obj)
|
|
if len(self._create_queues[model_key]) >= self.chunk_size:
|
|
self._commit(model_class)
|
|
|
|
def done(self):
|
|
"""
|
|
Always call this upon completion to make sure the final partial chunk
|
|
is saved.
|
|
"""
|
|
for model_name, objs in self._create_queues.items():
|
|
if len(objs) > 0:
|
|
self._commit(apps.get_model(model_name))
|