修改分割模块

This commit is contained in:
2026-03-05 17:12:01 +08:00
parent d84d886f35
commit 10fd2b00d4
43 changed files with 1858 additions and 284 deletions

692
main.py
View File

@ -14,10 +14,8 @@ import time
matplotlib.use('TkAgg')
# 训练相机波长(168通道)
TRAIN_WAVELENGTHS = [
898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2
]
# 训练相机波长(237通道)
TRAIN_WAVELENGTHS = [912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6]
def read_wavelengths_from_hdr(bil_path):
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
@ -50,6 +48,28 @@ def resample_spectra_matrix(X, src_waves, dst_waves):
return out
def apply_background_no_resample(df, bg_spectrum):
"""
仅做背景校正,不做任何重采样。
- 自动选择以 wavelength_ 或 band_ 开头的光谱列
- 若背景长度与光谱列数不一致,按尾部对齐取最小长度进行校正
"""
# 识别光谱列
spec_cols = [c for c in df.columns if isinstance(c, str) and (c.startswith('wavelength_') or c.startswith('band_'))]
if not spec_cols:
raise ValueError("未找到光谱列(以 wavelength_ 或 band_ 开头)")
bg = np.asarray(bg_spectrum, dtype=np.float64).ravel()
if bg.size == 0:
raise ValueError("背景光谱长度为0无法进行背景校正")
# 尾部对齐,取最小长度,避免维度不一致
n = min(len(spec_cols), bg.shape[0])
use_cols = spec_cols[-n:]
df.loc[:, use_cols] = df.loc[:, use_cols].div(bg[-n:], axis=1)
return df
def parse_arguments():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description='Microplastic spectral shape classification')
@ -59,7 +79,7 @@ def parse_arguments():
parser.add_argument('--output_path', required=True, help='Path to output classification result')
parser.add_argument('--model_path', required=True, help='Path to primary classification model')
# 可选参数
# parser.add_argument('--primary_model_type', default='SVM', help='Type of primary model (default: SVM)')
# parser.add_argument('--primary_process_methods1', default='SS', help='Primary process method 1 (default: SS)')
# parser.add_argument('--primary_process_methods2', default='None', help='Primary process method 2 (default: None)')
@ -176,11 +196,11 @@ def save_envi_classification(bil_path, df, savepath):
for _, row in df.iterrows():
contour = row['contour']
prediction = int(row['Predictions']) + 1
contour = np.array(contour, dtype=np.int32)
# 先将 classification_result 中的 10 和 11 替换为 0
classification_result[(classification_result == 10)] = 0
prediction = int(row['Predictions']) + 1 # 先加1
if prediction in (10, 11): # 再判断是否为10或11
prediction = 0 # 视为背景
contour = np.array(contour, dtype=np.int32)
cv2.fillPoly(classification_result, [contour], prediction)
output_path = savepath
@ -206,7 +226,6 @@ byte order = 0
wavelength units = nm
"""
filename, ext = os.path.splitext(savepath)
# 替换扩展名为 '.hdr'
header_filename = filename + '.hdr'
with open(header_filename, 'w') as header_file:
@ -214,300 +233,463 @@ wavelength units = nm
def change_hdr_file(bil_path, wavelengths=None):
# wavelengths=None 时仅在HDR缺失wavelength字段才写入若提供则按提供内容写入
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
if not os.path.exists(hdr_path):
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
return
with open(hdr_path, 'r') as file:
# 仅在缺少 wavelength 字段时才尝试写入
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as file:
content = file.read()
if 'wavelength' in content and wavelengths is None:
print(f"File {os.path.basename(hdr_path)} already contains wavelength information; no changes needed.")
if 'wavelength' in content:
print(f"{os.path.basename(hdr_path)} 已包含 wavelength 字段,跳过追加。")
return
if wavelengths is None:
print(f"No wavelengths provided and HDR lacks wavelength; skipping write to avoid wrong bands.")
if wavelengths is None or len(wavelengths) == 0:
print("HDR 缺少 wavelength,但未提供 wavelengths跳过写入以避免错误。")
return
needs_newline = not content.endswith('\n')
wavelength_info = "wavelength = {" + ", ".join(str(float(w)) for w in wavelengths) + "}\n"
with open(hdr_path, 'a') as file:
with open(hdr_path, 'a', encoding='utf-8', errors='ignore') as file:
if needs_newline:
file.write('\n')
file.write(wavelength_info)
print(f"Successfully ensured wavelength information in file: {os.path.basename(hdr_path)}")
print(f"已在 {os.path.basename(hdr_path)} 末尾追加 wavelength 字段。")
def validate_inputs(bil_path, output_path, model_path):
"""验证输入文件和参数"""
# 检查BIL和HDR文件存在
if not os.path.exists(bil_path):
raise FileNotFoundError(f"BIL文件不存在: {bil_path}")
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
if not os.path.exists(hdr_path):
raise FileNotFoundError(f"HDR文件不存在: {hdr_path}")
# 检查输出目录可写
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
try:
os.makedirs(output_dir, exist_ok=True)
except Exception as e:
raise RuntimeError(f"无法创建输出目录: {output_dir}") from e
# 检查模型文件存在
if not os.path.exists(model_path):
raise FileNotFoundError(f"主模型文件不存在: {model_path}")
# 检查BIL文件波段数是否足够
try:
from spectral.io import envi
img = envi.open(hdr_path, bil_path)
n_bands = img.nbands
# bil2rgb需要波段索引9, 59, 159
if n_bands <= 159:
raise ValueError(f"BIL文件波段数不足: 需要至少160个波段但只有{n_bands}")
except Exception as e:
raise RuntimeError(f"无法读取BIL文件头信息: {bil_path}") from e
def generate_rgb(bil_path):
"""处理BIL文件生成RGB图像"""
try:
rgb_img = process_bil_files(bil_path)
return rgb_img
except Exception as e:
raise RuntimeError(f"生成RGB图像失败: bil_path={bil_path}") from e
def run_segmentation(rgb_img, segmentation_model_path=None):
"""运行分割获取掩膜"""
try:
mask, filter_mask_original = detect_microplastic_mask_from_array(
image=rgb_img,
filter_method='threshold',
diameter=None,
flow_threshold=0.4,
cellprob_threshold=-1,
model_path=segmentation_model_path,
detect_filter=True
)
return mask, filter_mask_original
except Exception as e:
raise RuntimeError("分割失败: 无法检测微塑料颗粒") from e
def extract_primary_features(bil_path, mask):
"""提取主要特征"""
try:
df = process_images(bil_path, mask)
return df
except Exception as e:
raise RuntimeError(f"特征提取失败: bil_path={bil_path}") from e
def compute_background_spectrum(bil_path, mask):
"""计算背景光谱"""
try:
df_correct = process_images_background(bil_path, mask)
return df_correct
except Exception as e:
raise RuntimeError(f"背景光谱计算失败: bil_path={bil_path}") from e
def apply_background_and_optional_resample(df, bg_spectrum, bil_path):
"""应用背景校正和可选的重采样"""
# 识别光谱列所有以wavelength_开头的列
spec_cols = [c for c in df.columns if c.startswith('wavelength_')]
if not spec_cols:
raise ValueError("未找到光谱列以wavelength_开头的列")
if len(spec_cols) != len(bg_spectrum):
raise ValueError(f"光谱列数量({len(spec_cols)})与背景光谱长度({len(bg_spectrum)})不匹配")
# 背景校正:用背景光谱逐列相除
df[spec_cols] = df[spec_cols].div(bg_spectrum, axis=1)
# 检查是否需要重采样
src_waves = read_wavelengths_from_hdr(bil_path)
need_resample = (src_waves.size > 0 and
(src_waves.size != len(TRAIN_WAVELENGTHS) or
not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2)))
if need_resample:
print(f"重采样光谱: 源波段数 {src_waves.size} -> 目标波段数 {len(TRAIN_WAVELENGTHS)}")
# 提取光谱数据
X_src = df[spec_cols].to_numpy(dtype=np.float64)
X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS)
# 替换光谱列
spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))]
df = pd.concat([
df.drop(columns=spec_cols), # 移除原有光谱列
pd.DataFrame(X_dst, columns=spec_col_names, index=df.index)
], axis=1)
return df
def clean_and_select_columns(df):
"""数据清理和列选择"""
# 移除NaN值
df = df.dropna()
# 过滤轮廓点数不足的样本
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
# 过滤面积过小的样本
df = df[df['area'] >= 500]
# 列筛选:使用原来的硬编码索引删除逻辑
cols_to_remove = df.columns[np.r_[-14: -1]]
df = df.drop(columns=cols_to_remove)
return df
def run_primary_classification(df, primary_model_path):
"""运行主要分类"""
try:
# 验证特征维度
try:
import joblib
scaler_path = os.path.join(os.path.dirname(primary_model_path), 'scaler_params.pkl')
if os.path.exists(scaler_path):
scaler = joblib.load(scaler_path)
numeric_cols = [c for c in df.columns[1:] if np.issubdtype(df[c].dtype, np.number) and c != 'contour']
if hasattr(scaler, 'mean_') and len(numeric_cols) != scaler.mean_.shape[0]:
raise ValueError(f"特征维度不匹配: 当前{numeric_cols}列 != 训练时{scaler.mean_.shape[0]}")
except Exception as e:
print(f"警告: 无法验证特征维度: {e}")
df_pre = predict_with_model(
df,
primary_model_path,
model_type='SVM',
ProcessMethods1='SS',
ProcessMethods2='None'
)
return df_pre
except Exception as e:
raise RuntimeError(f"主要分类失败: model_path={primary_model_path}") from e
def run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original):
"""根据需要运行二次分类"""
# 二次分类配置(保持在代码内)
secondary_model_path = os.path.join(os.path.dirname(__file__), 'modelsave', 'HDPELDPE_model', 'svm.m')
secondary_target_classes = [1, 2] # HDPE, LDPE
target_classes = set(secondary_target_classes or [])
mask_secondary = df_pre['Predictions'].isin(target_classes)
if not mask_secondary.any():
print("未找到目标类别样本,跳过二次分类")
return df_pre
print(f"为类别 {sorted(target_classes)} 运行二次分类")
# 检查二次模型是否存在
if not os.path.exists(secondary_model_path):
print(f"警告: 二次模型不存在: {secondary_model_path},跳过二次分类")
return df_pre
try:
# 图像信息的背景矫正
df_correct = shape_correct_background(bil_path, mask, filter_mask_original)
# 创建只包含目标类别的掩膜
mask_second = np.zeros_like(mask, dtype=np.uint16)
for idx in df_pre[mask_secondary].index:
contour = df_pre.loc[idx, 'contour']
if isinstance(contour, list) and len(contour) > 0:
contour_array = np.array(contour, dtype=np.int32)
cv2.fillPoly(mask_second, [contour_array], idx + 1)
# 提取特征
df_shape = extract_features(df_correct, mask_second)
# 确保使用前13列作为模型输入特征
if len(df_shape.columns) >= 13:
df_shape = df_shape.iloc[:, :13]
# 二次分类
df_secondary = predict_with_model(
df_shape,
secondary_model_path,
model_type='SVM',
ProcessMethods1='None',
ProcessMethods2='None'
)
# 更新预测结果(类别+1
df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1
except Exception as e:
print(f"警告: 二次分类失败,将继续使用主要分类结果: {e}")
return df_pre
def postprocess_class7_shadow(df_pre, rgb_img):
"""后处理类别7/8中的背景阴影更稳健"""
# 7和8都纳入检查范围
mask_targets = df_pre['Predictions'].isin([7, 8])
if not mask_targets.any():
return df_pre
print(f"处理 {mask_targets.sum()} 个类别7/8样本识别背景阴影...")
# 灰度图
if hasattr(rgb_img, 'mode'): # PIL Image
rgb_img_array = np.array(rgb_img)
else:
rgb_img_array = rgb_img
if len(rgb_img_array.shape) == 3:
gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY)
else:
gray_img = rgb_img_array
# 更稳的梯度Scharr
grad_x = cv2.Scharr(gray_img, cv2.CV_64F, 1, 0)
grad_y = cv2.Scharr(gray_img, cv2.CV_64F, 0, 1)
gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2)
# 统计指标
edge_ratios = []
contrast_norms = []
areas_list = []
measures_per_idx = {}
edge_thick = 3
ring_thick = 5
eps = 1e-6
for idx in df_pre[mask_targets].index:
try:
contour = df_pre.loc[idx, 'contour']
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
continue
contour_array = np.array(contour, dtype=np.int32)
if len(contour_array.shape) == 1:
continue
poly_mask = np.zeros(gray_img.shape, dtype=np.uint8)
cv2.fillPoly(poly_mask, [contour_array], 255)
# 边界带
edge_mask = np.zeros_like(poly_mask)
cv2.drawContours(edge_mask, [contour_array], -1, 255, thickness=edge_thick)
# 外环:膨胀边界去掉边界本身与内区
ring_mask = cv2.dilate(edge_mask, np.ones((ring_thick, ring_thick), np.uint8), iterations=1)
ring_mask = cv2.bitwise_and(ring_mask, cv2.bitwise_not(edge_mask))
ring_mask = cv2.bitwise_and(ring_mask, cv2.bitwise_not(poly_mask))
edge_vals = gradient_magnitude[edge_mask > 0]
ring_vals = gradient_magnitude[ring_mask > 0]
if edge_vals.size == 0 or ring_vals.size == 0:
continue
r_edge = float(np.median(edge_vals) / (np.median(ring_vals) + eps))
inside_vals = gray_img[poly_mask > 0]
outside_vals = gray_img[ring_mask > 0]
if inside_vals.size == 0 or outside_vals.size == 0:
continue
dI = float(np.median(inside_vals) - np.median(outside_vals))
c_norm = abs(dI) / (np.std(outside_vals) + eps)
# 面积(可选保护)
area_val = None
if 'area' in df_pre.columns:
try:
area_val = float(df_pre.loc[idx, 'area'])
except Exception:
area_val = None
edge_ratios.append(r_edge)
contrast_norms.append(c_norm)
areas_list.append(area_val if area_val is not None else 0.0)
measures_per_idx[idx] = (r_edge, c_norm, area_val)
except Exception:
continue
if not measures_per_idx:
print("无可用的7/8类样本进行阴影判别")
return df_pre
def robust_q(arr, q):
vals = [v for v in arr if v is not None]
return float(np.percentile(vals, q)) if len(vals) > 0 else None
# 稳健阈值低于30分位更像阴影
r_thresh = robust_q(edge_ratios, 30.0)
c_thresh = robust_q(contrast_norms, 30.0)
# 面积保护仅对较小目标允许改写阈值取面积分布的40%分位,限定上限避免过大
a_thresh = robust_q(areas_list, 40.0)
if a_thresh is None or a_thresh <= 0:
a_thresh = 1200.0
a_thresh = min(a_thresh, 2000.0)
indices_to_update = []
for idx, (r_edge, c_norm, area_val) in measures_per_idx.items():
small_enough = (area_val is None) or (area_val <= a_thresh)
if (r_thresh is not None and c_thresh is not None and small_enough):
# 两个指标都低,且面积不大 -> 判定为阴影
if (r_edge < r_thresh) and (c_norm < c_thresh):
indices_to_update.append(idx)
if indices_to_update:
# 改为背景0而不是9PVC
df_pre.loc[indices_to_update, 'Predictions'] = 9
print(f"{len(indices_to_update)} 个样本从类别7/8改为背景阴影面积阈值≈{a_thresh:.0f}")
else:
print("无需更新类别7/8样本")
return df_pre
def write_outputs(bil_path, df_pre, output_path):
"""写入输出结果"""
try:
# 收缩轮廓
df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1)
# 保存ENVI分类结果
save_envi_classification(bil_path, df_pre, output_path)
except Exception as e:
raise RuntimeError(f"保存结果失败: output_path={output_path}") from e
def main():
"""主函数"""
args = parse_arguments()
bil_path = args.bil_path
output_path = args.output_path
primary_model_path = args.model_path
primary_model_type = 'SVM'
primary_process_methods1 = 'SS'
primary_process_methods2 = "None"
segmentation_model_path = None
# secondary_model_path = args.secondary_model
# secondary_model_type = args.secondary_model_type
# secondary_process_methods1 = args.secondary_process_methods1
# secondary_process_methods2 = args.secondary_process_methods2
# secondary_target_classes = args.secondary_target_classes
secondary_model_path = "E:\code\plastic\plastic20260224\plastic\plastic\modelsave\HDPELDPE_model\svm.m"
secondary_model_type = 'SVM'
secondary_process_methods1 = 'None'
secondary_process_methods2 = 'None'
secondary_target_classes = [1, 2]
# 记录总开始时间
total_start_time = time.time()
# 处理BIL文件生成RGB图像
print("Processing BIL file to generate RGB image...\n")
rgb_img = process_bil_files(bil_path)
# 修改hdr
change_hdr_file(bil_path)
segmentation_start_time = time.time()
# 生成掩膜mask为16位的塑料标签掩膜
print("Generating mask ...\n")
mask, filter_mask_original = detect_microplastic_mask_from_array(
image=rgb_img, # 直接传入cv2.imread的结果
filter_method='threshold',
diameter=None,
flow_threshold=0.4,
cellprob_threshold=-1,
detect_filter=True
)
# 提取特征
print("Extracting features from BIL file...\n")
df = process_images(bil_path, mask)
# 背景校正(保持现有逻辑)
print("Applying background correction...\n")
df_correct = process_images_background(bil_path, mask)
# 自动识别光谱列范围假设从第2列开始连续为光谱列长度=背景校正矩阵列数
spec_start = 1
spec_len_src = len(df_correct)
spec_end_src = spec_start + spec_len_src
# 归一化(按通道逐列相除)
df.iloc[:, spec_start:spec_end_src] = df.iloc[:, spec_start:spec_end_src].div(df_correct, axis=1)
# 读取当前BIL的波长并重采样到训练用波长
src_waves = read_wavelengths_from_hdr(bil_path)
need_resample = (src_waves.size > 0) and (src_waves.size != len(TRAIN_WAVELENGTHS) or not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2))
if need_resample:
print(f"Resampling spectra: src {src_waves.size} bands -> dst {len(TRAIN_WAVELENGTHS)} bands\n")
X_src = df.iloc[:, spec_start:spec_end_src].to_numpy(dtype=np.float64)
X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS)
# 用训练维度(168)替换原光谱列;其余形状/统计特征保持不变
spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))]
df = pd.concat(
[
df.iloc[:, :spec_start],
pd.DataFrame(X_dst, columns=spec_col_names, index=df.index),
df.iloc[:, spec_end_src:]
],
axis=1
)
# 更新光谱列的新范围
spec_end_src = spec_start + len(TRAIN_WAVELENGTHS)
# 数据清理(保持原有规则)
print("Cleaning data...\n")
df = df.dropna()
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
df = df[df['area'] >= 500]
# 列筛选此时光谱列已与训练对齐为168维因此区间索引仍可复用
cols_to_remove = df.columns[np.r_[1:5, 87:110, 166:169, -10:-1]]
df = df.drop(columns=cols_to_remove)
# 继续后续流程
segmentation_time = time.time() - segmentation_start_time
df = df.iloc[:, :]
# 预测分类(分类阶段)
classification_start_time = time.time()
# 预测分类
print("Predicting classes...\n")
loaded_model = load_model(primary_model_path)
# 断言数值特征维度应等于训练时的scaler输入维度
try:
import joblib
scaler = joblib.load(os.path.join(os.path.dirname(primary_model_path), 'scaler_params.pkl'))
numeric_cols = [c for c in df.columns[1:] if np.issubdtype(df[c].dtype, np.number) and c != 'contour']
if hasattr(scaler, 'mean_') and len(numeric_cols) != scaler.mean_.shape[0]:
raise ValueError(f"Feature dimension mismatch: current {len(numeric_cols)} != scaler {scaler.mean_.shape[0]}. Check resampling and column selection.")
except Exception:
pass
# 验证输入
validate_inputs(bil_path, output_path, primary_model_path)
bands =[912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6]
df_pre = predict_with_model(
df,
primary_model_path,
model_type=primary_model_type,
ProcessMethods1=primary_process_methods1,
ProcessMethods2=primary_process_methods2
)
# 修改HDR文件
change_hdr_file(bil_path,bands)
# 对HDPE和LDPE进行二次分类
# 从第一次分类结果中提取SECONDARY_TARGET_CLASSES类别的掩膜轮廓
target_classes = set(secondary_target_classes or [])
mask_secondary = df_pre['Predictions'].isin(target_classes)
# 处理BIL文件生成RGB图像
print("处理BIL文件生成RGB图像...")
rgb_img = generate_rgb(bil_path)
if mask_secondary.any():
# 只有在找到目标类别时才进行背景校正和二次分类
print(f"Running secondary classification for classes: {sorted(target_classes)}")
# 图像信息的背景矫正
df_correct = shape_correct_background(bil_path, mask, filter_mask_original)
# 创建新的掩膜mask_second只包含目标类别的轮廓
mask_second = np.zeros_like(mask, dtype=np.uint16)
for idx in df_pre[mask_secondary].index:
contour = df_pre.loc[idx, 'contour']
if isinstance(contour, list) and len(contour) > 0:
contour_array = np.array(contour, dtype=np.int32)
cv2.fillPoly(mask_second, [contour_array], idx + 1) # 使用索引+1作为标签
# 分割阶段
segmentation_start_time = time.time()
print("生成掩膜...")
mask, filter_mask_original = run_segmentation(rgb_img, segmentation_model_path)
# 提取特征
df_shape = extract_features(df_correct, mask_second)
print("从BIL文件提取特征...")
df = extract_primary_features(bil_path, mask)
# 确保使用第2到13列作为模型输入特征
if len(df_shape.columns) >= 13:
df_shape = df_shape.iloc[:, :13]
# 背景校正
print("应用背景校正...")
bg_spectrum = compute_background_spectrum(bil_path, mask)
# 二次分类:使用第二个模型预测并更新分类结果
if secondary_model_path:
df_secondary = predict_with_model(
df_shape,
secondary_model_path,
model_type=secondary_model_type,
ProcessMethods1=secondary_process_methods1,
ProcessMethods2=secondary_process_methods2
)
df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1
else:
print("Secondary model path not provided; skipping secondary classification.\n")
else:
print("No samples from target classes found; skipping secondary classification.\n")
# 背景校正 + 仅在与训练相机波长不一致时重采样
df = apply_background_and_optional_resample(df, bg_spectrum, bil_path)
# 识别类别7中的背景阴影误判通过边界清晰度特征
# 真正的类别7边界清晰背景阴影边界模糊
class_7_mask = df_pre['Predictions'] == 7
if class_7_mask.any():
print(f"Processing {class_7_mask.sum()} samples with class 7 to identify background shadows...\n")
# 数据清理和列选择
print("清理数据...")
df = clean_and_select_columns(df)
# 将PIL Image转换为numpy数组
if hasattr(rgb_img, 'mode'): # 检查是否是PIL Image
rgb_img_array = np.array(rgb_img)
else:
rgb_img_array = rgb_img
segmentation_time = time.time() - segmentation_start_time
# 转换为灰度图(用于计算梯度)
if len(rgb_img_array.shape) == 3:
gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY)
else:
gray_img = rgb_img_array
# 分类阶段
classification_start_time = time.time()
print("预测分类...")
df_pre = run_primary_classification(df, primary_model_path)
# 计算梯度图使用Sobel算子
grad_x = cv2.Sobel(gray_img, cv2.CV_64F, 1, 0, ksize=3)
grad_y = cv2.Sobel(gray_img, cv2.CV_64F, 0, 1, ksize=3)
gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2)
# 二次分类
df_pre = run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original)
# 先收集所有类别7样本的边缘梯度值用于确定阈值
all_class7_gradients = []
valid_indices = []
for idx in df_pre[class_7_mask].index:
try:
contour = df_pre.loc[idx, 'contour']
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
continue
# 后处理类别7阴影
df_pre = postprocess_class7_shadow(df_pre, rgb_img)
contour_array = np.array(contour, dtype=np.int32)
if len(contour_array.shape) == 1:
continue
classification_time = time.time() - classification_start_time
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
edge_gradients = gradient_magnitude[mask_img > 0]
if len(edge_gradients) > 0:
all_class7_gradients.extend(edge_gradients)
valid_indices.append(idx)
except:
continue
# 保存结果
print("保存ENVI分类结果...")
write_outputs(bil_path, df_pre, output_path)
# 基于类别7样本的梯度分布确定阈值
# 使用类别7样本梯度值的中位数作为基准低于某个分位数如30%)的认为是背景阴影
if len(all_class7_gradients) > 0:
gradient_threshold = np.percentile(all_class7_gradients, 30) # 使用类别7样本梯度值的30%分位数
else:
gradient_threshold = np.percentile(gradient_magnitude, 30) # 如果没有有效样本使用整张图的30%分位数
print(f"ENVI分类结果已保存至: {output_path}")
print(f"Gradient threshold for class 7: {gradient_threshold:.2f}\n")
# 计算总耗时
total_time = time.time() - total_start_time
# 处理每个类别7的样本判断是否为背景阴影
indices_to_update = []
for idx in valid_indices:
try:
contour = df_pre.loc[idx, 'contour']
contour_array = np.array(contour, dtype=np.int32)
# 打印耗时统计
print(f"\n{'=' * 60}")
print("处理完成")
print(f"{'=' * 60}")
print(f"分割耗时: {segmentation_time:.2f}")
print(f"分类耗时: {classification_time:.2f}")
print(f"总耗时: {total_time:.2f}")
print(f"{'=' * 60}")
print(f"结果已保存至: {output_path}")
# 创建轮廓掩膜线宽为2像素用于提取边缘
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
# 提取轮廓边缘的梯度值
edge_gradients = gradient_magnitude[mask_img > 0]
if len(edge_gradients) == 0:
continue
# 计算轮廓边缘的平均梯度强度
mean_gradient = np.mean(edge_gradients)
# 如果平均梯度强度低于阈值,认为是背景阴影(边界模糊)
if mean_gradient < gradient_threshold:
indices_to_update.append(idx)
print(
f"Sample {idx}: mean_gradient={mean_gradient:.2f}, threshold={gradient_threshold:.2f} -> identified as background shadow")
except Exception as e:
print(f"Error processing sample at index {idx}: {str(e)}")
continue
# 将背景阴影的类别7改为类别9
if indices_to_update:
df_pre.loc[indices_to_update, 'Predictions'] = 9
print(f"Updated {len(indices_to_update)} samples from class 7 to class 9 (background shadows)\n")
else:
print("No samples needed to be updated from class 7\n")
classification_time = time.time() - classification_start_time
df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1)
# 保存ENVI分类结果
print("Saving ENVI classification results...\n")
save_envi_classification(bil_path, df_pre, output_path)
print(f"ENVI classification results saved to: {output_path}")
# 计算总耗时
total_time = time.time() - total_start_time
# 打印耗时统计
print(f"\n{'=' * 60}")
print(f"处理完成")
print(f"{'=' * 60}")
print(f"分割耗时: {segmentation_time:.2f}")
print(f"分类耗时: {classification_time:.2f}")
print(f"总耗时: {total_time:.2f}")
print(f"{'=' * 60}")
print(f"结果已保存至: {output_path}")
except Exception as e:
print(f"处理失败: {e}")
raise
if __name__ == "__main__":