Files
micro_plastic/main.py
2026-03-05 17:12:01 +08:00

696 lines
28 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import matplotlib
import numpy as np
import argparse
import pandas as pd
from bil2rgb import process_bil_files
from classification_model.Parallel.predict_plastic import load_model, predict_with_model
from mask import detect_microplastic_mask_from_array
from shape_spectral import process_images
from shape_spectral_background import process_images_background
from extact_shape import shape_correct_background, extract_features
import time
matplotlib.use('TkAgg')
# 训练相机波长237通道
TRAIN_WAVELENGTHS = [912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6]
def read_wavelengths_from_hdr(bil_path):
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
if not os.path.exists(hdr_path):
return np.array([], dtype=np.float64)
with open(hdr_path, 'r') as f:
txt = f.read()
if 'wavelength' not in txt:
return np.array([], dtype=np.float64)
seg = txt.split('wavelength', 1)[1]
seg = seg[seg.find('{')+1: seg.find('}')]
vals = [v.strip() for v in seg.split(',') if v.strip()]
try:
waves = np.array([float(v) for v in vals], dtype=np.float64)
except Exception:
waves = np.array([], dtype=np.float64)
return waves
def resample_spectra_matrix(X, src_waves, dst_waves):
src = np.asarray(src_waves, dtype=np.float64)
dst = np.asarray(dst_waves, dtype=np.float64)
X = np.asarray(X, dtype=np.float64)
if src.size == 0 or dst.size == 0:
return X
# 线性插值,越界用端点外推,避免维度缺失
out = np.empty((X.shape[0], dst.size), dtype=np.float64)
for i in range(X.shape[0]):
row = X[i]
out[i] = np.interp(dst, src, row, left=row[0], right=row[-1])
return out
def apply_background_no_resample(df, bg_spectrum):
"""
仅做背景校正,不做任何重采样。
- 自动选择以 wavelength_ 或 band_ 开头的光谱列
- 若背景长度与光谱列数不一致,按尾部对齐取最小长度进行校正
"""
# 识别光谱列
spec_cols = [c for c in df.columns if isinstance(c, str) and (c.startswith('wavelength_') or c.startswith('band_'))]
if not spec_cols:
raise ValueError("未找到光谱列(以 wavelength_ 或 band_ 开头)")
bg = np.asarray(bg_spectrum, dtype=np.float64).ravel()
if bg.size == 0:
raise ValueError("背景光谱长度为0无法进行背景校正")
# 尾部对齐,取最小长度,避免维度不一致
n = min(len(spec_cols), bg.shape[0])
use_cols = spec_cols[-n:]
df.loc[:, use_cols] = df.loc[:, use_cols].div(bg[-n:], axis=1)
return df
def parse_arguments():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description='Microplastic spectral shape classification')
# 必需参数
parser.add_argument('--bil_path', required=True, help='Path to input BIL file')
parser.add_argument('--output_path', required=True, help='Path to output classification result')
parser.add_argument('--model_path', required=True, help='Path to primary classification model')
# parser.add_argument('--primary_model_type', default='SVM', help='Type of primary model (default: SVM)')
# parser.add_argument('--primary_process_methods1', default='SS', help='Primary process method 1 (default: SS)')
# parser.add_argument('--primary_process_methods2', default='None', help='Primary process method 2 (default: None)')
# parser.add_argument('--secondary_model', default="D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m", help='Path to secondary classification model')
# parser.add_argument('--secondary_model_type', default='SVM', help='Type of secondary model (default: SVM)')
# parser.add_argument('--secondary_process_methods1', default='None',
# help='Secondary process method 1 (default: None)')
# parser.add_argument('--secondary_process_methods2', default='None',
# help='Secondary process method 2 (default: None)')
# parser.add_argument('--secondary_target_classes', nargs='+', type=int, default=[1,2],
# help='Target classes for secondary classification (space separated)')
return parser.parse_args()
# ----------------------------
# 配置参数:直接在此修改
# ----------------------------
# BIL_PATH = r"D:/Data/Test/PET_bottle2.bil"
# OUTPUT_PATH = r'D:/Data/PET_bottle2_class.bil'
#
# PRIMARY_MODEL_PATH = r"D:\plastic\plastic\modelsave\svm.m"
# PRIMARY_MODEL_TYPE = 'SVM'
# PRIMARY_PROCESS_METHODS1 = 'SS'
# PRIMARY_PROCESS_METHODS2 = 'None'
#
# SECONDARY_MODEL_PATH = "D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m" # 若不需要二次分类,则保持为 None
# SECONDARY_MODEL_TYPE = 'SVM'
# SECONDARY_PROCESS_METHODS1 = 'None'
# SECONDARY_PROCESS_METHODS2 = 'None'
# SECONDARY_TARGET_CLASSES = [1, 2]
def read_hdr_file(bil_path):
hdr_path = bil_path.replace('.bil', '.hdr')
with open(hdr_path, 'r') as f:
header = f.readlines()
samples, lines = None, None
for line in header:
if line.startswith('samples'):
samples = int(line.split('=')[-1].strip())
if line.startswith('lines'):
lines = int(line.split('=')[-1].strip())
return samples, lines
def shrink_contours(bil_path, df, shrink_pixels=1):
"""
对DataFrame中的所有轮廓进行收缩操作避免塑料之间的相连
Args:
bil_path: BIL文件路径用于获取图像尺寸
df: 包含contour列的DataFrame
shrink_pixels: 收缩的像素数默认1像素
Returns:
更新后的DataFramecontour列已被收缩
"""
samples, lines = read_hdr_file(bil_path)
# 创建腐蚀核
kernel = np.ones((2 * shrink_pixels + 1, 2 * shrink_pixels + 1), np.uint8)
# 创建临时掩膜用于处理
temp_mask = np.zeros((lines, samples), dtype=np.uint8)
# 创建DataFrame副本
df = df.copy()
# 遍历每一行,更新轮廓
for idx, row in df.iterrows():
contour = row['contour']
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
continue
try:
contour_array = np.array(contour, dtype=np.int32)
if len(contour_array.shape) == 1:
continue
# 清空临时掩膜
temp_mask.fill(0)
# 填充轮廓
cv2.fillPoly(temp_mask, [contour_array], 255)
# 对掩膜进行腐蚀操作
eroded_mask = cv2.erode(temp_mask, kernel, iterations=1)
# 重新提取轮廓
contours, _ = cv2.findContours(eroded_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if len(contours) > 0:
# 选择最大的轮廓(如果有多个)
largest_contour = max(contours, key=cv2.contourArea)
# 转换为列表格式,保持与原始格式一致
if len(largest_contour) >= 3:
updated_contour = largest_contour.reshape(-1, 2).tolist()
df.at[idx, 'contour'] = updated_contour
except Exception as e:
# 如果处理失败,保留原始轮廓
continue
return df
def save_envi_classification(bil_path, df, savepath):
samples, lines = read_hdr_file(bil_path)
classification_result = np.zeros((lines, samples), dtype=np.uint16)
for _, row in df.iterrows():
contour = row['contour']
prediction = int(row['Predictions']) + 1 # 先加1
if prediction in (10, 11): # 再判断是否为10或11
prediction = 0 # 视为背景
contour = np.array(contour, dtype=np.int32)
cv2.fillPoly(classification_result, [contour], prediction)
output_path = savepath
with open(output_path, 'wb') as f:
classification_result.tofile(f)
header_content = f"""ENVI
description = {{
Classification Result.}}
samples = {samples}
lines = {lines}
bands = 1
header offset = 0
file type = ENVI Standard
data type = 2
interleave = bil
classes = 10
class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC }}
single pixel area = 0.000036
unit = mm2
byte order = 0
wavelength units = nm
"""
filename, ext = os.path.splitext(savepath)
header_filename = filename + '.hdr'
with open(header_filename, 'w') as header_file:
header_file.write(header_content)
def change_hdr_file(bil_path, wavelengths=None):
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
if not os.path.exists(hdr_path):
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
return
# 仅在缺少 wavelength 字段时才尝试写入
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as file:
content = file.read()
if 'wavelength' in content:
print(f"{os.path.basename(hdr_path)} 已包含 wavelength 字段,跳过追加。")
return
if wavelengths is None or len(wavelengths) == 0:
print("HDR 缺少 wavelength但未提供 wavelengths跳过写入以避免错误。")
return
needs_newline = not content.endswith('\n')
wavelength_info = "wavelength = {" + ", ".join(str(float(w)) for w in wavelengths) + "}\n"
with open(hdr_path, 'a', encoding='utf-8', errors='ignore') as file:
if needs_newline:
file.write('\n')
file.write(wavelength_info)
print(f"已在 {os.path.basename(hdr_path)} 末尾追加 wavelength 字段。")
def validate_inputs(bil_path, output_path, model_path):
"""验证输入文件和参数"""
# 检查BIL和HDR文件存在
if not os.path.exists(bil_path):
raise FileNotFoundError(f"BIL文件不存在: {bil_path}")
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
if not os.path.exists(hdr_path):
raise FileNotFoundError(f"HDR文件不存在: {hdr_path}")
# 检查输出目录可写
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
try:
os.makedirs(output_dir, exist_ok=True)
except Exception as e:
raise RuntimeError(f"无法创建输出目录: {output_dir}") from e
# 检查模型文件存在
if not os.path.exists(model_path):
raise FileNotFoundError(f"主模型文件不存在: {model_path}")
# 检查BIL文件波段数是否足够
try:
from spectral.io import envi
img = envi.open(hdr_path, bil_path)
n_bands = img.nbands
# bil2rgb需要波段索引9, 59, 159
if n_bands <= 159:
raise ValueError(f"BIL文件波段数不足: 需要至少160个波段但只有{n_bands}")
except Exception as e:
raise RuntimeError(f"无法读取BIL文件头信息: {bil_path}") from e
def generate_rgb(bil_path):
"""处理BIL文件生成RGB图像"""
try:
rgb_img = process_bil_files(bil_path)
return rgb_img
except Exception as e:
raise RuntimeError(f"生成RGB图像失败: bil_path={bil_path}") from e
def run_segmentation(rgb_img, segmentation_model_path=None):
"""运行分割获取掩膜"""
try:
mask, filter_mask_original = detect_microplastic_mask_from_array(
image=rgb_img,
filter_method='threshold',
diameter=None,
flow_threshold=0.4,
cellprob_threshold=-1,
model_path=segmentation_model_path,
detect_filter=True
)
return mask, filter_mask_original
except Exception as e:
raise RuntimeError("分割失败: 无法检测微塑料颗粒") from e
def extract_primary_features(bil_path, mask):
"""提取主要特征"""
try:
df = process_images(bil_path, mask)
return df
except Exception as e:
raise RuntimeError(f"特征提取失败: bil_path={bil_path}") from e
def compute_background_spectrum(bil_path, mask):
"""计算背景光谱"""
try:
df_correct = process_images_background(bil_path, mask)
return df_correct
except Exception as e:
raise RuntimeError(f"背景光谱计算失败: bil_path={bil_path}") from e
def apply_background_and_optional_resample(df, bg_spectrum, bil_path):
"""应用背景校正和可选的重采样"""
# 识别光谱列所有以wavelength_开头的列
spec_cols = [c for c in df.columns if c.startswith('wavelength_')]
if not spec_cols:
raise ValueError("未找到光谱列以wavelength_开头的列")
if len(spec_cols) != len(bg_spectrum):
raise ValueError(f"光谱列数量({len(spec_cols)})与背景光谱长度({len(bg_spectrum)})不匹配")
# 背景校正:用背景光谱逐列相除
df[spec_cols] = df[spec_cols].div(bg_spectrum, axis=1)
# 检查是否需要重采样
src_waves = read_wavelengths_from_hdr(bil_path)
need_resample = (src_waves.size > 0 and
(src_waves.size != len(TRAIN_WAVELENGTHS) or
not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2)))
if need_resample:
print(f"重采样光谱: 源波段数 {src_waves.size} -> 目标波段数 {len(TRAIN_WAVELENGTHS)}")
# 提取光谱数据
X_src = df[spec_cols].to_numpy(dtype=np.float64)
X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS)
# 替换光谱列
spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))]
df = pd.concat([
df.drop(columns=spec_cols), # 移除原有光谱列
pd.DataFrame(X_dst, columns=spec_col_names, index=df.index)
], axis=1)
return df
def clean_and_select_columns(df):
"""数据清理和列选择"""
# 移除NaN值
df = df.dropna()
# 过滤轮廓点数不足的样本
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
# 过滤面积过小的样本
df = df[df['area'] >= 500]
# 列筛选:使用原来的硬编码索引删除逻辑
cols_to_remove = df.columns[np.r_[-14: -1]]
df = df.drop(columns=cols_to_remove)
return df
def run_primary_classification(df, primary_model_path):
"""运行主要分类"""
try:
# 验证特征维度
try:
import joblib
scaler_path = os.path.join(os.path.dirname(primary_model_path), 'scaler_params.pkl')
if os.path.exists(scaler_path):
scaler = joblib.load(scaler_path)
numeric_cols = [c for c in df.columns[1:] if np.issubdtype(df[c].dtype, np.number) and c != 'contour']
if hasattr(scaler, 'mean_') and len(numeric_cols) != scaler.mean_.shape[0]:
raise ValueError(f"特征维度不匹配: 当前{numeric_cols}列 != 训练时{scaler.mean_.shape[0]}")
except Exception as e:
print(f"警告: 无法验证特征维度: {e}")
df_pre = predict_with_model(
df,
primary_model_path,
model_type='SVM',
ProcessMethods1='SS',
ProcessMethods2='None'
)
return df_pre
except Exception as e:
raise RuntimeError(f"主要分类失败: model_path={primary_model_path}") from e
def run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original):
"""根据需要运行二次分类"""
# 二次分类配置(保持在代码内)
secondary_model_path = os.path.join(os.path.dirname(__file__), 'modelsave', 'HDPELDPE_model', 'svm.m')
secondary_target_classes = [1, 2] # HDPE, LDPE
target_classes = set(secondary_target_classes or [])
mask_secondary = df_pre['Predictions'].isin(target_classes)
if not mask_secondary.any():
print("未找到目标类别样本,跳过二次分类")
return df_pre
print(f"为类别 {sorted(target_classes)} 运行二次分类")
# 检查二次模型是否存在
if not os.path.exists(secondary_model_path):
print(f"警告: 二次模型不存在: {secondary_model_path},跳过二次分类")
return df_pre
try:
# 图像信息的背景矫正
df_correct = shape_correct_background(bil_path, mask, filter_mask_original)
# 创建只包含目标类别的掩膜
mask_second = np.zeros_like(mask, dtype=np.uint16)
for idx in df_pre[mask_secondary].index:
contour = df_pre.loc[idx, 'contour']
if isinstance(contour, list) and len(contour) > 0:
contour_array = np.array(contour, dtype=np.int32)
cv2.fillPoly(mask_second, [contour_array], idx + 1)
# 提取特征
df_shape = extract_features(df_correct, mask_second)
# 确保使用前13列作为模型输入特征
if len(df_shape.columns) >= 13:
df_shape = df_shape.iloc[:, :13]
# 二次分类
df_secondary = predict_with_model(
df_shape,
secondary_model_path,
model_type='SVM',
ProcessMethods1='None',
ProcessMethods2='None'
)
# 更新预测结果(类别+1
df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1
except Exception as e:
print(f"警告: 二次分类失败,将继续使用主要分类结果: {e}")
return df_pre
def postprocess_class7_shadow(df_pre, rgb_img):
"""后处理类别7/8中的背景阴影更稳健"""
# 7和8都纳入检查范围
mask_targets = df_pre['Predictions'].isin([7, 8])
if not mask_targets.any():
return df_pre
print(f"处理 {mask_targets.sum()} 个类别7/8样本识别背景阴影...")
# 灰度图
if hasattr(rgb_img, 'mode'): # PIL Image
rgb_img_array = np.array(rgb_img)
else:
rgb_img_array = rgb_img
if len(rgb_img_array.shape) == 3:
gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY)
else:
gray_img = rgb_img_array
# 更稳的梯度Scharr
grad_x = cv2.Scharr(gray_img, cv2.CV_64F, 1, 0)
grad_y = cv2.Scharr(gray_img, cv2.CV_64F, 0, 1)
gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2)
# 统计指标
edge_ratios = []
contrast_norms = []
areas_list = []
measures_per_idx = {}
edge_thick = 3
ring_thick = 5
eps = 1e-6
for idx in df_pre[mask_targets].index:
try:
contour = df_pre.loc[idx, 'contour']
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
continue
contour_array = np.array(contour, dtype=np.int32)
if len(contour_array.shape) == 1:
continue
poly_mask = np.zeros(gray_img.shape, dtype=np.uint8)
cv2.fillPoly(poly_mask, [contour_array], 255)
# 边界带
edge_mask = np.zeros_like(poly_mask)
cv2.drawContours(edge_mask, [contour_array], -1, 255, thickness=edge_thick)
# 外环:膨胀边界去掉边界本身与内区
ring_mask = cv2.dilate(edge_mask, np.ones((ring_thick, ring_thick), np.uint8), iterations=1)
ring_mask = cv2.bitwise_and(ring_mask, cv2.bitwise_not(edge_mask))
ring_mask = cv2.bitwise_and(ring_mask, cv2.bitwise_not(poly_mask))
edge_vals = gradient_magnitude[edge_mask > 0]
ring_vals = gradient_magnitude[ring_mask > 0]
if edge_vals.size == 0 or ring_vals.size == 0:
continue
r_edge = float(np.median(edge_vals) / (np.median(ring_vals) + eps))
inside_vals = gray_img[poly_mask > 0]
outside_vals = gray_img[ring_mask > 0]
if inside_vals.size == 0 or outside_vals.size == 0:
continue
dI = float(np.median(inside_vals) - np.median(outside_vals))
c_norm = abs(dI) / (np.std(outside_vals) + eps)
# 面积(可选保护)
area_val = None
if 'area' in df_pre.columns:
try:
area_val = float(df_pre.loc[idx, 'area'])
except Exception:
area_val = None
edge_ratios.append(r_edge)
contrast_norms.append(c_norm)
areas_list.append(area_val if area_val is not None else 0.0)
measures_per_idx[idx] = (r_edge, c_norm, area_val)
except Exception:
continue
if not measures_per_idx:
print("无可用的7/8类样本进行阴影判别")
return df_pre
def robust_q(arr, q):
vals = [v for v in arr if v is not None]
return float(np.percentile(vals, q)) if len(vals) > 0 else None
# 稳健阈值低于30分位更像阴影
r_thresh = robust_q(edge_ratios, 30.0)
c_thresh = robust_q(contrast_norms, 30.0)
# 面积保护仅对较小目标允许改写阈值取面积分布的40%分位,限定上限避免过大
a_thresh = robust_q(areas_list, 40.0)
if a_thresh is None or a_thresh <= 0:
a_thresh = 1200.0
a_thresh = min(a_thresh, 2000.0)
indices_to_update = []
for idx, (r_edge, c_norm, area_val) in measures_per_idx.items():
small_enough = (area_val is None) or (area_val <= a_thresh)
if (r_thresh is not None and c_thresh is not None and small_enough):
# 两个指标都低,且面积不大 -> 判定为阴影
if (r_edge < r_thresh) and (c_norm < c_thresh):
indices_to_update.append(idx)
if indices_to_update:
# 改为背景0而不是9PVC
df_pre.loc[indices_to_update, 'Predictions'] = 9
print(f"{len(indices_to_update)} 个样本从类别7/8改为背景阴影面积阈值≈{a_thresh:.0f}")
else:
print("无需更新类别7/8样本")
return df_pre
def write_outputs(bil_path, df_pre, output_path):
"""写入输出结果"""
try:
# 收缩轮廓
df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1)
# 保存ENVI分类结果
save_envi_classification(bil_path, df_pre, output_path)
except Exception as e:
raise RuntimeError(f"保存结果失败: output_path={output_path}") from e
def main():
"""主函数"""
args = parse_arguments()
bil_path = args.bil_path
output_path = args.output_path
primary_model_path = args.model_path
segmentation_model_path = None
# 记录总开始时间
total_start_time = time.time()
try:
# 验证输入
validate_inputs(bil_path, output_path, primary_model_path)
bands =[912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6]
# 修改HDR文件
change_hdr_file(bil_path,bands)
# 处理BIL文件生成RGB图像
print("处理BIL文件生成RGB图像...")
rgb_img = generate_rgb(bil_path)
# 分割阶段
segmentation_start_time = time.time()
print("生成掩膜...")
mask, filter_mask_original = run_segmentation(rgb_img, segmentation_model_path)
# 提取特征
print("从BIL文件提取特征...")
df = extract_primary_features(bil_path, mask)
# 背景校正
print("应用背景校正...")
bg_spectrum = compute_background_spectrum(bil_path, mask)
# 背景校正 + 仅在与训练相机波长不一致时重采样
df = apply_background_and_optional_resample(df, bg_spectrum, bil_path)
# 数据清理和列选择
print("清理数据...")
df = clean_and_select_columns(df)
segmentation_time = time.time() - segmentation_start_time
# 分类阶段
classification_start_time = time.time()
print("预测分类...")
df_pre = run_primary_classification(df, primary_model_path)
# 二次分类
df_pre = run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original)
# 后处理类别7阴影
df_pre = postprocess_class7_shadow(df_pre, rgb_img)
classification_time = time.time() - classification_start_time
# 保存结果
print("保存ENVI分类结果...")
write_outputs(bil_path, df_pre, output_path)
print(f"ENVI分类结果已保存至: {output_path}")
# 计算总耗时
total_time = time.time() - total_start_time
# 打印耗时统计
print(f"\n{'=' * 60}")
print("处理完成")
print(f"{'=' * 60}")
print(f"分割耗时: {segmentation_time:.2f}")
print(f"分类耗时: {classification_time:.2f}")
print(f"总耗时: {total_time:.2f}")
print(f"{'=' * 60}")
print(f"结果已保存至: {output_path}")
except Exception as e:
print(f"处理失败: {e}")
raise
if __name__ == "__main__":
main()