初始提交
This commit is contained in:
116
classification_model/WaveSelect/Spa.py
Normal file
116
classification_model/WaveSelect/Spa.py
Normal file
@ -0,0 +1,116 @@
|
||||
import scipy.stats
|
||||
import numpy as np
|
||||
from scipy.linalg import qr, inv, pinv
|
||||
import scipy.stats
|
||||
from progress.bar import Bar
|
||||
from matplotlib import pyplot as plt
|
||||
class SPA:
|
||||
def _projections_qr(self, X, k, M):
|
||||
X_projected = X.copy()
|
||||
norms = np.sum((X ** 2), axis=0)
|
||||
norm_max = np.amax(norms)
|
||||
X_projected.iloc[:, k] = X_projected.iloc[:, k] * 2 * norm_max / norms[k]
|
||||
_, __, order = qr(X_projected.to_numpy(), 0, pivoting=True)
|
||||
return order[:M].T
|
||||
|
||||
def _validation(self, Xcal, ycal, var_sel, Xval=None, yval=None):
|
||||
N = Xcal.shape[0]
|
||||
NV = Xval.shape[0] if Xval is not None else 0
|
||||
|
||||
yhat, e = None, None
|
||||
if NV > 0:
|
||||
Xcal_ones = np.hstack([np.ones((N, 1)), Xcal.iloc[:, var_sel].to_numpy()])
|
||||
b = np.linalg.lstsq(Xcal_ones, ycal, rcond=None)[0]
|
||||
Xval_ones = np.hstack([np.ones((NV, 1)), Xval.iloc[:, var_sel].to_numpy()])
|
||||
yhat = Xval_ones.dot(b)
|
||||
e = yval - yhat
|
||||
else:
|
||||
yhat = np.zeros((N, 1))
|
||||
for i in range(N):
|
||||
cal = np.hstack([np.arange(i), np.arange(i + 1, N)])
|
||||
X = Xcal.iloc[cal, var_sel]
|
||||
y = ycal.iloc[cal]
|
||||
X_ones = np.hstack([np.ones((N - 1, 1)), X.to_numpy()])
|
||||
b = np.linalg.lstsq(X_ones, y, rcond=None)[0]
|
||||
xtest = Xcal.iloc[i, var_sel].to_numpy()
|
||||
yhat[i] = np.hstack([1, xtest]).dot(b)
|
||||
e = ycal.to_numpy() - yhat
|
||||
return yhat, e
|
||||
|
||||
def spa(self, Xcal, ycal, m_min=1, m_max=None, Xval=None, yval=None, autoscaling=1, save_path=None):
|
||||
N, K = Xcal.shape
|
||||
m_max = min(N - 1, K) if m_max is None else m_max
|
||||
|
||||
normalization_factor = Xcal.std(ddof=1, axis=0) if autoscaling else np.ones(K)
|
||||
Xcaln = (Xcal - Xcal.mean()) / normalization_factor
|
||||
|
||||
SEL = np.zeros((m_max, K))
|
||||
with Bar('Projections :', max=K) as bar:
|
||||
for k in range(K):
|
||||
SEL[:, k] = self._projections_qr(Xcaln, k, m_max)
|
||||
bar.next()
|
||||
|
||||
PRESS = np.full((m_max + 1, K), np.inf)
|
||||
with Bar('Evaluating subsets:', max=K * (m_max - m_min + 1)) as bar:
|
||||
for k in range(K):
|
||||
for m in range(m_min, m_max + 1):
|
||||
var_sel = SEL[:m, k].astype(int)
|
||||
_, e = self._validation(Xcal, ycal, var_sel, Xval, yval)
|
||||
PRESS[m, k] = e.T @ e
|
||||
bar.next()
|
||||
|
||||
m_sel = np.argmin(PRESS, axis=0)
|
||||
k_sel = np.argmin(np.min(PRESS, axis=0))
|
||||
var_sel_phase2 = SEL[:m_sel[k_sel], k_sel].astype(int)
|
||||
|
||||
Xcal2 = np.hstack([np.ones((N, 1)), Xcal.iloc[:, var_sel_phase2].to_numpy()])
|
||||
b = np.linalg.lstsq(Xcal2, ycal, rcond=None)[0]
|
||||
std_deviation = Xcal2.std(ddof=1, axis=0)
|
||||
relev = np.abs(b * std_deviation)[1:]
|
||||
|
||||
index_decreasing_relev = np.argsort(-relev)
|
||||
PRESS_scree = np.empty(len(var_sel_phase2))
|
||||
for i in range(len(var_sel_phase2)):
|
||||
var_sel = var_sel_phase2[index_decreasing_relev[:i + 1]]
|
||||
_, e = self._validation(Xcal, ycal, var_sel, Xval, yval)
|
||||
PRESS_scree[i] = np.conj(e).T @ e
|
||||
|
||||
RMSEP_scree = np.sqrt(PRESS_scree / len(e))
|
||||
alpha = 0.25
|
||||
dof = len(e)
|
||||
fcrit = scipy.stats.f.ppf(1 - alpha, dof, dof)
|
||||
PRESS_crit = np.min(PRESS_scree) * fcrit
|
||||
i_crit = np.min(np.nonzero(PRESS_scree < PRESS_crit))
|
||||
i_crit = max(m_min, i_crit)
|
||||
var_sel = var_sel_phase2[index_decreasing_relev[:i_crit]]
|
||||
|
||||
# 绘图
|
||||
plt.figure()
|
||||
|
||||
# 设置字体为 Times New Roman
|
||||
plt.rcParams['font.sans-serif'] = ['Times New Roman']
|
||||
plt.rcParams['axes.unicode_minus'] = False # 确保负号显示正常
|
||||
|
||||
# 设置标题、标签和网格
|
||||
plt.xlabel('Number of variables included in the model', fontsize=14)
|
||||
plt.ylabel('RMSE', fontsize=14)
|
||||
plt.title(f'Final number of selected variables: {len(var_sel)} (RMSE={RMSEP_scree[i_crit]:.4f})', fontsize=16)
|
||||
|
||||
# 绘制 RMSEP 曲线
|
||||
plt.plot(RMSEP_scree, label='RMSEP Scree Plot')
|
||||
plt.scatter(i_crit, RMSEP_scree[i_crit], color='r', marker='s', label='Selected Point')
|
||||
|
||||
# 添加网格和图例
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
|
||||
# 显示或保存图像
|
||||
if save_path:
|
||||
plt.savefig(save_path, bbox_inches='tight', dpi=300)
|
||||
print(f"图像已保存至: {save_path}")
|
||||
else:
|
||||
plt.show()
|
||||
return var_sel, var_sel_phase2
|
||||
|
||||
def __repr__(self):
|
||||
return "SPA()"
|
||||
Reference in New Issue
Block a user