# app/utils/decorators.py from functools import wraps from flask_jwt_extended import get_jwt, verify_jwt_in_request, get_jwt_identity from flask import jsonify, g, request import logging import json def _verify_token_in_redis(): """ 验证当前 Token 是否与 Redis 中存储的 Token 一致(单设备登录互踢) """ from app.extensions import redis_client from flask import current_app if redis_client is None: # Redis 不可用,跳过验证 return True try: # 获取请求中的 Token auth_header = request.headers.get('Authorization', '') if not auth_header.startswith('Bearer '): return True request_token = auth_header[7:] # 去掉 'Bearer ' 前缀 # 获取当前用户 ID claims = get_jwt() user_id = claims.get('sub') if user_id is None: return True # 从 Redis 获取存储的 Token stored_token = redis_client.get(f"user_token_{user_id}") # 如果 Redis 中没有存储的 Token(可能是旧登录或 Redis 重启),允许通过 if stored_token is None: return True # 比较 Token 是否一致 if request_token != stored_token: current_app.logger.warning(f"Token mismatch for user {user_id}: request token != stored token") return False return True except Exception as e: current_app.logger.error(f"Redis token verification error: {e}") # 出错时默认放行,避免影响正常业务 return True def _raise_token_mismatch_error(): """抛出 Token 不一致的错误(用于单设备登录互踢)""" return jsonify({ 'msg': '您的账号已在其他设备登录,请重新登录', 'code': 401, 'reason': 'token_mismatch' }), 401 def role_required(*roles): """ 自定义装饰器:检查用户角色 使用方法: @role_required('super_admin', 'finance') """ def wrapper(fn): @wraps(fn) def decorator(*args, **kwargs): claims = get_jwt() user_role = claims.get('role') user_role_upper = user_role.upper() if user_role else None # 如果是超级管理员,拥有上帝视角,直接放行 (可选) if user_role_upper == 'SUPER_ADMIN': return fn(*args, **kwargs) if user_role_upper not in [r.upper() for r in roles]: return jsonify(msg='权限不足:您没有访问此资源的权限'), 403 return fn(*args, **kwargs) return decorator return wrapper def login_required(fn): """ 验证 JWT 令牌是否存在且有效 """ @wraps(fn) def decorator(*args, **kwargs): try: verify_jwt_in_request() except Exception as e: logging.warning(f"JWT verification failed: {e}") return jsonify(msg='登录已过期,请重新登录'), 401 # 单设备登录互踢检查 if not _verify_token_in_redis(): return _raise_token_mismatch_error() return fn(*args, **kwargs) return decorator def permission_required(permission_code): """ 检查当前用户是否拥有指定权限码 使用方法: @permission_required('material:base:read') """ def wrapper(fn): @wraps(fn) def decorator(*args, **kwargs): # 首先验证 JWT try: verify_jwt_in_request() except Exception as e: logging.warning(f"JWT verification failed: {e}") return jsonify(msg='登录已过期,请重新登录'), 401 # 单设备登录互踢检查 if not _verify_token_in_redis(): return _raise_token_mismatch_error() claims = get_jwt() user_role = claims.get('role') # 超级管理员放行 (忽略大小写) if user_role and user_role.upper() == 'SUPER_ADMIN': return fn(*args, **kwargs) # 根据角色查询数据库中的权限 try: from app.services.auth_service import AuthService perm_dict = AuthService.get_user_permissions(user_role) except Exception as e: logging.warning(f"Failed to fetch permissions for role {user_role}: {e}") return jsonify(msg='权限查询失败'), 403 # 合并菜单和元素权限 all_perms = perm_dict.get('menus', []) + perm_dict.get('elements', []) if permission_code not in all_perms: # 详细的调试日志 print(f"🔴 [权限拦截] 角色 '{user_role}' 访问被拒!需要权限码: '{permission_code}', 但该角色实际拥有: {all_perms}") logging.warning( f"权限检查失败: 角色={user_role}, 所需权限={permission_code}, 实际权限列表={all_perms}") return jsonify(msg='权限不足:您没有访问此资源的权限'), 403 return fn(*args, **kwargs) return decorator return wrapper def audit_log(module: str, action: str = None, get_target_id_fn=None, get_target_name_fn=None, get_details_fn=None): """ 审计日志装饰器 用法: @audit_log(module='inbound_buy', action='create') @audit_log(module='bom', action='update', get_target_id_fn=lambda: ..., get_details_fn=lambda req, resp: ...) 升级特性: - 自动捕获请求 Payload 作为变更明细 - 自动过滤过长的 Base64 图片数据 - 支持自定义 get_details_fn 覆盖默认行为 """ # 需要过滤的图片字段 IMAGE_FIELDS = {'arrival_photo', 'product_photo', 'photo', 'image', 'signature', 'borrow_signature', 'return_signature'} def _filter_payload(payload): """过滤 Payload 中的大字段,防止数据库膨胀""" if not payload or not isinstance(payload, dict): return payload filtered = {} for key, value in payload.items(): if key.lower() in IMAGE_FIELDS and isinstance(value, str) and len(value) > 100: filtered[key] = '[图片数据已省略]' elif isinstance(value, dict): filtered[key] = _filter_payload(value) elif isinstance(value, list): filtered[key] = [ _filter_payload(item) if isinstance(item, dict) else item for item in value ] else: filtered[key] = value return filtered def _get_payload(): """自动获取请求 Payload""" # 尝试 JSON payload = request.get_json(silent=True) if payload: return payload # 尝试 Form Data if request.form: return request.form.to_dict() return None def wrapper(fn): @wraps(fn) def decorator(*args, **kwargs): # 获取请求上下文 claims = get_jwt() user_id = get_jwt_identity() username = claims.get('username', '') display_name = claims.get('display_name', '') # 兜底:如果 display_name 为空,查询数据库获取 if not display_name and user_id: try: from app.models.system import SysUser user = SysUser.query.get(user_id) if user: user_info = user.to_dict() display_name = user_info.get('display_name', username) except Exception: pass # 获取IP ip_address = request.headers.get('X-Forwarded-For') or request.remote_addr or '' if ip_address and ',' in ip_address: ip_address = ip_address.split(',')[0].strip() # 获取请求信息 http_method = request.method url = request.url user_agent = request.headers.get('User-Agent', '')[:500] # 解析 action(支持动态) final_action = action if callable(action): final_action = action() # 预先获取 Payload(用于后续 details 记录) raw_payload = _get_payload() filtered_payload = _filter_payload(raw_payload) if raw_payload else None # 执行原函数 response = fn(*args, **kwargs) # 只记录成功的请求(响应状态码 200/201) status_code = 200 if hasattr(response, 'status_code'): status_code = response.status_code if status_code in [200, 201]: try: from app.models.audit import AuditLog from app.extensions import db from flask import current_app # 获取 target_id target_id = None if get_target_id_fn: try: target_id = get_target_id_fn() except Exception: pass if not target_id and hasattr(response, 'json'): resp_data = response.get_json() if resp_data and isinstance(resp_data, dict): target_id = resp_data.get('id') # 获取 target_name target_name = None if get_target_name_fn: try: target_name = get_target_name_fn() except Exception: pass # 如果仍未获取到目标名称,尝试从响应 JSON 中常见字段获取 if not target_name and hasattr(response, 'json'): resp_data = response.get_json() if resp_data and isinstance(resp_data, dict): # 优先从顶层获取 for field in ['order_no', 'outbound_no', 'borrow_no', 'adjustment_no', 'material_name']: if field in resp_data: target_name = resp_data[field] break # 再尝试从 data 字段获取(部分 API 返回格式) if not target_name and 'data' in resp_data: data = resp_data['data'] if isinstance(data, dict): for field in ['order_no', 'outbound_no', 'borrow_no', 'adjustment_no', 'material_name']: if field in data: target_name = data[field] break # 获取 details details = None if get_details_fn: # 优先使用自定义差异对比函数 try: details = get_details_fn(request, response) except Exception: pass elif filtered_payload: # 默认:记录请求 Payload details = {'payload': filtered_payload} # 保存日志 log_entry = AuditLog( user_id=user_id, username=username, display_name=display_name, action=final_action or http_method.lower(), module=module, target_id=str(target_id) if target_id else None, target_name=target_name, details=details, ip_address=ip_address, user_agent=user_agent, method=http_method, url=url, status_code=status_code ) db.session.add(log_entry) db.session.commit() except Exception as e: current_app.logger.error(f"审计日志记录失败: {str(e)}") db.session.rollback() return response return decorator return wrapper