版本变更V3.35将图像的处理统一更换到新表当中
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
db:
|
||||
|
||||
@ -1,6 +1,12 @@
|
||||
venv/
|
||||
__pycache__/
|
||||
.git
|
||||
.idea
|
||||
__pycache__
|
||||
*.pyc
|
||||
.git/
|
||||
*.pyo
|
||||
venv
|
||||
.venv
|
||||
env
|
||||
uploads
|
||||
pgdata
|
||||
.env
|
||||
pgdata/
|
||||
simhei.ttf
|
||||
|
||||
@ -19,6 +19,14 @@ from app.models.base import MaterialBase
|
||||
# 注册蓝图
|
||||
image_search_bp = Blueprint('image_search', __name__)
|
||||
|
||||
# ============================================================================
|
||||
# 可配置参数
|
||||
# ============================================================================
|
||||
# 以图搜图相似度阈值:余弦距离必须小于此值(距离越小越相似)
|
||||
# 即余弦相似度 = 1 - 距离,必须 > (1 - SIMILARITY_THRESHOLD)
|
||||
# 默认 0.25 对应余弦相似度 > 0.75
|
||||
SIMILARITY_DISTANCE_THRESHOLD = 0.40
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# POST /api/v1/common/image-search
|
||||
@ -87,27 +95,80 @@ def image_search():
|
||||
ie.module_name,
|
||||
ie.target_id,
|
||||
ie.image_url,
|
||||
(1 - (ie.embedding <=> :query_vector)) AS similarity
|
||||
(1 - (ie.embedding <=> :query_vector)) AS similarity,
|
||||
(ie.embedding <=> :query_vector) AS distance
|
||||
FROM image_embeddings ie
|
||||
WHERE ie.embedding IS NOT NULL
|
||||
AND (ie.embedding <=> :query_vector) < :distance_threshold
|
||||
ORDER BY ie.embedding <=> :query_vector
|
||||
LIMIT 200
|
||||
""")
|
||||
|
||||
raw_records = db.session.execute(sql, {"query_vector": query_vector_str}).fetchall()
|
||||
raw_records = db.session.execute(sql, {
|
||||
"query_vector": query_vector_str,
|
||||
"distance_threshold": SIMILARITY_DISTANCE_THRESHOLD
|
||||
}).fetchall()
|
||||
if not raw_records:
|
||||
return jsonify({"code": 200, "data": []})
|
||||
return jsonify({"code": 200, "data": [], "msg": "未找到相似图片(阈值过滤后)"})
|
||||
|
||||
# 按 (module_name, target_id) 去重,每业务记录只保留最相似的那张图
|
||||
seen = {}
|
||||
# ---------------------------------------------------------
|
||||
# Step 1: 初步去重(同入库单只保留最相似的图片)
|
||||
# ---------------------------------------------------------
|
||||
first_img_seen = {}
|
||||
unique_records = []
|
||||
for row in raw_records:
|
||||
key = (row.module_name, row.target_id)
|
||||
if key not in seen:
|
||||
seen[key] = row
|
||||
if key not in first_img_seen:
|
||||
first_img_seen[key] = True
|
||||
unique_records.append(row)
|
||||
|
||||
# 批量回填业务数据
|
||||
# ---------------------------------------------------------
|
||||
# Step 2: 按物料维度去重(相同物料只保留第一条 = 相似度最高的那条)
|
||||
# ---------------------------------------------------------
|
||||
target_ids_by_module = {}
|
||||
for row in seen.values():
|
||||
for row in unique_records:
|
||||
target_ids_by_module.setdefault(row.module_name, []).append(row.target_id)
|
||||
|
||||
# 查询每条记录的 base_id(跨 stock_buy/semi/product/material_base)
|
||||
base_id_map = {}
|
||||
|
||||
for module in ('stock_buy', 'stock_semi', 'stock_product'):
|
||||
if module not in target_ids_by_module:
|
||||
continue
|
||||
ids = target_ids_by_module[module]
|
||||
ModelCls = StockBuy if module == 'stock_buy' else (StockSemi if module == 'stock_semi' else StockProduct)
|
||||
id_col = getattr(ModelCls, 'id')
|
||||
base_col = getattr(ModelCls, 'base_id')
|
||||
|
||||
rows = (
|
||||
db.session.query(id_col, base_col)
|
||||
.outerjoin(MaterialBase, base_col == MaterialBase.id)
|
||||
.filter(id_col.in_(ids))
|
||||
.all()
|
||||
)
|
||||
for rec_id, base_id in rows:
|
||||
base_id_map[(module, rec_id)] = base_id
|
||||
|
||||
if 'material_base' in target_ids_by_module:
|
||||
for rec_id in target_ids_by_module['material_base']:
|
||||
base_id_map[('material_base', rec_id)] = rec_id
|
||||
|
||||
# 按 base_id 去重:相同物料只保留第一张图
|
||||
material_seen = {}
|
||||
final_records = []
|
||||
for row in unique_records:
|
||||
base_id = base_id_map.get((row.module_name, row.target_id))
|
||||
if base_id is not None and base_id in material_seen:
|
||||
continue
|
||||
if base_id is not None:
|
||||
material_seen[base_id] = True
|
||||
final_records.append(row)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Step 3: 批量回填业务数据(基于去重后的 final_records)
|
||||
# ---------------------------------------------------------
|
||||
target_ids_by_module = {}
|
||||
for row in final_records:
|
||||
target_ids_by_module.setdefault(row.module_name, []).append(row.target_id)
|
||||
|
||||
business_map = {}
|
||||
@ -205,9 +266,9 @@ def image_search():
|
||||
'url': '/material/index',
|
||||
}
|
||||
|
||||
# 组装最终返回
|
||||
# 组装最终返回(基于 final_records,按相似度从高到低)
|
||||
results = []
|
||||
for row in seen.values():
|
||||
for row in final_records:
|
||||
key = (row.module_name, row.target_id)
|
||||
biz = business_map.get(key, {})
|
||||
raw_url = row.image_url or ''
|
||||
|
||||
@ -561,13 +561,20 @@ class MaterialBaseService:
|
||||
db.session.add(new_material)
|
||||
db.session.flush() # 获取 new_material.id
|
||||
|
||||
# 提取产品图向量到独立表(失败不影响业务)
|
||||
# 先提交主事务,图片向量异步后台提取
|
||||
db.session.commit()
|
||||
|
||||
image_list = data.get('generalImage', [])
|
||||
if isinstance(image_list, list) and image_list:
|
||||
ImageEmbeddingService.save_embeddings(
|
||||
ImageEmbeddingService.MODULE_MATERIAL_BASE, new_material.id, image_list
|
||||
from flask import current_app
|
||||
from app.utils.executor import run_embedding_task
|
||||
run_embedding_task(
|
||||
ImageEmbeddingService.save_embeddings_background,
|
||||
current_app._get_current_object(),
|
||||
ImageEmbeddingService.MODULE_MATERIAL_BASE,
|
||||
new_material.id,
|
||||
image_list
|
||||
)
|
||||
db.session.commit()
|
||||
return new_material
|
||||
|
||||
except Exception as e:
|
||||
@ -597,9 +604,16 @@ class MaterialBaseService:
|
||||
if 'generalImage' in data:
|
||||
new_photo_list = data['generalImage']
|
||||
material.product_image = json.dumps(new_photo_list)
|
||||
# 保存向量到独立表(全量替换)
|
||||
ImageEmbeddingService.save_embeddings(
|
||||
ImageEmbeddingService.MODULE_MATERIAL_BASE, material.id, new_photo_list
|
||||
# 立即触发异步向量提取,不阻塞主事务提交
|
||||
if isinstance(new_photo_list, list) and new_photo_list:
|
||||
from flask import current_app
|
||||
from app.utils.executor import run_embedding_task
|
||||
run_embedding_task(
|
||||
ImageEmbeddingService.save_embeddings_background,
|
||||
current_app._get_current_object(),
|
||||
ImageEmbeddingService.MODULE_MATERIAL_BASE,
|
||||
material.id,
|
||||
new_photo_list
|
||||
)
|
||||
else:
|
||||
material.product_image = None
|
||||
|
||||
@ -183,13 +183,20 @@ class BuyInboundService:
|
||||
db.session.add(new_stock)
|
||||
db.session.flush() # 获取 new_stock.id
|
||||
|
||||
# 提取到货图片向量到新表(失败不影响业务)
|
||||
# 先提交主事务(入库单必须落盘),图片向量异步后台提取
|
||||
db.session.commit()
|
||||
|
||||
photo_list = data.get('arrival_photo', [])
|
||||
if isinstance(photo_list, list) and photo_list:
|
||||
ImageEmbeddingService.save_embeddings(
|
||||
ImageEmbeddingService.MODULE_STOCK_BUY, new_stock.id, photo_list
|
||||
from flask import current_app
|
||||
from app.utils.executor import run_embedding_task
|
||||
run_embedding_task(
|
||||
ImageEmbeddingService.save_embeddings_background,
|
||||
current_app._get_current_object(),
|
||||
ImageEmbeddingService.MODULE_STOCK_BUY,
|
||||
new_stock.id,
|
||||
photo_list
|
||||
)
|
||||
db.session.commit()
|
||||
return new_stock
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
@ -254,9 +261,16 @@ class BuyInboundService:
|
||||
if 'arrival_photo' in data:
|
||||
new_photo_list = data['arrival_photo']
|
||||
stock.arrival_photo = json.dumps(new_photo_list)
|
||||
# 保存向量到独立表(全量替换)
|
||||
ImageEmbeddingService.save_embeddings(
|
||||
ImageEmbeddingService.MODULE_STOCK_BUY, stock.id, new_photo_list
|
||||
# 立即触发异步向量提取,不阻塞主事务提交
|
||||
if isinstance(new_photo_list, list) and new_photo_list:
|
||||
from flask import current_app
|
||||
from app.utils.executor import run_embedding_task
|
||||
run_embedding_task(
|
||||
ImageEmbeddingService.save_embeddings_background,
|
||||
current_app._get_current_object(),
|
||||
ImageEmbeddingService.MODULE_STOCK_BUY,
|
||||
stock.id,
|
||||
new_photo_list
|
||||
)
|
||||
else:
|
||||
stock.arrival_photo = None
|
||||
|
||||
@ -189,12 +189,19 @@ class ProductInboundService:
|
||||
db.session.add(new_stock)
|
||||
db.session.flush() # 获取 new_stock.id
|
||||
|
||||
# 提取产品图片向量到独立表(失败不影响业务)
|
||||
if isinstance(photo_list, list) and photo_list:
|
||||
ImageEmbeddingService.save_embeddings(
|
||||
ImageEmbeddingService.MODULE_STOCK_PRODUCT, new_stock.id, photo_list
|
||||
)
|
||||
# 先提交主事务,图片向量异步后台提取
|
||||
db.session.commit()
|
||||
|
||||
if isinstance(photo_list, list) and photo_list:
|
||||
from flask import current_app
|
||||
from app.utils.executor import run_embedding_task
|
||||
run_embedding_task(
|
||||
ImageEmbeddingService.save_embeddings_background,
|
||||
current_app._get_current_object(),
|
||||
ImageEmbeddingService.MODULE_STOCK_PRODUCT,
|
||||
new_stock.id,
|
||||
photo_list
|
||||
)
|
||||
return new_stock
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
@ -225,9 +232,16 @@ class ProductInboundService:
|
||||
if 'product_photo' in data:
|
||||
new_photo_list = data['product_photo']
|
||||
stock.product_photo = json.dumps(new_photo_list)
|
||||
# 保存向量到独立表(全量替换)
|
||||
ImageEmbeddingService.save_embeddings(
|
||||
ImageEmbeddingService.MODULE_STOCK_PRODUCT, stock.id, new_photo_list
|
||||
# 立即触发异步向量提取,不阻塞主事务提交
|
||||
if isinstance(new_photo_list, list) and new_photo_list:
|
||||
from flask import current_app
|
||||
from app.utils.executor import run_embedding_task
|
||||
run_embedding_task(
|
||||
ImageEmbeddingService.save_embeddings_background,
|
||||
current_app._get_current_object(),
|
||||
ImageEmbeddingService.MODULE_STOCK_PRODUCT,
|
||||
stock.id,
|
||||
new_photo_list
|
||||
)
|
||||
else:
|
||||
stock.product_photo = None
|
||||
|
||||
@ -226,12 +226,19 @@ class SemiInboundService:
|
||||
db.session.add(new_stock)
|
||||
db.session.flush() # 获取 new_stock.id
|
||||
|
||||
# 提取到货图片向量到独立表(失败不影响业务)
|
||||
if isinstance(arrival_list, list) and arrival_list:
|
||||
ImageEmbeddingService.save_embeddings(
|
||||
ImageEmbeddingService.MODULE_STOCK_SEMI, new_stock.id, arrival_list
|
||||
)
|
||||
# 先提交主事务,图片向量异步后台提取
|
||||
db.session.commit()
|
||||
|
||||
if isinstance(arrival_list, list) and arrival_list:
|
||||
from flask import current_app
|
||||
from app.utils.executor import run_embedding_task
|
||||
run_embedding_task(
|
||||
ImageEmbeddingService.save_embeddings_background,
|
||||
current_app._get_current_object(),
|
||||
ImageEmbeddingService.MODULE_STOCK_SEMI,
|
||||
new_stock.id,
|
||||
arrival_list
|
||||
)
|
||||
return new_stock
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
@ -280,9 +287,16 @@ class SemiInboundService:
|
||||
if 'arrival_photo' in data:
|
||||
new_photo_list = data['arrival_photo']
|
||||
stock.arrival_photo = json.dumps(new_photo_list)
|
||||
# 保存向量到独立表(全量替换)
|
||||
ImageEmbeddingService.save_embeddings(
|
||||
ImageEmbeddingService.MODULE_STOCK_SEMI, stock.id, new_photo_list
|
||||
# 立即触发异步向量提取,不阻塞主事务提交
|
||||
if isinstance(new_photo_list, list) and new_photo_list:
|
||||
from flask import current_app
|
||||
from app.utils.executor import run_embedding_task
|
||||
run_embedding_task(
|
||||
ImageEmbeddingService.save_embeddings_background,
|
||||
current_app._get_current_object(),
|
||||
ImageEmbeddingService.MODULE_STOCK_SEMI,
|
||||
stock.id,
|
||||
new_photo_list
|
||||
)
|
||||
else:
|
||||
stock.arrival_photo = None
|
||||
|
||||
@ -9,6 +9,7 @@ import time
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import onnxruntime as ort
|
||||
import cv2
|
||||
|
||||
# ============================================================================
|
||||
# 全局模型单例(项目启动时加载一次)
|
||||
@ -48,33 +49,115 @@ IMAGENET_STD = [0.229, 0.224, 0.225]
|
||||
# 模型输入尺寸
|
||||
INPUT_SIZE = 224
|
||||
|
||||
# ============================================================================
|
||||
# 背景去除配置:HSV 色彩空间阈值
|
||||
# ============================================================================
|
||||
# OpenCV HSV: H∈[0,180], S∈[0,255], V∈[0,255]
|
||||
# 注意:OpenCV 中 H 通道范围是 0-180(是 OpenCV 自己的标准,和美术的 0-360 对应)
|
||||
|
||||
def _center_crop_and_resize(image: Image.Image) -> Image.Image:
|
||||
# 绿色背景阈值(工业绿幕常用色)
|
||||
# H: 35~85 对应绿色谱(浅绿到深绿)
|
||||
# S: 低饱和度(35)到高饱和度(255)
|
||||
# V: 明暗均可(30~255)
|
||||
BG_GREEN_LOWER = np.array([35, 35, 30])
|
||||
BG_GREEN_UPPER = np.array([90, 255, 255])
|
||||
|
||||
# 白色/浅色背景阈值(高明度、低饱和度区域)
|
||||
# H: 不限制(0~180),只看 S 和 V
|
||||
# S: 很低的饱和度(0~35)→ 接近纯灰/白色
|
||||
# V: 高明度(180~255)
|
||||
BG_WHITE_LOWER = np.array([0, 0, 180])
|
||||
BG_WHITE_UPPER = np.array([180, 40, 255])
|
||||
|
||||
# 中性灰填充色(BGR → 转换后 RGB 也是 128,128,128)
|
||||
NEUTRAL_GRAY_BGR = (128, 128, 128)
|
||||
|
||||
|
||||
def _remove_background(image: Image.Image) -> Image.Image:
|
||||
"""
|
||||
CLIP 官方预处理:中心裁剪抗干扰
|
||||
- 将图片最短边缩放到 224
|
||||
- 从正中间切取 224x224 区域
|
||||
利用 OpenCV HSV 色彩空间识别并替换背景为中性灰
|
||||
|
||||
支持两种背景类型:
|
||||
1. 工业绿幕/绿色背景(H: 35~90)
|
||||
2. 白色/浅色背景(高亮度、低饱和度)
|
||||
|
||||
逻辑:
|
||||
- 将 PIL Image 转为 OpenCV 格式 (RGB → BGR)
|
||||
- 转 HSV,分别生成绿色掩码和白色掩码
|
||||
- 合并掩码后,按掩码将背景区域替换为中性灰
|
||||
- 还原为 PIL Image (BGR → RGB) 返回
|
||||
|
||||
参数:
|
||||
image: PIL Image (RGB, uint8)
|
||||
|
||||
返回:
|
||||
处理后的 PIL Image (RGB, uint8)
|
||||
"""
|
||||
# PIL (RGB) → OpenCV (BGR)
|
||||
img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# 转入 HSV 色彩空间
|
||||
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# 生成掩码 1:绿色背景
|
||||
mask_green = cv2.inRange(hsv, BG_GREEN_LOWER, BG_GREEN_UPPER)
|
||||
|
||||
# 生成掩码 2:白色/浅色背景
|
||||
mask_white = cv2.inRange(hsv, BG_WHITE_LOWER, BG_WHITE_UPPER)
|
||||
|
||||
# 合并掩码(任意一种背景都替换)
|
||||
mask_combined = cv2.bitwise_or(mask_green, mask_white)
|
||||
|
||||
# 形态学处理:消除噪点(小面积背景噪点填平)
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||||
mask_combined = cv2.morphologyEx(mask_combined, cv2.MORPH_CLOSE, kernel)
|
||||
mask_combined = cv2.morphologyEx(mask_combined, cv2.MORPH_OPEN, kernel)
|
||||
|
||||
# 背景替换:将掩码区域填充为中性灰
|
||||
# 其中 mask_combined=255 的区域为背景,替换为 NEUTRAL_GRAY_BGR
|
||||
img_bgr_no_bg = img_bgr.copy()
|
||||
img_bgr_no_bg[mask_combined > 0] = NEUTRAL_GRAY_BGR
|
||||
|
||||
# OpenCV (BGR) → PIL (RGB)
|
||||
result = Image.fromarray(cv2.cvtColor(img_bgr_no_bg, cv2.COLOR_BGR2RGB))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _letterbox_image(image: Image.Image, size: int = 224) -> Image.Image:
|
||||
"""
|
||||
Letterbox 预处理:等比例缩放 + 灰色填充,保持内容不变形
|
||||
|
||||
- 将原图最长边缩放到 224
|
||||
- 短边按相同比例缩放
|
||||
- 不足部分用 RGB(128,128,128) 灰色填充至 224x224
|
||||
|
||||
参数:
|
||||
image: PIL Image 对象
|
||||
size: 目标尺寸,默认 224
|
||||
|
||||
返回:
|
||||
224x224 PIL Image
|
||||
"""
|
||||
w, h = image.size
|
||||
|
||||
# 计算缩放后的目标尺寸
|
||||
if w < h:
|
||||
new_w = INPUT_SIZE
|
||||
new_h = int(h * INPUT_SIZE / w)
|
||||
else:
|
||||
new_h = INPUT_SIZE
|
||||
new_w = int(w * INPUT_SIZE / h)
|
||||
# 计算缩放比例,使最长边等于 size
|
||||
scale = size / max(w, h)
|
||||
new_w = int(w * scale)
|
||||
new_h = int(h * scale)
|
||||
|
||||
# 缩放
|
||||
image = image.resize((new_w, new_h), Image.BILINEAR)
|
||||
# 等比例缩放
|
||||
resized = image.resize((new_w, new_h), Image.LANCZOS)
|
||||
|
||||
# 中心裁剪
|
||||
left = (new_w - INPUT_SIZE) // 2
|
||||
top = (new_h - INPUT_SIZE) // 2
|
||||
right = left + INPUT_SIZE
|
||||
bottom = top + INPUT_SIZE
|
||||
# 创建灰色画布
|
||||
canvas = Image.new('RGB', (size, size), (128, 128, 128))
|
||||
|
||||
return image.crop((left, top, right, bottom))
|
||||
# 将缩放后的图片粘贴到画布正中央
|
||||
paste_x = (size - new_w) // 2
|
||||
paste_y = (size - new_h) // 2
|
||||
canvas.paste(resized, (paste_x, paste_y))
|
||||
|
||||
return canvas
|
||||
|
||||
|
||||
def _normalize(image_np: np.ndarray) -> np.ndarray:
|
||||
@ -111,8 +194,12 @@ def get_image_embedding(image_path: str) -> list:
|
||||
load_clip_model()
|
||||
|
||||
# 1. 图片预处理
|
||||
# Step 1: 背景去除(HSV 色彩空间,绿色/白色背景 → 中性灰替换)
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = _center_crop_and_resize(image)
|
||||
image = _remove_background(image)
|
||||
|
||||
# Step 2: Letterbox 等比例缩放(保持内容不变形)
|
||||
image = _letterbox_image(image, INPUT_SIZE)
|
||||
input_data = _normalize(np.array(image))
|
||||
input_data = np.expand_dims(input_data, axis=0) # [1, 3, 224, 224]
|
||||
|
||||
|
||||
@ -10,6 +10,8 @@ flask-cors==4.0.0
|
||||
redis==5.0.1
|
||||
# 图片处理核心库
|
||||
Pillow>=10.0.0
|
||||
# OpenCV(背景去除、HSV色彩空间抠图)
|
||||
opencv-python-headless>=4.8.0
|
||||
# ONNX 模型本地 CPU 推理
|
||||
onnxruntime>=1.16.0
|
||||
# 数值计算(ONNX 推理依赖)
|
||||
|
||||
Reference in New Issue
Block a user