Files
micro_plastic/classification_model/WaveSelect/Cars.py
2026-02-25 09:42:51 +08:00

177 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import copy
from sklearn.cross_decomposition import PLSRegression
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold
def PC_Cross_Validation(X, y, pc, cv):
'''
X : 光谱矩阵 (DataFrame) nxm
y : 浓度阵 (Series) (化学值)
pc: 最大主成分数
cv: 交叉验证数量
return :
RMSECV: 各主成分数对应的RMSECV
rindex: 最佳主成分数
'''
kf = KFold(n_splits=cv)
RMSECV = []
for i in range(pc):
RMSE = []
for train_index, test_index in kf.split(X):
x_train, x_test = X.iloc[train_index], X.iloc[test_index]
y_train, y_test = y.iloc[train_index], y.iloc[test_index]
pls = PLSRegression(n_components=i + 1)
pls.fit(x_train, y_train)
y_predict = pls.predict(x_test)
RMSE.append(np.sqrt(mean_squared_error(y_test, y_predict)))
RMSE_mean = np.mean(RMSE)
RMSECV.append(RMSE_mean)
rindex = np.argmin(RMSECV)
return RMSECV, rindex
def Cross_Validation(X, y, pc, cv):
'''
X : 光谱矩阵 (DataFrame) nxm
y : 浓度阵 (Series) (化学值)
pc: 最大主成分数
cv: 交叉验证数量
return :
RMSECV: 各主成分数对应的RMSECV
'''
kf = KFold(n_splits=cv)
RMSE = []
for train_index, test_index in kf.split(X):
x_train, x_test = X.iloc[train_index], X.iloc[test_index]
y_train, y_test = y.iloc[train_index], y.iloc[test_index]
pls = PLSRegression(n_components=pc)
pls.fit(x_train, y_train)
y_predict = pls.predict(x_test)
RMSE.append(np.sqrt(mean_squared_error(y_test, y_predict)))
RMSE_mean = np.mean(RMSE)
return RMSE_mean
def CARS_Cloud(X, y, N=50, f=20, cv=10, save_fig=False, save_path=None):
'''
X : 光谱矩阵 (DataFrame 或 ndarray)
y : 浓度阵 (Series 或 ndarray)
N : 蒙特卡洛迭代次数
f : 最大特征数
cv : 交叉验证的次数
save_fig : 是否保存图像
save_path : 图像保存路径
return :
OptWave : 选择的波长
'''
p = 0.8
m, n = X.shape
u = np.power((n / 2), (1 / (N - 1)))
k = (1 / (N - 1)) * np.log(n / 2)
cal_num = np.round(m * p)
b2 = np.arange(n)
x = X # 将 DataFrame 转换为 numpy 数组
y = y # 将 Series 转换为 numpy 数组
D = np.vstack((np.array(b2).reshape(1, -1), x))
WaveData = []
WaveNum = []
RMSECV = []
r = []
for i in range(1, N + 1):
r.append(u * np.exp(-1 * k * i))
wave_num = int(np.round(r[i - 1] * n))
WaveNum = np.hstack((WaveNum, wave_num))
cal_index = np.random.choice(np.arange(m), size=int(cal_num), replace=False)
wave_index = b2[:wave_num].reshape(1, -1)[0]
# 使用 np.ix_ 来进行行列索引
xcal = x[np.ix_(cal_index, wave_index)] # 选择对应的行和列
ycal = y[cal_index] # 选择对应的 y
# 将 ycal 转换为一维数组
ycal = ycal.ravel() # 使其成为一维数组
x = x[:, wave_index] # 更新 x
D = D[:, wave_index] # 更新 D
d = D[0, :].reshape(1, -1)
wnum = n - wave_num
if wnum > 0:
d = np.hstack((d, np.full((1, wnum), -1)))
if len(WaveData) == 0:
WaveData = d
else:
WaveData = np.vstack((WaveData, d.reshape(1, -1)))
if wave_num < f:
f = wave_num
pls = PLSRegression(n_components=f)
pls.fit(xcal, ycal)
beta = pls.coef_
# 针对新版sklearn处理 coef_ 的方式
if beta.shape[0] == 1: # 新版sklearn(1, x)
b = np.abs(beta[0]) # 从第一行提取数据
coeff = beta[0, b2] # 修改为beta[0, b2]因为coef只有一行
else: # 旧版sklearn(x, 1)
b = np.abs(beta[:, 0]) # 从列中提取数据
coeff = beta[b2, 0] # 修改为beta[b2, 0]因为coef只有一列
b2 = np.argsort(-b, axis=0)
coef = copy.deepcopy(beta)
coeff = coef[b2, :].reshape(len(b2), -1)
rmsecv, rindex = PC_Cross_Validation(pd.DataFrame(xcal), pd.Series(ycal), f, cv)
RMSECV.append(Cross_Validation(pd.DataFrame(xcal), pd.Series(ycal), rindex + 1, cv))
WAVE = []
for i in range(WaveData.shape[0]):
wd = WaveData[i, :]
WD = np.ones((len(wd)))
for j in range(len(wd)):
ind = np.where(wd == j)
if len(ind[0]) == 0:
WD[j] = 0
else:
WD[j] = wd[ind[0]]
if len(WAVE) == 0:
WAVE = copy.deepcopy(WD)
else:
WAVE = np.vstack((WAVE, WD.reshape(1, -1)))
MinIndex = np.argmin(RMSECV)
Optimal = WAVE[MinIndex, :]
boindex = np.where(Optimal != 0)
OptWave = boindex[0]
plt.figure(figsize=(12, 10))
# 设置字体为新罗马
plt.rcParams['font.sans-serif'] = ['Times New Roman'] # 使用 Times New Roman 字体
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
fonts = 20
plt.subplot(211)
plt.xlabel('Monte Carlo Iterations', fontsize=fonts)
plt.ylabel('Number of Selected Wavelengths', fontsize=fonts)
plt.title('Optimal Iteration: ' + str(MinIndex), fontsize=fonts)
plt.plot(np.arange(N), WaveNum)
plt.subplot(212)
plt.xlabel('Monte Carlo Iterations', fontsize=fonts)
plt.ylabel('RMSECV', fontsize=fonts)
plt.plot(np.arange(N), RMSECV)
# 保存图像
if save_fig:
plt.savefig(save_path) # 保存图像到文件
print(f"The figure has been saved as {save_path}")
# plt.show()
return OptWave