Changes
This commit is contained in:
@ -19,7 +19,6 @@ from sklearn.cross_decomposition import PLSRegression
|
|||||||
from sklearn.ensemble import GradientBoostingRegressor, AdaBoostRegressor, ExtraTreesRegressor
|
from sklearn.ensemble import GradientBoostingRegressor, AdaBoostRegressor, ExtraTreesRegressor
|
||||||
from sklearn.tree import DecisionTreeRegressor
|
from sklearn.tree import DecisionTreeRegressor
|
||||||
from sklearn.neural_network import MLPRegressor
|
from sklearn.neural_network import MLPRegressor
|
||||||
from joblib import parallel_backend
|
|
||||||
# 第三方模型导入
|
# 第三方模型导入
|
||||||
# try:
|
# try:
|
||||||
# import lightgbm as lgb
|
# import lightgbm as lgb
|
||||||
@ -43,6 +42,11 @@ import os
|
|||||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||||
|
|
||||||
|
|
||||||
|
def _sklearn_parallel_n_jobs() -> int:
|
||||||
|
"""PyInstaller 等打包环境下,joblib loky 会再次启动当前 exe,出现多个同名进程。"""
|
||||||
|
return 1 if getattr(sys, "frozen", False) else -1
|
||||||
|
|
||||||
|
|
||||||
class WaterQualityModelingBatch:
|
class WaterQualityModelingBatch:
|
||||||
"""水质参数反演批量建模类"""
|
"""水质参数反演批量建模类"""
|
||||||
|
|
||||||
@ -638,25 +642,26 @@ class WaterQualityModelingBatch:
|
|||||||
# 网格搜索 - 使用KFold代替StratifiedKFold
|
# 网格搜索 - 使用KFold代替StratifiedKFold
|
||||||
cv_strategy = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
|
cv_strategy = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
|
||||||
|
|
||||||
|
_n_jobs = _sklearn_parallel_n_jobs()
|
||||||
grid_search = GridSearchCV(
|
grid_search = GridSearchCV(
|
||||||
base_model,
|
base_model,
|
||||||
config['params'],
|
config['params'],
|
||||||
cv=cv_strategy,
|
cv=cv_strategy,
|
||||||
scoring=scoring,
|
scoring=scoring,
|
||||||
n_jobs=-1,
|
n_jobs=_n_jobs,
|
||||||
verbose=1
|
verbose=1
|
||||||
)
|
)
|
||||||
|
|
||||||
# 在训练集上训练模型
|
# 在训练集上训练模型
|
||||||
# with parallel_backend("threading", n_jobs=-1):
|
|
||||||
# grid_search.fit(X_train, y_train)
|
|
||||||
grid_search.fit(X_train, y_train)
|
grid_search.fit(X_train, y_train)
|
||||||
|
|
||||||
# 获取最佳模型
|
# 获取最佳模型
|
||||||
best_model = grid_search.best_estimator_
|
best_model = grid_search.best_estimator_
|
||||||
|
|
||||||
# 交叉验证评估(在训练集上)
|
# 交叉验证评估(在训练集上)
|
||||||
cv_scores = cross_val_score(best_model, X_train, y_train, cv=cv_strategy, scoring=scoring)
|
cv_scores = cross_val_score(
|
||||||
|
best_model, X_train, y_train, cv=cv_strategy, scoring=scoring, n_jobs=_n_jobs
|
||||||
|
)
|
||||||
|
|
||||||
# 计算训练集上的回归指标
|
# 计算训练集上的回归指标
|
||||||
y_train_pred = best_model.predict(X_train)
|
y_train_pred = best_model.predict(X_train)
|
||||||
|
|||||||
@ -2263,7 +2263,7 @@ class Step6Panel(QWidget):
|
|||||||
|
|
||||||
for i, method in enumerate(preproc_methods):
|
for i, method in enumerate(preproc_methods):
|
||||||
checkbox = QCheckBox(method)
|
checkbox = QCheckBox(method)
|
||||||
checkbox.setChecked(True) # 默认全选
|
checkbox.setChecked(method == 'SNV') # 默认全选
|
||||||
self.preproc_checkboxes[method] = checkbox
|
self.preproc_checkboxes[method] = checkbox
|
||||||
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
||||||
|
|
||||||
@ -2344,7 +2344,7 @@ class Step6Panel(QWidget):
|
|||||||
|
|
||||||
for i, method in enumerate(split_methods):
|
for i, method in enumerate(split_methods):
|
||||||
checkbox = QCheckBox(method)
|
checkbox = QCheckBox(method)
|
||||||
checkbox.setChecked(True) # 默认全选
|
checkbox.setChecked(method == 'spxy') # 默认全选
|
||||||
self.split_checkboxes[method] = checkbox
|
self.split_checkboxes[method] = checkbox
|
||||||
split_grid.addWidget(checkbox, 0, i)
|
split_grid.addWidget(checkbox, 0, i)
|
||||||
|
|
||||||
@ -4812,7 +4812,7 @@ class Step6_5Panel(QWidget):
|
|||||||
|
|
||||||
for i, method in enumerate(preproc_methods):
|
for i, method in enumerate(preproc_methods):
|
||||||
checkbox = QCheckBox(method)
|
checkbox = QCheckBox(method)
|
||||||
checkbox.setChecked(True)
|
checkbox.setChecked(method == 'SNV')
|
||||||
self.preproc_checkboxes[method] = checkbox
|
self.preproc_checkboxes[method] = checkbox
|
||||||
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
||||||
|
|
||||||
@ -4843,7 +4843,7 @@ class Step6_5Panel(QWidget):
|
|||||||
row_layout = QHBoxLayout()
|
row_layout = QHBoxLayout()
|
||||||
row_layout.setContentsMargins(0, 0, 0, 0)
|
row_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
checkbox = QCheckBox(algorithm)
|
checkbox = QCheckBox(algorithm)
|
||||||
checkbox.setChecked(True)
|
checkbox.setChecked(algorithm == 'chl_a')
|
||||||
spinbox = QSpinBox()
|
spinbox = QSpinBox()
|
||||||
spinbox.setRange(0, 500)
|
spinbox.setRange(0, 500)
|
||||||
spinbox.setValue(0)
|
spinbox.setValue(0)
|
||||||
@ -5144,7 +5144,7 @@ class Step6_75Panel(QWidget):
|
|||||||
for i, method in enumerate(regression_methods):
|
for i, method in enumerate(regression_methods):
|
||||||
checkbox = QCheckBox(method)
|
checkbox = QCheckBox(method)
|
||||||
# 默认选择常用的方法
|
# 默认选择常用的方法
|
||||||
if method in ['linear', 'exponential', 'power', 'logarithmic']:
|
if method in ['linear', 'exponential', 'power']:
|
||||||
checkbox.setChecked(True)
|
checkbox.setChecked(True)
|
||||||
self.method_checkboxes[method] = checkbox
|
self.method_checkboxes[method] = checkbox
|
||||||
method_grid.addWidget(checkbox, i // 3, i % 3)
|
method_grid.addWidget(checkbox, i // 3, i % 3)
|
||||||
@ -5251,7 +5251,7 @@ class Step6_75Panel(QWidget):
|
|||||||
for i, col in enumerate(self.csv_columns):
|
for i, col in enumerate(self.csv_columns):
|
||||||
checkbox = QCheckBox(col)
|
checkbox = QCheckBox(col)
|
||||||
# 默认选择一些常见的指数列
|
# 默认选择一些常见的指数列
|
||||||
if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd', 'b']):
|
if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd','BGA_','Chl_','Turb_']):
|
||||||
checkbox.setChecked(True)
|
checkbox.setChecked(True)
|
||||||
self.x_column_checkboxes[col] = checkbox
|
self.x_column_checkboxes[col] = checkbox
|
||||||
self.x_columns_layout.addWidget(checkbox, i // 3, i % 3)
|
self.x_columns_layout.addWidget(checkbox, i // 3, i % 3)
|
||||||
@ -5260,7 +5260,8 @@ class Step6_75Panel(QWidget):
|
|||||||
for i, col in enumerate(self.csv_columns):
|
for i, col in enumerate(self.csv_columns):
|
||||||
checkbox = QCheckBox(col)
|
checkbox = QCheckBox(col)
|
||||||
# 默认选择一些常见的水质参数列
|
# 默认选择一些常见的水质参数列
|
||||||
if any(keyword in col.lower() for keyword in ['chl', 'tn', 'tp', 'turbidity', 'do', 'ph', 'conductivity']):
|
target_keywords = ['Chlorophyll', 'COD', 'DO', 'PH', 'Temperature', 'spCond', 'Turbidity', 'TDS', 'Cl-', 'NO3-N', 'NH3-N', 'BGA', 'TT']
|
||||||
|
if any(keyword.lower() in col.lower() for keyword in target_keywords):
|
||||||
checkbox.setChecked(True)
|
checkbox.setChecked(True)
|
||||||
self.y_column_checkboxes[col] = checkbox
|
self.y_column_checkboxes[col] = checkbox
|
||||||
self.y_columns_layout.addWidget(checkbox, i // 2, i % 2)
|
self.y_columns_layout.addWidget(checkbox, i // 2, i % 2)
|
||||||
@ -6990,7 +6991,6 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#冻结,只显示1个exe
|
multiprocessing.freeze_support()
|
||||||
# multiprocessing.freeze_support()
|
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user