cool.views.mixins 源代码

# encoding: utf-8

from django.core.exceptions import ValidationError
from django.db import transaction
from django.utils.translation import gettext as _, gettext_lazy
from rest_framework import fields
from rest_framework.exceptions import ValidationError as RestValidationError
from rest_framework.utils import model_meta

from cool.core.utils import get_search_results
from cool.views import CoolAPIException, ErrorCode
from cool.views.fields import JSONCheckField, SplitCharField
from cool.views.utils import (
    get_rest_field_from_model_field, parse_validation_error,
)


[文档]class PageMixin: """ 分页返回数据Mixin """ PAGE_SIZE_MAX = 200 DEFAULT_PAGE_SIZE = 100 @classmethod def get_extend_param_fields(cls): assert 0 < cls.DEFAULT_PAGE_SIZE <= cls.PAGE_SIZE_MAX, ( "DEFAULT_PAGE_SIZE mast between 0 and PAGE_SIZE_MAX in class %s" % cls.__name__ ) return super().get_extend_param_fields() + ( ( 'page', fields.IntegerField( label=gettext_lazy('Page number'), default=1, help_text=gettext_lazy('Start with %(start)s') % {'start': 1} ) ), ( 'page_size', fields.IntegerField( label=gettext_lazy('Page size'), default=cls.DEFAULT_PAGE_SIZE, min_value=1, max_value=cls.PAGE_SIZE_MAX ) ), ) @classmethod def response_info_data(cls): return { 'page_size': _('Page size'), 'list': [super().response_info_data()], 'page': _('Page number'), 'total_page': _('Total page'), 'total_data': _('Total data') } def get_page_context(self, request, queryset, serializer_cls): page_size = request.params.page_size total_data = queryset.count() total_page = (total_data + page_size - 1) // page_size page = request.params.page data = [] if total_data > 0 and 1 <= page <= total_page: start = (page - 1) * page_size data = serializer_cls(queryset[start:start + page_size], request=request, many=True).data return {'page_size': page_size, 'list': data, 'page': page, 'total_page': total_page, 'total_data': total_data}
class CRIDMixin: """ class ObjectAddDMixin(Add, APIBase): model = models.Object add_fields = ['name', 'desc'] """ model = None @classmethod def get_model_field_info(cls): if not hasattr(cls, '_model_field_info'): setattr(cls, '_model_field_info', model_meta.get_field_info(cls.model)) return getattr(cls, '_model_field_info') @classmethod def get_field_detail(cls, fields_list): return [field if isinstance(field, (list, tuple)) else (field, field) for field in fields_list] class SearchListMixin(PageMixin, CRIDMixin): PAGE_SIZE_MAX = 1000 model = None order_field = ('-pk', ) order_fields = () filter_fields = () @classmethod def get_extend_param_fields(cls): """ 添加搜索字段 """ ret = list() ret.extend(super().get_extend_param_fields()) ret.append(('search_term', fields.CharField(label=_('Search key'), default=''))) if cls.order_fields: ret.append(('order', fields.JSONField(label=_('Order field'), default='', ))) if cls.model is not None and cls.filter_fields: for req_name, filter_id in cls.get_field_detail(cls.filter_fields): ret.append((req_name, get_rest_field_from_model_field( cls.model, filter_id, **{'default': None} ))) return tuple(ret) @property def name(self): """ API文档中view的名字 """ return _("{model_name} List").format(model_name=self.model._meta.verbose_name) def get_search_fields(self): """ 返回本model可以被搜索的字段集合(基类回自动将带索引的字段生成搜索字段集合) """ return self.model.get_search_fields() def get_queryset(self, request, queryset=None): if queryset is None: queryset = self.model.objects.order_by(*self.order_field) for req_name, field_name in self.get_field_detail(self.filter_fields): field = getattr(request.params, req_name) if field is not None: queryset = queryset.filter(**{field_name: field}) if request.params.search_term: # 筛选搜索关键词 queryset, use_distinct = get_search_results( queryset, request.params.search_term, self.get_search_fields(), self.model ) if use_distinct: queryset = queryset.distinct() return queryset def get_context(self, request, *args, **kwargs): return self.get_page_context(request, self.get_queryset(request), self.response_info_serializer_class) class InfoMixin(CRIDMixin): model = None pk_id = True ex_unique_ids = [] @property def name(self): return _("{model_name} Info").format(model_name=self.model._meta.verbose_name) @classmethod def get_extend_param_fields(cls): ret = list() ret.extend(super().get_extend_param_fields()) if cls.model is not None: num = len(cls.ex_unique_ids) info = cls.get_model_field_info() if cls.pk_id: num += 1 ret.append((info.pk.name, get_rest_field_from_model_field( cls.model, info.pk.name, **{'default': None} if num > 1 else {'required': True} ))) for req_name, ex_unique_id in cls.get_field_detail(cls.ex_unique_ids): assert ex_unique_id in info.fields_and_pk and info.fields_and_pk[ex_unique_id].unique, ( "Field %s not found in %s's unique fields" % (ex_unique_id, cls.model.__name__) ) ret.append((req_name, get_rest_field_from_model_field( cls.model, ex_unique_id, **{'default': None} if num > 1 else {'required': True} ))) assert num > 0, 'Must set unique fields to ex_unique_ids or set True to pk_id' return tuple(ret) def get_obj(self, request, queryset=None): if queryset is None: queryset = self.model.objects.all() blank = True param_fields = [] if self.pk_id: param_fields.append((self.get_model_field_info().pk.name, self.get_model_field_info().pk.name)) param_fields.extend(self.get_field_detail(self.ex_unique_ids)) for req_name, field_name in param_fields: field = getattr(request.params, req_name) if field is not None: blank = False queryset = queryset.filter(**{field_name: field}) if blank: raise CoolAPIException( ErrorCode.ERROR_BAD_PARAMETER, data=_("{fields} cannot be empty at the same time").format( fields=",".join(map(lambda x: x[0], param_fields)) ) ) return queryset.first() def get_context(self, request, *args, **kwargs): return self.response_info_serializer_class(self.get_obj(request), request=request).data class AddMixin(CRIDMixin): add_fields = [] @property def name(self): return _("Add {model_name}").format(model_name=self.model._meta.verbose_name) @classmethod def get_extend_param_fields(cls): ret = list() ret.extend(super().get_extend_param_fields()) if cls.model is not None: for req_name, field_name in cls.get_field_detail(cls.add_fields): field = get_rest_field_from_model_field(cls.model, field_name) ret.append((req_name, field)) return tuple(ret) def init_fields(self, request, obj): for req_name, field_name in self.get_field_detail(self.add_fields): value = getattr(request.params, req_name, None) if value is not None: setattr(obj, field_name, value) def save_obj(self, request, obj): obj.full_clean() obj.save(force_insert=True) def serializer_response(self, data, request): return self.response_info_serializer_class(data, request=request).data def get_context(self, request, *args, **kwargs): with transaction.atomic(): try: obj = self.model() self.init_fields(request, obj) self.save_obj(request, obj) except ValidationError as e: raise RestValidationError(parse_validation_error(e)) return self.serializer_response(obj, request=request) class EditMixin(CRIDMixin): unique_key = 'pk' edit_fields = [] @property def name(self): return _("Edit {model_name}").format(model_name=self.model._meta.verbose_name) @classmethod def get_extend_param_fields(cls): ret = list() ret.extend(super().get_extend_param_fields()) field = cls.get_model_field_info().fields_and_pk[cls.unique_key] assert field.unique, "Field %s is not unique" % cls.unique_key ret.append((field.name, get_rest_field_from_model_field(cls.model, field, required=True))) if cls.model is not None: for req_name, field_name in cls.get_field_detail(cls.edit_fields): ret.append((req_name, get_rest_field_from_model_field(cls.model, field_name, default=None))) return tuple(ret) def get_obj(self, request): return self.model.get_obj_by_pk_from_cache(request.params.id) def modify_obj(self, request, obj): for req_name, field_name in self.get_field_detail(self.edit_fields): value = getattr(request.params, req_name, None) if value is not None: setattr(obj, field_name, value) def save_obj(self, request, obj): obj.full_clean() obj.save_changed() def serializer_response(self, data, request): return self.response_info_serializer_class(data, request=request).data def get_context(self, request, *args, **kwargs): with transaction.atomic(): obj = self.get_obj(request) self.modify_obj(request, obj) self.save_obj(request, obj) return self.serializer_response(obj, request=request) class DeleteMixin(CRIDMixin): unique_key = 'pk' unique_key_sep = ',' @property def name(self): return _("Delete {model_name}").format(model_name=self.model._meta.verbose_name) @classmethod def get_extend_param_fields(cls): ret = list() ret.extend(super().get_extend_param_fields()) field = cls.get_model_field_info().fields_and_pk[cls.unique_key] assert field.unique, "Field %s is not unique" % cls.unique_key ret.append((field.name + 's', SplitCharField( label=_('Primary keys'), sep=cls.unique_key_sep, child=get_rest_field_from_model_field(cls.model, field, required=True) ))) return tuple(ret) def get_queryset(self, request): return self.model.objects.filter(id__in=request.params.ids) def delete_object(self, request, obj): obj.delete() def delete_queryset(self, request, queryset): for obj in queryset: self.delete_object(request, obj) def get_context(self, request, *args, **kwargs): with transaction.atomic(): queryset = self.get_queryset(request) self.delete_queryset(request, queryset) return None class ExtJSONCheckField(JSONCheckField): def __init__(self, *args, **kwargs): self.ext_model_field_key = kwargs.pop('ext_model_field_key') assert isinstance(self.ext_model_field_key, ExtModelFieldKey) super().__init__(*args, **kwargs) def clean_dict_data(self, data): data = super().clean_dict_data(data) if self.ext_model_field_key.pk_field is None: return data errors = dict() if data.get(self.ext_model_field_key.pk_field, None) is not None: for field_name in self.ext_model_field_key.add_field_list: if field_name in self.ext_model_field_key.add_not_required_field_list: continue if data.get(field_name, None) is None: try: self.children[field_name].fail('required') except RestValidationError as e: errors[field_name] = e.detail if errors: raise RestValidationError(errors) return data def run_children_validation(self, data): return super().run_children_validation(data) class ExtModelFieldKey: def __init__( self, field_name, ext_model, ext_foreign_key, edit_field_list=(), add_field_list=(), add_not_required_field_list=(), add_default_fields=None, pk_field='id', delete_not_found=True, ): """ 生成 ext_model_fields 供 get_ext_model_fields 使用 :param field_name: 接口字段名 :param ext_model: 扩展model类型 :param ext_foreign_key: 扩展model中的外键字段 :param add_field_list: 添加字段列表 :param edit_field_list: 修改字段列表 :param add_not_required_field_list: 新增非必填参数 :param add_default_fields: 添加时默认值 :param pk_field: 修改key(必须为主键或唯一键,提交数据有该值为修改,没有为新增),设空不允许修改 :param delete_not_found: 是否删除未出现在列表中的数据 """ if add_default_fields is None: add_default_fields = dict() self.field_name = field_name self.ext_model = ext_model self.ext_foreign_key = ext_foreign_key self.add_field_list = add_field_list self.edit_field_list = edit_field_list self.add_not_required_field_list = add_not_required_field_list self.add_default_fields = add_default_fields self.pk_field = pk_field self.delete_not_found = delete_not_found def get_json_check_field(self, label): children = dict() add_only = self.pk_field is None if not add_only: children[self.pk_field] = get_rest_field_from_model_field( self.ext_model, self.pk_field, default=None, help_text=_('If the parameter has a value, it is modified; if it has no value, it is added') ) field_list = list() field_list.extend(self.add_field_list) if not add_only: field_list.extend(self.edit_field_list) for field in field_list: if field in children: continue add_not_required = field in self.add_not_required_field_list field_kwargs = dict() if not add_only or field in self.add_not_required_field_list: field_kwargs['default'] = None if not add_only: help_text = [] if field in self.add_field_list: help_text.append(_('not required when add') if add_not_required else _('required when add')) if field in self.edit_field_list: help_text.append(_('not required when edit')) field_kwargs['help_text'] = ",".join(help_text) children[field] = get_rest_field_from_model_field( self.ext_model, field, **field_kwargs ) return ExtJSONCheckField( label=label, children=children, is_list=True, default=None, ext_model_field_key=self ) def gen_objs(self, data, obj, get_ext_obj): params = dict() params[self.ext_foreign_key] = obj add_objs = [] edit_objs = [] edit_ids = [] param_errors = dict() for idx, p in enumerate(data): if self.pk_field is not None and self.pk_field in p and p[self.pk_field] is not None: if p[self.pk_field] in edit_ids: param_errors[idx] = ValidationError(_('Primary key duplicate')) continue edit_ids.append(p[self.pk_field]) obj = get_ext_obj(self.ext_model, self.pk_field, p[self.pk_field]) if obj is None: param_errors[idx] = ValidationError(_('Modification item not found')) continue for ext_key, ext_value in params.items(): if getattr(obj, ext_key) != ext_value: param_errors[idx] = ValidationError(_('Modification item not found')) break else: for key, value in p.items(): if key in self.edit_field_list and value is not None: setattr(obj, key, value) try: obj.full_clean() except ValidationError as e: param_errors[idx] = e edit_objs.append(obj) else: obj = self.ext_model( **self.add_default_fields, **params, **dict(filter(lambda x: x[0] in self.add_field_list, p.items())), ) try: obj.full_clean() except ValidationError as e: param_errors[idx] = e add_objs.append(obj) del_objs = [] if self.delete_not_found: queryset = self.ext_model.objects.filter(**params) if edit_ids: queryset = queryset.exclude(**{"%s__in" % self.pk_field: edit_ids}) del_objs = list(queryset) if param_errors: raise RestValidationError(parse_validation_error(param_errors)) return add_objs, edit_objs, del_objs class ExtManyToOneMixin: @classmethod def get_ext_model_fields(cls): """ [ ExtModelFieldKey() ] """ return () def get_ext_obj(self, ext_model, unique_field, unique_field_value): return ext_model.get_obj_by_unique_key_from_cache(**{unique_field: unique_field_value}) def delete_ext_obj(self, obj): obj.delete() def delete_ext_objs(self, objs): for obj in objs: self.delete_ext_obj(obj) def save_ext_obj(self, obj): obj.save_changed() def edit_ext_objs(self, objs): for obj in objs: self.save_ext_obj(obj) def add_ext_objs(self, ext_model, objs): ext_model.objects.bulk_create(objs) def save_obj(self, request, obj): super().save_obj(request, obj) self.do_ext(request, obj) def do_ext(self, request, obj): errors = dict() ex_objs = list() for model_fields in self.get_ext_model_fields(): param = getattr(request.params, model_fields.field_name) if param is None: continue def _get_ext_obj(*args, **kwargs): return self.get_ext_obj(*args, **kwargs) try: add_objs, edit_objs, del_objs = model_fields.gen_objs(param, obj, _get_ext_obj) ex_objs.append((model_fields.ext_model, add_objs, edit_objs, del_objs)) except RestValidationError as e: errors[model_fields.field_name] = e if errors: raise RestValidationError(parse_validation_error(errors)) for ext_model, add_objs, edit_objs, del_objs in ex_objs: if del_objs: self.delete_ext_objs(del_objs) if edit_objs: self.edit_ext_objs(edit_objs) if add_objs: self.add_ext_objs(ext_model, add_objs) @classmethod def get_extend_param_fields(cls): ret = list() ret.extend(super().get_extend_param_fields()) for model_fields in cls.get_ext_model_fields(): ret.append((model_fields.field_name, model_fields.get_json_check_field( label="{model_name} List".format(model_name=model_fields.ext_model._meta.verbose_name) ))) return tuple(ret)