13前后端联调登录/注册接口

2023年 7月 14日 72.0k 0

写在前面

本次篇幅较大,预计近2w字左右,请耐心看完。且对编程有些要求,本次后端将采用Dao设计模式来编写后端数据校验部分(采用异步),且新增了与User相关的modelschema。因此对新手来说难度可能偏大,耐心!

回顾

接上篇,编写好成型的login页面之后,我们就要开始着手准备给login页面的按钮加上相应的请求。简单把这篇文章分成几个模块:

  • 新增异步数据库操作的常量配置
  • models中创建异步连接
  • 新增basic models
  • 公共查询Mapper类的编写
  • 新增basic schema
  • user schema继承basic schema
  • 编写user_dao
  • 接口调用dao层
  • aiomysql异步操作数据库

    为什么要用aiomysql?

    aiomysql是一个基于异步IO的Python MySQL数据库驱动程序。相比于传统的同步IO库,它具有以下几个好处:

  • 异步IO:aiomysql利用Python的协程和异步IO机制,可以在执行数据库操作时不阻塞主线程或其他任务,提高了数据库访问的并发性能。
  • 高效性能:由于aiomysql使用异步IO,它可以有效地处理大量并发请求,使得数据库访问更加高效和快速。
  • 更好的可扩展性:aiomysql适用于异步编程模型,可以方便地与其他异步框架(如asyncio、aiohttp等)结合使用,提供更好的可扩展性和灵活性。
  • 兼容性:aiomysql完全兼容标准的MySQL协议,可以无缝与现有的MySQL数据库集成,且与传统同步IO库的代码迁移成本较低。
  • 总而言之,使用aiomysql可以帮助开发者在异步编程环境下更高效地与MySQL数据库交互,提升性能和可扩展性。

    在config中添加aiomysql相关配置
    # 异步数据库操作配置
    ASYNC_SQLALCHEMY_URI: str = f'mysql+aiomysql://{MYSQL_USER}:{parse.quote_plus(MYSQL_PASSWD)}@{MYSQL_HOST}:{MYSQL_PORT}/{MYSQL_DB}'
    
    models中创建异步连接
    from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession  
      
    async_engine = create_async_engine(AbandonConfig.ASYNC_SQLALCHEMY_URI, pool_recycle=1500)  
    async_session = sessionmaker(async_engine, class_=AsyncSession)
    
    在models中创建basic

    新增abandon-server/src/app/models/basic.py文件,编写Base基类。

    import json
    from datetime import datetime
    from sqlalchemy import INT, Column, BIGINT, TIMESTAMP
    from typing import Tuple
    from decimal import Decimal
    
    from src.app.models import Base
    
    
    class AbandonBase(Base):
        id = Column(INT, primary_key=True)
        created_at = Column(TIMESTAMP, nullable=False)
        updated_at = Column(TIMESTAMP, nullable=False)
        deleted_at = Column(BIGINT, nullable=False, default=0)
        create_user = Column(INT, nullable=False)
        update_user = Column(INT, nullable=False)
        __abstract__ = True
        __fields__: Tuple[Column] = [id]
        __tag__ = "未定义"
        __alias__ = dict(name="名称")
        __show__ = 1
    
        def __init__(self, user, id=None):
            self.created_at = datetime.now()
            self.updated_at = datetime.now()
            self.create_user = user
            self.update_user = user
    
        def serialize(self, *ignore):
            """
            dump self
            :return:
            """
            data = dict()
            for c in self.__table__.columns:
                if c.name in ignore:
                    # 如果字段忽略, 则不进行转换
                    continue
                val = getattr(self, c.name)
                if isinstance(val, datetime):
                    data[c.name] = val.strftime("%Y-%m-%d %H:%M:%S")
                elif isinstance(val, Decimal):
                    data[c.name] = str(val)
                elif isinstance(val, bytes):
                    data[c.name] = val.decode(encoding='utf-8')
                else:
                    data[c.name] = val
            return json.dumps(data, ensure_ascii=False)
    
    
    公共查询Mapper类的编写

    新增文件abandon-server/src/app/dao/common/mapper.py

    import time
    from collections.abc import Iterable
    from datetime import datetime
    from typing import Tuple, List, TypeVar, Any, Callable
    
    from sqlalchemy import select, update
    from sqlalchemy.ext.asyncio import AsyncSession
    
    from src.app.models.basic import AbandonBase
    
    
    class Mapper(object):
        __model__ = AbandonBase
    
        @classmethod
        async def select_list(cls, *, session: AsyncSession = None, condition: list = None, **kwargs):
            """
            基础model查询条件
            :param session: 查询session
            :param condition: 自定义查询条件
            :param kwargs: 普通查询条件
            :return:
            """
            sql = cls.query_wrapper(condition, **kwargs)
            result = await session.execute(sql)
            return result.scalars().all()
    
        @staticmethod
        def like(s: str):
            if s:
                return f"%{s}%"
            return s
    
        @staticmethod
        def rlike(s: str):
            if s:
                return f"{s}%"
            return s
    
        @staticmethod
        def llike(s: str):
            if s:
                return f"%{s}"
            return s
    
        @staticmethod
        async def pagination(page: int, size: int, session, sql: str, scalars=True, **kwargs):
            """
            分页查询
            :param scalars:
            :param session:
            :param page:
            :param size:
            :param sql:
            :return:
            """
            data = await session.execute(sql)
            total = data.raw.rowcount
            if total == 0:
                return [], 0
            sql = sql.offset((page - 1) * size).limit(size)
            data = await session.execute(sql)
            if scalars and kwargs.get("_join") is None:
                return data.scalars().all(), total
            return data.all(), total
    
        @staticmethod
        def update_model(dist, source, update_user=None, not_null=False):
            """
            :param dist:
            :param source:
            :param not_null:
            :param update_user:
            :return:
            """
            changed = []
            for var, value in vars(source).items():
                if not_null:
                    if value is None:
                        continue
                    if isinstance(value, bool) or isinstance(value, int) or value:
                        # 如果是bool值或者int, false和0也是可以接受的
                        if not hasattr(dist, var):
                            continue
                        if getattr(dist, var) != value:
                            changed.append(var)
                            setattr(dist, var, value)
                else:
                    if getattr(dist, var) != value:
                        changed.append(var)
                        setattr(dist, var, value)
            if update_user:
                setattr(dist, 'update_user', update_user)
            setattr(dist, 'updated_at', datetime.now())
            return changed
    
        @staticmethod
        def delete_model(dist, update_user):
            """
            删除数据,兼容老的deleted_at
            :param dist:
            :param update_user:
            :return:
            """
            if str(dist.__class__.deleted_at.property.columns[0].type) == "DATETIME":
                dist.deleted_at = datetime.now()
            else:
                dist.deleted_at = int(time.time() * 1000)
            dist.updated_at = datetime.now()
            dist.update_user = update_user
    
        @classmethod
        async def list_with_pagination(cls, page, size, /, *, session=None, **kwargs):
            """
            通过分页获取数据
            :param session:
            :param page:
            :param size:
            :param kwargs:
            :return:
            """
            return await cls.pagination(page, size, session, cls.query_wrapper(**kwargs), **kwargs)
    
        @classmethod
        def where(cls, param: Any, sentence, condition: list):
            """
            根据where语句的内容,决定是否生成对应的sql
            :param param:
            :param sentence:
            :param condition:
            :return:
            """
            if param is None:
                return cls
            if isinstance(param, bool):
                condition.append(sentence)
                return cls
            if isinstance(param, int):
                condition.append(sentence)
                return cls
            if param:
                condition.append(sentence)
            return cls
    
        @classmethod
        def query_wrapper(cls, condition=None, **kwargs):
            """
            包装查询条件,支持like, == 和自定义条件(condition)
            :param condition:
            :param kwargs:
            :return:
            """
            conditions = condition if condition else list()
            if getattr(cls.__model__, "deleted_at", None):
                conditions.append(getattr(cls.__model__, "deleted_at") == 0)
            _sort = kwargs.pop("_sort", None)
            _select = kwargs.pop("_select", list())
            _join = kwargs.pop("_join", None)
            # 遍历参数,当参数不为None的时候传递
            for k, v in kwargs.items():
                # 判断是否是like的情况
                like = isinstance(v, str) and (v.startswith("%") or v.endswith("%"))
                if like and v == "%%":
                    continue
                # 如果是like模式,则使用Model.字段.like 否则用 Model.字段 等于
                cls.where(v, getattr(cls.__model__, k).like(v) if like else getattr(cls.__model__, k) == v,
                          conditions)
            sql = select(cls.__model__, *_select)
            if isinstance(_join, Iterable):
                for j in _join:
                    sql = sql.outerjoin(*j)
            where = sql.where(*conditions)
            if _sort and isinstance(_sort, Iterable):
                for d in _sort:
                    where = getattr(where, "order_by")(d)
            return where
    
        @classmethod
        async def query_record(cls, session: AsyncSession = None, **kwargs):
            sql = cls.query_wrapper(**kwargs)
            result = await session.execute(sql)
            return result.scalars().first()
    
    
    新增basic schema

    新增文件abandon-server/src/app/exception/error.py

    class AuthError(Exception):
        """user authorization error
        """
    
    
    class CaseParametersError(Exception):
        """extract parameters error
        """
    
    
    class ParamsError(ValueError):
        """request params error
        """
    
    
    class RedisError(Exception):
        """redis error
        """
    
    
    class RpcError(Exception):
        """rpc error
        """
    
    

    新增文件abandon-server/src/app/schema/basic.py

    from src.app.exception.error import ParamsError
    
    
    class AbandonModel(object):
    
        @staticmethod
        def not_empty(v):
            if isinstance(v, str) and len(v.strip()) == 0:
                raise ParamsError("不能为空")
            if not isinstance(v, int):
                if not v:
                    raise ParamsError("不能为空")
            return v
    
        @property
        def parameters(self):
            raise NotImplementedError
    
    
    user schema继承basic schema

    新增文件abandon-server/src/app/schema/user.py

    from pydantic import BaseModel, validator
    
    from src.app.exception.error import ParamsError
    from src.app.schema.base import AbandonModel
    
    
    class UserUpdateForm(BaseModel):
        id: int
        name: str = None
        email: str = None
        phone: str = None
        role: int = None
        is_valid: bool = None
    
        @validator('id')
        def id_not_empty(cls, v):
            return AbandonModel.not_empty(v)
    
    
    class UserDto(BaseModel):
        name: str
        password: str
        username: str
        email: str
    
        @validator('name', 'password', 'username', 'email')
        def field_not_empty(cls, v):
            if isinstance(v, str) and len(v.strip()) == 0:
                raise ParamsError("不能为空")
            return v
    
    
    class UserForm(BaseModel):
        username: str
        password: str
    
        @validator('password', 'username')
        def name_not_empty(cls, v):
            if isinstance(v, str) and len(v.strip()) == 0:
                raise ParamsError("不能为空")
            return v
    
    
    class ResetPwdForm(BaseModel):
        password: str
        token: str
    
        @validator('token', 'password')
        def name_not_empty(cls, v):
            if isinstance(v, str) and len(v.strip()) == 0:
                raise ParamsError("不能为空")
            return v
    
    
    更改User表的字段

    还记得咱们之前在初始化数据库的时候讲到的吗,当时无法确定User表中具体信息,所以这次对User表进行了一些更新内容,更新内容如下:

    from datetime import datetime
    # dColumn 用于定义表字段,String 和 INT 分别表示字符串和整数类型,DATETIME 表示日期时间类型
    from sqlalchemy import Column, String, INT, DATETIME, Boolean
    
    from src.app.models import Base
    
    
    class User(Base):
        # 定义表名为 "abandon_user",表名和类名不必相同,但通常保持一致比较好
        __tablename__ = "abandon_user"
    
        id = Column(INT, primary_key=True, comment="用户唯一id")
        # 定义字段 id,类型为整数,是主键,注释为 "用户唯一id"
        username = Column(String(16), unique=True, index=True, comment="用户名")
        # 定义字段 username,类型为字符串,长度为 16,唯一且建立索引,注释为 "用户名"
        name = Column(String(16), index=True, comment="姓名")
        # 定义字段 name,类型为字符串,长度为 16,建立索引,注释为 "姓名"
        password = Column(String(32), unique=False, comment="用户密码")
        # 定义字段 password,类型为字符串,长度为 32,不唯一,注释为 "用户密码"
        email = Column(String(64), unique=True, nullable=False, comment="用户邮箱")
        # 定义字段 email,类型为字符串,长度为 64,唯一且不能为空,注释为 "用户邮箱"
        role = Column(INT, default=0, comment="0: 普通用户 1: 组长 2: 超级管理员")
        # 定义字段 role,类型为整数,缺省值为 0,注释为 "0: 普通用户 1: 组长 2: 超级管理员"
        created_at = Column(DATETIME, nullable=False, comment="创建时间")
        # 定义字段 created_at,类型为日期时间,不能为空,注释为 "创建时间"
        updated_at = Column(DATETIME, nullable=False, comment="更改时间")
        # 定义字段 updated_at,类型为日期时间,不能为空,注释为 "更改时间"
        deleted_at = Column(DATETIME, comment="删除时间")
        # 定义字段 deleted_at,类型为日期时间,可为空,注释为 "删除时间"
        last_login_at = Column(DATETIME, comment="上次登录时间")
        # 定义字段 last_login_at,类型为日期时间,可为空,注释为 "上次登录时间"
        avatar = Column(String(128), nullable=True, default=None)
        # 管理员可以禁用某个用户,当他离职后
        is_valid = Column(Boolean, nullable=False, default=True, comment="是否合法")
    
        def __init__(self, username, name, password, email):
            self.username = username
            self.password = password
            self.email = email
            self.name = name
            self.created_at = datetime.now()
            self.updated_at = datetime.now()
            self.role = 0
    
    编写user_dao

    新增文件abandon-server/src/app/dao/auth/user_dao.py

    from datetime import datetime  # 导入datetime类,用于处理日期和时间
    
    from sqlalchemy import or_, select, func  # 导入or_、select和func类/函数,用于构建SQL查询语句
    
    from src.app.dao.common.mapper import Mapper  # 导入Mapper类,用作该类的父类
    from src.app.middleware.my_jwt import AbandonJWT  # 导入AbandonJWT类,用于处理JWT认证
    from src.app.models.user import User  # 导入User类,用于操作用户数据表
    from src.app.utils.log_config import logger  # 导入logger,用于日志记录
    from src.app.models import async_session  # 导入async_session,用于操作异步数据库会话
    
    
    class UserDao(Mapper):  # 定义名为UserDao的类,继承自Mapper类
    
        @staticmethod
        async def register_user(username: str, name: str, password: str, email: str):
            """
            注册用户
            :param username: 用户名
            :param name: 姓名
            :param password: 密码
            :param email: 邮箱
            :return: 用户对象
            """
            try:
                # 采用aiomysql异步操作数据库
                async with async_session() as session:
                    async with session.begin():
                        # 检查用户名和邮箱是否已存在
                        users = await session.execute(
                            select(User).where(or_(User.username == username, User.email == email)))
                        counts = await session.execute(select(func.count(User.id)))
                        if users.scalars().first():
                            raise Exception("用户名或邮箱已存在")
                        # 注册时给密码加盐
                        pwd = AbandonJWT.add_salt(password)
                        user = User(username, name, pwd, email)
                        user.last_login_at = datetime.now()
                        session.add(user)
                        await session.flush()
                        session.expunge(user)
                        return user  # 返回注册成功的用户对象
            except Exception as e:
                logger.error(f"用户注册失败: {str(e)}")
                raise Exception(f"注册失败: {e}")
    
        @staticmethod
        async def login(username, password):
            """
            用户登录
            :param username: 用户名
            :param password: 密码
            :return: 用户对象
            """
            try:
                # 将输入的密码加密并赋值给变量pwd
                pwd = AbandonJWT.add_salt(password)
                # aiomysql异步操作数据库
                async with async_session() as session:
                    async with session.begin():
                        # 查询用户名/密码匹配且没有被删除的用户
                        # where中的语句意思:数据库中的username与输入的username相等,且数据库中的password与pwd相等
                        query = await session.execute(
                            select(User).where(or_(User.username == username, User.password == pwd)))
                        user = query.scalars().first()
                        if user is None:
                            raise Exception("用户名或密码错误")
                        if not user.is_valid:
                            # 说明用户被禁用
                            raise Exception("您的账号已被封禁, 请联系管理员")
                        user.last_login_at = datetime.now()
                        await session.flush()
                        session.expunge(user)
                        return user  # 返回登录成功的用户对象
            except Exception as e:
                logger.error(f"用户{username}登录失败: {str(e)}")
                raise e
    
        @staticmethod
        async def list_users():
            """
            获取用户列表
            :return: 用户列表
            """
            try:
                # aiomysql异步操作数据库
                async with async_session() as session:
                    query = await session.execute(select(User))
                    return query.scalars().all()  # 返回所有用户对象的列表
            except Exception as e:
                logger.error(f"获取用户列表失败: {str(e)}")
                raise Exception("获取用户列表失败")
    
        @staticmethod
        async def query_user(id: int):
            """
            查询用户
            :param id: 用户ID
            :return: 用户对象
            """
            async with async_session() as session:
                query = await session.execute(select(User).where(User.id == id))
                return query.scalars().first()  # 返回查询到的用户对象
    
        @staticmethod
        async def list_user_touch(*user):
            """
            获取用户联系方式列表
            :param user: 用户ID列表
            :return: 用户联系方式列表
            """
            try:
                if not user:
                    return []
                async with async_session() as session:
                    query = await session.execute(select(User).where(User.id.in_(user)))
                    # 返回包含用户邮箱和电话信息的字典列表
                    return [{"email": q.email, "phone": q.phone} for q in query.scalars().all()]
            except Exception as e:
                logger.error(f"获取用户联系方式失败: {str(e)}")
                raise Exception(f"获取用户联系方式失败: {e}")
    

    简单解释一下,我们这边新建了一些方法,接受参数是username和password,接着我们通过orm筛选出第一条username与password匹配且没有被删除的用户。

    注意: 如果这里没有这个用户的话,user变量会是None,所以我采用了判断None的方式

    最后我们把该用户的最后登录时间改成了当前时间。然后提交到了orm的session,这句话等同于执行sql。

    引入核心方法!!!

    编辑abandon-server/src/app/routes/auth/user.py

    from fastapi import APIRouter
    from starlette import status
    
    from src.app.customized.customized_response import AbandonJSONResponse
    from src.app.dao.auth.user_dao import UserDao
    from src.app.exception.request import AuthException
    from src.app.middleware.my_jwt import AbandonJWT
    from src.app.schema.user import UserDto, UserForm
    
    router = APIRouter(prefix="/auth")
    
    
    # router注册的函数都会自带/auth,所以url是/auth/register
    @router.post("/register")
    async def register(user: UserDto):
        try:
            user = await UserDao.register_user(**user.dict())
            user = AbandonJSONResponse.model_to_dict(user, "password")
            expire, token = AbandonJWT.get_token(user)
            return AbandonJSONResponse.success(dict(token=token, expire=expire, usr_info=user))
        except Exception as e:
            return AbandonJSONResponse.failed(e)
    
    
    @router.post("/login")
    async def login(data: UserForm):
        try:
            user = await UserDao.login(data.username, data.password)
            user = AbandonJSONResponse.model_to_dict(user, "password")
            expire, token = AbandonJWT.get_token(user)
            return AbandonJSONResponse.success(dict(token=token, expire=expire, usr_info=user))
        except Exception as e:
            return AbandonJSONResponse.failed(e)
    
    
    @router.get("/listUser")
    async def list_users():
        try:
            user = await UserDao.list_users()
            return AbandonJSONResponse.success(user, exclude=("password",))
        except Exception as e:
            return AbandonJSONResponse.failed(str(e))
    
    
    @router.get("/query")
    async def query_user_info(token: str):
        try:
            if not token:
                raise AuthException(status.HTTP_200_OK, "token不存在")
            user_info = AbandonJWT.parse_token(token)
            user = await UserDao.query_user(user_info['id'])
            if user is None:
                return AbandonJSONResponse.failed("用户不存在")
            return AbandonJSONResponse.success(
                dict(token=token, expire=("password",), usr_info=AbandonJSONResponse.model_to_dict(user, "password")))
        except Exception as e:
            return AbandonJSONResponse.failed(e)
    
    接口信息

    现在我们auth中共有四个接口,暂时足够我们使用,分别为:

  • 登录接口,post请求,json
  • 注册接口,post请求,json
  • 校验token接口,get请求
  • User_list接口,get请求
  • 后续会考虑出一份接口文档在项目中,敬请期待。

    已知BUG

    已知list_users接口返回的data中是null,但是实际上想要list中是user_info信息。因为list反序列化的问题,在本文章暂时不进行处理,下一篇在看。如图:

    image.png

    验证四个接口

    image.png

    image.png

    image.png

    相关文章

    JavaScript2024新功能:Object.groupBy、正则表达式v标志
    PHP trim 函数对多字节字符的使用和限制
    新函数 json_validate() 、randomizer 类扩展…20 个PHP 8.3 新特性全面解析
    使用HTMX为WordPress增效:如何在不使用复杂框架的情况下增强平台功能
    为React 19做准备:WordPress 6.6用户指南
    如何删除WordPress中的所有评论

    发布评论