Files
micro_plastic/main_batch_nosample.py
2026-03-05 17:12:01 +08:00

764 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import matplotlib
import numpy as np
import argparse
import pandas as pd
from bil2rgb import process_bil_files
from classification_model.Parallel.predict_plastic import load_model, predict_with_model
from mask import detect_microplastic_mask_from_array
from shape_spectral import process_images
from shape_spectral_background import process_images_background
from extact_shape import shape_correct_background, extract_features
import time
##批量预测文件夹内的bil文件不进行降采样使用新采集的数据进行训练
matplotlib.use('TkAgg')
# 训练相机波长168通道
TRAIN_WAVELENGTHS = [
898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2
]
def read_wavelengths_from_hdr(bil_path):
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
if not os.path.exists(hdr_path):
return np.array([], dtype=np.float64)
with open(hdr_path, 'r') as f:
txt = f.read()
if 'wavelength' not in txt:
return np.array([], dtype=np.float64)
seg = txt.split('wavelength', 1)[1]
seg = seg[seg.find('{')+1: seg.find('}')]
vals = [v.strip() for v in seg.split(',') if v.strip()]
try:
waves = np.array([float(v) for v in vals], dtype=np.float64)
except Exception:
waves = np.array([], dtype=np.float64)
return waves
def resample_spectra_matrix(X, src_waves, dst_waves):
src = np.asarray(src_waves, dtype=np.float64)
dst = np.asarray(dst_waves, dtype=np.float64)
X = np.asarray(X, dtype=np.float64)
if src.size == 0 or dst.size == 0:
return X
# 线性插值,越界用端点外推,避免维度缺失
out = np.empty((X.shape[0], dst.size), dtype=np.float64)
for i in range(X.shape[0]):
row = X[i]
out[i] = np.interp(dst, src, row, left=row[0], right=row[-1])
return out
def apply_background_no_resample(df, bg_spectrum):
"""
仅做背景校正,不做任何重采样。
- 自动选择以 wavelength_ 或 band_ 开头的光谱列
- 若背景长度与光谱列数不一致,按尾部对齐取最小长度进行校正
"""
# 识别光谱列
spec_cols = [c for c in df.columns if isinstance(c, str) and (c.startswith('wavelength_') or c.startswith('band_'))]
if not spec_cols:
raise ValueError("未找到光谱列(以 wavelength_ 或 band_ 开头)")
bg = np.asarray(bg_spectrum, dtype=np.float64).ravel()
if bg.size == 0:
raise ValueError("背景光谱长度为0无法进行背景校正")
# 尾部对齐,取最小长度,避免维度不一致
n = min(len(spec_cols), bg.shape[0])
use_cols = spec_cols[-n:]
df.loc[:, use_cols] = df.loc[:, use_cols].div(bg[-n:], axis=1)
return df
def parse_arguments():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description='Microplastic spectral shape classification - Batch processing')
# 必需参数
parser.add_argument('--input_dir', required=True, help='Path to input directory containing BIL files')
parser.add_argument('--output_dir', required=True, help='Path to output directory for classification results')
parser.add_argument('--model_path', required=True, help='Path to primary classification model')
# 可选参数
# parser.add_argument('--primary_model_type', default='SVM', help='Type of primary model (default: SVM)')
# parser.add_argument('--primary_process_methods1', default='SS', help='Primary process method 1 (default: SS)')
# parser.add_argument('--primary_process_methods2', default='None', help='Primary process method 2 (default: None)')
# parser.add_argument('--secondary_model', default="D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m", help='Path to secondary classification model')
# parser.add_argument('--secondary_model_type', default='SVM', help='Type of secondary model (default: SVM)')
# parser.add_argument('--secondary_process_methods1', default='None',
# help='Secondary process method 1 (default: None)')
# parser.add_argument('--secondary_process_methods2', default='None',
# help='Secondary process method 2 (default: None)')
# parser.add_argument('--secondary_target_classes', nargs='+', type=int, default=[1,2],
# help='Target classes for secondary classification (space separated)')
return parser.parse_args()
# ----------------------------
# 配置参数:直接在此修改
# ----------------------------
# BIL_PATH = r"D:/Data/Test/PET_bottle2.bil"
# OUTPUT_PATH = r'D:/Data/PET_bottle2_class.bil'
#
# PRIMARY_MODEL_PATH = r"D:\plastic\plastic\modelsave\svm.m"
# PRIMARY_MODEL_TYPE = 'SVM'
# PRIMARY_PROCESS_METHODS1 = 'SS'
# PRIMARY_PROCESS_METHODS2 = 'None'
#
# SECONDARY_MODEL_PATH = "D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m" # 若不需要二次分类,则保持为 None
# SECONDARY_MODEL_TYPE = 'SVM'
# SECONDARY_PROCESS_METHODS1 = 'None'
# SECONDARY_PROCESS_METHODS2 = 'None'
# SECONDARY_TARGET_CLASSES = [1, 2]
def read_hdr_file(bil_path):
hdr_path = bil_path.replace('.bil', '.hdr')
with open(hdr_path, 'r') as f:
header = f.readlines()
samples, lines = None, None
for line in header:
if line.startswith('samples'):
samples = int(line.split('=')[-1].strip())
if line.startswith('lines'):
lines = int(line.split('=')[-1].strip())
return samples, lines
def shrink_contours(bil_path, df, shrink_pixels=1):
"""
对DataFrame中的所有轮廓进行收缩操作避免塑料之间的相连
Args:
bil_path: BIL文件路径用于获取图像尺寸
df: 包含contour列的DataFrame
shrink_pixels: 收缩的像素数默认1像素
Returns:
更新后的DataFramecontour列已被收缩
"""
samples, lines = read_hdr_file(bil_path)
# 创建腐蚀核
kernel = np.ones((2 * shrink_pixels + 1, 2 * shrink_pixels + 1), np.uint8)
# 创建临时掩膜用于处理
temp_mask = np.zeros((lines, samples), dtype=np.uint8)
# 创建DataFrame副本
df = df.copy()
# 遍历每一行,更新轮廓
for idx, row in df.iterrows():
contour = row['contour']
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
continue
try:
contour_array = np.array(contour, dtype=np.int32)
if len(contour_array.shape) == 1:
continue
# 清空临时掩膜
temp_mask.fill(0)
# 填充轮廓
cv2.fillPoly(temp_mask, [contour_array], 255)
# 对掩膜进行腐蚀操作
eroded_mask = cv2.erode(temp_mask, kernel, iterations=1)
# 重新提取轮廓
contours, _ = cv2.findContours(eroded_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if len(contours) > 0:
# 选择最大的轮廓(如果有多个)
largest_contour = max(contours, key=cv2.contourArea)
# 转换为列表格式,保持与原始格式一致
if len(largest_contour) >= 3:
updated_contour = largest_contour.reshape(-1, 2).tolist()
df.at[idx, 'contour'] = updated_contour
except Exception as e:
# 如果处理失败,保留原始轮廓
continue
return df
def save_envi_classification(bil_path, df, savepath):
samples, lines = read_hdr_file(bil_path)
classification_result = np.zeros((lines, samples), dtype=np.uint16)
# 预处理清除可能存在的类别10和11移到循环外以提高效率
classification_result[(classification_result == 10)] = 0
classification_result[(classification_result == 11)] = 0
for _, row in df.iterrows():
contour = row['contour']
prediction = int(row['Predictions']) + 1
contour = np.array(contour, dtype=np.int32)
cv2.fillPoly(classification_result, [contour], prediction)
output_path = savepath
with open(output_path, 'wb') as f:
classification_result.tofile(f)
header_content = f"""ENVI
description = {{
Classification Result.}}
samples = {samples}
lines = {lines}
bands = 1
header offset = 0
file type = ENVI Standard
data type = 2
interleave = bil
classes = 10
class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC }}
single pixel area = 0.000036
unit = mm2
byte order = 0
wavelength units = nm
"""
filename, ext = os.path.splitext(savepath)
# 替换扩展名为 '.hdr'
header_filename = filename + '.hdr'
with open(header_filename, 'w') as header_file:
header_file.write(header_content)
def change_hdr_file(bil_path, wavelengths=None):
# wavelengths=None 时仅在HDR缺失wavelength字段才写入若提供则按提供内容写入
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
if not os.path.exists(hdr_path):
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
return
with open(hdr_path, 'r') as file:
content = file.read()
if 'wavelength' in content and wavelengths is None:
print(f"File {os.path.basename(hdr_path)} already contains wavelength information; no changes needed.")
return
if wavelengths is None:
print(f"No wavelengths provided and HDR lacks wavelength; skipping write to avoid wrong bands.")
return
needs_newline = not content.endswith('\n')
wavelength_info = "wavelength = {" + ", ".join(str(float(w)) for w in wavelengths) + "}\n"
with open(hdr_path, 'a') as file:
if needs_newline:
file.write('\n')
file.write(wavelength_info)
print(f"Successfully ensured wavelength information in file: {os.path.basename(hdr_path)}")
def get_bil_files(input_dir):
"""获取输入目录中的所有BIL文件"""
if not os.path.exists(input_dir):
raise FileNotFoundError(f"输入目录不存在: {input_dir}")
bil_files = []
for file in os.listdir(input_dir):
if file.lower().endswith('.bil'):
bil_path = os.path.join(input_dir, file)
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
if os.path.exists(hdr_path):
bil_files.append(bil_path)
else:
print(f"警告: 找到BIL文件 {file} 但缺少对应的HDR文件跳过")
if not bil_files:
raise ValueError(f"在输入目录 {input_dir} 中未找到有效的BIL文件")
return sorted(bil_files)
def validate_inputs(input_dir, output_dir, model_path):
"""验证输入参数"""
# 检查输入目录存在
if not os.path.exists(input_dir):
raise FileNotFoundError(f"输入目录不存在: {input_dir}")
# 检查输出目录存在,如果不存在则创建
if not os.path.exists(output_dir):
try:
os.makedirs(output_dir, exist_ok=True)
except Exception as e:
raise RuntimeError(f"无法创建输出目录: {output_dir}") from e
# 检查模型文件存在
if not os.path.exists(model_path):
raise FileNotFoundError(f"主模型文件不存在: {model_path}")
def validate_single_bil_file(bil_path):
"""验证单个BIL文件"""
# 检查BIL和HDR文件存在
if not os.path.exists(bil_path):
raise FileNotFoundError(f"BIL文件不存在: {bil_path}")
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
if not os.path.exists(hdr_path):
raise FileNotFoundError(f"HDR文件不存在: {hdr_path}")
# 检查BIL文件波段数是否足够
try:
from spectral.io import envi
img = envi.open(hdr_path, bil_path)
n_bands = img.nbands
# bil2rgb需要波段索引9, 59, 159
if n_bands <= 159:
raise ValueError(f"BIL文件波段数不足: 需要至少160个波段但只有{n_bands}")
except Exception as e:
raise RuntimeError(f"无法读取BIL文件头信息: {bil_path}") from e
def generate_rgb(bil_path):
"""处理BIL文件生成RGB图像"""
try:
rgb_img = process_bil_files(bil_path)
return rgb_img
except Exception as e:
raise RuntimeError(f"生成RGB图像失败: bil_path={bil_path}") from e
def run_segmentation(rgb_img):
"""运行分割获取掩膜"""
try:
mask, filter_mask_original = detect_microplastic_mask_from_array(
image=rgb_img,
filter_method='threshold',
diameter=None,
flow_threshold=0.4,
cellprob_threshold=-1,
detect_filter=True
)
return mask, filter_mask_original
except Exception as e:
raise RuntimeError("分割失败: 无法检测微塑料颗粒") from e
def extract_primary_features(bil_path, mask):
"""提取主要特征"""
try:
df = process_images(bil_path, mask)
return df
except Exception as e:
raise RuntimeError(f"特征提取失败: bil_path={bil_path}") from e
def compute_background_spectrum(bil_path, mask):
"""计算背景光谱"""
try:
df_correct = process_images_background(bil_path, mask)
return df_correct
except Exception as e:
raise RuntimeError(f"背景光谱计算失败: bil_path={bil_path}") from e
def apply_background_and_optional_resample(df, bg_spectrum, bil_path):
"""应用背景校正和可选的重采样"""
# 识别光谱列所有以wavelength_开头的列
spec_cols = [c for c in df.columns if c.startswith('wavelength_')]
if not spec_cols:
raise ValueError("未找到光谱列以wavelength_开头的列")
if len(spec_cols) != len(bg_spectrum):
raise ValueError(f"光谱列数量({len(spec_cols)})与背景光谱长度({len(bg_spectrum)})不匹配")
# 背景校正:用背景光谱逐列相除
df[spec_cols] = df[spec_cols].div(bg_spectrum, axis=1)
# 检查是否需要重采样
src_waves = read_wavelengths_from_hdr(bil_path)
need_resample = (src_waves.size > 0 and
(src_waves.size != len(TRAIN_WAVELENGTHS) or
not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2)))
if need_resample:
print(f"重采样光谱: 源波段数 {src_waves.size} -> 目标波段数 {len(TRAIN_WAVELENGTHS)}")
# 提取光谱数据
X_src = df[spec_cols].to_numpy(dtype=np.float64)
X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS)
# 替换光谱列
spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))]
df = pd.concat([
df.drop(columns=spec_cols), # 移除原有光谱列
pd.DataFrame(X_dst, columns=spec_col_names, index=df.index)
], axis=1)
return df
def clean_and_select_columns(df):
"""数据清理和列选择"""
# 移除NaN值
df = df.dropna()
# 过滤轮廓点数不足的样本
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
# 过滤面积过小的样本
df = df[df['area'] >= 500]
# 列筛选:使用原来的硬编码索引删除逻辑
# cols_to_remove = df.columns[np.r_[1:10, 11:15, 97:120, 176:179 ]]
cols_to_remove = df.columns[np.r_[-10: -1]]
df = df.drop(columns=cols_to_remove)
return df
def run_primary_classification(df, primary_model_path):
"""运行主要分类"""
try:
# 验证特征维度
try:
import joblib
scaler_path = os.path.join(os.path.dirname(primary_model_path), 'scaler_params.pkl')
if os.path.exists(scaler_path):
scaler = joblib.load(scaler_path)
numeric_cols = [c for c in df.columns[1:] if np.issubdtype(df[c].dtype, np.number) and c != 'contour']
if hasattr(scaler, 'mean_') and len(numeric_cols) != scaler.mean_.shape[0]:
raise ValueError(f"特征维度不匹配: 当前{numeric_cols}列 != 训练时{scaler.mean_.shape[0]}")
except Exception as e:
print(f"警告: 无法验证特征维度: {e}")
df_pre = predict_with_model(
df,
primary_model_path,
model_type='SVM',
ProcessMethods1='SS',
ProcessMethods2='None'
)
return df_pre
except Exception as e:
raise RuntimeError(f"主要分类失败: model_path={primary_model_path}") from e
def run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original):
"""根据需要运行二次分类"""
# 二次分类配置(保持在代码内)
secondary_model_path = os.path.join(os.path.dirname(__file__), 'modelsave', 'HDPELDPE_model', 'svm.m')
secondary_target_classes = [1, 2] # HDPE, LDPE
target_classes = set(secondary_target_classes or [])
mask_secondary = df_pre['Predictions'].isin(target_classes)
if not mask_secondary.any():
print("未找到目标类别样本,跳过二次分类")
return df_pre
print(f"为类别 {sorted(target_classes)} 运行二次分类")
# 检查二次模型是否存在
if not os.path.exists(secondary_model_path):
print(f"警告: 二次模型不存在: {secondary_model_path},跳过二次分类")
return df_pre
try:
# 图像信息的背景矫正
df_correct = shape_correct_background(bil_path, mask, filter_mask_original)
# 创建只包含目标类别的掩膜
mask_second = np.zeros_like(mask, dtype=np.uint16)
for idx in df_pre[mask_secondary].index:
contour = df_pre.loc[idx, 'contour']
if isinstance(contour, list) and len(contour) > 0:
contour_array = np.array(contour, dtype=np.int32)
cv2.fillPoly(mask_second, [contour_array], idx + 1)
# 提取特征
df_shape = extract_features(df_correct, mask_second)
# 确保使用前13列作为模型输入特征
if len(df_shape.columns) >= 13:
df_shape = df_shape.iloc[:, :13]
# 二次分类
df_secondary = predict_with_model(
df_shape,
secondary_model_path,
model_type='SVM',
ProcessMethods1='None',
ProcessMethods2='None'
)
# 更新预测结果(类别+1
df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1
except Exception as e:
print(f"警告: 二次分类失败,将继续使用主要分类结果: {e}")
return df_pre
def postprocess_class7_shadow(df_pre, rgb_img):
"""后处理类别7中的背景阴影"""
class_7_mask = df_pre['Predictions'] == 7
if not class_7_mask.any():
return df_pre
print(f"处理 {class_7_mask.sum()} 个类别7样本识别背景阴影...")
# 将PIL Image转换为numpy数组
if hasattr(rgb_img, 'mode'): # PIL Image
rgb_img_array = np.array(rgb_img)
else:
rgb_img_array = rgb_img
# 转换为灰度图
if len(rgb_img_array.shape) == 3:
gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY)
else:
gray_img = rgb_img_array
# 计算梯度图
grad_x = cv2.Sobel(gray_img, cv2.CV_64F, 1, 0, ksize=3)
grad_y = cv2.Sobel(gray_img, cv2.CV_64F, 0, 1, ksize=3)
gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2)
# 收集类别7样本的边缘梯度值
all_class7_gradients = []
valid_indices = []
for idx in df_pre[class_7_mask].index:
try:
contour = df_pre.loc[idx, 'contour']
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
continue
contour_array = np.array(contour, dtype=np.int32)
if len(contour_array.shape) == 1:
continue
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
edge_gradients = gradient_magnitude[mask_img > 0]
if len(edge_gradients) > 0:
all_class7_gradients.extend(edge_gradients)
valid_indices.append(idx)
except Exception:
continue
# 确定梯度阈值
if len(all_class7_gradients) > 0:
gradient_threshold = np.percentile(all_class7_gradients, 30)
else:
gradient_threshold = np.percentile(gradient_magnitude, 30)
print(f"类别7梯度阈值: {gradient_threshold:.2f}")
# 处理每个类别7样本
indices_to_update = []
for idx in valid_indices:
try:
contour = df_pre.loc[idx, 'contour']
contour_array = np.array(contour, dtype=np.int32)
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
edge_gradients = gradient_magnitude[mask_img > 0]
if len(edge_gradients) == 0:
continue
mean_gradient = np.mean(edge_gradients)
if mean_gradient < gradient_threshold:
indices_to_update.append(idx)
print(f"样本 {idx}: 平均梯度={mean_gradient:.2f}, 阈值={gradient_threshold:.2f} -> 识别为背景阴影")
except Exception as e:
print(f"处理样本 {idx} 时出错: {e}")
continue
# 更新分类结果
if indices_to_update:
df_pre.loc[indices_to_update, 'Predictions'] = 9
print(f"{len(indices_to_update)} 个样本从类别7改为类别9背景阴影")
else:
print("无需更新类别7样本")
return df_pre
def write_outputs(bil_path, df_pre, output_path):
"""写入输出结果"""
try:
# 收缩轮廓
df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1)
# 保存ENVI分类结果
save_envi_classification(bil_path, df_pre, output_path)
except Exception as e:
raise RuntimeError(f"保存结果失败: output_path={output_path}") from e
def process_single_file(bil_path, output_path, primary_model_path):
"""处理单个BIL文件的完整流程"""
try:
# 验证输入
validate_single_bil_file(bil_path)
# 修改HDR文件
change_hdr_file(bil_path)
# 处理BIL文件生成RGB图像
print(f" 处理BIL文件生成RGB图像...")
rgb_img = generate_rgb(bil_path)
# 分割阶段
segmentation_start_time = time.time()
print(f" 生成掩膜...")
mask, filter_mask_original = run_segmentation(rgb_img)
# 提取特征
print(f" 从BIL文件提取特征...")
df = extract_primary_features(bil_path, mask)
# 背景校正
print(f" 应用背景校正...")
bg_spectrum = compute_background_spectrum(bil_path, mask)
# 仅应用背景校正,不进行重采样
df = apply_background_no_resample(df, bg_spectrum)
# 数据清理和列选择
print(f" 清理数据...")
df = clean_and_select_columns(df)
segmentation_time = time.time() - segmentation_start_time
# 分类阶段
classification_start_time = time.time()
print(f" 预测分类...")
df_pre = run_primary_classification(df, primary_model_path)
# 二次分类
df_pre = run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original)
# 后处理类别7阴影
df_pre = postprocess_class7_shadow(df_pre, rgb_img)
classification_time = time.time() - classification_start_time
# 保存结果
print(f" 保存ENVI分类结果...")
write_outputs(bil_path, df_pre, output_path)
return segmentation_time, classification_time
except Exception as e:
print(f"处理文件失败 {os.path.basename(bil_path)}: {e}")
raise
def main():
"""主函数 - 批量处理"""
args = parse_arguments()
input_dir = args.input_dir
output_dir = args.output_dir
primary_model_path = args.model_path
# 记录总开始时间
total_start_time = time.time()
try:
# 验证输入参数
validate_inputs(input_dir, output_dir, primary_model_path)
# 获取所有BIL文件
bil_files = get_bil_files(input_dir)
print(f"找到 {len(bil_files)} 个BIL文件待处理")
# 统计信息
total_files = len(bil_files)
processed_files = 0
failed_files = 0
total_segmentation_time = 0
total_classification_time = 0
# 逐个处理文件
for i, bil_path in enumerate(bil_files, 1):
print(f"\n{'='*60}")
print(f"处理文件 {i}/{total_files}: {os.path.basename(bil_path)}")
print(f"{'='*60}")
try:
# 生成输出文件名:原文件名 + "_classification.bil"
base_name = os.path.splitext(os.path.basename(bil_path))[0]
output_filename = f"{base_name}_classification.bil"
output_path = os.path.join(output_dir, output_filename)
# 处理单个文件
segmentation_time, classification_time = process_single_file(
bil_path, output_path, primary_model_path
)
# 更新统计信息
total_segmentation_time += segmentation_time
total_classification_time += classification_time
processed_files += 1
print(f"文件 {os.path.basename(bil_path)} 处理完成")
print(f"结果保存至: {output_path}")
except Exception as e:
print(f"文件 {os.path.basename(bil_path)} 处理失败: {e}")
failed_files += 1
continue
# 计算平均耗时
if processed_files > 0:
avg_segmentation_time = total_segmentation_time / processed_files
avg_classification_time = total_classification_time / processed_files
avg_total_time = (total_segmentation_time + total_classification_time) / processed_files
# 计算总耗时
total_time = time.time() - total_start_time
# 打印汇总统计
print(f"\n{'=' * 60}")
print("批量处理完成")
print(f"{'=' * 60}")
print(f"总文件数: {total_files}")
print(f"成功处理: {processed_files}")
print(f"处理失败: {failed_files}")
print(f"成功率: {processed_files/total_files*100:.1f}%" if total_files > 0 else "成功率: 0%")
print(f"{'=' * 60}")
if processed_files > 0:
print(f"平均分割耗时: {avg_segmentation_time:.2f}")
print(f"平均分类耗时: {avg_classification_time:.2f}")
print(f"平均总耗时: {avg_total_time:.2f}")
print(f"实际总耗时: {total_time:.2f}")
print(f"{'=' * 60}")
print(f"结果保存目录: {output_dir}")
except Exception as e:
print(f"批量处理失败: {e}")
raise
if __name__ == "__main__":
main()