import argparse import pandas as pd import numpy as np import itertools import re from collections import defaultdict def parse_args(): ap = argparse.ArgumentParser(description="Select 3 bands for false-color to maximize inter-class separability") ap.add_argument("--csv", required=True, help="Input CSV: first column is class label, rest are spectra columns") ap.add_argument("--top_k", type=int, default=30, help="How many best single bands to preselect (default: 30)") ap.add_argument("--top_triplets", type=int, default=10, help="How many best triplets to print (default: 10)") ap.add_argument("--map_order", choices=["auto", "wavelength_bgr"], default="auto", help="RGB mapping order: auto=use scored order; wavelength_bgr=short->B, mid->G, long->R") return ap.parse_args() def try_parse_wavelength(col_name): # 支持 "wavelength_912.36" / "912.36" / "band_12" 等 m = re.search(r"([0-9]+(\.[0-9]+)?)", str(col_name)) if m: try: return float(m.group(1)) except Exception: return None return None def read_data(csv_path): df = pd.read_csv(csv_path) if df.shape[1] < 4: raise ValueError("需要至少1个类别列 + 3个光谱列") y = df.iloc[:, 0].to_numpy() X = df.iloc[:, 1:].to_numpy(dtype=np.float64) cols = list(df.columns[1:]) waves = [] for c in cols: w = try_parse_wavelength(c) waves.append(w if w is not None else c) return X, y, cols, waves def compute_anova_f_scores(X, y): # 一维ANOVA F-score: F = (SS_between/df_between) / (SS_within/df_within) N, M = X.shape classes = [] for cls in np.unique(y): idx = (y == cls) if np.sum(idx) > 0: classes.append(idx) k = len(classes) if k < 2: raise ValueError("类别不足2类,无法计算区分度") mu = X.mean(axis=0) between_ss = np.zeros(M, dtype=np.float64) within_ss = np.zeros(M, dtype=np.float64) for idx in classes: ng = np.sum(idx) Xg = X[idx] if ng == 0: continue mg = Xg.mean(axis=0) # 无偏方差,若ng==1则设为0 vg = Xg.var(axis=0, ddof=1) if ng > 1 else np.zeros(M, dtype=np.float64) between_ss += ng * (mg - mu) ** 2 within_ss += (ng - 1) * vg dfb = k - 1 dfw = N - k if N > k else 1 F = (between_ss / max(dfb, 1)) / (within_ss / max(dfw, 1) + 1e-12) return F # shape (M,) def score_triplet(X, y, idx_triplet, eps=1e-6): # LDA准则: J = trace( (Sw+epsI)^-1 Sb ),三维空间 t = list(idx_triplet) Xt = X[:, t] # (N,3) classes = [] for cls in np.unique(y): idx = (y == cls) if np.sum(idx) > 0: classes.append(idx) m = Xt.mean(axis=0) Sb = np.zeros((3, 3), dtype=np.float64) Sw = np.zeros((3, 3), dtype=np.float64) for idx in classes: Xg = Xt[idx] ng = Xg.shape[0] if ng == 0: continue mg = Xg.mean(axis=0) diff = (mg - m).reshape(3, 1) Sb += ng * (diff @ diff.T) if ng > 1: Cg = np.cov(Xg, rowvar=False, ddof=1) else: Cg = np.zeros((3, 3), dtype=np.float64) Sw += (ng - 1) * Cg # 稳定逆 A = Sw + eps * np.eye(3) try: Ainv = np.linalg.pinv(A) except Exception: return -np.inf J = np.trace(Ainv @ Sb) return float(J) def select_bands(csv_path, top_k=30, top_triplets=10, map_order="auto"): X, y, cols, waves = read_data(csv_path) F = compute_anova_f_scores(X, y) order = np.argsort(-F) top_idx = order[:min(top_k, X.shape[1])] best = [] for comb in itertools.combinations(top_idx, 3): s = score_triplet(X, y, comb) best.append((s, comb)) best.sort(key=lambda x: -x[0]) results = [] for s, comb in best[:top_triplets]: names = [cols[i] for i in comb] wavs = [waves[i] for i in comb] # 建议RGB映射 if map_order == "wavelength_bgr" and all(isinstance(w, (int, float)) for w in wavs): order_rgb = np.argsort(wavs) # 小->大 # 小(blue), 中(green), 大(red) b,g,r = [names[order_rgb[0]]], [names[order_rgb[1]]], [names[order_rgb[2]]] rgb = {"R": r[0], "G": g[0], "B": b[0]} rgb_w = {"R": wavs[order_rgb[2]], "G": wavs[order_rgb[1]], "B": wavs[order_rgb[0]]} else: # 直接用得分顺序:第一->R, 第二->G, 第三->B rgb = {"R": names[0], "G": names[1], "B": names[2]} rgb_w = {"R": waves[comb[0]], "G": waves[comb[1]], "B": waves[comb[2]]} results.append({ "score": round(s, 6), "bands": names, "wavelengths": wavs, "rgb_mapping": rgb, "rgb_wavelengths": rgb_w }) return results, top_idx, F def main(): args = parse_args() results, top_idx, F = select_bands(args.csv, args.top_k, args.top_triplets, args.map_order) print(f"Top-{args.top_k} bands by ANOVA F-score (index:col_name):") print(", ".join([f"{i}:{i}" for i in top_idx])) # 如需列名可自行打印 print("\nBest triplets:") for r in results: bands = r["bands"] wavs = r["wavelengths"] rgb = r["rgb_mapping"] print(f"- score={r['score']:.4f} bands={bands} wavelengths={wavs} RGB={rgb}") if __name__ == "__main__": main()