feat(audit): 添加全局无侵入审计日志拦截器
This commit is contained in:
302
inventory-backend/app/utils/audit_events.py
Normal file
302
inventory-backend/app/utils/audit_events.py
Normal file
@ -0,0 +1,302 @@
|
||||
# inventory-backend/app/utils/audit_events.py
|
||||
"""
|
||||
全局无侵入的审计日志拦截器
|
||||
监听所有模型的增删改操作,自动提取旧值和新值存入 audit_logs 表
|
||||
完美对接前端 AuditLog.vue 的解析逻辑 (changes, deleted_snapshot, created)
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
from flask import request, has_request_context
|
||||
from sqlalchemy import event, text
|
||||
|
||||
|
||||
class AuditJSONEncoder(json.JSONEncoder):
|
||||
"""JSON 序列化增强器,支持 datetime/Decimal 等特殊类型"""
|
||||
def default(self, obj):
|
||||
if isinstance(obj, (datetime, date)):
|
||||
return obj.isoformat()
|
||||
if isinstance(obj, Decimal):
|
||||
return float(obj)
|
||||
return str(obj)
|
||||
|
||||
|
||||
def model_to_dict(obj):
|
||||
"""将 SQLAlchemy 模型实例转换为字典"""
|
||||
return {c.name: getattr(obj, c.name) for c in obj.__table__.columns}
|
||||
|
||||
|
||||
def get_current_user_info():
|
||||
"""
|
||||
从当前 HTTP 请求上下文中提取用户信息
|
||||
兼容 JWT 和匿名访问
|
||||
"""
|
||||
user_info = {
|
||||
'user_id': 'system',
|
||||
'username': 'system',
|
||||
'display_name': 'System',
|
||||
'ip_address': '127.0.0.1',
|
||||
'method': 'SYSTEM',
|
||||
'url': ''
|
||||
}
|
||||
|
||||
if has_request_context():
|
||||
# 获取 IP 地址
|
||||
user_info['ip_address'] = request.headers.get('X-Forwarded-For', '') or request.remote_addr or '127.0.0.1'
|
||||
if ',' in user_info['ip_address']:
|
||||
user_info['ip_address'] = user_info['ip_address'].split(',')[0].strip()
|
||||
|
||||
user_info['method'] = request.method
|
||||
user_info['url'] = request.path
|
||||
|
||||
# 尝试从 JWT 获取用户信息
|
||||
try:
|
||||
from flask_jwt_extended import get_jwt_identity, get_jwt
|
||||
user_id = get_jwt_identity()
|
||||
claims = get_jwt()
|
||||
|
||||
if user_id:
|
||||
user_info['user_id'] = str(user_id)
|
||||
if claims:
|
||||
user_info['username'] = claims.get('username', 'unknown')
|
||||
user_info['display_name'] = claims.get('display_name', claims.get('username', 'Unknown'))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return user_info
|
||||
|
||||
|
||||
def serialize_value(value):
|
||||
"""序列化单个值,确保 JSON 兼容"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (datetime, date)):
|
||||
return value.strftime('%Y-%m-%d %H:%M:%S')
|
||||
if isinstance(value, Decimal):
|
||||
return float(value)
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
try:
|
||||
return value.decode('utf-8')
|
||||
except Exception:
|
||||
return '[二进制数据]'
|
||||
return value
|
||||
|
||||
|
||||
# 需要忽略的审计字段(时间戳等自动维护字段)
|
||||
IGNORE_FIELDS = {
|
||||
'updated_at', 'update_time', 'modified_time', 'last_modified',
|
||||
'created_at', 'create_time', 'created_on', 'version',
|
||||
}
|
||||
|
||||
# 审计日志表名
|
||||
AUDIT_TABLE = 'audit_logs'
|
||||
|
||||
# 不需要审计的表
|
||||
IGNORE_TABLES = {'audit_logs', 'sys_log', 'syslog', 'alembic_version'}
|
||||
|
||||
|
||||
def insert_audit_log(connection, action, target, details):
|
||||
"""
|
||||
使用 connection.execute 直接插入审计日志
|
||||
避免干扰当前 session 事务,自动随主事务一起提交/回滚
|
||||
"""
|
||||
tablename = target.__tablename__
|
||||
|
||||
# 严禁监听日志表本身,防止无限递归
|
||||
if tablename in IGNORE_TABLES:
|
||||
return
|
||||
|
||||
# 获取目标 ID
|
||||
target_id = ''
|
||||
if hasattr(target, 'id'):
|
||||
target_id = str(target.id)
|
||||
elif hasattr(target, 'stock_id'):
|
||||
target_id = str(target.stock_id)
|
||||
elif hasattr(target, 'uuid'):
|
||||
target_id = str(target.uuid)
|
||||
elif hasattr(target, 'bom_no'):
|
||||
target_id = str(target.bom_no)
|
||||
|
||||
# 获取目标名称(用于展示)
|
||||
target_name = target_id
|
||||
for name_field in ['name', 'title', 'material_name', 'product_name', 'display_name', 'username']:
|
||||
if hasattr(target, name_field):
|
||||
val = getattr(target, name_field)
|
||||
if val:
|
||||
target_name = str(val)
|
||||
break
|
||||
|
||||
user_info = get_current_user_info()
|
||||
|
||||
# 推断模块名称
|
||||
module = _infer_module_name(tablename, target)
|
||||
|
||||
# 使用原始 SQL 插入,确保事务一致性
|
||||
sql = text("""
|
||||
INSERT INTO audit_logs
|
||||
(user_id, username, display_name, action, module, target_id, target_name, details, ip_address, method, url, created_at)
|
||||
VALUES
|
||||
(:user_id, :username, :display_name, :action, :module, :target_id, :target_name, :details, :ip_address, :method, :url, :created_at)
|
||||
""")
|
||||
|
||||
connection.execute(sql, {
|
||||
'user_id': user_info['user_id'],
|
||||
'username': user_info['username'],
|
||||
'display_name': user_info['display_name'],
|
||||
'action': action,
|
||||
'module': module,
|
||||
'target_id': target_id,
|
||||
'target_name': target_name,
|
||||
'details': json.dumps(details, cls=AuditJSONEncoder),
|
||||
'ip_address': user_info['ip_address'],
|
||||
'method': user_info['method'],
|
||||
'url': user_info['url'],
|
||||
'created_at': datetime.now()
|
||||
})
|
||||
|
||||
|
||||
def _infer_module_name(tablename, target):
|
||||
"""根据表名或模型类推断所属模块"""
|
||||
class_name = target.__class__.__name__
|
||||
|
||||
if any(kw in class_name for kw in ['Stock', 'Buy', 'Inbound']):
|
||||
return '入库管理'
|
||||
if any(kw in class_name for kw in ['Outbound']):
|
||||
return '出库管理'
|
||||
if any(kw in class_name for kw in ['Borrow', 'Return']):
|
||||
return '借还管理'
|
||||
if any(kw in class_name for kw in ['Repair']):
|
||||
return '维修管理'
|
||||
if any(kw in class_name for kw in ['Scrap']):
|
||||
return '报废管理'
|
||||
if any(kw in class_name for kw in ['Bom', 'BOM']):
|
||||
return 'BOM管理'
|
||||
if any(kw in class_name for kw in ['StockTake', 'StockAdjust', 'Adjustment']):
|
||||
return '盘点管理'
|
||||
if any(kw in class_name for kw in ['Material', 'Base']):
|
||||
return '基础数据'
|
||||
if any(kw in class_name for kw in ['SysUser', 'SysMenu', 'SysRole', 'SysPermission']):
|
||||
return '系统管理'
|
||||
if any(kw in class_name for kw in ['Warehouse', 'Location']):
|
||||
return '库位管理'
|
||||
|
||||
return tablename or '未知模块'
|
||||
|
||||
|
||||
def _has_changes(history):
|
||||
"""检查历史记录对象是否有变更"""
|
||||
return history.has_changes()
|
||||
|
||||
|
||||
def register_audit_events(db):
|
||||
"""
|
||||
全局注册审计事件监听器
|
||||
监听所有模型的 INSERT/UPDATE/DELETE 事件
|
||||
"""
|
||||
from sqlalchemy import inspect
|
||||
|
||||
@event.listens_for(db.Model, 'before_update', propagate=True)
|
||||
def before_update_listener(mapper, connection, target):
|
||||
"""UPDATE 事件:抓取字段变更明细"""
|
||||
if target.__tablename__ in IGNORE_TABLES:
|
||||
return
|
||||
|
||||
try:
|
||||
state = inspect(target)
|
||||
changes = {}
|
||||
|
||||
for attr in state.attrs:
|
||||
prop = attr.key
|
||||
|
||||
# 跳过忽略字段
|
||||
if prop in IGNORE_FIELDS:
|
||||
continue
|
||||
|
||||
# 跳过关系属性
|
||||
if hasattr(attr, 'property') and hasattr(attr.property, 'direction'):
|
||||
continue
|
||||
|
||||
if _has_changes(attr.history):
|
||||
old_value = attr.history.deleted[0] if attr.history.deleted else None
|
||||
new_value = attr.history.added[0] if attr.history.added else None
|
||||
|
||||
# 序列化值
|
||||
old_serialized = serialize_value(old_value)
|
||||
new_serialized = serialize_value(new_value)
|
||||
|
||||
# 只记录真正变化的字段
|
||||
if old_serialized != new_serialized:
|
||||
changes[prop] = {
|
||||
'old': old_serialized,
|
||||
'new': new_serialized
|
||||
}
|
||||
|
||||
if changes:
|
||||
insert_audit_log(connection, 'UPDATE', target, {'changes': changes})
|
||||
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.error(f"Audit Update Error: {e}")
|
||||
|
||||
@event.listens_for(db.Model, 'before_delete', propagate=True)
|
||||
def before_delete_listener(mapper, connection, target):
|
||||
"""DELETE 事件:抓取被删除对象的完整快照"""
|
||||
if target.__tablename__ in IGNORE_TABLES:
|
||||
return
|
||||
|
||||
try:
|
||||
state = inspect(target)
|
||||
snapshot = {}
|
||||
|
||||
for attr in state.attrs:
|
||||
prop = attr.key
|
||||
|
||||
# 跳过忽略字段
|
||||
if prop in IGNORE_FIELDS:
|
||||
continue
|
||||
|
||||
# 跳过关系属性
|
||||
if hasattr(attr, 'property') and hasattr(attr.property, 'direction'):
|
||||
continue
|
||||
|
||||
value = getattr(target, prop, None)
|
||||
snapshot[prop] = serialize_value(value)
|
||||
|
||||
insert_audit_log(connection, 'DELETE', target, {'deleted_snapshot': snapshot})
|
||||
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.error(f"Audit Delete Error: {e}")
|
||||
|
||||
@event.listens_for(db.Model, 'after_insert', propagate=True)
|
||||
def after_insert_listener(mapper, connection, target):
|
||||
"""INSERT 事件:抓取新增对象的完整快照"""
|
||||
if target.__tablename__ in IGNORE_TABLES:
|
||||
return
|
||||
|
||||
try:
|
||||
state = inspect(target)
|
||||
snapshot = {}
|
||||
|
||||
for attr in state.attrs:
|
||||
prop = attr.key
|
||||
|
||||
# 跳过忽略字段
|
||||
if prop in IGNORE_FIELDS:
|
||||
continue
|
||||
|
||||
# 跳过关系属性
|
||||
if hasattr(attr, 'property') and hasattr(attr.property, 'direction'):
|
||||
continue
|
||||
|
||||
value = getattr(target, prop, None)
|
||||
snapshot[prop] = serialize_value(value)
|
||||
|
||||
insert_audit_log(connection, 'CREATE', target, {'created': snapshot})
|
||||
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.error(f"Audit Insert Error: {e}")
|
||||
|
||||
# 返回注册成功信息
|
||||
return True
|
||||
Reference in New Issue
Block a user