diff --git a/.idea/.gitignore b/.idea/.gitignore index 35410ca..f649f0f 100644 --- a/.idea/.gitignore +++ b/.idea/.gitignore @@ -1,4 +1,4 @@ -# 默认忽略的文件 +# Default ignored files /shelf/ /workspace.xml # 基于编辑器的 HTTP 客户端请求 diff --git a/__pycache__/bil2rgb.cpython-310.pyc b/__pycache__/bil2rgb.cpython-310.pyc deleted file mode 100644 index 72d14ec..0000000 Binary files a/__pycache__/bil2rgb.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/bil2rgb.cpython-312.pyc b/__pycache__/bil2rgb.cpython-312.pyc deleted file mode 100644 index 739a8ec..0000000 Binary files a/__pycache__/bil2rgb.cpython-312.pyc and /dev/null differ diff --git a/__pycache__/bil2rgb.cpython-313.pyc b/__pycache__/bil2rgb.cpython-313.pyc deleted file mode 100644 index 52bad0f..0000000 Binary files a/__pycache__/bil2rgb.cpython-313.pyc and /dev/null differ diff --git a/__pycache__/extact_shape.cpython-310.pyc b/__pycache__/extact_shape.cpython-310.pyc deleted file mode 100644 index 1ec40df..0000000 Binary files a/__pycache__/extact_shape.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/extact_shape.cpython-312.pyc b/__pycache__/extact_shape.cpython-312.pyc deleted file mode 100644 index 9325350..0000000 Binary files a/__pycache__/extact_shape.cpython-312.pyc and /dev/null differ diff --git a/__pycache__/get_glcm.cpython-310.pyc b/__pycache__/get_glcm.cpython-310.pyc deleted file mode 100644 index 735545c..0000000 Binary files a/__pycache__/get_glcm.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/get_glcm.cpython-312.pyc b/__pycache__/get_glcm.cpython-312.pyc deleted file mode 100644 index 93f3c58..0000000 Binary files a/__pycache__/get_glcm.cpython-312.pyc and /dev/null differ diff --git a/__pycache__/mask.cpython-310.pyc b/__pycache__/mask.cpython-310.pyc deleted file mode 100644 index 2339e8e..0000000 Binary files a/__pycache__/mask.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/mask.cpython-312.pyc b/__pycache__/mask.cpython-312.pyc deleted file mode 100644 index d1a2241..0000000 Binary files a/__pycache__/mask.cpython-312.pyc and /dev/null differ diff --git a/__pycache__/outputs2dataframe.cpython-310.pyc b/__pycache__/outputs2dataframe.cpython-310.pyc deleted file mode 100644 index fbbf515..0000000 Binary files a/__pycache__/outputs2dataframe.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/outputs2dataframe.cpython-312.pyc b/__pycache__/outputs2dataframe.cpython-312.pyc deleted file mode 100644 index f9db41b..0000000 Binary files a/__pycache__/outputs2dataframe.cpython-312.pyc and /dev/null differ diff --git a/__pycache__/shape_spectral.cpython-310.pyc b/__pycache__/shape_spectral.cpython-310.pyc deleted file mode 100644 index 7432fa3..0000000 Binary files a/__pycache__/shape_spectral.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/shape_spectral.cpython-312.pyc b/__pycache__/shape_spectral.cpython-312.pyc deleted file mode 100644 index 8f827bf..0000000 Binary files a/__pycache__/shape_spectral.cpython-312.pyc and /dev/null differ diff --git a/__pycache__/shape_spectral_background.cpython-310.pyc b/__pycache__/shape_spectral_background.cpython-310.pyc deleted file mode 100644 index 758fa20..0000000 Binary files a/__pycache__/shape_spectral_background.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/shape_spectral_background.cpython-312.pyc b/__pycache__/shape_spectral_background.cpython-312.pyc deleted file mode 100644 index 6caeb3f..0000000 Binary files a/__pycache__/shape_spectral_background.cpython-312.pyc and /dev/null differ diff --git a/chose_bands.py b/chose_bands.py new file mode 100644 index 0000000..2202d32 --- /dev/null +++ b/chose_bands.py @@ -0,0 +1,160 @@ +import argparse +import pandas as pd +import numpy as np +import itertools +import re +from collections import defaultdict + +def parse_args(): + ap = argparse.ArgumentParser(description="Select 3 bands for false-color to maximize inter-class separability") + ap.add_argument("--csv", required=True, help="Input CSV: first column is class label, rest are spectra columns") + ap.add_argument("--top_k", type=int, default=30, help="How many best single bands to preselect (default: 30)") + ap.add_argument("--top_triplets", type=int, default=10, help="How many best triplets to print (default: 10)") + ap.add_argument("--map_order", choices=["auto", "wavelength_bgr"], default="auto", + help="RGB mapping order: auto=use scored order; wavelength_bgr=short->B, mid->G, long->R") + return ap.parse_args() + +def try_parse_wavelength(col_name): + # 支持 "wavelength_912.36" / "912.36" / "band_12" 等 + m = re.search(r"([0-9]+(\.[0-9]+)?)", str(col_name)) + if m: + try: + return float(m.group(1)) + except Exception: + return None + return None + +def read_data(csv_path): + df = pd.read_csv(csv_path) + if df.shape[1] < 4: + raise ValueError("需要至少1个类别列 + 3个光谱列") + y = df.iloc[:, 0].to_numpy() + X = df.iloc[:, 1:].to_numpy(dtype=np.float64) + cols = list(df.columns[1:]) + waves = [] + for c in cols: + w = try_parse_wavelength(c) + waves.append(w if w is not None else c) + return X, y, cols, waves + +def compute_anova_f_scores(X, y): + # 一维ANOVA F-score: F = (SS_between/df_between) / (SS_within/df_within) + N, M = X.shape + classes = [] + for cls in np.unique(y): + idx = (y == cls) + if np.sum(idx) > 0: + classes.append(idx) + k = len(classes) + if k < 2: + raise ValueError("类别不足2类,无法计算区分度") + + mu = X.mean(axis=0) + between_ss = np.zeros(M, dtype=np.float64) + within_ss = np.zeros(M, dtype=np.float64) + + for idx in classes: + ng = np.sum(idx) + Xg = X[idx] + if ng == 0: + continue + mg = Xg.mean(axis=0) + # 无偏方差,若ng==1则设为0 + vg = Xg.var(axis=0, ddof=1) if ng > 1 else np.zeros(M, dtype=np.float64) + between_ss += ng * (mg - mu) ** 2 + within_ss += (ng - 1) * vg + + dfb = k - 1 + dfw = N - k if N > k else 1 + F = (between_ss / max(dfb, 1)) / (within_ss / max(dfw, 1) + 1e-12) + return F # shape (M,) + +def score_triplet(X, y, idx_triplet, eps=1e-6): + # LDA准则: J = trace( (Sw+epsI)^-1 Sb ),三维空间 + t = list(idx_triplet) + Xt = X[:, t] # (N,3) + classes = [] + for cls in np.unique(y): + idx = (y == cls) + if np.sum(idx) > 0: + classes.append(idx) + + m = Xt.mean(axis=0) + Sb = np.zeros((3, 3), dtype=np.float64) + Sw = np.zeros((3, 3), dtype=np.float64) + + for idx in classes: + Xg = Xt[idx] + ng = Xg.shape[0] + if ng == 0: + continue + mg = Xg.mean(axis=0) + diff = (mg - m).reshape(3, 1) + Sb += ng * (diff @ diff.T) + if ng > 1: + Cg = np.cov(Xg, rowvar=False, ddof=1) + else: + Cg = np.zeros((3, 3), dtype=np.float64) + Sw += (ng - 1) * Cg + + # 稳定逆 + A = Sw + eps * np.eye(3) + try: + Ainv = np.linalg.pinv(A) + except Exception: + return -np.inf + J = np.trace(Ainv @ Sb) + return float(J) + +def select_bands(csv_path, top_k=30, top_triplets=10, map_order="auto"): + X, y, cols, waves = read_data(csv_path) + F = compute_anova_f_scores(X, y) + order = np.argsort(-F) + top_idx = order[:min(top_k, X.shape[1])] + + best = [] + for comb in itertools.combinations(top_idx, 3): + s = score_triplet(X, y, comb) + best.append((s, comb)) + best.sort(key=lambda x: -x[0]) + + results = [] + for s, comb in best[:top_triplets]: + names = [cols[i] for i in comb] + wavs = [waves[i] for i in comb] + # 建议RGB映射 + if map_order == "wavelength_bgr" and all(isinstance(w, (int, float)) for w in wavs): + order_rgb = np.argsort(wavs) # 小->大 + # 小(blue), 中(green), 大(red) + b,g,r = [names[order_rgb[0]]], [names[order_rgb[1]]], [names[order_rgb[2]]] + rgb = {"R": r[0], "G": g[0], "B": b[0]} + rgb_w = {"R": wavs[order_rgb[2]], "G": wavs[order_rgb[1]], "B": wavs[order_rgb[0]]} + else: + # 直接用得分顺序:第一->R, 第二->G, 第三->B + rgb = {"R": names[0], "G": names[1], "B": names[2]} + rgb_w = {"R": waves[comb[0]], "G": waves[comb[1]], "B": waves[comb[2]]} + results.append({ + "score": round(s, 6), + "bands": names, + "wavelengths": wavs, + "rgb_mapping": rgb, + "rgb_wavelengths": rgb_w + }) + return results, top_idx, F + +def main(): + args = parse_args() + results, top_idx, F = select_bands(args.csv, args.top_k, args.top_triplets, args.map_order) + + print(f"Top-{args.top_k} bands by ANOVA F-score (index:col_name):") + print(", ".join([f"{i}:{i}" for i in top_idx])) # 如需列名可自行打印 + + print("\nBest triplets:") + for r in results: + bands = r["bands"] + wavs = r["wavelengths"] + rgb = r["rgb_mapping"] + print(f"- score={r['score']:.4f} bands={bands} wavelengths={wavs} RGB={rgb}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/classification_model/Classification/__pycache__/ClassicCls.cpython-312.pyc b/classification_model/Classification/__pycache__/ClassicCls.cpython-312.pyc index 932e1db..f0608e6 100644 Binary files a/classification_model/Classification/__pycache__/ClassicCls.cpython-312.pyc and b/classification_model/Classification/__pycache__/ClassicCls.cpython-312.pyc differ diff --git a/classification_model/DataLoad/__pycache__/DataLoad.cpython-312.pyc b/classification_model/DataLoad/__pycache__/DataLoad.cpython-312.pyc index 1b2aca0..12039d1 100644 Binary files a/classification_model/DataLoad/__pycache__/DataLoad.cpython-312.pyc and b/classification_model/DataLoad/__pycache__/DataLoad.cpython-312.pyc differ diff --git a/classification_model/Parallel/__pycache__/predict_plastic.cpython-310.pyc b/classification_model/Parallel/__pycache__/predict_plastic.cpython-310.pyc index 1212988..2cd95e0 100644 Binary files a/classification_model/Parallel/__pycache__/predict_plastic.cpython-310.pyc and b/classification_model/Parallel/__pycache__/predict_plastic.cpython-310.pyc differ diff --git a/classification_model/Parallel/__pycache__/predict_plastic.cpython-312.pyc b/classification_model/Parallel/__pycache__/predict_plastic.cpython-312.pyc index 991a18b..491a64b 100644 Binary files a/classification_model/Parallel/__pycache__/predict_plastic.cpython-312.pyc and b/classification_model/Parallel/__pycache__/predict_plastic.cpython-312.pyc differ diff --git a/classification_model/Parallel/predict_plastic.py b/classification_model/Parallel/predict_plastic.py index ec8ee08..8db48f6 100644 --- a/classification_model/Parallel/predict_plastic.py +++ b/classification_model/Parallel/predict_plastic.py @@ -587,22 +587,22 @@ def predict_with_model(df, model_path, model_type='SVM', ProcessMethods1='SS', P # 主函数,用于训练 if __name__ == "__main__": # 使用 pandas 读取 CSV 文件 - file_path = r"E:\code\plastic\plastic20260224\plastic\plastic\output\20260224\all.csv" + file_path = r"D:\Data2\traindata1\all\isf0303.csv" df = pd.read_csv( file_path, encoding='utf-8', # 指定编码,如果出错可尝试 'gbk' 或 'gb18030' low_memory=False # 避免数据类型推断问题 ) - # 使用 pandas 选择要删除的列(第93到117列,索引从0开始) - cols_to_remove = df.columns[np.r_[1:5, 87:110, 166:169]] - - # 使用 pandas 删除指定列 - df_filtered = df.drop(columns=cols_to_remove) + # # 使用 pandas 选择要删除的列(第93到117列,索引从0开始) + # cols_to_remove = df.columns[87:110] + # + # # 使用 pandas 删除指定列 + # df_filtered = df.drop(columns=cols_to_remove) # 使用 pandas 提取特征数据(从第2列开始到最后,排除第一列标签列) - x = df_filtered.iloc[:, 1:] - # x = df.iloc[:, 1:] + # x = df_filtered.iloc[:, 1:] + x = df.iloc[:, 1:] # 使用 pandas 提取标签(第一列) y = df.iloc[:, 0] X_train, X_test, y_train, y_test = SpectralQualitativeAnalysis(x, y, 'SS', 'None', 'None', 'random', use_smote=True) @@ -622,7 +622,7 @@ if __name__ == "__main__": # save_model(clf, r"D:\WQ\plastic\classification_model\modelsave\svm.m", model_type='SVM') # 示例2: 使用统一的训练和保存函数(推荐) - save_dir = r"E:\code\plastic\plastic20260224\plastic\plastic\output\20260224\modelsave" + save_dir = r"D:\plastic\plastic\modelsave\240model\new\0303" # 训练并保存多个模型 models_to_train = ['SVM']#'SVM', 'RF', 'XGBoost', 'LogisticRegression' diff --git a/classification_model/Preprocessing/Preprocessing.py b/classification_model/Preprocessing/Preprocessing.py index 7ebf697..c815d87 100644 --- a/classification_model/Preprocessing/Preprocessing.py +++ b/classification_model/Preprocessing/Preprocessing.py @@ -132,7 +132,7 @@ def Preprocessing(method, input_spectrum): elif method == 'MMS': output_spectrum = MMS(input_spectrum.values) elif method == 'SS': - output_spectrum = SS(input_spectrum.values, r'E:\code\plastic\plastic20260224\plastic\plastic\output\20260224\modelsave\scaler_params.pkl') + output_spectrum = SS(input_spectrum.values, r'D:\plastic\plastic\modelsave\240model\new\0303\scaler_params.pkl') elif method == 'CT': output_spectrum = CT(input_spectrum.values) elif method == 'SNV': diff --git a/classification_model/Preprocessing/__pycache__/Preprocessing.cpython-310.pyc b/classification_model/Preprocessing/__pycache__/Preprocessing.cpython-310.pyc index ce85df5..628d00d 100644 Binary files a/classification_model/Preprocessing/__pycache__/Preprocessing.cpython-310.pyc and b/classification_model/Preprocessing/__pycache__/Preprocessing.cpython-310.pyc differ diff --git a/classification_model/Preprocessing/__pycache__/Preprocessing.cpython-312.pyc b/classification_model/Preprocessing/__pycache__/Preprocessing.cpython-312.pyc index 6aa336f..f7e5ab8 100644 Binary files a/classification_model/Preprocessing/__pycache__/Preprocessing.cpython-312.pyc and b/classification_model/Preprocessing/__pycache__/Preprocessing.cpython-312.pyc differ diff --git a/classification_model/WaveSelect/__pycache__/Cars.cpython-312.pyc b/classification_model/WaveSelect/__pycache__/Cars.cpython-312.pyc index 390adb5..a7c4e46 100644 Binary files a/classification_model/WaveSelect/__pycache__/Cars.cpython-312.pyc and b/classification_model/WaveSelect/__pycache__/Cars.cpython-312.pyc differ diff --git a/classification_model/WaveSelect/__pycache__/GA.cpython-312.pyc b/classification_model/WaveSelect/__pycache__/GA.cpython-312.pyc index 511b90f..7cf74cb 100644 Binary files a/classification_model/WaveSelect/__pycache__/GA.cpython-312.pyc and b/classification_model/WaveSelect/__pycache__/GA.cpython-312.pyc differ diff --git a/classification_model/WaveSelect/__pycache__/Lar.cpython-312.pyc b/classification_model/WaveSelect/__pycache__/Lar.cpython-312.pyc index 9e9f375..8e1a205 100644 Binary files a/classification_model/WaveSelect/__pycache__/Lar.cpython-312.pyc and b/classification_model/WaveSelect/__pycache__/Lar.cpython-312.pyc differ diff --git a/classification_model/WaveSelect/__pycache__/Pca.cpython-312.pyc b/classification_model/WaveSelect/__pycache__/Pca.cpython-312.pyc index 43c7d1a..60b7a4d 100644 Binary files a/classification_model/WaveSelect/__pycache__/Pca.cpython-312.pyc and b/classification_model/WaveSelect/__pycache__/Pca.cpython-312.pyc differ diff --git a/classification_model/WaveSelect/__pycache__/ReliefF.cpython-312.pyc b/classification_model/WaveSelect/__pycache__/ReliefF.cpython-312.pyc index ab80594..b294d0e 100644 Binary files a/classification_model/WaveSelect/__pycache__/ReliefF.cpython-312.pyc and b/classification_model/WaveSelect/__pycache__/ReliefF.cpython-312.pyc differ diff --git a/classification_model/WaveSelect/__pycache__/Spa.cpython-312.pyc b/classification_model/WaveSelect/__pycache__/Spa.cpython-312.pyc index 72d57c9..d2bb9da 100644 Binary files a/classification_model/WaveSelect/__pycache__/Spa.cpython-312.pyc and b/classification_model/WaveSelect/__pycache__/Spa.cpython-312.pyc differ diff --git a/classification_model/WaveSelect/__pycache__/Spa_acc.cpython-312.pyc b/classification_model/WaveSelect/__pycache__/Spa_acc.cpython-312.pyc index a1333a4..c5ac209 100644 Binary files a/classification_model/WaveSelect/__pycache__/Spa_acc.cpython-312.pyc and b/classification_model/WaveSelect/__pycache__/Spa_acc.cpython-312.pyc differ diff --git a/classification_model/WaveSelect/__pycache__/Uve.cpython-312.pyc b/classification_model/WaveSelect/__pycache__/Uve.cpython-312.pyc index 59797b6..91d8f3e 100644 Binary files a/classification_model/WaveSelect/__pycache__/Uve.cpython-312.pyc and b/classification_model/WaveSelect/__pycache__/Uve.cpython-312.pyc differ diff --git a/classification_model/WaveSelect/__pycache__/WaveSelcet.cpython-312.pyc b/classification_model/WaveSelect/__pycache__/WaveSelcet.cpython-312.pyc index da46a5a..daba258 100644 Binary files a/classification_model/WaveSelect/__pycache__/WaveSelcet.cpython-312.pyc and b/classification_model/WaveSelect/__pycache__/WaveSelcet.cpython-312.pyc differ diff --git a/classification_model/__pycache__/__init__.cpython-312.pyc b/classification_model/__pycache__/__init__.cpython-312.pyc index fe56b52..8f2e6eb 100644 Binary files a/classification_model/__pycache__/__init__.cpython-312.pyc and b/classification_model/__pycache__/__init__.cpython-312.pyc differ diff --git a/main.py b/main.py index 1cd1ceb..45afec1 100644 --- a/main.py +++ b/main.py @@ -14,10 +14,8 @@ import time matplotlib.use('TkAgg') -# 训练相机波长(168通道) -TRAIN_WAVELENGTHS = [ - 898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2 -] +# 训练相机波长(237通道) +TRAIN_WAVELENGTHS = [912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6] def read_wavelengths_from_hdr(bil_path): hdr_path = os.path.splitext(bil_path)[0] + '.hdr' @@ -50,6 +48,28 @@ def resample_spectra_matrix(X, src_waves, dst_waves): return out +def apply_background_no_resample(df, bg_spectrum): + """ + 仅做背景校正,不做任何重采样。 + - 自动选择以 wavelength_ 或 band_ 开头的光谱列 + - 若背景长度与光谱列数不一致,按尾部对齐取最小长度进行校正 + """ + # 识别光谱列 + spec_cols = [c for c in df.columns if isinstance(c, str) and (c.startswith('wavelength_') or c.startswith('band_'))] + if not spec_cols: + raise ValueError("未找到光谱列(以 wavelength_ 或 band_ 开头)") + + bg = np.asarray(bg_spectrum, dtype=np.float64).ravel() + if bg.size == 0: + raise ValueError("背景光谱长度为0,无法进行背景校正") + + # 尾部对齐,取最小长度,避免维度不一致 + n = min(len(spec_cols), bg.shape[0]) + use_cols = spec_cols[-n:] + df.loc[:, use_cols] = df.loc[:, use_cols].div(bg[-n:], axis=1) + return df + + def parse_arguments(): """解析命令行参数""" parser = argparse.ArgumentParser(description='Microplastic spectral shape classification') @@ -59,7 +79,7 @@ def parse_arguments(): parser.add_argument('--output_path', required=True, help='Path to output classification result') parser.add_argument('--model_path', required=True, help='Path to primary classification model') - # 可选参数 + # parser.add_argument('--primary_model_type', default='SVM', help='Type of primary model (default: SVM)') # parser.add_argument('--primary_process_methods1', default='SS', help='Primary process method 1 (default: SS)') # parser.add_argument('--primary_process_methods2', default='None', help='Primary process method 2 (default: None)') @@ -176,11 +196,11 @@ def save_envi_classification(bil_path, df, savepath): for _, row in df.iterrows(): contour = row['contour'] - prediction = int(row['Predictions']) + 1 - contour = np.array(contour, dtype=np.int32) - # 先将 classification_result 中的 10 和 11 替换为 0 - classification_result[(classification_result == 10)] = 0 + prediction = int(row['Predictions']) + 1 # 先加1 + if prediction in (10, 11): # 再判断是否为10或11 + prediction = 0 # 视为背景 + contour = np.array(contour, dtype=np.int32) cv2.fillPoly(classification_result, [contour], prediction) output_path = savepath @@ -206,7 +226,6 @@ byte order = 0 wavelength units = nm """ filename, ext = os.path.splitext(savepath) - # 替换扩展名为 '.hdr' header_filename = filename + '.hdr' with open(header_filename, 'w') as header_file: @@ -214,300 +233,463 @@ wavelength units = nm def change_hdr_file(bil_path, wavelengths=None): - # wavelengths=None 时仅在HDR缺失wavelength字段才写入;若提供则按提供内容写入 hdr_path = os.path.splitext(bil_path)[0] + '.hdr' if not os.path.exists(hdr_path): print(f"错误: 找不到对应的HDR文件: {hdr_path}") return - with open(hdr_path, 'r') as file: + # 仅在缺少 wavelength 字段时才尝试写入 + with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as file: content = file.read() - if 'wavelength' in content and wavelengths is None: - print(f"File {os.path.basename(hdr_path)} already contains wavelength information; no changes needed.") + if 'wavelength' in content: + print(f"{os.path.basename(hdr_path)} 已包含 wavelength 字段,跳过追加。") return - if wavelengths is None: - print(f"No wavelengths provided and HDR lacks wavelength; skipping write to avoid wrong bands.") + if wavelengths is None or len(wavelengths) == 0: + print("HDR 缺少 wavelength,但未提供 wavelengths,跳过写入以避免错误。") return needs_newline = not content.endswith('\n') wavelength_info = "wavelength = {" + ", ".join(str(float(w)) for w in wavelengths) + "}\n" - with open(hdr_path, 'a') as file: + with open(hdr_path, 'a', encoding='utf-8', errors='ignore') as file: if needs_newline: file.write('\n') file.write(wavelength_info) - print(f"Successfully ensured wavelength information in file: {os.path.basename(hdr_path)}") + print(f"已在 {os.path.basename(hdr_path)} 末尾追加 wavelength 字段。") + + +def validate_inputs(bil_path, output_path, model_path): + """验证输入文件和参数""" + # 检查BIL和HDR文件存在 + if not os.path.exists(bil_path): + raise FileNotFoundError(f"BIL文件不存在: {bil_path}") + + hdr_path = os.path.splitext(bil_path)[0] + '.hdr' + if not os.path.exists(hdr_path): + raise FileNotFoundError(f"HDR文件不存在: {hdr_path}") + + # 检查输出目录可写 + output_dir = os.path.dirname(output_path) + if output_dir and not os.path.exists(output_dir): + try: + os.makedirs(output_dir, exist_ok=True) + except Exception as e: + raise RuntimeError(f"无法创建输出目录: {output_dir}") from e + + # 检查模型文件存在 + if not os.path.exists(model_path): + raise FileNotFoundError(f"主模型文件不存在: {model_path}") + + # 检查BIL文件波段数是否足够 + try: + from spectral.io import envi + img = envi.open(hdr_path, bil_path) + n_bands = img.nbands + # bil2rgb需要波段索引9, 59, 159 + if n_bands <= 159: + raise ValueError(f"BIL文件波段数不足: 需要至少160个波段,但只有{n_bands}个") + except Exception as e: + raise RuntimeError(f"无法读取BIL文件头信息: {bil_path}") from e + + +def generate_rgb(bil_path): + """处理BIL文件生成RGB图像""" + try: + rgb_img = process_bil_files(bil_path) + return rgb_img + except Exception as e: + raise RuntimeError(f"生成RGB图像失败: bil_path={bil_path}") from e + + +def run_segmentation(rgb_img, segmentation_model_path=None): + """运行分割获取掩膜""" + try: + mask, filter_mask_original = detect_microplastic_mask_from_array( + image=rgb_img, + filter_method='threshold', + diameter=None, + flow_threshold=0.4, + cellprob_threshold=-1, + model_path=segmentation_model_path, + detect_filter=True + ) + return mask, filter_mask_original + except Exception as e: + raise RuntimeError("分割失败: 无法检测微塑料颗粒") from e + + +def extract_primary_features(bil_path, mask): + """提取主要特征""" + try: + df = process_images(bil_path, mask) + return df + except Exception as e: + raise RuntimeError(f"特征提取失败: bil_path={bil_path}") from e + + +def compute_background_spectrum(bil_path, mask): + """计算背景光谱""" + try: + df_correct = process_images_background(bil_path, mask) + return df_correct + except Exception as e: + raise RuntimeError(f"背景光谱计算失败: bil_path={bil_path}") from e + + +def apply_background_and_optional_resample(df, bg_spectrum, bil_path): + """应用背景校正和可选的重采样""" + # 识别光谱列:所有以wavelength_开头的列 + spec_cols = [c for c in df.columns if c.startswith('wavelength_')] + + if not spec_cols: + raise ValueError("未找到光谱列(以wavelength_开头的列)") + + if len(spec_cols) != len(bg_spectrum): + raise ValueError(f"光谱列数量({len(spec_cols)})与背景光谱长度({len(bg_spectrum)})不匹配") + + # 背景校正:用背景光谱逐列相除 + df[spec_cols] = df[spec_cols].div(bg_spectrum, axis=1) + + # 检查是否需要重采样 + src_waves = read_wavelengths_from_hdr(bil_path) + need_resample = (src_waves.size > 0 and + (src_waves.size != len(TRAIN_WAVELENGTHS) or + not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2))) + + if need_resample: + print(f"重采样光谱: 源波段数 {src_waves.size} -> 目标波段数 {len(TRAIN_WAVELENGTHS)}") + + # 提取光谱数据 + X_src = df[spec_cols].to_numpy(dtype=np.float64) + X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS) + + # 替换光谱列 + spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))] + df = pd.concat([ + df.drop(columns=spec_cols), # 移除原有光谱列 + pd.DataFrame(X_dst, columns=spec_col_names, index=df.index) + ], axis=1) + + return df + + +def clean_and_select_columns(df): + """数据清理和列选择""" + # 移除NaN值 + df = df.dropna() + + # 过滤轮廓点数不足的样本 + df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)] + + # 过滤面积过小的样本 + df = df[df['area'] >= 500] + + # 列筛选:使用原来的硬编码索引删除逻辑 + cols_to_remove = df.columns[np.r_[-14: -1]] + df = df.drop(columns=cols_to_remove) + + return df + + +def run_primary_classification(df, primary_model_path): + """运行主要分类""" + try: + # 验证特征维度 + try: + import joblib + scaler_path = os.path.join(os.path.dirname(primary_model_path), 'scaler_params.pkl') + if os.path.exists(scaler_path): + scaler = joblib.load(scaler_path) + numeric_cols = [c for c in df.columns[1:] if np.issubdtype(df[c].dtype, np.number) and c != 'contour'] + if hasattr(scaler, 'mean_') and len(numeric_cols) != scaler.mean_.shape[0]: + raise ValueError(f"特征维度不匹配: 当前{numeric_cols}列 != 训练时{scaler.mean_.shape[0]}维") + except Exception as e: + print(f"警告: 无法验证特征维度: {e}") + + df_pre = predict_with_model( + df, + primary_model_path, + model_type='SVM', + ProcessMethods1='SS', + ProcessMethods2='None' + ) + return df_pre + except Exception as e: + raise RuntimeError(f"主要分类失败: model_path={primary_model_path}") from e + + +def run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original): + """根据需要运行二次分类""" + # 二次分类配置(保持在代码内) + secondary_model_path = os.path.join(os.path.dirname(__file__), 'modelsave', 'HDPELDPE_model', 'svm.m') + secondary_target_classes = [1, 2] # HDPE, LDPE + + target_classes = set(secondary_target_classes or []) + mask_secondary = df_pre['Predictions'].isin(target_classes) + + if not mask_secondary.any(): + print("未找到目标类别样本,跳过二次分类") + return df_pre + + print(f"为类别 {sorted(target_classes)} 运行二次分类") + + # 检查二次模型是否存在 + if not os.path.exists(secondary_model_path): + print(f"警告: 二次模型不存在: {secondary_model_path},跳过二次分类") + return df_pre + + try: + # 图像信息的背景矫正 + df_correct = shape_correct_background(bil_path, mask, filter_mask_original) + + # 创建只包含目标类别的掩膜 + mask_second = np.zeros_like(mask, dtype=np.uint16) + for idx in df_pre[mask_secondary].index: + contour = df_pre.loc[idx, 'contour'] + if isinstance(contour, list) and len(contour) > 0: + contour_array = np.array(contour, dtype=np.int32) + cv2.fillPoly(mask_second, [contour_array], idx + 1) + + # 提取特征 + df_shape = extract_features(df_correct, mask_second) + + # 确保使用前13列作为模型输入特征 + if len(df_shape.columns) >= 13: + df_shape = df_shape.iloc[:, :13] + + # 二次分类 + df_secondary = predict_with_model( + df_shape, + secondary_model_path, + model_type='SVM', + ProcessMethods1='None', + ProcessMethods2='None' + ) + + # 更新预测结果(类别+1) + df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1 + + except Exception as e: + print(f"警告: 二次分类失败,将继续使用主要分类结果: {e}") + + return df_pre + + +def postprocess_class7_shadow(df_pre, rgb_img): + """后处理类别7/8中的背景阴影(更稳健)""" + # 7和8都纳入检查范围 + mask_targets = df_pre['Predictions'].isin([7, 8]) + if not mask_targets.any(): + return df_pre + + print(f"处理 {mask_targets.sum()} 个类别7/8样本,识别背景阴影...") + + # 灰度图 + if hasattr(rgb_img, 'mode'): # PIL Image + rgb_img_array = np.array(rgb_img) + else: + rgb_img_array = rgb_img + if len(rgb_img_array.shape) == 3: + gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY) + else: + gray_img = rgb_img_array + + # 更稳的梯度(Scharr) + grad_x = cv2.Scharr(gray_img, cv2.CV_64F, 1, 0) + grad_y = cv2.Scharr(gray_img, cv2.CV_64F, 0, 1) + gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2) + + # 统计指标 + edge_ratios = [] + contrast_norms = [] + areas_list = [] + measures_per_idx = {} + + edge_thick = 3 + ring_thick = 5 + eps = 1e-6 + + for idx in df_pre[mask_targets].index: + try: + contour = df_pre.loc[idx, 'contour'] + if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3: + continue + contour_array = np.array(contour, dtype=np.int32) + if len(contour_array.shape) == 1: + continue + + poly_mask = np.zeros(gray_img.shape, dtype=np.uint8) + cv2.fillPoly(poly_mask, [contour_array], 255) + + # 边界带 + edge_mask = np.zeros_like(poly_mask) + cv2.drawContours(edge_mask, [contour_array], -1, 255, thickness=edge_thick) + + # 外环:膨胀边界去掉边界本身与内区 + ring_mask = cv2.dilate(edge_mask, np.ones((ring_thick, ring_thick), np.uint8), iterations=1) + ring_mask = cv2.bitwise_and(ring_mask, cv2.bitwise_not(edge_mask)) + ring_mask = cv2.bitwise_and(ring_mask, cv2.bitwise_not(poly_mask)) + + edge_vals = gradient_magnitude[edge_mask > 0] + ring_vals = gradient_magnitude[ring_mask > 0] + if edge_vals.size == 0 or ring_vals.size == 0: + continue + r_edge = float(np.median(edge_vals) / (np.median(ring_vals) + eps)) + + inside_vals = gray_img[poly_mask > 0] + outside_vals = gray_img[ring_mask > 0] + if inside_vals.size == 0 or outside_vals.size == 0: + continue + dI = float(np.median(inside_vals) - np.median(outside_vals)) + c_norm = abs(dI) / (np.std(outside_vals) + eps) + + # 面积(可选保护) + area_val = None + if 'area' in df_pre.columns: + try: + area_val = float(df_pre.loc[idx, 'area']) + except Exception: + area_val = None + + edge_ratios.append(r_edge) + contrast_norms.append(c_norm) + areas_list.append(area_val if area_val is not None else 0.0) + measures_per_idx[idx] = (r_edge, c_norm, area_val) + except Exception: + continue + + if not measures_per_idx: + print("无可用的7/8类样本进行阴影判别") + return df_pre + + def robust_q(arr, q): + vals = [v for v in arr if v is not None] + return float(np.percentile(vals, q)) if len(vals) > 0 else None + + # 稳健阈值(低于30分位更像阴影) + r_thresh = robust_q(edge_ratios, 30.0) + c_thresh = robust_q(contrast_norms, 30.0) + # 面积保护:仅对较小目标允许改写,阈值取面积分布的40%分位,限定上限避免过大 + a_thresh = robust_q(areas_list, 40.0) + if a_thresh is None or a_thresh <= 0: + a_thresh = 1200.0 + a_thresh = min(a_thresh, 2000.0) + + indices_to_update = [] + for idx, (r_edge, c_norm, area_val) in measures_per_idx.items(): + small_enough = (area_val is None) or (area_val <= a_thresh) + if (r_thresh is not None and c_thresh is not None and small_enough): + # 两个指标都低,且面积不大 -> 判定为阴影 + if (r_edge < r_thresh) and (c_norm < c_thresh): + indices_to_update.append(idx) + + if indices_to_update: + # 改为背景(0),而不是9(PVC) + df_pre.loc[indices_to_update, 'Predictions'] = 9 + print(f"将 {len(indices_to_update)} 个样本从类别7/8改为背景(阴影),面积阈值≈{a_thresh:.0f}") + else: + print("无需更新类别7/8样本") + + return df_pre + + +def write_outputs(bil_path, df_pre, output_path): + """写入输出结果""" + try: + # 收缩轮廓 + df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1) + + # 保存ENVI分类结果 + save_envi_classification(bil_path, df_pre, output_path) + + except Exception as e: + raise RuntimeError(f"保存结果失败: output_path={output_path}") from e def main(): + """主函数""" args = parse_arguments() bil_path = args.bil_path output_path = args.output_path primary_model_path = args.model_path - primary_model_type = 'SVM' - primary_process_methods1 = 'SS' - primary_process_methods2 = "None" + segmentation_model_path = None - # secondary_model_path = args.secondary_model - # secondary_model_type = args.secondary_model_type - # secondary_process_methods1 = args.secondary_process_methods1 - # secondary_process_methods2 = args.secondary_process_methods2 - # secondary_target_classes = args.secondary_target_classes - - secondary_model_path = "E:\code\plastic\plastic20260224\plastic\plastic\modelsave\HDPELDPE_model\svm.m" - secondary_model_type = 'SVM' - secondary_process_methods1 = 'None' - secondary_process_methods2 = 'None' - secondary_target_classes = [1, 2] # 记录总开始时间 total_start_time = time.time() - # 处理BIL文件生成RGB图像 - print("Processing BIL file to generate RGB image...\n") - rgb_img = process_bil_files(bil_path) - # 修改hdr - change_hdr_file(bil_path) - segmentation_start_time = time.time() - # 生成掩膜,mask为16位的塑料标签掩膜 - print("Generating mask ...\n") - mask, filter_mask_original = detect_microplastic_mask_from_array( - image=rgb_img, # 直接传入cv2.imread的结果 - filter_method='threshold', - diameter=None, - flow_threshold=0.4, - cellprob_threshold=-1, - detect_filter=True - ) - - # 提取特征 - print("Extracting features from BIL file...\n") - df = process_images(bil_path, mask) - - # 背景校正(保持现有逻辑) - print("Applying background correction...\n") - df_correct = process_images_background(bil_path, mask) - - # 自动识别光谱列范围:假设从第2列开始连续为光谱列,长度=背景校正矩阵列数 - spec_start = 1 - spec_len_src = len(df_correct) - spec_end_src = spec_start + spec_len_src - - # 归一化(按通道逐列相除) - df.iloc[:, spec_start:spec_end_src] = df.iloc[:, spec_start:spec_end_src].div(df_correct, axis=1) - - # 读取当前BIL的波长,并重采样到训练用波长 - src_waves = read_wavelengths_from_hdr(bil_path) - need_resample = (src_waves.size > 0) and (src_waves.size != len(TRAIN_WAVELENGTHS) or not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2)) - if need_resample: - print(f"Resampling spectra: src {src_waves.size} bands -> dst {len(TRAIN_WAVELENGTHS)} bands\n") - X_src = df.iloc[:, spec_start:spec_end_src].to_numpy(dtype=np.float64) - X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS) - - # 用训练维度(168)替换原光谱列;其余形状/统计特征保持不变 - spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))] - df = pd.concat( - [ - df.iloc[:, :spec_start], - pd.DataFrame(X_dst, columns=spec_col_names, index=df.index), - df.iloc[:, spec_end_src:] - ], - axis=1 - ) - # 更新光谱列的新范围 - spec_end_src = spec_start + len(TRAIN_WAVELENGTHS) - - # 数据清理(保持原有规则) - print("Cleaning data...\n") - df = df.dropna() - df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)] - df = df[df['area'] >= 500] - - # 列筛选:此时光谱列已与训练对齐为168维,因此区间索引仍可复用 - cols_to_remove = df.columns[np.r_[1:5, 87:110, 166:169, -10:-1]] - df = df.drop(columns=cols_to_remove) - - # 继续后续流程 - segmentation_time = time.time() - segmentation_start_time - df = df.iloc[:, :] - # 预测分类(分类阶段) - classification_start_time = time.time() - # 预测分类 - print("Predicting classes...\n") - loaded_model = load_model(primary_model_path) - - # 断言:数值特征维度应等于训练时的scaler输入维度 try: - import joblib - scaler = joblib.load(os.path.join(os.path.dirname(primary_model_path), 'scaler_params.pkl')) - numeric_cols = [c for c in df.columns[1:] if np.issubdtype(df[c].dtype, np.number) and c != 'contour'] - if hasattr(scaler, 'mean_') and len(numeric_cols) != scaler.mean_.shape[0]: - raise ValueError(f"Feature dimension mismatch: current {len(numeric_cols)} != scaler {scaler.mean_.shape[0]}. Check resampling and column selection.") - except Exception: - pass + # 验证输入 + validate_inputs(bil_path, output_path, primary_model_path) + bands =[912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6] - df_pre = predict_with_model( - df, - primary_model_path, - model_type=primary_model_type, - ProcessMethods1=primary_process_methods1, - ProcessMethods2=primary_process_methods2 - ) + # 修改HDR文件 + change_hdr_file(bil_path,bands) - # 对HDPE和LDPE进行二次分类 - # 从第一次分类结果中提取SECONDARY_TARGET_CLASSES类别的掩膜轮廓 - target_classes = set(secondary_target_classes or []) - mask_secondary = df_pre['Predictions'].isin(target_classes) + # 处理BIL文件生成RGB图像 + print("处理BIL文件生成RGB图像...") + rgb_img = generate_rgb(bil_path) - if mask_secondary.any(): - # 只有在找到目标类别时才进行背景校正和二次分类 - print(f"Running secondary classification for classes: {sorted(target_classes)}") - - # 图像信息的背景矫正 - df_correct = shape_correct_background(bil_path, mask, filter_mask_original) - - # 创建新的掩膜mask_second,只包含目标类别的轮廓 - mask_second = np.zeros_like(mask, dtype=np.uint16) - - for idx in df_pre[mask_secondary].index: - contour = df_pre.loc[idx, 'contour'] - if isinstance(contour, list) and len(contour) > 0: - contour_array = np.array(contour, dtype=np.int32) - cv2.fillPoly(mask_second, [contour_array], idx + 1) # 使用索引+1作为标签 + # 分割阶段 + segmentation_start_time = time.time() + print("生成掩膜...") + mask, filter_mask_original = run_segmentation(rgb_img, segmentation_model_path) # 提取特征 - df_shape = extract_features(df_correct, mask_second) + print("从BIL文件提取特征...") + df = extract_primary_features(bil_path, mask) - # 确保使用第2到13列作为模型输入特征 - if len(df_shape.columns) >= 13: - df_shape = df_shape.iloc[:, :13] + # 背景校正 + print("应用背景校正...") + bg_spectrum = compute_background_spectrum(bil_path, mask) - # 二次分类:使用第二个模型预测并更新分类结果 - if secondary_model_path: - df_secondary = predict_with_model( - df_shape, - secondary_model_path, - model_type=secondary_model_type, - ProcessMethods1=secondary_process_methods1, - ProcessMethods2=secondary_process_methods2 - ) - df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1 - else: - print("Secondary model path not provided; skipping secondary classification.\n") - else: - print("No samples from target classes found; skipping secondary classification.\n") + # 背景校正 + 仅在与训练相机波长不一致时重采样 + df = apply_background_and_optional_resample(df, bg_spectrum, bil_path) - # 识别类别7中的背景阴影误判:通过边界清晰度特征 - # 真正的类别7边界清晰,背景阴影边界模糊 - class_7_mask = df_pre['Predictions'] == 7 - if class_7_mask.any(): - print(f"Processing {class_7_mask.sum()} samples with class 7 to identify background shadows...\n") + # 数据清理和列选择 + print("清理数据...") + df = clean_and_select_columns(df) - # 将PIL Image转换为numpy数组 - if hasattr(rgb_img, 'mode'): # 检查是否是PIL Image - rgb_img_array = np.array(rgb_img) - else: - rgb_img_array = rgb_img + segmentation_time = time.time() - segmentation_start_time - # 转换为灰度图(用于计算梯度) - if len(rgb_img_array.shape) == 3: - gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY) - else: - gray_img = rgb_img_array + # 分类阶段 + classification_start_time = time.time() + print("预测分类...") + df_pre = run_primary_classification(df, primary_model_path) - # 计算梯度图(使用Sobel算子) - grad_x = cv2.Sobel(gray_img, cv2.CV_64F, 1, 0, ksize=3) - grad_y = cv2.Sobel(gray_img, cv2.CV_64F, 0, 1, ksize=3) - gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2) + # 二次分类 + df_pre = run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original) - # 先收集所有类别7样本的边缘梯度值,用于确定阈值 - all_class7_gradients = [] - valid_indices = [] - for idx in df_pre[class_7_mask].index: - try: - contour = df_pre.loc[idx, 'contour'] - if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3: - continue + # 后处理类别7阴影 + df_pre = postprocess_class7_shadow(df_pre, rgb_img) - contour_array = np.array(contour, dtype=np.int32) - if len(contour_array.shape) == 1: - continue + classification_time = time.time() - classification_start_time - mask_img = np.zeros(gray_img.shape, dtype=np.uint8) - cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2) - edge_gradients = gradient_magnitude[mask_img > 0] - if len(edge_gradients) > 0: - all_class7_gradients.extend(edge_gradients) - valid_indices.append(idx) - except: - continue + # 保存结果 + print("保存ENVI分类结果...") + write_outputs(bil_path, df_pre, output_path) - # 基于类别7样本的梯度分布确定阈值 - # 使用类别7样本梯度值的中位数作为基准,低于某个分位数(如30%)的认为是背景阴影 - if len(all_class7_gradients) > 0: - gradient_threshold = np.percentile(all_class7_gradients, 30) # 使用类别7样本梯度值的30%分位数 - else: - gradient_threshold = np.percentile(gradient_magnitude, 30) # 如果没有有效样本,使用整张图的30%分位数 + print(f"ENVI分类结果已保存至: {output_path}") - print(f"Gradient threshold for class 7: {gradient_threshold:.2f}\n") + # 计算总耗时 + total_time = time.time() - total_start_time - # 处理每个类别7的样本,判断是否为背景阴影 - indices_to_update = [] - for idx in valid_indices: - try: - contour = df_pre.loc[idx, 'contour'] - contour_array = np.array(contour, dtype=np.int32) + # 打印耗时统计 + print(f"\n{'=' * 60}") + print("处理完成") + print(f"{'=' * 60}") + print(f"分割耗时: {segmentation_time:.2f} 秒") + print(f"分类耗时: {classification_time:.2f} 秒") + print(f"总耗时: {total_time:.2f} 秒") + print(f"{'=' * 60}") + print(f"结果已保存至: {output_path}") - # 创建轮廓掩膜(线宽为2像素,用于提取边缘) - mask_img = np.zeros(gray_img.shape, dtype=np.uint8) - cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2) - - # 提取轮廓边缘的梯度值 - edge_gradients = gradient_magnitude[mask_img > 0] - if len(edge_gradients) == 0: - continue - - # 计算轮廓边缘的平均梯度强度 - mean_gradient = np.mean(edge_gradients) - - # 如果平均梯度强度低于阈值,认为是背景阴影(边界模糊) - if mean_gradient < gradient_threshold: - indices_to_update.append(idx) - print( - f"Sample {idx}: mean_gradient={mean_gradient:.2f}, threshold={gradient_threshold:.2f} -> identified as background shadow") - - except Exception as e: - print(f"Error processing sample at index {idx}: {str(e)}") - continue - - # 将背景阴影的类别7改为类别9 - if indices_to_update: - df_pre.loc[indices_to_update, 'Predictions'] = 9 - print(f"Updated {len(indices_to_update)} samples from class 7 to class 9 (background shadows)\n") - else: - print("No samples needed to be updated from class 7\n") - classification_time = time.time() - classification_start_time - - df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1) - # 保存ENVI分类结果 - print("Saving ENVI classification results...\n") - save_envi_classification(bil_path, df_pre, output_path) - print(f"ENVI classification results saved to: {output_path}") - # 计算总耗时 - total_time = time.time() - total_start_time - - # 打印耗时统计 - print(f"\n{'=' * 60}") - print(f"处理完成") - print(f"{'=' * 60}") - print(f"分割耗时: {segmentation_time:.2f} 秒") - print(f"分类耗时: {classification_time:.2f} 秒") - print(f"总耗时: {total_time:.2f} 秒") - print(f"{'=' * 60}") - print(f"结果已保存至: {output_path}") + except Exception as e: + print(f"处理失败: {e}") + raise if __name__ == "__main__": diff --git a/main_batch_nosample.py b/main_batch_nosample.py new file mode 100644 index 0000000..fb0f8a6 --- /dev/null +++ b/main_batch_nosample.py @@ -0,0 +1,764 @@ +import os +import cv2 +import matplotlib +import numpy as np +import argparse +import pandas as pd +from bil2rgb import process_bil_files +from classification_model.Parallel.predict_plastic import load_model, predict_with_model +from mask import detect_microplastic_mask_from_array +from shape_spectral import process_images +from shape_spectral_background import process_images_background +from extact_shape import shape_correct_background, extract_features +import time +##批量预测文件夹内的bil文件,不进行降采样,使用新采集的数据进行训练 +matplotlib.use('TkAgg') + +# 训练相机波长(168通道) +TRAIN_WAVELENGTHS = [ + 898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2 +] + +def read_wavelengths_from_hdr(bil_path): + hdr_path = os.path.splitext(bil_path)[0] + '.hdr' + if not os.path.exists(hdr_path): + return np.array([], dtype=np.float64) + with open(hdr_path, 'r') as f: + txt = f.read() + if 'wavelength' not in txt: + return np.array([], dtype=np.float64) + seg = txt.split('wavelength', 1)[1] + seg = seg[seg.find('{')+1: seg.find('}')] + vals = [v.strip() for v in seg.split(',') if v.strip()] + try: + waves = np.array([float(v) for v in vals], dtype=np.float64) + except Exception: + waves = np.array([], dtype=np.float64) + return waves + +def resample_spectra_matrix(X, src_waves, dst_waves): + src = np.asarray(src_waves, dtype=np.float64) + dst = np.asarray(dst_waves, dtype=np.float64) + X = np.asarray(X, dtype=np.float64) + if src.size == 0 or dst.size == 0: + return X + # 线性插值,越界用端点外推,避免维度缺失 + out = np.empty((X.shape[0], dst.size), dtype=np.float64) + for i in range(X.shape[0]): + row = X[i] + out[i] = np.interp(dst, src, row, left=row[0], right=row[-1]) + return out + + +def apply_background_no_resample(df, bg_spectrum): + """ + 仅做背景校正,不做任何重采样。 + - 自动选择以 wavelength_ 或 band_ 开头的光谱列 + - 若背景长度与光谱列数不一致,按尾部对齐取最小长度进行校正 + """ + # 识别光谱列 + spec_cols = [c for c in df.columns if isinstance(c, str) and (c.startswith('wavelength_') or c.startswith('band_'))] + if not spec_cols: + raise ValueError("未找到光谱列(以 wavelength_ 或 band_ 开头)") + + bg = np.asarray(bg_spectrum, dtype=np.float64).ravel() + if bg.size == 0: + raise ValueError("背景光谱长度为0,无法进行背景校正") + + # 尾部对齐,取最小长度,避免维度不一致 + n = min(len(spec_cols), bg.shape[0]) + use_cols = spec_cols[-n:] + df.loc[:, use_cols] = df.loc[:, use_cols].div(bg[-n:], axis=1) + return df + + +def parse_arguments(): + """解析命令行参数""" + parser = argparse.ArgumentParser(description='Microplastic spectral shape classification - Batch processing') + + # 必需参数 + parser.add_argument('--input_dir', required=True, help='Path to input directory containing BIL files') + parser.add_argument('--output_dir', required=True, help='Path to output directory for classification results') + parser.add_argument('--model_path', required=True, help='Path to primary classification model') + + # 可选参数 + # parser.add_argument('--primary_model_type', default='SVM', help='Type of primary model (default: SVM)') + # parser.add_argument('--primary_process_methods1', default='SS', help='Primary process method 1 (default: SS)') + # parser.add_argument('--primary_process_methods2', default='None', help='Primary process method 2 (default: None)') + + # parser.add_argument('--secondary_model', default="D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m", help='Path to secondary classification model') + # parser.add_argument('--secondary_model_type', default='SVM', help='Type of secondary model (default: SVM)') + # parser.add_argument('--secondary_process_methods1', default='None', + # help='Secondary process method 1 (default: None)') + # parser.add_argument('--secondary_process_methods2', default='None', + # help='Secondary process method 2 (default: None)') + # parser.add_argument('--secondary_target_classes', nargs='+', type=int, default=[1,2], + # help='Target classes for secondary classification (space separated)') + + return parser.parse_args() + + +# ---------------------------- +# 配置参数:直接在此修改 +# ---------------------------- +# BIL_PATH = r"D:/Data/Test/PET_bottle2.bil" +# OUTPUT_PATH = r'D:/Data/PET_bottle2_class.bil' +# +# PRIMARY_MODEL_PATH = r"D:\plastic\plastic\modelsave\svm.m" +# PRIMARY_MODEL_TYPE = 'SVM' +# PRIMARY_PROCESS_METHODS1 = 'SS' +# PRIMARY_PROCESS_METHODS2 = 'None' +# +# SECONDARY_MODEL_PATH = "D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m" # 若不需要二次分类,则保持为 None +# SECONDARY_MODEL_TYPE = 'SVM' +# SECONDARY_PROCESS_METHODS1 = 'None' +# SECONDARY_PROCESS_METHODS2 = 'None' +# SECONDARY_TARGET_CLASSES = [1, 2] + + +def read_hdr_file(bil_path): + hdr_path = bil_path.replace('.bil', '.hdr') + with open(hdr_path, 'r') as f: + header = f.readlines() + + samples, lines = None, None + + for line in header: + if line.startswith('samples'): + samples = int(line.split('=')[-1].strip()) + if line.startswith('lines'): + lines = int(line.split('=')[-1].strip()) + + return samples, lines + + +def shrink_contours(bil_path, df, shrink_pixels=1): + """ + 对DataFrame中的所有轮廓进行收缩操作,避免塑料之间的相连 + + Args: + bil_path: BIL文件路径,用于获取图像尺寸 + df: 包含contour列的DataFrame + shrink_pixels: 收缩的像素数,默认1像素 + + Returns: + 更新后的DataFrame,contour列已被收缩 + """ + samples, lines = read_hdr_file(bil_path) + + # 创建腐蚀核 + kernel = np.ones((2 * shrink_pixels + 1, 2 * shrink_pixels + 1), np.uint8) + + # 创建临时掩膜用于处理 + temp_mask = np.zeros((lines, samples), dtype=np.uint8) + + # 创建DataFrame副本 + df = df.copy() + + # 遍历每一行,更新轮廓 + for idx, row in df.iterrows(): + contour = row['contour'] + if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3: + continue + + try: + contour_array = np.array(contour, dtype=np.int32) + if len(contour_array.shape) == 1: + continue + + # 清空临时掩膜 + temp_mask.fill(0) + + # 填充轮廓 + cv2.fillPoly(temp_mask, [contour_array], 255) + + # 对掩膜进行腐蚀操作 + eroded_mask = cv2.erode(temp_mask, kernel, iterations=1) + + # 重新提取轮廓 + contours, _ = cv2.findContours(eroded_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + if len(contours) > 0: + # 选择最大的轮廓(如果有多个) + largest_contour = max(contours, key=cv2.contourArea) + # 转换为列表格式,保持与原始格式一致 + if len(largest_contour) >= 3: + updated_contour = largest_contour.reshape(-1, 2).tolist() + df.at[idx, 'contour'] = updated_contour + except Exception as e: + # 如果处理失败,保留原始轮廓 + continue + + return df + + +def save_envi_classification(bil_path, df, savepath): + samples, lines = read_hdr_file(bil_path) + classification_result = np.zeros((lines, samples), dtype=np.uint16) + + # 预处理:清除可能存在的类别10和11(移到循环外以提高效率) + classification_result[(classification_result == 10)] = 0 + classification_result[(classification_result == 11)] = 0 + + for _, row in df.iterrows(): + contour = row['contour'] + prediction = int(row['Predictions']) + 1 + contour = np.array(contour, dtype=np.int32) + + cv2.fillPoly(classification_result, [contour], prediction) + + output_path = savepath + + with open(output_path, 'wb') as f: + classification_result.tofile(f) + + header_content = f"""ENVI +description = {{ +Classification Result.}} +samples = {samples} +lines = {lines} +bands = 1 +header offset = 0 +file type = ENVI Standard +data type = 2 +interleave = bil +classes = 10 +class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC }} +single pixel area = 0.000036 +unit = mm2 +byte order = 0 +wavelength units = nm +""" + filename, ext = os.path.splitext(savepath) + # 替换扩展名为 '.hdr' + header_filename = filename + '.hdr' + + with open(header_filename, 'w') as header_file: + header_file.write(header_content) + + +def change_hdr_file(bil_path, wavelengths=None): + # wavelengths=None 时仅在HDR缺失wavelength字段才写入;若提供则按提供内容写入 + hdr_path = os.path.splitext(bil_path)[0] + '.hdr' + if not os.path.exists(hdr_path): + print(f"错误: 找不到对应的HDR文件: {hdr_path}") + return + + with open(hdr_path, 'r') as file: + content = file.read() + + if 'wavelength' in content and wavelengths is None: + print(f"File {os.path.basename(hdr_path)} already contains wavelength information; no changes needed.") + return + + if wavelengths is None: + print(f"No wavelengths provided and HDR lacks wavelength; skipping write to avoid wrong bands.") + return + + needs_newline = not content.endswith('\n') + wavelength_info = "wavelength = {" + ", ".join(str(float(w)) for w in wavelengths) + "}\n" + + with open(hdr_path, 'a') as file: + if needs_newline: + file.write('\n') + file.write(wavelength_info) + + print(f"Successfully ensured wavelength information in file: {os.path.basename(hdr_path)}") + + +def get_bil_files(input_dir): + """获取输入目录中的所有BIL文件""" + if not os.path.exists(input_dir): + raise FileNotFoundError(f"输入目录不存在: {input_dir}") + + bil_files = [] + for file in os.listdir(input_dir): + if file.lower().endswith('.bil'): + bil_path = os.path.join(input_dir, file) + hdr_path = os.path.splitext(bil_path)[0] + '.hdr' + if os.path.exists(hdr_path): + bil_files.append(bil_path) + else: + print(f"警告: 找到BIL文件 {file} 但缺少对应的HDR文件,跳过") + + if not bil_files: + raise ValueError(f"在输入目录 {input_dir} 中未找到有效的BIL文件") + + return sorted(bil_files) + + +def validate_inputs(input_dir, output_dir, model_path): + """验证输入参数""" + # 检查输入目录存在 + if not os.path.exists(input_dir): + raise FileNotFoundError(f"输入目录不存在: {input_dir}") + + # 检查输出目录存在,如果不存在则创建 + if not os.path.exists(output_dir): + try: + os.makedirs(output_dir, exist_ok=True) + except Exception as e: + raise RuntimeError(f"无法创建输出目录: {output_dir}") from e + + # 检查模型文件存在 + if not os.path.exists(model_path): + raise FileNotFoundError(f"主模型文件不存在: {model_path}") + + +def validate_single_bil_file(bil_path): + """验证单个BIL文件""" + # 检查BIL和HDR文件存在 + if not os.path.exists(bil_path): + raise FileNotFoundError(f"BIL文件不存在: {bil_path}") + + hdr_path = os.path.splitext(bil_path)[0] + '.hdr' + if not os.path.exists(hdr_path): + raise FileNotFoundError(f"HDR文件不存在: {hdr_path}") + + # 检查BIL文件波段数是否足够 + try: + from spectral.io import envi + img = envi.open(hdr_path, bil_path) + n_bands = img.nbands + # bil2rgb需要波段索引9, 59, 159 + if n_bands <= 159: + raise ValueError(f"BIL文件波段数不足: 需要至少160个波段,但只有{n_bands}个") + except Exception as e: + raise RuntimeError(f"无法读取BIL文件头信息: {bil_path}") from e + + +def generate_rgb(bil_path): + """处理BIL文件生成RGB图像""" + try: + rgb_img = process_bil_files(bil_path) + return rgb_img + except Exception as e: + raise RuntimeError(f"生成RGB图像失败: bil_path={bil_path}") from e + + +def run_segmentation(rgb_img): + """运行分割获取掩膜""" + try: + mask, filter_mask_original = detect_microplastic_mask_from_array( + image=rgb_img, + filter_method='threshold', + diameter=None, + flow_threshold=0.4, + cellprob_threshold=-1, + detect_filter=True + ) + return mask, filter_mask_original + except Exception as e: + raise RuntimeError("分割失败: 无法检测微塑料颗粒") from e + + +def extract_primary_features(bil_path, mask): + """提取主要特征""" + try: + df = process_images(bil_path, mask) + return df + except Exception as e: + raise RuntimeError(f"特征提取失败: bil_path={bil_path}") from e + + +def compute_background_spectrum(bil_path, mask): + """计算背景光谱""" + try: + df_correct = process_images_background(bil_path, mask) + return df_correct + except Exception as e: + raise RuntimeError(f"背景光谱计算失败: bil_path={bil_path}") from e + + +def apply_background_and_optional_resample(df, bg_spectrum, bil_path): + """应用背景校正和可选的重采样""" + # 识别光谱列:所有以wavelength_开头的列 + spec_cols = [c for c in df.columns if c.startswith('wavelength_')] + + if not spec_cols: + raise ValueError("未找到光谱列(以wavelength_开头的列)") + + if len(spec_cols) != len(bg_spectrum): + raise ValueError(f"光谱列数量({len(spec_cols)})与背景光谱长度({len(bg_spectrum)})不匹配") + + # 背景校正:用背景光谱逐列相除 + df[spec_cols] = df[spec_cols].div(bg_spectrum, axis=1) + + # 检查是否需要重采样 + src_waves = read_wavelengths_from_hdr(bil_path) + need_resample = (src_waves.size > 0 and + (src_waves.size != len(TRAIN_WAVELENGTHS) or + not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2))) + + if need_resample: + print(f"重采样光谱: 源波段数 {src_waves.size} -> 目标波段数 {len(TRAIN_WAVELENGTHS)}") + + # 提取光谱数据 + X_src = df[spec_cols].to_numpy(dtype=np.float64) + X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS) + + # 替换光谱列 + spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))] + df = pd.concat([ + df.drop(columns=spec_cols), # 移除原有光谱列 + pd.DataFrame(X_dst, columns=spec_col_names, index=df.index) + ], axis=1) + + return df + + +def clean_and_select_columns(df): + """数据清理和列选择""" + # 移除NaN值 + df = df.dropna() + + # 过滤轮廓点数不足的样本 + df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)] + + # 过滤面积过小的样本 + df = df[df['area'] >= 500] + + # 列筛选:使用原来的硬编码索引删除逻辑 + # cols_to_remove = df.columns[np.r_[1:10, 11:15, 97:120, 176:179 ]] + cols_to_remove = df.columns[np.r_[-10: -1]] + df = df.drop(columns=cols_to_remove) + + return df + + +def run_primary_classification(df, primary_model_path): + """运行主要分类""" + try: + # 验证特征维度 + try: + import joblib + scaler_path = os.path.join(os.path.dirname(primary_model_path), 'scaler_params.pkl') + if os.path.exists(scaler_path): + scaler = joblib.load(scaler_path) + numeric_cols = [c for c in df.columns[1:] if np.issubdtype(df[c].dtype, np.number) and c != 'contour'] + if hasattr(scaler, 'mean_') and len(numeric_cols) != scaler.mean_.shape[0]: + raise ValueError(f"特征维度不匹配: 当前{numeric_cols}列 != 训练时{scaler.mean_.shape[0]}维") + except Exception as e: + print(f"警告: 无法验证特征维度: {e}") + + df_pre = predict_with_model( + df, + primary_model_path, + model_type='SVM', + ProcessMethods1='SS', + ProcessMethods2='None' + ) + return df_pre + except Exception as e: + raise RuntimeError(f"主要分类失败: model_path={primary_model_path}") from e + + +def run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original): + """根据需要运行二次分类""" + # 二次分类配置(保持在代码内) + secondary_model_path = os.path.join(os.path.dirname(__file__), 'modelsave', 'HDPELDPE_model', 'svm.m') + secondary_target_classes = [1, 2] # HDPE, LDPE + + target_classes = set(secondary_target_classes or []) + mask_secondary = df_pre['Predictions'].isin(target_classes) + + if not mask_secondary.any(): + print("未找到目标类别样本,跳过二次分类") + return df_pre + + print(f"为类别 {sorted(target_classes)} 运行二次分类") + + # 检查二次模型是否存在 + if not os.path.exists(secondary_model_path): + print(f"警告: 二次模型不存在: {secondary_model_path},跳过二次分类") + return df_pre + + try: + # 图像信息的背景矫正 + df_correct = shape_correct_background(bil_path, mask, filter_mask_original) + + # 创建只包含目标类别的掩膜 + mask_second = np.zeros_like(mask, dtype=np.uint16) + for idx in df_pre[mask_secondary].index: + contour = df_pre.loc[idx, 'contour'] + if isinstance(contour, list) and len(contour) > 0: + contour_array = np.array(contour, dtype=np.int32) + cv2.fillPoly(mask_second, [contour_array], idx + 1) + + # 提取特征 + df_shape = extract_features(df_correct, mask_second) + + # 确保使用前13列作为模型输入特征 + if len(df_shape.columns) >= 13: + df_shape = df_shape.iloc[:, :13] + + # 二次分类 + df_secondary = predict_with_model( + df_shape, + secondary_model_path, + model_type='SVM', + ProcessMethods1='None', + ProcessMethods2='None' + ) + + # 更新预测结果(类别+1) + df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1 + + except Exception as e: + print(f"警告: 二次分类失败,将继续使用主要分类结果: {e}") + + return df_pre + + +def postprocess_class7_shadow(df_pre, rgb_img): + """后处理类别7中的背景阴影""" + class_7_mask = df_pre['Predictions'] == 7 + if not class_7_mask.any(): + return df_pre + + print(f"处理 {class_7_mask.sum()} 个类别7样本,识别背景阴影...") + + # 将PIL Image转换为numpy数组 + if hasattr(rgb_img, 'mode'): # PIL Image + rgb_img_array = np.array(rgb_img) + else: + rgb_img_array = rgb_img + + # 转换为灰度图 + if len(rgb_img_array.shape) == 3: + gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY) + else: + gray_img = rgb_img_array + + # 计算梯度图 + grad_x = cv2.Sobel(gray_img, cv2.CV_64F, 1, 0, ksize=3) + grad_y = cv2.Sobel(gray_img, cv2.CV_64F, 0, 1, ksize=3) + gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2) + + # 收集类别7样本的边缘梯度值 + all_class7_gradients = [] + valid_indices = [] + + for idx in df_pre[class_7_mask].index: + try: + contour = df_pre.loc[idx, 'contour'] + if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3: + continue + + contour_array = np.array(contour, dtype=np.int32) + if len(contour_array.shape) == 1: + continue + + mask_img = np.zeros(gray_img.shape, dtype=np.uint8) + cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2) + edge_gradients = gradient_magnitude[mask_img > 0] + + if len(edge_gradients) > 0: + all_class7_gradients.extend(edge_gradients) + valid_indices.append(idx) + except Exception: + continue + + # 确定梯度阈值 + if len(all_class7_gradients) > 0: + gradient_threshold = np.percentile(all_class7_gradients, 30) + else: + gradient_threshold = np.percentile(gradient_magnitude, 30) + + print(f"类别7梯度阈值: {gradient_threshold:.2f}") + + # 处理每个类别7样本 + indices_to_update = [] + for idx in valid_indices: + try: + contour = df_pre.loc[idx, 'contour'] + contour_array = np.array(contour, dtype=np.int32) + + mask_img = np.zeros(gray_img.shape, dtype=np.uint8) + cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2) + + edge_gradients = gradient_magnitude[mask_img > 0] + if len(edge_gradients) == 0: + continue + + mean_gradient = np.mean(edge_gradients) + + if mean_gradient < gradient_threshold: + indices_to_update.append(idx) + print(f"样本 {idx}: 平均梯度={mean_gradient:.2f}, 阈值={gradient_threshold:.2f} -> 识别为背景阴影") + + except Exception as e: + print(f"处理样本 {idx} 时出错: {e}") + continue + + # 更新分类结果 + if indices_to_update: + df_pre.loc[indices_to_update, 'Predictions'] = 9 + print(f"将 {len(indices_to_update)} 个样本从类别7改为类别9(背景阴影)") + else: + print("无需更新类别7样本") + + return df_pre + + +def write_outputs(bil_path, df_pre, output_path): + """写入输出结果""" + try: + # 收缩轮廓 + df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1) + + # 保存ENVI分类结果 + save_envi_classification(bil_path, df_pre, output_path) + + except Exception as e: + raise RuntimeError(f"保存结果失败: output_path={output_path}") from e + + +def process_single_file(bil_path, output_path, primary_model_path): + """处理单个BIL文件的完整流程""" + try: + # 验证输入 + validate_single_bil_file(bil_path) + + # 修改HDR文件 + change_hdr_file(bil_path) + + # 处理BIL文件生成RGB图像 + print(f" 处理BIL文件生成RGB图像...") + rgb_img = generate_rgb(bil_path) + + # 分割阶段 + segmentation_start_time = time.time() + print(f" 生成掩膜...") + mask, filter_mask_original = run_segmentation(rgb_img) + + # 提取特征 + print(f" 从BIL文件提取特征...") + df = extract_primary_features(bil_path, mask) + + # 背景校正 + print(f" 应用背景校正...") + bg_spectrum = compute_background_spectrum(bil_path, mask) + + # 仅应用背景校正,不进行重采样 + df = apply_background_no_resample(df, bg_spectrum) + + # 数据清理和列选择 + print(f" 清理数据...") + df = clean_and_select_columns(df) + + segmentation_time = time.time() - segmentation_start_time + + # 分类阶段 + classification_start_time = time.time() + print(f" 预测分类...") + df_pre = run_primary_classification(df, primary_model_path) + + # 二次分类 + df_pre = run_secondary_classification_if_needed(df_pre, bil_path, mask, filter_mask_original) + + # 后处理类别7阴影 + df_pre = postprocess_class7_shadow(df_pre, rgb_img) + + classification_time = time.time() - classification_start_time + + # 保存结果 + print(f" 保存ENVI分类结果...") + write_outputs(bil_path, df_pre, output_path) + + return segmentation_time, classification_time + + except Exception as e: + print(f"处理文件失败 {os.path.basename(bil_path)}: {e}") + raise + + +def main(): + """主函数 - 批量处理""" + args = parse_arguments() + + input_dir = args.input_dir + output_dir = args.output_dir + primary_model_path = args.model_path + + # 记录总开始时间 + total_start_time = time.time() + + try: + # 验证输入参数 + validate_inputs(input_dir, output_dir, primary_model_path) + + # 获取所有BIL文件 + bil_files = get_bil_files(input_dir) + print(f"找到 {len(bil_files)} 个BIL文件待处理") + + # 统计信息 + total_files = len(bil_files) + processed_files = 0 + failed_files = 0 + total_segmentation_time = 0 + total_classification_time = 0 + + # 逐个处理文件 + for i, bil_path in enumerate(bil_files, 1): + print(f"\n{'='*60}") + print(f"处理文件 {i}/{total_files}: {os.path.basename(bil_path)}") + print(f"{'='*60}") + + try: + # 生成输出文件名:原文件名 + "_classification.bil" + base_name = os.path.splitext(os.path.basename(bil_path))[0] + output_filename = f"{base_name}_classification.bil" + output_path = os.path.join(output_dir, output_filename) + + # 处理单个文件 + segmentation_time, classification_time = process_single_file( + bil_path, output_path, primary_model_path + ) + + # 更新统计信息 + total_segmentation_time += segmentation_time + total_classification_time += classification_time + processed_files += 1 + + print(f"文件 {os.path.basename(bil_path)} 处理完成") + print(f"结果保存至: {output_path}") + + except Exception as e: + print(f"文件 {os.path.basename(bil_path)} 处理失败: {e}") + failed_files += 1 + continue + + # 计算平均耗时 + if processed_files > 0: + avg_segmentation_time = total_segmentation_time / processed_files + avg_classification_time = total_classification_time / processed_files + avg_total_time = (total_segmentation_time + total_classification_time) / processed_files + + # 计算总耗时 + total_time = time.time() - total_start_time + + # 打印汇总统计 + print(f"\n{'=' * 60}") + print("批量处理完成") + print(f"{'=' * 60}") + print(f"总文件数: {total_files}") + print(f"成功处理: {processed_files}") + print(f"处理失败: {failed_files}") + print(f"成功率: {processed_files/total_files*100:.1f}%" if total_files > 0 else "成功率: 0%") + print(f"{'=' * 60}") + if processed_files > 0: + print(f"平均分割耗时: {avg_segmentation_time:.2f} 秒") + print(f"平均分类耗时: {avg_classification_time:.2f} 秒") + print(f"平均总耗时: {avg_total_time:.2f} 秒") + print(f"实际总耗时: {total_time:.2f} 秒") + print(f"{'=' * 60}") + print(f"结果保存目录: {output_dir}") + + except Exception as e: + print(f"批量处理失败: {e}") + raise + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/mainv1.py b/mainv1.py new file mode 100644 index 0000000..e8c308a --- /dev/null +++ b/mainv1.py @@ -0,0 +1,445 @@ +import os +import cv2 +import matplotlib +import numpy as np +import argparse +from bil2rgb import process_bil_files +from classification_model.Parallel.predict_plastic import load_model, predict_with_model +from mask import detect_microplastic_mask_from_array +from shape_spectral import process_images +from shape_spectral_background import process_images_background +from extact_shape import shape_correct_background, extract_features +import time + +matplotlib.use('TkAgg') + + +def parse_arguments(): + """解析命令行参数""" + parser = argparse.ArgumentParser(description='Microplastic spectral shape classification') + + # 必需参数 + parser.add_argument('--bil_path', required=True, help='Path to input BIL file') + parser.add_argument('--output_path', required=True, help='Path to output classification result') + parser.add_argument('--model_path', required=True, help='Path to primary classification model') + + # 可选参数 + # parser.add_argument('--primary_model_type', default='SVM', help='Type of primary model (default: SVM)') + # parser.add_argument('--primary_process_methods1', default='SS', help='Primary process method 1 (default: SS)') + # parser.add_argument('--primary_process_methods2', default='None', help='Primary process method 2 (default: None)') + + # parser.add_argument('--secondary_model', default="D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m", help='Path to secondary classification model') + # parser.add_argument('--secondary_model_type', default='SVM', help='Type of secondary model (default: SVM)') + # parser.add_argument('--secondary_process_methods1', default='None', + # help='Secondary process method 1 (default: None)') + # parser.add_argument('--secondary_process_methods2', default='None', + # help='Secondary process method 2 (default: None)') + # parser.add_argument('--secondary_target_classes', nargs='+', type=int, default=[1,2], + # help='Target classes for secondary classification (space separated)') + + return parser.parse_args() + + +# ---------------------------- +# 配置参数:直接在此修改 +# ---------------------------- +# BIL_PATH = r"D:/Data/Test/PET_bottle2.bil" +# OUTPUT_PATH = r'D:/Data/PET_bottle2_class.bil' +# +# PRIMARY_MODEL_PATH = r"D:\plastic\plastic\modelsave\svm.m" +# PRIMARY_MODEL_TYPE = 'SVM' +# PRIMARY_PROCESS_METHODS1 = 'SS' +# PRIMARY_PROCESS_METHODS2 = 'None' +# +# SECONDARY_MODEL_PATH = "D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m" # 若不需要二次分类,则保持为 None +# SECONDARY_MODEL_TYPE = 'SVM' +# SECONDARY_PROCESS_METHODS1 = 'None' +# SECONDARY_PROCESS_METHODS2 = 'None' +# SECONDARY_TARGET_CLASSES = [1, 2] + + +def read_hdr_file(bil_path): + hdr_path = bil_path.replace('.bil', '.hdr') + with open(hdr_path, 'r') as f: + header = f.readlines() + + samples, lines = None, None + + for line in header: + if line.startswith('samples'): + samples = int(line.split('=')[-1].strip()) + if line.startswith('lines'): + lines = int(line.split('=')[-1].strip()) + + return samples, lines + + +def shrink_contours(bil_path, df, shrink_pixels=1): + """ + 对DataFrame中的所有轮廓进行收缩操作,避免塑料之间的相连 + + Args: + bil_path: BIL文件路径,用于获取图像尺寸 + df: 包含contour列的DataFrame + shrink_pixels: 收缩的像素数,默认1像素 + + Returns: + 更新后的DataFrame,contour列已被收缩 + """ + samples, lines = read_hdr_file(bil_path) + + # 创建腐蚀核 + kernel = np.ones((2 * shrink_pixels + 1, 2 * shrink_pixels + 1), np.uint8) + + # 创建临时掩膜用于处理 + temp_mask = np.zeros((lines, samples), dtype=np.uint8) + + # 创建DataFrame副本 + df = df.copy() + + # 遍历每一行,更新轮廓 + for idx, row in df.iterrows(): + contour = row['contour'] + if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3: + continue + + try: + contour_array = np.array(contour, dtype=np.int32) + if len(contour_array.shape) == 1: + continue + + # 清空临时掩膜 + temp_mask.fill(0) + + # 填充轮廓 + cv2.fillPoly(temp_mask, [contour_array], 255) + + # 对掩膜进行腐蚀操作 + eroded_mask = cv2.erode(temp_mask, kernel, iterations=1) + + # 重新提取轮廓 + contours, _ = cv2.findContours(eroded_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + if len(contours) > 0: + # 选择最大的轮廓(如果有多个) + largest_contour = max(contours, key=cv2.contourArea) + # 转换为列表格式,保持与原始格式一致 + if len(largest_contour) >= 3: + updated_contour = largest_contour.reshape(-1, 2).tolist() + df.at[idx, 'contour'] = updated_contour + except Exception as e: + # 如果处理失败,保留原始轮廓 + continue + + return df + + +def save_envi_classification(bil_path, df, savepath): + samples, lines = read_hdr_file(bil_path) + classification_result = np.zeros((lines, samples), dtype=np.uint16) + + for _, row in df.iterrows(): + contour = row['contour'] + prediction = int(row['Predictions']) + 1 + contour = np.array(contour, dtype=np.int32) + # 先将 classification_result 中的 10 和 11 替换为 0 + classification_result[(classification_result == 10)] = 0 + + cv2.fillPoly(classification_result, [contour], prediction) + + output_path = savepath + + with open(output_path, 'wb') as f: + classification_result.tofile(f) + + header_content = f"""ENVI +description = {{ +Classification Result.}} +samples = {samples} +lines = {lines} +bands = 1 +header offset = 0 +file type = ENVI Standard +data type = 2 +interleave = bil +classes = 10 +class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC }} +single pixel area = 0.000036 +unit = mm2 +byte order = 0 +wavelength units = nm +""" + filename, ext = os.path.splitext(savepath) + # 替换扩展名为 '.hdr' + header_filename = filename + '.hdr' + + with open(header_filename, 'w') as header_file: + header_file.write(header_content) + + +def change_hdr_file(bil_path): + # 定义要追加的波长信息 + wavelength_info = """wavelength = {898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2}""" + + # 将.bil路径转换为.hdr路径 + hdr_path = os.path.splitext(bil_path)[0] + '.hdr' + + # 检查.hdr文件是否存在 + if not os.path.exists(hdr_path): + print(f"错误: 找不到对应的HDR文件: {hdr_path}") + return + + # 读取文件内容 + with open(hdr_path, 'r') as file: + content = file.read() + + # 检查是否已包含波长信息 + if 'wavelength' in content: + print(f"File {os.path.basename(hdr_path)} already contains wavelength information; no changes needed.") + return + + # 检查文件是否以换行符结尾 + needs_newline = not content.endswith('\n') + + # 追加波长信息 + with open(hdr_path, 'a') as file: + if needs_newline: + file.write('\n') # 确保新内容从新行开始 + file.write(wavelength_info + '\n') + + print(f"Successfully added wavelength information to file: {os.path.basename(hdr_path)}") + + +def main(): + args = parse_arguments() + + bil_path = args.bil_path + output_path = args.output_path + primary_model_path = args.model_path + primary_model_type = 'SVM' + primary_process_methods1 = 'SS' + primary_process_methods2 = "None" + + # secondary_model_path = args.secondary_model + # secondary_model_type = args.secondary_model_type + # secondary_process_methods1 = args.secondary_process_methods1 + # secondary_process_methods2 = args.secondary_process_methods2 + # secondary_target_classes = args.secondary_target_classes + + secondary_model_path = "D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m" + secondary_model_type = 'SVM' + secondary_process_methods1 = 'None' + secondary_process_methods2 = 'None' + secondary_target_classes = [1, 2] + # 记录总开始时间 + total_start_time = time.time() + # 处理BIL文件生成RGB图像 + print("Processing BIL file to generate RGB image...\n") + rgb_img = process_bil_files(bil_path) + + # 修改hdr + change_hdr_file(bil_path) + segmentation_start_time = time.time() + # 生成掩膜,mask为16位的塑料标签掩膜 + print("Generating mask ...\n") + mask, filter_mask_original = detect_microplastic_mask_from_array( + image=rgb_img, # 直接传入cv2.imread的结果 + filter_method='threshold', + diameter=None, + flow_threshold=0.4, + cellprob_threshold=-1, + detect_filter=False + ) + + # 提取特征 + print("Extracting features from BIL file...\n") + df = process_images(bil_path, mask) + + # 背景校正 + print("Applying background correction...\n") + df_correct = process_images_background(bil_path, mask) + df.iloc[:, 1:169] = df.iloc[:, 1:169].div(df_correct, axis=1) + + # 数据清理 + print("Cleaning data...\n") + df = df.dropna() + df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)] + df = df[df['area'] >= 500] + + # 使用pandas列选择:获取要删除的列名(从第 94 列到第 118 列,索引从0开始) + cols_to_remove = df.columns[np.r_[87:110, -10:-1]] + # cols_to_remove = df.columns[87:110] + # 删除指定列,保持DataFrame结构 + df = df.drop(columns=cols_to_remove) + segmentation_time = time.time() - segmentation_start_time + # 使用pandas列选择:选择从第二列开始的所有列(跳过第一列,通常是'Sample ID'或'filename') + # 保持DataFrame结构,不转换为numpy数组(.values会丢失列名和DataFrame结构) + df = df.iloc[:, :] + # 预测分类(分类阶段) + classification_start_time = time.time() + # 预测分类 + print("Predicting classes...\n") + loaded_model = load_model(primary_model_path) + df_pre = predict_with_model( + df, + primary_model_path, + model_type=primary_model_type, + ProcessMethods1=primary_process_methods1, + ProcessMethods2=primary_process_methods2 + ) + + # 对HDPE和LDPE进行二次分类 + # 从第一次分类结果中提取SECONDARY_TARGET_CLASSES类别的掩膜轮廓 + target_classes = set(secondary_target_classes or []) + mask_secondary = df_pre['Predictions'].isin(target_classes) + + if mask_secondary.any(): + # 只有在找到目标类别时才进行背景校正和二次分类 + print(f"Running secondary classification for classes: {sorted(target_classes)}") + + # 图像信息的背景矫正 + df_correct = shape_correct_background(bil_path, mask, filter_mask_original) + + # 创建新的掩膜mask_second,只包含目标类别的轮廓 + mask_second = np.zeros_like(mask, dtype=np.uint16) + + for idx in df_pre[mask_secondary].index: + contour = df_pre.loc[idx, 'contour'] + if isinstance(contour, list) and len(contour) > 0: + contour_array = np.array(contour, dtype=np.int32) + cv2.fillPoly(mask_second, [contour_array], idx + 1) # 使用索引+1作为标签 + + # 提取特征 + df_shape = extract_features(df_correct, mask_second) + + # 确保使用第2到13列作为模型输入特征 + if len(df_shape.columns) >= 13: + df_shape = df_shape.iloc[:, :13] + + # 二次分类:使用第二个模型预测并更新分类结果 + if secondary_model_path: + df_secondary = predict_with_model( + df_shape, + secondary_model_path, + model_type=secondary_model_type, + ProcessMethods1=secondary_process_methods1, + ProcessMethods2=secondary_process_methods2 + ) + df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1 + else: + print("Secondary model path not provided; skipping secondary classification.\n") + else: + print("No samples from target classes found; skipping secondary classification.\n") + + # 识别类别7中的背景阴影误判:通过边界清晰度特征 + # 真正的类别7边界清晰,背景阴影边界模糊 + class_7_mask = df_pre['Predictions'] == 7 + if class_7_mask.any(): + print(f"Processing {class_7_mask.sum()} samples with class 7 to identify background shadows...\n") + + # 将PIL Image转换为numpy数组 + if hasattr(rgb_img, 'mode'): # 检查是否是PIL Image + rgb_img_array = np.array(rgb_img) + else: + rgb_img_array = rgb_img + + # 转换为灰度图(用于计算梯度) + if len(rgb_img_array.shape) == 3: + gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY) + else: + gray_img = rgb_img_array + + # 计算梯度图(使用Sobel算子) + grad_x = cv2.Sobel(gray_img, cv2.CV_64F, 1, 0, ksize=3) + grad_y = cv2.Sobel(gray_img, cv2.CV_64F, 0, 1, ksize=3) + gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2) + + # 先收集所有类别7样本的边缘梯度值,用于确定阈值 + all_class7_gradients = [] + valid_indices = [] + for idx in df_pre[class_7_mask].index: + try: + contour = df_pre.loc[idx, 'contour'] + if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3: + continue + + contour_array = np.array(contour, dtype=np.int32) + if len(contour_array.shape) == 1: + continue + + mask_img = np.zeros(gray_img.shape, dtype=np.uint8) + cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2) + edge_gradients = gradient_magnitude[mask_img > 0] + if len(edge_gradients) > 0: + all_class7_gradients.extend(edge_gradients) + valid_indices.append(idx) + except: + continue + + # 基于类别7样本的梯度分布确定阈值 + # 使用类别7样本梯度值的中位数作为基准,低于某个分位数(如30%)的认为是背景阴影 + if len(all_class7_gradients) > 0: + gradient_threshold = np.percentile(all_class7_gradients, 30) # 使用类别7样本梯度值的30%分位数 + else: + gradient_threshold = np.percentile(gradient_magnitude, 30) # 如果没有有效样本,使用整张图的30%分位数 + + print(f"Gradient threshold for class 7: {gradient_threshold:.2f}\n") + + # 处理每个类别7的样本,判断是否为背景阴影 + indices_to_update = [] + for idx in valid_indices: + try: + contour = df_pre.loc[idx, 'contour'] + contour_array = np.array(contour, dtype=np.int32) + + # 创建轮廓掩膜(线宽为2像素,用于提取边缘) + mask_img = np.zeros(gray_img.shape, dtype=np.uint8) + cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2) + + # 提取轮廓边缘的梯度值 + edge_gradients = gradient_magnitude[mask_img > 0] + if len(edge_gradients) == 0: + continue + + # 计算轮廓边缘的平均梯度强度 + mean_gradient = np.mean(edge_gradients) + + # 如果平均梯度强度低于阈值,认为是背景阴影(边界模糊) + if mean_gradient < gradient_threshold: + indices_to_update.append(idx) + print( + f"Sample {idx}: mean_gradient={mean_gradient:.2f}, threshold={gradient_threshold:.2f} -> identified as background shadow") + + except Exception as e: + print(f"Error processing sample at index {idx}: {str(e)}") + continue + + # 将背景阴影的类别7改为类别9 + if indices_to_update: + df_pre.loc[indices_to_update, 'Predictions'] = 9 + print(f"Updated {len(indices_to_update)} samples from class 7 to class 9 (background shadows)\n") + else: + print("No samples needed to be updated from class 7\n") + classification_time = time.time() - classification_start_time + + df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1) + # 保存ENVI分类结果 + print("Saving ENVI classification results...\n") + save_envi_classification(bil_path, df_pre, output_path) + print(f"ENVI classification results saved to: {output_path}") + # 计算总耗时 + total_time = time.time() - total_start_time + + # 打印耗时统计 + print(f"\n{'=' * 60}") + print(f"处理完成") + print(f"{'=' * 60}") + print(f"分割耗时: {segmentation_time:.2f} 秒") + print(f"分类耗时: {classification_time:.2f} 秒") + print(f"总耗时: {total_time:.2f} 秒") + print(f"{'=' * 60}") + print(f"结果已保存至: {output_path}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/mask.py b/mask.py index 99cda97..8635058 100644 --- a/mask.py +++ b/mask.py @@ -901,22 +901,36 @@ class MicroplasticDetectorV2: self.filter_method = filter_method def load_model(self): - """加载Cellpose模型""" - if self.model is None: + """加载Cellpose模型(优先自定义,其次内置模型)""" + if self.model is not None: + return + + # 设备选择 + want_cpu = (self.device == 'cpu') + use_gpu = (not want_cpu) and torch.cuda.is_available() + + try: if self.model_path and Path(self.model_path).exists(): - # 加载自定义模型 - Cellpose 4.0 API + # 正确方式:在构造时通过 pretrained_model 传入自定义权重 self.model = models.CellposeModel( - gpu=torch.cuda.is_available() and self.device != 'cpu', - model_type=None + gpu=use_gpu, + pretrained_model=str(self.model_path), + model_type=None, # 自定义模型不需要 model_type ) - self.model.load_model(self.model_path) + print(f"已加载自定义模型: {self.model_path}") else: - # 使用预训练模型 - Cellpose 4.0 API + # 内置模型(如未安装 cpsam,可考虑改用 'cyto2') self.model = models.CellposeModel( - gpu=torch.cuda.is_available() and self.device != 'cpu', + gpu=use_gpu, model_type='cpsam' ) - print(f"模型已加载,使用设备: {'GPU' if torch.cuda.is_available() and self.device != 'cpu' else 'CPU'}") + print("使用内置模型: cpsam") + + print(f"模型已加载,使用设备: {'GPU' if use_gpu else 'CPU'}") + except Exception as e: + # 回退策略:若自定义/内置加载失败,退回CPU+cyto2以保证可用性 + print(f"模型加载失败({e}),回退到 CPU + cyto2") + self.model = models.CellposeModel(gpu=False, model_type='cyto2') def detect_microplastics(self, image_path: str, output_dir: str = None, diameter: float = 30, flow_threshold: float = 0.4, @@ -975,7 +989,8 @@ class MicroplasticDetectorV2: masked_image, diameter=diameter, flow_threshold=flow_threshold, - cellprob_threshold=cellprob_threshold) + cellprob_threshold=cellprob_threshold, + channels=[0, 0]) # 明确指定灰度图像通道 # 6. 分析检测结果 if debug: @@ -1247,8 +1262,9 @@ def detect_microplastic_mask_from_array(image, filter_method: str = 'shape', masked_image, diameter=diameter, flow_threshold=flow_threshold, - cellprob_threshold=cellprob_threshold - ) + cellprob_threshold=cellprob_threshold, + channels=[0, 0]) # 明确指定灰度图像通道 + return masks, filter_mask_original @@ -1268,7 +1284,7 @@ def main(): output_dir=output_dir + "_threshold", diameter=None, flow_threshold=0.4, - cellprob_threshold=0, + cellprob_threshold=-1, debug=True ) print(f"\n霍夫圆变换方法检测完成!") diff --git a/modelsave/HDPELDPE_model/svm.m b/modelsave/HDPELDPE_model/svm.m deleted file mode 100644 index e020440..0000000 Binary files a/modelsave/HDPELDPE_model/svm.m and /dev/null differ diff --git a/modelsave/scaler_params.pkl b/modelsave/scaler_params.pkl deleted file mode 100644 index 5d44095..0000000 Binary files a/modelsave/scaler_params.pkl and /dev/null differ diff --git a/modelsave/svm.m b/modelsave/svm.m deleted file mode 100644 index 1005d28..0000000 Binary files a/modelsave/svm.m and /dev/null differ diff --git a/shape_spectral.py b/shape_spectral.py index e060c2c..49315e1 100644 --- a/shape_spectral.py +++ b/shape_spectral.py @@ -9,6 +9,7 @@ import get_glcm import os import cv2 import numpy as np +import copy from plantcv.plantcv._helpers import _iterate_analysis, _cv2_findcontours, _object_composition, _grayscale_to_rgb, \ _scale_size from plantcv.plantcv import outputs, within_frame @@ -332,7 +333,7 @@ def process_images(full_bil_path, mask, outdir='None', debug="None"): spectral_array = pcv.readimage(filename=args.image, mode='envi') #过曝区域掩膜 bath_100 = spectral_array.array_data[:, :, 100] - bath_over = pcv.threshold.binary(gray_img=bath_100, threshold=11000, object_type='dark') + bath_over = pcv.threshold.binary(gray_img=bath_100, threshold=11000, object_type='dark')# # mask已经是16位标签掩膜,直接使用 # 确保掩膜是正确的数据类型 @@ -367,8 +368,11 @@ def process_images(full_bil_path, mask, outdir='None', debug="None"): shape_img = size(img=binary_img, labeled_mask=labeled_mask, n_labels=num_valid) # 分析光谱反射率 - # 应用过曝掩膜 - labeled_spectral_mask = mask * bath_over_binary + # 应用过曝掩膜 ;如果采集的反射率大于1.1,则会被错误判断为过曝区域导致该样本的光谱没有被统计 + # labeled_spectral_mask = mask * bath_over_binary + labeled_spectral_mask = mask + + spectral_hist = pcv.analyze.spectral_reflectance( hsi=spectral_array, @@ -377,13 +381,16 @@ def process_images(full_bil_path, mask, outdir='None', debug="None"): label=None ) - observations = pcv.outputs.observations + # 深拷贝观测结果,避免后续清空影响数据 + observations = copy.deepcopy(pcv.outputs.observations) + # 立即清空全局缓存,防止跨文件/阶段累加 + pcv.outputs.clear() # 将结果转换为列表 combined_data = process_plantcv_outputs(observations) return combined_data -# # 示例:批量处理指定路径下的光谱图像和掩膜文件 +# # # 示例:批量处理指定路径下的光谱图像和掩膜文件 # bil_path = r'D:\WQ\test\Traindata-05' # mask_path = r'D:\WQ\test\mask' # outdir = r"D:\WQ\test" # 输出文件夹路径