以下内容为原创,转载请注明出处!

往往我们用DRF来获取一个资源的时候,会用很多字段来过滤自己想要的内容,下面记录下怎么进行字段过滤并且自定义过滤,让过滤做到更加人性化


下面是简单的商品模型:

SOURCE_CHOICES = (
    (1, '平台'),
    (2, '商户'),
    (3, '自媒体'),
)

class Product(models.Model):
    name = models.CharField(max_length=64, verbose_name="产品名称", help_text="商品名称")
    img = models.CharField(max_length=256, verbose_name='产品图片', help_text="商品图片url")
    price = models.IntegerField(verbose_name='产品价格', help_text='产品价格')
    source_name = models.IntegerField(verbose_name='商品来源名称', choices=SOURCE_CHOICES, default=1, help_text='商品来源。1:平台,2:商户,3:自媒体')

    class Meta:
        ordering = ['-id']
        verbose_name_plural = verbose_name = '商品'

    def __str__(self):
        return self.name


简单过滤:

        以下是对name和source_name进行简单过滤,比如查找source_name为1的所有商品

from rest_framework import viewsets

class ProductApi(viewsets.ModelViewSet):
    serializer_class = ProductSerializers
    queryset = Product.objects.all()
    authentication_classes = [LoginAuthClass]
    pagination_class = StandardResultsSetPagination
    filter_backends = (DjangoFilterBackend, filters.SearchFilter)
    filter_fields = ('name', 'source_name')
    search_fields = ('name',)


自定义过滤:

        如果有这样的需求该怎么办:我要查看source_name为1和2的所有商品并且价格区间在100到500。上面那样明显满足不了。这时候需要我们继承django-filter中的方法了

from rest_framework import viewsets
import django_filters

class ProductSourceNameListFilter(django_filters.CharFilter):

    def filter(self, qs, value):
        value = list(filter(None, value.split(",")))
        return super(ProductSourceNameListFilter, self).filter(qs=qs, value=value)


class ProductFilter(django_filters.rest_framework.FilterSet):
    source_name = ProductSourceNameListFilter(field_name='source_name',lookup_expr='in')
    min_price = django_filters.NumberFilter(field_name='price', lookup_expr='gt')
    max_price = django_filters.NumberFilter(field_name='price', lookup_expr='lt')

    class Meta:
        model = Product
        fields = ['name', 'source_name', 'min_price', 'max_price']

class ProductApi(viewsets.ModelViewSet):
    serializer_class = ProductSerializers
    queryset = Product.objects.all()
    authentication_classes = [LoginAuthClass]
    pagination_class = StandardResultsSetPagination
    filter_backends = (DjangoFilterBackend, filters.SearchFilter)
    filter_class = ProductFilter #和上面那种过滤,这里是最大的区别点
    search_fields = ('name',)
  • 访问/api/v1/product/?source_name=1,2&min_price=19&max_price=31就能得到需要的结果了

  • DRF文档中也会自动生成source_name,min_price,max_price三个参数。

  • list(filter(None, value.split(",")))是为了前端传多个source_name过来,我这边split生成列表,用django orm中的in方法来查询;至于这个分隔符可以自己定义,不一定是英文逗号;