Files
KCGL/inventory-backend/app/utils/decorators.py

317 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# app/utils/decorators.py
from functools import wraps
from flask_jwt_extended import get_jwt, verify_jwt_in_request, get_jwt_identity
from flask import jsonify, g, request
import logging
import json
def _verify_token_in_redis():
"""
验证当前 Token 是否与 Redis 中存储的 Token 一致(单设备登录互踢)
"""
from app.extensions import redis_client
from flask import current_app
if redis_client is None:
# Redis 不可用,跳过验证
return True
try:
# 获取请求中的 Token
auth_header = request.headers.get('Authorization', '')
if not auth_header.startswith('Bearer '):
return True
request_token = auth_header[7:] # 去掉 'Bearer ' 前缀
# 获取当前用户 ID
claims = get_jwt()
user_id = claims.get('sub')
if user_id is None:
return True
# 从 Redis 获取存储的 Token
stored_token = redis_client.get(f"user_token_{user_id}")
# 如果 Redis 中没有存储的 Token可能是旧登录或 Redis 重启),允许通过
if stored_token is None:
return True
# 比较 Token 是否一致
if request_token != stored_token:
current_app.logger.warning(f"Token mismatch for user {user_id}: request token != stored token")
return False
return True
except Exception as e:
current_app.logger.error(f"Redis token verification error: {e}")
# 出错时默认放行,避免影响正常业务
return True
def _raise_token_mismatch_error():
"""抛出 Token 不一致的错误(用于单设备登录互踢)"""
return jsonify({
'msg': '您的账号已在其他设备登录,请重新登录',
'code': 401,
'reason': 'token_mismatch'
}), 401
def role_required(*roles):
"""
自定义装饰器:检查用户角色
使用方法: @role_required('super_admin', 'finance')
"""
def wrapper(fn):
@wraps(fn)
def decorator(*args, **kwargs):
claims = get_jwt()
user_role = claims.get('role')
user_role_upper = user_role.upper() if user_role else None
# 如果是超级管理员,拥有上帝视角,直接放行 (可选)
if user_role_upper == 'SUPER_ADMIN':
return fn(*args, **kwargs)
if user_role_upper not in [r.upper() for r in roles]:
return jsonify(msg='权限不足:您没有访问此资源的权限'), 403
return fn(*args, **kwargs)
return decorator
return wrapper
def login_required(fn):
"""
验证 JWT 令牌是否存在且有效
"""
@wraps(fn)
def decorator(*args, **kwargs):
try:
verify_jwt_in_request()
except Exception as e:
logging.warning(f"JWT verification failed: {e}")
return jsonify(msg='登录已过期,请重新登录'), 401
# 单设备登录互踢检查
if not _verify_token_in_redis():
return _raise_token_mismatch_error()
return fn(*args, **kwargs)
return decorator
def permission_required(permission_code):
"""
检查当前用户是否拥有指定权限码
使用方法: @permission_required('material:base:read')
"""
def wrapper(fn):
@wraps(fn)
def decorator(*args, **kwargs):
# 首先验证 JWT
try:
verify_jwt_in_request()
except Exception as e:
logging.warning(f"JWT verification failed: {e}")
return jsonify(msg='登录已过期,请重新登录'), 401
# 单设备登录互踢检查
if not _verify_token_in_redis():
return _raise_token_mismatch_error()
claims = get_jwt()
user_role = claims.get('role')
# 超级管理员放行 (忽略大小写)
if user_role and user_role.upper() == 'SUPER_ADMIN':
return fn(*args, **kwargs)
# 根据角色查询数据库中的权限
try:
from app.services.auth_service import AuthService
perm_dict = AuthService.get_user_permissions(user_role)
except Exception as e:
logging.warning(f"Failed to fetch permissions for role {user_role}: {e}")
return jsonify(msg='权限查询失败'), 403
# 合并菜单和元素权限
all_perms = perm_dict.get('menus', []) + perm_dict.get('elements', [])
if permission_code not in all_perms:
# 详细的调试日志
print(f"🔴 [权限拦截] 角色 '{user_role}' 访问被拒!需要权限码: '{permission_code}', 但该角色实际拥有: {all_perms}")
logging.warning(
f"权限检查失败: 角色={user_role}, 所需权限={permission_code}, 实际权限列表={all_perms}")
return jsonify(msg='权限不足:您没有访问此资源的权限'), 403
return fn(*args, **kwargs)
return decorator
return wrapper
def audit_log(module: str, action: str = None, get_target_id_fn=None, get_target_name_fn=None, get_details_fn=None):
"""
审计日志装饰器
用法: @audit_log(module='inbound_buy', action='create')
@audit_log(module='bom', action='update', get_target_id_fn=lambda: ..., get_details_fn=lambda req, resp: ...)
升级特性:
- 自动捕获请求 Payload 作为变更明细
- 自动过滤过长的 Base64 图片数据
- 支持自定义 get_details_fn 覆盖默认行为
"""
# 需要过滤的图片字段
IMAGE_FIELDS = {'arrival_photo', 'product_photo', 'photo', 'image', 'signature', 'borrow_signature', 'return_signature'}
def _filter_payload(payload):
"""过滤 Payload 中的大字段,防止数据库膨胀"""
if not payload or not isinstance(payload, dict):
return payload
filtered = {}
for key, value in payload.items():
if key.lower() in IMAGE_FIELDS and isinstance(value, str) and len(value) > 100:
filtered[key] = '[图片数据已省略]'
elif isinstance(value, dict):
filtered[key] = _filter_payload(value)
elif isinstance(value, list):
filtered[key] = [
_filter_payload(item) if isinstance(item, dict) else item
for item in value
]
else:
filtered[key] = value
return filtered
def _get_payload():
"""自动获取请求 Payload"""
# 尝试 JSON
payload = request.get_json(silent=True)
if payload:
return payload
# 尝试 Form Data
if request.form:
return request.form.to_dict()
return None
def wrapper(fn):
@wraps(fn)
def decorator(*args, **kwargs):
# 获取请求上下文
claims = get_jwt()
user_id = get_jwt_identity()
username = claims.get('username', '')
display_name = claims.get('display_name', '')
# 获取IP
ip_address = request.headers.get('X-Forwarded-For') or request.remote_addr or ''
if ip_address and ',' in ip_address:
ip_address = ip_address.split(',')[0].strip()
# 获取请求信息
http_method = request.method
url = request.url
user_agent = request.headers.get('User-Agent', '')[:500]
# 解析 action支持动态
final_action = action
if callable(action):
final_action = action()
# 预先获取 Payload用于后续 details 记录)
raw_payload = _get_payload()
filtered_payload = _filter_payload(raw_payload) if raw_payload else None
# 执行原函数
response = fn(*args, **kwargs)
# 只记录成功的请求(响应状态码 200/201
status_code = 200
if hasattr(response, 'status_code'):
status_code = response.status_code
if status_code in [200, 201]:
try:
from app.models.audit import AuditLog
from app.extensions import db
from flask import current_app
# 获取 target_id
target_id = None
if get_target_id_fn:
try:
target_id = get_target_id_fn()
except Exception:
pass
if not target_id and hasattr(response, 'json'):
resp_data = response.get_json()
if resp_data and isinstance(resp_data, dict):
target_id = resp_data.get('id')
# 获取 target_name
target_name = None
if get_target_name_fn:
try:
target_name = get_target_name_fn()
except Exception:
pass
# 如果仍未获取到目标名称,尝试从响应 JSON 中常见字段获取
if not target_name and hasattr(response, 'json'):
resp_data = response.get_json()
if resp_data and isinstance(resp_data, dict):
# 优先从顶层获取
for field in ['order_no', 'outbound_no', 'borrow_no', 'adjustment_no', 'material_name']:
if field in resp_data:
target_name = resp_data[field]
break
# 再尝试从 data 字段获取(部分 API 返回格式)
if not target_name and 'data' in resp_data:
data = resp_data['data']
if isinstance(data, dict):
for field in ['order_no', 'outbound_no', 'borrow_no', 'adjustment_no', 'material_name']:
if field in data:
target_name = data[field]
break
# 获取 details
details = None
if get_details_fn:
# 优先使用自定义差异对比函数
try:
details = get_details_fn(request, response)
except Exception:
pass
elif filtered_payload:
# 默认:记录请求 Payload
details = {'payload': filtered_payload}
# 保存日志
log_entry = AuditLog(
user_id=user_id,
username=username,
display_name=display_name,
action=final_action or http_method.lower(),
module=module,
target_id=str(target_id) if target_id else None,
target_name=target_name,
details=details,
ip_address=ip_address,
user_agent=user_agent,
method=http_method,
url=url,
status_code=status_code
)
db.session.add(log_entry)
db.session.commit()
except Exception as e:
current_app.logger.error(f"审计日志记录失败: {str(e)}")
db.session.rollback()
return response
return decorator
return wrapper