Flask一种通用视图,增删改查RESTful API的设计

模型设计是后端开发的第一步。数据模型反映了各种对象之间的相互关系。

from app import db

class Role(db.Model):
    """角色"""  # TODO: 权限控制
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(64), unique=True)
    default = db.Column(db.Boolean, default=False, index=True)
    users = db.relationship('User', backref='role_users', lazy='dynamic')

class User(db.Model):
    """用户"""
    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(50), unique=True)
    password_hash = db.Column(db.String(128))
    role_id = db.Column(db.Integer, db.ForeignKey('role.id'))

对于不同的数据模型之后,往往都要开发增删改查功能的页面或者API。而每个模型的接口写起来又有很多相似之处。因此便可以封装统一的操作API,来快速实现任意模型的增删改查功能。

这里采用RESTful API的典型设计,每个模型设计两个接口:列表接口和详情接口,支持的功能如下:

  • 列表接口:GET获取列表,POST新建对象
  • 详情接口:GET获取详情,PUT修改对象,DELETE删除对象。

比如Role对象可以设计以下两个接口:

@app.route('/roles', methods=['GET', 'POST'])
def roles():
    ...

@app.route('/roles/<int:role_id>', methods=['GET', 'PUT', 'DELETE'])
def role(role_id):
    ...

这里我们都使用复数roles来统一endpoint。
默认情况下,我们需要在每个方法下通过if reqeust.method进行判断然后进行相应增删改查的操作,如:

from flask import request, jsonify, abort
from models import db, Role

@app.route('/roles', methods=['GET', 'POST'])
def roles():
    if request.method == 'GET':
        role_list = Role.query.all()
        # 序列化并返回响应
        return jsonify([{'id': role.id, 'name': role.name, 'default': role.default} for obj in role_list]
    else:
        json_data = request.get_json():
        name = json_data.get('name')
        if name is None:
            abort(404)
        role = Role(name=name)
        db.session.add(role)
        db.commit()
        return jsonify({'id': role.id, 'name': role.name, 'default': role.default})


@app.route('/roles/<int:role_id>', methods=['GET', 'PUT', 'DELETE'])
def role(role_id):
    if request.method == 'GET':
        role = Role.query.get_or_404(role_id)
        return jsonify({'id': role.id, 'name': role.name, 'default': role.default})
    elif request.method == 'PUT':
        role = Role.query.get_or_404(role_id)
        name = json_data.get('name')
        if name is None:
            abort(404)
        role.name = name
        db.commit()
        return jsonify({'id': role.id, 'name': role.name, 'default': role.default})
    else:
        role = Role.query.get_or_404(role_id)
        db.session.delete(role)
        db.commit()
        return jsonify({})

以上代码只简略处理了name一个参数。如果每个模型都要这样写一遍的话会非常繁琐。以下设计了一种通用的API视图。

# filename: utils.py
import datetime
from flask import jsonify, request
from models import db

LIST_METHODS = ['GET', 'POST']  # 资源列表默认方法,GET获取,POST新建
DETAIL_METHODS = ['GET', 'PUT', 'DELETE']  # 资源详情默认方法,GET获取,PUT修改,DELETE删除

def get_obj_fields(obj):
    """获取模型对象的表字段, obj或model均可"""
    if obj is None:
        return []
    return [column.name for column in obj.__table__.columns]

def obj2dict(obj):
    """对象转字典"""
    # 优先取模型fields字段指定的范围
    fields = getattr(obj, 'fields') if hasattr(obj, 'fields') else get_obj_fields(obj)
    obj_dict = {}
    for field in fields:
        value = getattr(obj, field)
        if isinstance(value, (datetime.datetime, datetime.date)):  # 处理datetime对象
            value = value.isoformat()
        obj_dict[field] = value
    return obj_dict

def get_request_valid_data(obj):
    data = request.get_json()
    if data is not None:
        data = {key: value for key, value in request.get_json().items()
                if key in get_obj_fields(obj)}
    return data

class Api(object):
    """通用Api"""
    def __init__(self, model, obj_id=None):
        self.model = model
        self.obj_id = obj_id
        self.data = get_request_valid_data(model)
        if obj_id is not None:
            self.obj = self.model.query.get_or_404(self.obj_id)

    @property
    def resource(self):
        """根据请求方法执行对应的操作,返回对象或列表"""
        func = getattr(self, request.method.lower())
        return func()

    @property
    def jsonify(self):
        """将资源结果转为JSON响应"""
        res = self.resource
        if isinstance(res, list):
            return jsonify([obj2dict(obj) for obj in res])
        else:
            return jsonify(obj2dict(res))

class ListApi(Api):
    """资源列表通用Api"""
    def get(self):
        """GET获取列表"""
        obj_list = self.model.query.all()
        return obj_list

    def post(self):
        """POST创建对象"""
        obj = self.model(**self.data)
        db.session.add(obj)
        db.session.commit()
        return obj

class DetailApi(Api):
    """资源列表通用Api"""
    def get(self):
        """GET获取详情"""
        return self.obj

    def put(self):
        """修改对象"""
        [setattr(self.obj, key, value) for key, value in self.data.items()]
        db.session.commit()
        return self.obj

    def delete(self):
        """删除对象"""
        db.session.delete(self.obj)
        db.session.commit()

使用方法为在views.py中

from models import Role, User
from utils import LIST_METHODS, DETAIL_METHODS, ListApi, DetailApi

@app.route('/roles', methods=LIST_METHODS)
def roles():
    return ListApi(Role).jsonify

@app.route('/roles/<int:role_id>', methods=DETAIL_METHODS)
def role(role_id):
    return DetailApi(Role, role_id).jsonify

这样每个模型的通用接口写起来就非常简单。
另外接口数据的验证方法可以参考如下参考如下,也可以在app层集中捕获并处理异常。

def verify_data(obj, data):
    """验证数据"""
    if data is None:
        return False
    fields = get_obj_fields(obj)
    columns = {column.name: column for column in obj.__table__.columns}
    data = {key: value for key, value in data if key in fields}  # 剔除非fields参数

    for column in columns:
        # 不检查主键和有默认值的列
        if column.primary_key is True or column.default is not None:
            continue

        # nullable 验证
        value = data.get(column.name)
        column_name = column.name
        if value is None:
            if column.nullable is False:
                return False, f'{column_name} 不能为空'
        else:
            column_type = str(column.type)
            # Integer()类型验证
            if column_type== 'INTEGER':
                if not isinstance(value, int):
                    return False, f'{column_name} 需为int类型'

            # todo 浮点型验证datetime类型验证
            # String()类型验证
            if column_type.startswith('VARCHAR'):
                length = int(column_type.split('(')[1][-1])
                if not isinstance(value, str):
                    return False, f'{column_name} 需为str类型'
                elif len(value) > length:
                    return False, f'{column_name} 长度不能大于 {length}'

    return True, 'OK'