324 lines
11 KiB
Python
324 lines
11 KiB
Python
# inventory-backend/app/utils/audit_events.py
|
||
"""
|
||
全局无侵入的审计日志拦截器
|
||
监听所有模型的增删改操作,自动提取旧值和新值存入 audit_logs 表
|
||
完美对接前端 AuditLog.vue 的解析逻辑 (changes, deleted_snapshot, created)
|
||
"""
|
||
import json
|
||
from datetime import datetime, date
|
||
from decimal import Decimal
|
||
from flask import request, has_request_context
|
||
from sqlalchemy import event, text
|
||
|
||
|
||
class AuditJSONEncoder(json.JSONEncoder):
|
||
"""JSON 序列化增强器,支持 datetime/Decimal 等特殊类型"""
|
||
def default(self, obj):
|
||
if isinstance(obj, (datetime, date)):
|
||
return obj.isoformat()
|
||
if isinstance(obj, Decimal):
|
||
return float(obj)
|
||
return str(obj)
|
||
|
||
|
||
def model_to_dict(obj):
|
||
"""将 SQLAlchemy 模型实例转换为字典"""
|
||
return {c.name: getattr(obj, c.name) for c in obj.__table__.columns}
|
||
|
||
|
||
def get_current_user_info():
|
||
"""
|
||
从当前 HTTP 请求上下文中提取用户信息
|
||
兼容 JWT 和匿名访问
|
||
"""
|
||
user_info = {
|
||
'user_id': 'system',
|
||
'username': 'system',
|
||
'display_name': 'System',
|
||
'ip_address': '127.0.0.1',
|
||
'method': 'SYSTEM',
|
||
'url': ''
|
||
}
|
||
|
||
if has_request_context():
|
||
# 获取 IP 地址
|
||
user_info['ip_address'] = request.headers.get('X-Forwarded-For', '') or request.remote_addr or '127.0.0.1'
|
||
if ',' in user_info['ip_address']:
|
||
user_info['ip_address'] = user_info['ip_address'].split(',')[0].strip()
|
||
|
||
user_info['method'] = request.method
|
||
user_info['url'] = request.path
|
||
|
||
# 尝试从 JWT 获取用户信息
|
||
try:
|
||
from flask_jwt_extended import get_jwt_identity, get_jwt
|
||
user_id = get_jwt_identity()
|
||
claims = get_jwt()
|
||
|
||
if user_id:
|
||
user_info['user_id'] = str(user_id)
|
||
if claims:
|
||
user_info['username'] = claims.get('username', 'unknown')
|
||
user_info['display_name'] = claims.get('display_name', claims.get('username', 'Unknown'))
|
||
except Exception:
|
||
pass
|
||
|
||
return user_info
|
||
|
||
|
||
def serialize_value(value):
|
||
"""序列化单个值,确保 JSON 兼容"""
|
||
if value is None:
|
||
return None
|
||
if isinstance(value, (datetime, date)):
|
||
return value.strftime('%Y-%m-%d %H:%M:%S')
|
||
if isinstance(value, Decimal):
|
||
return float(value)
|
||
if isinstance(value, (bytes, bytearray)):
|
||
try:
|
||
return value.decode('utf-8')
|
||
except Exception:
|
||
return '[二进制数据]'
|
||
return value
|
||
|
||
|
||
# 需要忽略的审计字段(时间戳等自动维护字段)
|
||
IGNORE_FIELDS = {
|
||
'updated_at', 'update_time', 'modified_time', 'last_modified',
|
||
'created_at', 'create_time', 'created_on', 'version',
|
||
}
|
||
|
||
# 审计日志表名
|
||
AUDIT_TABLE = 'audit_logs'
|
||
|
||
# 不需要审计的表
|
||
IGNORE_TABLES = {'audit_logs', 'sys_log', 'syslog', 'alembic_version'}
|
||
|
||
|
||
def insert_audit_log(connection, action, target, details):
|
||
"""
|
||
使用 connection.execute 直接插入审计日志
|
||
避免干扰当前 session 事务,自动随主事务一起提交/回滚
|
||
"""
|
||
tablename = target.__tablename__
|
||
|
||
# 严禁监听日志表本身,防止无限递归
|
||
if tablename in IGNORE_TABLES:
|
||
return
|
||
|
||
# 获取目标 ID
|
||
target_id = ''
|
||
if hasattr(target, 'id'):
|
||
target_id = str(target.id)
|
||
elif hasattr(target, 'stock_id'):
|
||
target_id = str(target.stock_id)
|
||
elif hasattr(target, 'uuid'):
|
||
target_id = str(target.uuid)
|
||
elif hasattr(target, 'bom_no'):
|
||
target_id = str(target.bom_no)
|
||
|
||
# 获取目标名称(用于展示)
|
||
target_name = ''
|
||
for name_field in ['name', 'title', 'material_name', 'product_name', 'display_name', 'username']:
|
||
if hasattr(target, name_field):
|
||
val = getattr(target, name_field)
|
||
if val:
|
||
target_name = str(val)
|
||
break
|
||
|
||
# 如果当前表没名字,但它有关联的物料对象 (比如 material.name)
|
||
if not target_name and hasattr(target, 'material') and target.material:
|
||
target_name = getattr(target.material, 'name', '')
|
||
|
||
# 如果当前表有 material_id,尝试从关联的 material 表查询名称
|
||
if not target_name and hasattr(target, 'material_id') and target.material_id:
|
||
try:
|
||
# 使用 connection 查询物料表获取名称
|
||
result = connection.execute(
|
||
text("SELECT name FROM material_base WHERE id = :id"),
|
||
{'id': target.material_id}
|
||
).fetchone()
|
||
if result:
|
||
target_name = str(result[0])
|
||
except Exception:
|
||
pass
|
||
|
||
# 如果实在找不到名字,再用 表名 + ID 兜底
|
||
if not target_name:
|
||
target_name = f"{tablename} ID:{target_id}"
|
||
|
||
user_info = get_current_user_info()
|
||
|
||
# 推断模块名称
|
||
module = _infer_module_name(tablename, target)
|
||
|
||
# 使用原始 SQL 插入,确保事务一致性
|
||
sql = text("""
|
||
INSERT INTO audit_logs
|
||
(user_id, username, display_name, action, module, target_id, target_name, details, ip_address, method, url, created_at)
|
||
VALUES
|
||
(:user_id, :username, :display_name, :action, :module, :target_id, :target_name, :details, :ip_address, :method, :url, :created_at)
|
||
""")
|
||
|
||
connection.execute(sql, {
|
||
'user_id': user_info['user_id'],
|
||
'username': user_info['username'],
|
||
'display_name': user_info['display_name'],
|
||
'action': action,
|
||
'module': module,
|
||
'target_id': target_id,
|
||
'target_name': target_name,
|
||
'details': json.dumps(details, cls=AuditJSONEncoder),
|
||
'ip_address': user_info['ip_address'],
|
||
'method': user_info['method'],
|
||
'url': user_info['url'],
|
||
'created_at': datetime.now()
|
||
})
|
||
|
||
|
||
def _infer_module_name(tablename, target):
|
||
"""根据表名或模型类推断所属模块"""
|
||
class_name = target.__class__.__name__
|
||
|
||
if any(kw in class_name for kw in ['Stock', 'Buy', 'Inbound']):
|
||
return '入库管理'
|
||
if any(kw in class_name for kw in ['Outbound']):
|
||
return '出库管理'
|
||
if any(kw in class_name for kw in ['Borrow', 'Return']):
|
||
return '借还管理'
|
||
if any(kw in class_name for kw in ['Repair']):
|
||
return '维修管理'
|
||
if any(kw in class_name for kw in ['Scrap']):
|
||
return '报废管理'
|
||
if any(kw in class_name for kw in ['Bom', 'BOM']):
|
||
return 'BOM管理'
|
||
if any(kw in class_name for kw in ['StockTake', 'StockAdjust', 'Adjustment']):
|
||
return '盘点管理'
|
||
if any(kw in class_name for kw in ['Material', 'Base']):
|
||
return '基础数据'
|
||
if any(kw in class_name for kw in ['SysUser', 'SysMenu', 'SysRole', 'SysPermission']):
|
||
return '系统管理'
|
||
if any(kw in class_name for kw in ['Warehouse', 'Location']):
|
||
return '库位管理'
|
||
|
||
return tablename or '未知模块'
|
||
|
||
|
||
def _has_changes(history):
|
||
"""检查历史记录对象是否有变更"""
|
||
return history.has_changes()
|
||
|
||
|
||
def register_audit_events(db):
|
||
"""
|
||
全局注册审计事件监听器
|
||
监听所有模型的 INSERT/UPDATE/DELETE 事件
|
||
"""
|
||
from sqlalchemy import inspect
|
||
|
||
@event.listens_for(db.Model, 'before_update', propagate=True)
|
||
def before_update_listener(mapper, connection, target):
|
||
"""UPDATE 事件:抓取字段变更明细"""
|
||
if target.__tablename__ in IGNORE_TABLES:
|
||
return
|
||
|
||
try:
|
||
state = inspect(target)
|
||
changes = {}
|
||
|
||
for attr in state.attrs:
|
||
prop = attr.key
|
||
|
||
# 跳过忽略字段
|
||
if prop in IGNORE_FIELDS:
|
||
continue
|
||
|
||
# 跳过关系属性
|
||
if hasattr(attr, 'property') and hasattr(attr.property, 'direction'):
|
||
continue
|
||
|
||
if _has_changes(attr.history):
|
||
old_value = attr.history.deleted[0] if attr.history.deleted else None
|
||
new_value = attr.history.added[0] if attr.history.added else None
|
||
|
||
# 序列化值
|
||
old_serialized = serialize_value(old_value)
|
||
new_serialized = serialize_value(new_value)
|
||
|
||
# 只记录真正变化的字段
|
||
if old_serialized != new_serialized:
|
||
changes[prop] = {
|
||
'old': old_serialized,
|
||
'new': new_serialized
|
||
}
|
||
|
||
if changes:
|
||
insert_audit_log(connection, 'UPDATE', target, {'changes': changes})
|
||
|
||
except Exception as e:
|
||
import logging
|
||
logging.error(f"Audit Update Error: {e}")
|
||
|
||
@event.listens_for(db.Model, 'before_delete', propagate=True)
|
||
def before_delete_listener(mapper, connection, target):
|
||
"""DELETE 事件:抓取被删除对象的完整快照"""
|
||
if target.__tablename__ in IGNORE_TABLES:
|
||
return
|
||
|
||
try:
|
||
state = inspect(target)
|
||
snapshot = {}
|
||
|
||
for attr in state.attrs:
|
||
prop = attr.key
|
||
|
||
# 跳过忽略字段
|
||
if prop in IGNORE_FIELDS:
|
||
continue
|
||
|
||
# 跳过关系属性
|
||
if hasattr(attr, 'property') and hasattr(attr.property, 'direction'):
|
||
continue
|
||
|
||
value = getattr(target, prop, None)
|
||
snapshot[prop] = serialize_value(value)
|
||
|
||
insert_audit_log(connection, 'DELETE', target, {'deleted_snapshot': snapshot})
|
||
|
||
except Exception as e:
|
||
import logging
|
||
logging.error(f"Audit Delete Error: {e}")
|
||
|
||
@event.listens_for(db.Model, 'after_insert', propagate=True)
|
||
def after_insert_listener(mapper, connection, target):
|
||
"""INSERT 事件:抓取新增对象的完整快照"""
|
||
if target.__tablename__ in IGNORE_TABLES:
|
||
return
|
||
|
||
try:
|
||
state = inspect(target)
|
||
snapshot = {}
|
||
|
||
for attr in state.attrs:
|
||
prop = attr.key
|
||
|
||
# 跳过忽略字段
|
||
if prop in IGNORE_FIELDS:
|
||
continue
|
||
|
||
# 跳过关系属性
|
||
if hasattr(attr, 'property') and hasattr(attr.property, 'direction'):
|
||
continue
|
||
|
||
value = getattr(target, prop, None)
|
||
snapshot[prop] = serialize_value(value)
|
||
|
||
insert_audit_log(connection, 'CREATE', target, {'created': snapshot})
|
||
|
||
except Exception as e:
|
||
import logging
|
||
logging.error(f"Audit Insert Error: {e}")
|
||
|
||
# 返回注册成功信息
|
||
return True
|