import os import cv2 import matplotlib import numpy as np import argparse import pandas as pd from bil2rgb import process_bil_files from classification_model.Parallel.predict_plastic import load_model, predict_with_model from mask import detect_microplastic_mask_from_array from shape_spectral import process_images from shape_spectral_background import process_images_background from extact_shape import shape_correct_background, extract_features import time matplotlib.use('TkAgg') # 训练相机波长(237通道) TRAIN_WAVELENGTHS = [912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6] def read_wavelengths_from_hdr(bil_path): hdr_path = os.path.splitext(bil_path)[0] + '.hdr' if not os.path.exists(hdr_path): return np.array([], dtype=np.float64) with open(hdr_path, 'r') as f: txt = f.read() if 'wavelength' not in txt: return np.array([], dtype=np.float64) seg = txt.split('wavelength', 1)[1] seg = seg[seg.find('{')+1: seg.find('}')] vals = [v.strip() for v in seg.split(',') if v.strip()] try: waves = np.array([float(v) for v in vals], dtype=np.float64) except Exception: waves = np.array([], dtype=np.float64) return waves def resample_spectra_matrix(X, src_waves, dst_waves): src = np.asarray(src_waves, dtype=np.float64) dst = np.asarray(dst_waves, dtype=np.float64) X = np.asarray(X, dtype=np.float64) if src.size == 0 or dst.size == 0: return X # 线性插值,越界用端点外推,避免维度缺失 out = np.empty((X.shape[0], dst.size), dtype=np.float64) for i in range(X.shape[0]): row = X[i] out[i] = np.interp(dst, src, row, left=row[0], right=row[-1]) return out def apply_background_no_resample(df, bg_spectrum): """ 仅做背景校正,不做任何重采样。 - 自动选择以 wavelength_ 或 band_ 开头的光谱列 - 若背景长度与光谱列数不一致,按尾部对齐取最小长度进行校正 """ # 识别光谱列 spec_cols = [c for c in df.columns if isinstance(c, str) and (c.startswith('wavelength_') or c.startswith('band_'))] if not spec_cols: raise ValueError("未找到光谱列(以 wavelength_ 或 band_ 开头)") bg = np.asarray(bg_spectrum, dtype=np.float64).ravel() if bg.size == 0: raise ValueError("背景光谱长度为0,无法进行背景校正") # 尾部对齐,取最小长度,避免维度不一致 n = min(len(spec_cols), bg.shape[0]) use_cols = spec_cols[-n:] df.loc[:, use_cols] = df.loc[:, use_cols].div(bg[-n:], axis=1) return df def parse_arguments(): """解析命令行参数""" parser = argparse.ArgumentParser(description='Microplastic spectral shape classification') # 必需参数 parser.add_argument('--bil_path', required=True, help='Path to input BIL file') parser.add_argument('--output_path', required=True, help='Path to output classification result') parser.add_argument('--model_path', required=True, help='Path to primary classification model') # parser.add_argument('--primary_model_type', default='SVM', help='Type of primary model (default: SVM)') # parser.add_argument('--primary_process_methods1', default='SS', help='Primary process method 1 (default: SS)') # parser.add_argument('--primary_process_methods2', default='None', help='Primary process method 2 (default: None)') # parser.add_argument('--secondary_model', default="D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m", help='Path to secondary classification model') # parser.add_argument('--secondary_model_type', default='SVM', help='Type of secondary model (default: SVM)') # parser.add_argument('--secondary_process_methods1', default='None', # help='Secondary process method 1 (default: None)') # parser.add_argument('--secondary_process_methods2', default='None', # help='Secondary process method 2 (default: None)') # parser.add_argument('--secondary_target_classes', nargs='+', type=int, default=[1,2], # help='Target classes for secondary classification (space separated)') return parser.parse_args() # ---------------------------- # 配置参数:直接在此修改 # ---------------------------- # BIL_PATH = r"D:/Data/Test/PET_bottle2.bil" # OUTPUT_PATH = r'D:/Data/PET_bottle2_class.bil' # # PRIMARY_MODEL_PATH = r"D:\plastic\plastic\modelsave\svm.m" # PRIMARY_MODEL_TYPE = 'SVM' # PRIMARY_PROCESS_METHODS1 = 'SS' # PRIMARY_PROCESS_METHODS2 = 'None' # # SECONDARY_MODEL_PATH = "D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m" # 若不需要二次分类,则保持为 None # SECONDARY_MODEL_TYPE = 'SVM' # SECONDARY_PROCESS_METHODS1 = 'None' # SECONDARY_PROCESS_METHODS2 = 'None' # SECONDARY_TARGET_CLASSES = [1, 2] def read_hdr_file(bil_path): hdr_path = bil_path.replace('.bil', '.hdr') with open(hdr_path, 'r') as f: header = f.readlines() samples, lines = None, None for line in header: if line.startswith('samples'): samples = int(line.split('=')[-1].strip()) if line.startswith('lines'): lines = int(line.split('=')[-1].strip()) return samples, lines def shrink_contours(bil_path, df, shrink_pixels=1): """ 对DataFrame中的所有轮廓进行收缩操作,避免塑料之间的相连 Args: bil_path: BIL文件路径,用于获取图像尺寸 df: 包含contour列的DataFrame shrink_pixels: 收缩的像素数,默认1像素 Returns: 更新后的DataFrame,contour列已被收缩 """ samples, lines = read_hdr_file(bil_path) # 创建腐蚀核 kernel = np.ones((2 * shrink_pixels + 1, 2 * shrink_pixels + 1), np.uint8) # 创建临时掩膜用于处理 temp_mask = np.zeros((lines, samples), dtype=np.uint8) # 创建DataFrame副本 df = df.copy() # 遍历每一行,更新轮廓 for idx, row in df.iterrows(): contour = row['contour'] if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3: continue try: contour_array = np.array(contour, dtype=np.int32) if len(contour_array.shape) == 1: continue # 清空临时掩膜 temp_mask.fill(0) # 填充轮廓 cv2.fillPoly(temp_mask, [contour_array], 255) # 对掩膜进行腐蚀操作 eroded_mask = cv2.erode(temp_mask, kernel, iterations=1) # 重新提取轮廓 contours, _ = cv2.findContours(eroded_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if len(contours) > 0: # 选择最大的轮廓(如果有多个) largest_contour = max(contours, key=cv2.contourArea) # 转换为列表格式,保持与原始格式一致 if len(largest_contour) >= 3: updated_contour = largest_contour.reshape(-1, 2).tolist() df.at[idx, 'contour'] = updated_contour except Exception as e: # 如果处理失败,保留原始轮廓 continue return df def save_envi_classification(bil_path, df, savepath): samples, lines = read_hdr_file(bil_path) classification_result = np.zeros((lines, samples), dtype=np.uint16) for _, row in df.iterrows(): contour = row['contour'] prediction = int(row['Predictions']) + 1 # 先加1 if prediction in (10, 11): # 再判断是否为10或11 prediction = 0 # 视为背景 contour = np.array(contour, dtype=np.int32) cv2.fillPoly(classification_result, [contour], prediction) output_path = savepath with open(output_path, 'wb') as f: classification_result.tofile(f) header_content = f"""ENVI description = {{ Classification Result.}} samples = {samples} lines = {lines} bands = 1 header offset = 0 file type = ENVI Standard data type = 2 interleave = bil classes = 10 class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC }} single pixel area = 0.000036 unit = mm2 byte order = 0 wavelength units = nm """ filename, ext = os.path.splitext(savepath) header_filename = filename + '.hdr' with open(header_filename, 'w') as header_file: header_file.write(header_content) def change_hdr_file(bil_path, wavelengths=None): hdr_path = os.path.splitext(bil_path)[0] + '.hdr' if not os.path.exists(hdr_path): print(f"错误: 找不到对应的HDR文件: {hdr_path}") return # 仅在缺少 wavelength 字段时才尝试写入 with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as file: content = file.read() if 'wavelength' in content: print(f"{os.path.basename(hdr_path)} 已包含 wavelength 字段,跳过追加。") return if wavelengths is None or len(wavelengths) == 0: print("HDR 缺少 wavelength,但未提供 wavelengths,跳过写入以避免错误。") return needs_newline = not content.endswith('\n') wavelength_info = "wavelength = {" + ", ".join(str(float(w)) for w in wavelengths) + "}\n" with open(hdr_path, 'a', encoding='utf-8', errors='ignore') as file: if needs_newline: file.write('\n') file.write(wavelength_info) print(f"已在 {os.path.basename(hdr_path)} 末尾追加 wavelength 字段。") def validate_inputs(bil_path, output_path, model_path): """验证输入文件和参数""" # 检查BIL和HDR文件存在 if not os.path.exists(bil_path): raise FileNotFoundError(f"BIL文件不存在: {bil_path}") hdr_path = os.path.splitext(bil_path)[0] + '.hdr' if not os.path.exists(hdr_path): raise FileNotFoundError(f"HDR文件不存在: {hdr_path}") # 检查输出目录可写 output_dir = os.path.dirname(output_path) if output_dir and not os.path.exists(output_dir): try: os.makedirs(output_dir, exist_ok=True) except Exception as e: raise RuntimeError(f"无法创建输出目录: {output_dir}") from e # 检查模型文件存在 if not os.path.exists(model_path): raise FileNotFoundError(f"主模型文件不存在: {model_path}") # 检查BIL文件波段数是否足够 try: from spectral.io import envi img = envi.open(hdr_path, bil_path) n_bands = img.nbands # bil2rgb需要波段索引9, 59, 159 if n_bands <= 159: raise ValueError(f"BIL文件波段数不足: 需要至少160个波段,但只有{n_bands}个") except Exception as e: raise RuntimeError(f"无法读取BIL文件头信息: {bil_path}") from e def generate_rgb(bil_path): """处理BIL文件生成RGB图像""" try: rgb_img = process_bil_files(bil_path) return rgb_img except Exception as e: raise RuntimeError(f"生成RGB图像失败: bil_path={bil_path}") from e def run_segmentation(rgb_img, segmentation_model_path=None): """运行分割获取掩膜""" try: mask, filter_mask_original = detect_microplastic_mask_from_array( image=rgb_img, filter_method='threshold', diameter=None, flow_threshold=0.4, cellprob_threshold=-1, model_path=segmentation_model_path, detect_filter=True ) return mask, filter_mask_original except Exception as e: raise RuntimeError("分割失败: 无法检测微塑料颗粒") from e def extract_primary_features(bil_path, mask): """提取主要特征""" try: df = process_images(bil_path, mask) return df except Exception as e: raise RuntimeError(f"特征提取失败: bil_path={bil_path}") from e def compute_background_spectrum(bil_path, mask): """计算背景光谱""" try: df_correct = process_images_background(bil_path, mask) return df_correct except Exception as e: raise RuntimeError(f"背景光谱计算失败: bil_path={bil_path}") from e def apply_background_and_optional_resample(df, bg_spectrum, bil_path): """应用背景校正和可选的重采样""" # 识别光谱列:所有以wavelength_开头的列 spec_cols = [c for c in df.columns if c.startswith('wavelength_')] if not spec_cols: raise ValueError("未找到光谱列(以wavelength_开头的列)") if len(spec_cols) != len(bg_spectrum): raise ValueError(f"光谱列数量({len(spec_cols)})与背景光谱长度({len(bg_spectrum)})不匹配") # 背景校正:用背景光谱逐列相除 df[spec_cols] = df[spec_cols].div(bg_spectrum, axis=1) # 检查是否需要重采样 src_waves = read_wavelengths_from_hdr(bil_path) need_resample = (src_waves.size > 0 and (src_waves.size != len(TRAIN_WAVELENGTHS) or not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2))) if need_resample: print(f"重采样光谱: 源波段数 {src_waves.size} -> 目标波段数 {len(TRAIN_WAVELENGTHS)}") # 提取光谱数据 X_src = df[spec_cols].to_numpy(dtype=np.float64) X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS) # 替换光谱列 spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))] df = pd.concat([ df.drop(columns=spec_cols), # 移除原有光谱列 pd.DataFrame(X_dst, columns=spec_col_names, index=df.index) ], axis=1) return df def clean_and_select_columns(df): """数据清理和列选择""" # 移除NaN值 df = df.dropna() # 过滤轮廓点数不足的样本 df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)] # 过滤面积过小的样本 df = df[df['area'] >= 500] # 列筛选:使用原来的硬编码索引删除逻辑 cols_to_remove = df.columns[np.r_[-14: -1]] df = df.drop(columns=cols_to_remove) return df def run_primary_classification(df, primary_model_path): """运行主要分类""" try: # 验证特征维度 try: import joblib scaler_path = os.path.join(os.path.dirname(primary_model_path), 'scaler_params.pkl') if os.path.exists(scaler_path): scaler = joblib.load(scaler_path) numeric_cols = [c for c in df.columns[1:] if np.issubdtype(df[c].dtype, np.number) and c != 'contour'] if hasattr(scaler, 'mean_') and len(numeric_cols) != scaler.mean_.shape[0]: raise ValueError(f"特征维度不匹配: 当前{numeric_cols}列 != 训练时{scaler.mean_.shape[0]}维") except Exception as e: print(f"警告: 无法验证特征维度: {e}") df_pre = predict_with_model( df, primary_model_path, model_type='SVM', ProcessMethods1='SS', ProcessMethods2='None' ) return df_pre except Exception as e: raise RuntimeError(f"主要分类失败: model_path={primary_model_path}") from e def run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original): """根据需要运行二次分类""" # 二次分类配置(保持在代码内) secondary_model_path = os.path.join(os.path.dirname(__file__), 'modelsave', 'HDPELDPE_model', 'svm.m') secondary_target_classes = [1, 2] # HDPE, LDPE target_classes = set(secondary_target_classes or []) mask_secondary = df_pre['Predictions'].isin(target_classes) if not mask_secondary.any(): print("未找到目标类别样本,跳过二次分类") return df_pre print(f"为类别 {sorted(target_classes)} 运行二次分类") # 检查二次模型是否存在 if not os.path.exists(secondary_model_path): print(f"警告: 二次模型不存在: {secondary_model_path},跳过二次分类") return df_pre try: # 图像信息的背景矫正 df_correct = shape_correct_background(bil_path, mask, filter_mask_original) # 创建只包含目标类别的掩膜 mask_second = np.zeros_like(mask, dtype=np.uint16) for idx in df_pre[mask_secondary].index: contour = df_pre.loc[idx, 'contour'] if isinstance(contour, list) and len(contour) > 0: contour_array = np.array(contour, dtype=np.int32) cv2.fillPoly(mask_second, [contour_array], idx + 1) # 提取特征 df_shape = extract_features(df_correct, mask_second) # 确保使用前13列作为模型输入特征 if len(df_shape.columns) >= 13: df_shape = df_shape.iloc[:, :13] # 二次分类 df_secondary = predict_with_model( df_shape, secondary_model_path, model_type='SVM', ProcessMethods1='None', ProcessMethods2='None' ) # 更新预测结果(类别+1) df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1 except Exception as e: print(f"警告: 二次分类失败,将继续使用主要分类结果: {e}") return df_pre def postprocess_class7_shadow(df_pre, rgb_img): """后处理类别7/8中的背景阴影(更稳健)""" # 7和8都纳入检查范围 mask_targets = df_pre['Predictions'].isin([7, 8]) if not mask_targets.any(): return df_pre print(f"处理 {mask_targets.sum()} 个类别7/8样本,识别背景阴影...") # 灰度图 if hasattr(rgb_img, 'mode'): # PIL Image rgb_img_array = np.array(rgb_img) else: rgb_img_array = rgb_img if len(rgb_img_array.shape) == 3: gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY) else: gray_img = rgb_img_array # 更稳的梯度(Scharr) grad_x = cv2.Scharr(gray_img, cv2.CV_64F, 1, 0) grad_y = cv2.Scharr(gray_img, cv2.CV_64F, 0, 1) gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2) # 统计指标 edge_ratios = [] contrast_norms = [] areas_list = [] measures_per_idx = {} edge_thick = 3 ring_thick = 5 eps = 1e-6 for idx in df_pre[mask_targets].index: try: contour = df_pre.loc[idx, 'contour'] if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3: continue contour_array = np.array(contour, dtype=np.int32) if len(contour_array.shape) == 1: continue poly_mask = np.zeros(gray_img.shape, dtype=np.uint8) cv2.fillPoly(poly_mask, [contour_array], 255) # 边界带 edge_mask = np.zeros_like(poly_mask) cv2.drawContours(edge_mask, [contour_array], -1, 255, thickness=edge_thick) # 外环:膨胀边界去掉边界本身与内区 ring_mask = cv2.dilate(edge_mask, np.ones((ring_thick, ring_thick), np.uint8), iterations=1) ring_mask = cv2.bitwise_and(ring_mask, cv2.bitwise_not(edge_mask)) ring_mask = cv2.bitwise_and(ring_mask, cv2.bitwise_not(poly_mask)) edge_vals = gradient_magnitude[edge_mask > 0] ring_vals = gradient_magnitude[ring_mask > 0] if edge_vals.size == 0 or ring_vals.size == 0: continue r_edge = float(np.median(edge_vals) / (np.median(ring_vals) + eps)) inside_vals = gray_img[poly_mask > 0] outside_vals = gray_img[ring_mask > 0] if inside_vals.size == 0 or outside_vals.size == 0: continue dI = float(np.median(inside_vals) - np.median(outside_vals)) c_norm = abs(dI) / (np.std(outside_vals) + eps) # 面积(可选保护) area_val = None if 'area' in df_pre.columns: try: area_val = float(df_pre.loc[idx, 'area']) except Exception: area_val = None edge_ratios.append(r_edge) contrast_norms.append(c_norm) areas_list.append(area_val if area_val is not None else 0.0) measures_per_idx[idx] = (r_edge, c_norm, area_val) except Exception: continue if not measures_per_idx: print("无可用的7/8类样本进行阴影判别") return df_pre def robust_q(arr, q): vals = [v for v in arr if v is not None] return float(np.percentile(vals, q)) if len(vals) > 0 else None # 稳健阈值(低于30分位更像阴影) r_thresh = robust_q(edge_ratios, 30.0) c_thresh = robust_q(contrast_norms, 30.0) # 面积保护:仅对较小目标允许改写,阈值取面积分布的40%分位,限定上限避免过大 a_thresh = robust_q(areas_list, 40.0) if a_thresh is None or a_thresh <= 0: a_thresh = 1200.0 a_thresh = min(a_thresh, 2000.0) indices_to_update = [] for idx, (r_edge, c_norm, area_val) in measures_per_idx.items(): small_enough = (area_val is None) or (area_val <= a_thresh) if (r_thresh is not None and c_thresh is not None and small_enough): # 两个指标都低,且面积不大 -> 判定为阴影 if (r_edge < r_thresh) and (c_norm < c_thresh): indices_to_update.append(idx) if indices_to_update: # 改为背景(0),而不是9(PVC) df_pre.loc[indices_to_update, 'Predictions'] = 9 print(f"将 {len(indices_to_update)} 个样本从类别7/8改为背景(阴影),面积阈值≈{a_thresh:.0f}") else: print("无需更新类别7/8样本") return df_pre def write_outputs(bil_path, df_pre, output_path): """写入输出结果""" try: # 收缩轮廓 df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1) # 保存ENVI分类结果 save_envi_classification(bil_path, df_pre, output_path) except Exception as e: raise RuntimeError(f"保存结果失败: output_path={output_path}") from e def main(): """主函数""" args = parse_arguments() bil_path = args.bil_path output_path = args.output_path primary_model_path = args.model_path segmentation_model_path = None # 记录总开始时间 total_start_time = time.time() try: # 验证输入 validate_inputs(bil_path, output_path, primary_model_path) bands =[912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6] # 修改HDR文件 change_hdr_file(bil_path,bands) # 处理BIL文件生成RGB图像 print("处理BIL文件生成RGB图像...") rgb_img = generate_rgb(bil_path) # 分割阶段 segmentation_start_time = time.time() print("生成掩膜...") mask, filter_mask_original = run_segmentation(rgb_img, segmentation_model_path) # 提取特征 print("从BIL文件提取特征...") df = extract_primary_features(bil_path, mask) # 背景校正 print("应用背景校正...") bg_spectrum = compute_background_spectrum(bil_path, mask) # 背景校正 + 仅在与训练相机波长不一致时重采样 df = apply_background_and_optional_resample(df, bg_spectrum, bil_path) # 数据清理和列选择 print("清理数据...") df = clean_and_select_columns(df) segmentation_time = time.time() - segmentation_start_time # 分类阶段 classification_start_time = time.time() print("预测分类...") df_pre = run_primary_classification(df, primary_model_path) # 二次分类 df_pre = run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original) # 后处理类别7阴影 df_pre = postprocess_class7_shadow(df_pre, rgb_img) classification_time = time.time() - classification_start_time # 保存结果 print("保存ENVI分类结果...") write_outputs(bil_path, df_pre, output_path) print(f"ENVI分类结果已保存至: {output_path}") # 计算总耗时 total_time = time.time() - total_start_time # 打印耗时统计 print(f"\n{'=' * 60}") print("处理完成") print(f"{'=' * 60}") print(f"分割耗时: {segmentation_time:.2f} 秒") print(f"分类耗时: {classification_time:.2f} 秒") print(f"总耗时: {total_time:.2f} 秒") print(f"{'=' * 60}") print(f"结果已保存至: {output_path}") except Exception as e: print(f"处理失败: {e}") raise if __name__ == "__main__": main()