160 lines
5.3 KiB
Python
160 lines
5.3 KiB
Python
import argparse
|
||
import pandas as pd
|
||
import numpy as np
|
||
import itertools
|
||
import re
|
||
from collections import defaultdict
|
||
|
||
def parse_args():
|
||
ap = argparse.ArgumentParser(description="Select 3 bands for false-color to maximize inter-class separability")
|
||
ap.add_argument("--csv", required=True, help="Input CSV: first column is class label, rest are spectra columns")
|
||
ap.add_argument("--top_k", type=int, default=30, help="How many best single bands to preselect (default: 30)")
|
||
ap.add_argument("--top_triplets", type=int, default=10, help="How many best triplets to print (default: 10)")
|
||
ap.add_argument("--map_order", choices=["auto", "wavelength_bgr"], default="auto",
|
||
help="RGB mapping order: auto=use scored order; wavelength_bgr=short->B, mid->G, long->R")
|
||
return ap.parse_args()
|
||
|
||
def try_parse_wavelength(col_name):
|
||
# 支持 "wavelength_912.36" / "912.36" / "band_12" 等
|
||
m = re.search(r"([0-9]+(\.[0-9]+)?)", str(col_name))
|
||
if m:
|
||
try:
|
||
return float(m.group(1))
|
||
except Exception:
|
||
return None
|
||
return None
|
||
|
||
def read_data(csv_path):
|
||
df = pd.read_csv(csv_path)
|
||
if df.shape[1] < 4:
|
||
raise ValueError("需要至少1个类别列 + 3个光谱列")
|
||
y = df.iloc[:, 0].to_numpy()
|
||
X = df.iloc[:, 1:].to_numpy(dtype=np.float64)
|
||
cols = list(df.columns[1:])
|
||
waves = []
|
||
for c in cols:
|
||
w = try_parse_wavelength(c)
|
||
waves.append(w if w is not None else c)
|
||
return X, y, cols, waves
|
||
|
||
def compute_anova_f_scores(X, y):
|
||
# 一维ANOVA F-score: F = (SS_between/df_between) / (SS_within/df_within)
|
||
N, M = X.shape
|
||
classes = []
|
||
for cls in np.unique(y):
|
||
idx = (y == cls)
|
||
if np.sum(idx) > 0:
|
||
classes.append(idx)
|
||
k = len(classes)
|
||
if k < 2:
|
||
raise ValueError("类别不足2类,无法计算区分度")
|
||
|
||
mu = X.mean(axis=0)
|
||
between_ss = np.zeros(M, dtype=np.float64)
|
||
within_ss = np.zeros(M, dtype=np.float64)
|
||
|
||
for idx in classes:
|
||
ng = np.sum(idx)
|
||
Xg = X[idx]
|
||
if ng == 0:
|
||
continue
|
||
mg = Xg.mean(axis=0)
|
||
# 无偏方差,若ng==1则设为0
|
||
vg = Xg.var(axis=0, ddof=1) if ng > 1 else np.zeros(M, dtype=np.float64)
|
||
between_ss += ng * (mg - mu) ** 2
|
||
within_ss += (ng - 1) * vg
|
||
|
||
dfb = k - 1
|
||
dfw = N - k if N > k else 1
|
||
F = (between_ss / max(dfb, 1)) / (within_ss / max(dfw, 1) + 1e-12)
|
||
return F # shape (M,)
|
||
|
||
def score_triplet(X, y, idx_triplet, eps=1e-6):
|
||
# LDA准则: J = trace( (Sw+epsI)^-1 Sb ),三维空间
|
||
t = list(idx_triplet)
|
||
Xt = X[:, t] # (N,3)
|
||
classes = []
|
||
for cls in np.unique(y):
|
||
idx = (y == cls)
|
||
if np.sum(idx) > 0:
|
||
classes.append(idx)
|
||
|
||
m = Xt.mean(axis=0)
|
||
Sb = np.zeros((3, 3), dtype=np.float64)
|
||
Sw = np.zeros((3, 3), dtype=np.float64)
|
||
|
||
for idx in classes:
|
||
Xg = Xt[idx]
|
||
ng = Xg.shape[0]
|
||
if ng == 0:
|
||
continue
|
||
mg = Xg.mean(axis=0)
|
||
diff = (mg - m).reshape(3, 1)
|
||
Sb += ng * (diff @ diff.T)
|
||
if ng > 1:
|
||
Cg = np.cov(Xg, rowvar=False, ddof=1)
|
||
else:
|
||
Cg = np.zeros((3, 3), dtype=np.float64)
|
||
Sw += (ng - 1) * Cg
|
||
|
||
# 稳定逆
|
||
A = Sw + eps * np.eye(3)
|
||
try:
|
||
Ainv = np.linalg.pinv(A)
|
||
except Exception:
|
||
return -np.inf
|
||
J = np.trace(Ainv @ Sb)
|
||
return float(J)
|
||
|
||
def select_bands(csv_path, top_k=30, top_triplets=10, map_order="auto"):
|
||
X, y, cols, waves = read_data(csv_path)
|
||
F = compute_anova_f_scores(X, y)
|
||
order = np.argsort(-F)
|
||
top_idx = order[:min(top_k, X.shape[1])]
|
||
|
||
best = []
|
||
for comb in itertools.combinations(top_idx, 3):
|
||
s = score_triplet(X, y, comb)
|
||
best.append((s, comb))
|
||
best.sort(key=lambda x: -x[0])
|
||
|
||
results = []
|
||
for s, comb in best[:top_triplets]:
|
||
names = [cols[i] for i in comb]
|
||
wavs = [waves[i] for i in comb]
|
||
# 建议RGB映射
|
||
if map_order == "wavelength_bgr" and all(isinstance(w, (int, float)) for w in wavs):
|
||
order_rgb = np.argsort(wavs) # 小->大
|
||
# 小(blue), 中(green), 大(red)
|
||
b,g,r = [names[order_rgb[0]]], [names[order_rgb[1]]], [names[order_rgb[2]]]
|
||
rgb = {"R": r[0], "G": g[0], "B": b[0]}
|
||
rgb_w = {"R": wavs[order_rgb[2]], "G": wavs[order_rgb[1]], "B": wavs[order_rgb[0]]}
|
||
else:
|
||
# 直接用得分顺序:第一->R, 第二->G, 第三->B
|
||
rgb = {"R": names[0], "G": names[1], "B": names[2]}
|
||
rgb_w = {"R": waves[comb[0]], "G": waves[comb[1]], "B": waves[comb[2]]}
|
||
results.append({
|
||
"score": round(s, 6),
|
||
"bands": names,
|
||
"wavelengths": wavs,
|
||
"rgb_mapping": rgb,
|
||
"rgb_wavelengths": rgb_w
|
||
})
|
||
return results, top_idx, F
|
||
|
||
def main():
|
||
args = parse_args()
|
||
results, top_idx, F = select_bands(args.csv, args.top_k, args.top_triplets, args.map_order)
|
||
|
||
print(f"Top-{args.top_k} bands by ANOVA F-score (index:col_name):")
|
||
print(", ".join([f"{i}:{i}" for i in top_idx])) # 如需列名可自行打印
|
||
|
||
print("\nBest triplets:")
|
||
for r in results:
|
||
bands = r["bands"]
|
||
wavs = r["wavelengths"]
|
||
rgb = r["rgb_mapping"]
|
||
print(f"- score={r['score']:.4f} bands={bands} wavelengths={wavs} RGB={rgb}")
|
||
|
||
if __name__ == "__main__":
|
||
main() |