300 lines
11 KiB
Python
300 lines
11 KiB
Python
# app/utils/decorators.py
|
||
from functools import wraps
|
||
from flask_jwt_extended import get_jwt, verify_jwt_in_request, get_jwt_identity
|
||
from flask import jsonify, g, request
|
||
import logging
|
||
import json
|
||
|
||
|
||
def _verify_token_in_redis():
|
||
"""
|
||
验证当前 Token 是否与 Redis 中存储的 Token 一致(单设备登录互踢)
|
||
"""
|
||
from app.extensions import redis_client
|
||
from flask import current_app
|
||
|
||
if redis_client is None:
|
||
# Redis 不可用,跳过验证
|
||
return True
|
||
|
||
try:
|
||
# 获取请求中的 Token
|
||
auth_header = request.headers.get('Authorization', '')
|
||
if not auth_header.startswith('Bearer '):
|
||
return True
|
||
|
||
request_token = auth_header[7:] # 去掉 'Bearer ' 前缀
|
||
|
||
# 获取当前用户 ID
|
||
claims = get_jwt()
|
||
user_id = claims.get('sub')
|
||
if user_id is None:
|
||
return True
|
||
|
||
# 从 Redis 获取存储的 Token
|
||
stored_token = redis_client.get(f"user_token_{user_id}")
|
||
|
||
# 如果 Redis 中没有存储的 Token(可能是旧登录或 Redis 重启),允许通过
|
||
if stored_token is None:
|
||
return True
|
||
|
||
# 比较 Token 是否一致
|
||
if request_token != stored_token:
|
||
current_app.logger.warning(f"Token mismatch for user {user_id}: request token != stored token")
|
||
return False
|
||
|
||
return True
|
||
except Exception as e:
|
||
current_app.logger.error(f"Redis token verification error: {e}")
|
||
# 出错时默认放行,避免影响正常业务
|
||
return True
|
||
|
||
|
||
def _raise_token_mismatch_error():
|
||
"""抛出 Token 不一致的错误(用于单设备登录互踢)"""
|
||
return jsonify({
|
||
'msg': '您的账号已在其他设备登录,请重新登录',
|
||
'code': 401,
|
||
'reason': 'token_mismatch'
|
||
}), 401
|
||
|
||
|
||
def role_required(*roles):
|
||
"""
|
||
自定义装饰器:检查用户角色
|
||
使用方法: @role_required('super_admin', 'finance')
|
||
"""
|
||
|
||
def wrapper(fn):
|
||
@wraps(fn)
|
||
def decorator(*args, **kwargs):
|
||
claims = get_jwt()
|
||
user_role = claims.get('role')
|
||
user_role_upper = user_role.upper() if user_role else None
|
||
|
||
# 如果是超级管理员,拥有上帝视角,直接放行 (可选)
|
||
if user_role_upper == 'SUPER_ADMIN':
|
||
return fn(*args, **kwargs)
|
||
|
||
if user_role_upper not in [r.upper() for r in roles]:
|
||
return jsonify(msg='权限不足:您没有访问此资源的权限'), 403
|
||
|
||
return fn(*args, **kwargs)
|
||
|
||
return decorator
|
||
|
||
return wrapper
|
||
|
||
|
||
def login_required(fn):
|
||
"""
|
||
验证 JWT 令牌是否存在且有效
|
||
"""
|
||
@wraps(fn)
|
||
def decorator(*args, **kwargs):
|
||
try:
|
||
verify_jwt_in_request()
|
||
except Exception as e:
|
||
logging.warning(f"JWT verification failed: {e}")
|
||
return jsonify(msg='登录已过期,请重新登录'), 401
|
||
|
||
# 单设备登录互踢检查
|
||
if not _verify_token_in_redis():
|
||
return _raise_token_mismatch_error()
|
||
|
||
return fn(*args, **kwargs)
|
||
return decorator
|
||
|
||
|
||
def permission_required(permission_code):
|
||
"""
|
||
检查当前用户是否拥有指定权限码
|
||
使用方法: @permission_required('material:base:read')
|
||
"""
|
||
def wrapper(fn):
|
||
@wraps(fn)
|
||
def decorator(*args, **kwargs):
|
||
# 首先验证 JWT
|
||
try:
|
||
verify_jwt_in_request()
|
||
except Exception as e:
|
||
logging.warning(f"JWT verification failed: {e}")
|
||
return jsonify(msg='登录已过期,请重新登录'), 401
|
||
|
||
# 单设备登录互踢检查
|
||
if not _verify_token_in_redis():
|
||
return _raise_token_mismatch_error()
|
||
|
||
claims = get_jwt()
|
||
user_role = claims.get('role')
|
||
# 超级管理员放行 (忽略大小写)
|
||
if user_role and user_role.upper() == 'SUPER_ADMIN':
|
||
return fn(*args, **kwargs)
|
||
|
||
# 根据角色查询数据库中的权限
|
||
try:
|
||
from app.services.auth_service import AuthService
|
||
perm_dict = AuthService.get_user_permissions(user_role)
|
||
except Exception as e:
|
||
logging.warning(f"Failed to fetch permissions for role {user_role}: {e}")
|
||
return jsonify(msg='权限查询失败'), 403
|
||
|
||
# 合并菜单和元素权限
|
||
all_perms = perm_dict.get('menus', []) + perm_dict.get('elements', [])
|
||
if permission_code not in all_perms:
|
||
# 详细的调试日志
|
||
print(f"🔴 [权限拦截] 角色 '{user_role}' 访问被拒!需要权限码: '{permission_code}', 但该角色实际拥有: {all_perms}")
|
||
logging.warning(
|
||
f"权限检查失败: 角色={user_role}, 所需权限={permission_code}, 实际权限列表={all_perms}")
|
||
return jsonify(msg='权限不足:您没有访问此资源的权限'), 403
|
||
return fn(*args, **kwargs)
|
||
return decorator
|
||
return wrapper
|
||
|
||
|
||
def audit_log(module: str, action: str = None, get_target_id_fn=None, get_target_name_fn=None, get_details_fn=None):
|
||
"""
|
||
审计日志装饰器
|
||
用法: @audit_log(module='inbound_buy', action='create')
|
||
@audit_log(module='bom', action='update', get_target_id_fn=lambda: ..., get_details_fn=lambda req, resp: ...)
|
||
|
||
升级特性:
|
||
- 自动捕获请求 Payload 作为变更明细
|
||
- 自动过滤过长的 Base64 图片数据
|
||
- 支持自定义 get_details_fn 覆盖默认行为
|
||
"""
|
||
# 需要过滤的图片字段
|
||
IMAGE_FIELDS = {'arrival_photo', 'product_photo', 'photo', 'image', 'signature', 'borrow_signature', 'return_signature'}
|
||
|
||
def _filter_payload(payload):
|
||
"""过滤 Payload 中的大字段,防止数据库膨胀"""
|
||
if not payload or not isinstance(payload, dict):
|
||
return payload
|
||
filtered = {}
|
||
for key, value in payload.items():
|
||
if key.lower() in IMAGE_FIELDS and isinstance(value, str) and len(value) > 100:
|
||
filtered[key] = '[图片数据已省略]'
|
||
elif isinstance(value, dict):
|
||
filtered[key] = _filter_payload(value)
|
||
elif isinstance(value, list):
|
||
filtered[key] = [
|
||
_filter_payload(item) if isinstance(item, dict) else item
|
||
for item in value
|
||
]
|
||
else:
|
||
filtered[key] = value
|
||
return filtered
|
||
|
||
def _get_payload():
|
||
"""自动获取请求 Payload"""
|
||
# 尝试 JSON
|
||
payload = request.get_json(silent=True)
|
||
if payload:
|
||
return payload
|
||
# 尝试 Form Data
|
||
if request.form:
|
||
return request.form.to_dict()
|
||
return None
|
||
|
||
def wrapper(fn):
|
||
@wraps(fn)
|
||
def decorator(*args, **kwargs):
|
||
# 获取请求上下文
|
||
claims = get_jwt()
|
||
user_id = get_jwt_identity()
|
||
username = claims.get('username', '')
|
||
display_name = claims.get('display_name', '')
|
||
|
||
# 获取IP
|
||
ip_address = request.headers.get('X-Forwarded-For') or request.remote_addr or ''
|
||
if ip_address and ',' in ip_address:
|
||
ip_address = ip_address.split(',')[0].strip()
|
||
|
||
# 获取请求信息
|
||
http_method = request.method
|
||
url = request.url
|
||
user_agent = request.headers.get('User-Agent', '')[:500]
|
||
|
||
# 解析 action(支持动态)
|
||
final_action = action
|
||
if callable(action):
|
||
final_action = action()
|
||
|
||
# 预先获取 Payload(用于后续 details 记录)
|
||
raw_payload = _get_payload()
|
||
filtered_payload = _filter_payload(raw_payload) if raw_payload else None
|
||
|
||
# 执行原函数
|
||
response = fn(*args, **kwargs)
|
||
|
||
# 只记录成功的请求(响应状态码 200/201)
|
||
status_code = 200
|
||
if hasattr(response, 'status_code'):
|
||
status_code = response.status_code
|
||
|
||
if status_code in [200, 201]:
|
||
try:
|
||
from app.models.audit import AuditLog
|
||
from app.extensions import db
|
||
from flask import current_app
|
||
|
||
# 获取 target_id
|
||
target_id = None
|
||
if get_target_id_fn:
|
||
try:
|
||
target_id = get_target_id_fn()
|
||
except Exception:
|
||
pass
|
||
if not target_id and hasattr(response, 'json'):
|
||
resp_data = response.get_json()
|
||
if resp_data and isinstance(resp_data, dict):
|
||
target_id = resp_data.get('id')
|
||
|
||
# 获取 target_name
|
||
target_name = None
|
||
if get_target_name_fn:
|
||
try:
|
||
target_name = get_target_name_fn()
|
||
except Exception:
|
||
pass
|
||
|
||
# 获取 details
|
||
details = None
|
||
if get_details_fn:
|
||
# 优先使用自定义差异对比函数
|
||
try:
|
||
details = get_details_fn(request, response)
|
||
except Exception:
|
||
pass
|
||
elif filtered_payload:
|
||
# 默认:记录请求 Payload
|
||
details = {'payload': filtered_payload}
|
||
|
||
# 保存日志
|
||
log_entry = AuditLog(
|
||
user_id=user_id,
|
||
username=username,
|
||
display_name=display_name,
|
||
action=final_action or http_method.lower(),
|
||
module=module,
|
||
target_id=str(target_id) if target_id else None,
|
||
target_name=target_name,
|
||
details=details,
|
||
ip_address=ip_address,
|
||
user_agent=user_agent,
|
||
method=http_method,
|
||
url=url,
|
||
status_code=status_code
|
||
)
|
||
db.session.add(log_entry)
|
||
db.session.commit()
|
||
|
||
except Exception as e:
|
||
current_app.logger.error(f"审计日志记录失败: {str(e)}")
|
||
db.session.rollback()
|
||
|
||
return response
|
||
|
||
return decorator
|
||
return wrapper
|