3.Rbob.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
  1. import pandas as pd
  2. import numpy as np
  3. import xgboost as xgb
  4. from xgboost import XGBRegressor
  5. from sklearn.metrics import mean_squared_error, r2_score
  6. import matplotlib.pyplot as plt
  7. from skopt import BayesSearchCV
  8. from sklearn.preprocessing import StandardScaler
  9. from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, TimeSeriesSplit
  10. import argparse # 添加argparse模块用于解析命令行参数
  11. import itertools
  12. import random
  13. from skopt.space import Real, Integer, Categorical
  14. import json
  15. from Dtool import fill_missing_values, reverse_column
  16. from api import fetch_data_by_indicators
  17. # 添加命令行参数解析
  18. def parse_arguments():
  19. parser = argparse.ArgumentParser(description='RBOB汽油裂解预测模型')
  20. # XGBoost参数
  21. parser.add_argument('--objective', type=str, default='reg:squarederror', help='XGBoost目标函数')
  22. parser.add_argument('--learning_rate', type=float, default=0.1, help='学习率')
  23. parser.add_argument('--max_depth', type=int, default=8, help='最大树深度')
  24. parser.add_argument('--min_child_weight', type=int, default=3, help='最小子权重')
  25. parser.add_argument('--gamma', type=float, default=2, help='gamma参数')
  26. parser.add_argument('--subsample', type=float, default=0.85, help='子样本比例')
  27. parser.add_argument('--colsample_bytree', type=float, default=0.75, help='每棵树的列采样率')
  28. parser.add_argument('--eval_metric', type=str, default='rmse', help='评估指标')
  29. parser.add_argument('--seed', type=int, default=42, help='随机种子')
  30. parser.add_argument('--reg_alpha', type=float, default=0.45, help='L1正则化')
  31. parser.add_argument('--reg_lambda', type=float, default=1.29, help='L2正则化')
  32. # 其他参数
  33. parser.add_argument('--num_boost_round', type=int, default=1000, help='提升迭代次数')
  34. parser.add_argument('--use_hyperparam_tuning', type=str, default='False', help='是否使用超参数调优')
  35. parser.add_argument('--output_prefix', type=str, default='', help='输出文件前缀,如传入1234则生成1234_update.xlsx')
  36. args = parser.parse_args()
  37. return args
  38. # 使用示例
  39. INDICATOR_IDS = ["RBWTICKMc1", "C2406121350446455",'USGGBE02 Index', "Cinjcjc4 index",'injcjc4 index','C2201059138_241106232710','C2406036178','C22411071623523660','C2312081670','REFOC-T-EIA_241114135248','C2304065621_241024124344','REFOC-T-EIA_241114135248','C22503031424010431']
  40. # 这些变量将在main函数中从命令行参数更新
  41. NUM_BOOST_ROUND = 1000
  42. RANDOM_STATE = 42
  43. USE_HYPERPARAM_TUNING = False # 若 False 则直接使用 xgb.train
  44. TARGET_COL = '美国RBOB汽油裂解'
  45. TEST_PERIOD = 20
  46. SEARCH_MODE = 'random' # 可选 'grid' / 'bayesian' / 'random'
  47. SHOW_PLOTS = False
  48. ADJUST_FULL_PREDICTIONS = True
  49. TARGET_NAME = '美国RBOB汽油裂解'
  50. CLASSIFICATION = '原油'
  51. MODEL_FRAMEWORK = 'XGBoost'
  52. CREATOR = '张立舟'
  53. PRED_DATE = '2024/11/11'
  54. FREQUENCY = '月度'
  55. OUTPUT_PATH = 'update.xlsx'
  56. # XGBoost默认参数,将在main函数中从命令行参数更新
  57. DEFAULT_PARAMS = {
  58. 'objective': 'reg:squarederror',
  59. 'learning_rate': 0.1,
  60. 'max_depth': 8,
  61. 'min_child_weight': 3,
  62. 'gamma': 2,
  63. 'subsample': 0.85,
  64. 'colsample_bytree': 0.75,
  65. 'eval_metric': 'rmse',
  66. 'seed': 42,
  67. 'reg_alpha': 0.45,
  68. 'reg_lambda': 1.29,
  69. }
  70. # —— 因子预处理相关配置 ——
  71. FILL_METHODS = {
  72. '美国2年通胀预期': 'rolling_mean_5',
  73. '美国首次申领失业金人数/4WMA': 'interpolate',
  74. '道琼斯旅游与休闲/工业平均指数': 'interpolate',
  75. '美国EIA成品油总库存(预测/供应需求3年季节性)': 'interpolate',
  76. '美国成品车用汽油倒推产量(预测/汽油库存维持上年季节性)/8WMA': 'interpolate',
  77. '美国成品车用汽油炼厂与调和装置净产量/4WMA(预测/上年季节性)超季节性/5年': 'interpolate',
  78. '美国炼厂可用产能(路透)(预测)': 'interpolate',
  79. '美国炼厂CDU装置检修量(新)': 'interpolate',
  80. '美湾单位辛烷值价格(预测/季节性)': 'interpolate',
  81. '美国汽油调和组分RBOB库存(预测/线性外推)超季节性/3年': 'interpolate'
  82. }
  83. SHIFT_CONFIG = [
  84. ('美国2年通胀预期', 56, '美国2年通胀预期_提前56天'),
  85. ('美国首次申领失业金人数/4WMA', 100, '美国首次申领失业金人数/4WMA_提前100天'),
  86. ('美国首次申领失业金人数/4WMA', 112, '美国首次申领失业金人数/4WMA_提前112天'),
  87. ('道琼斯旅游与休闲/工业平均指数', 14, '道琼斯旅游与休闲/工业平均指数_提前14天'),
  88. ('美国EIA成品油总库存(预测/供应需求3年季节性)', 15,
  89. '美国EIA成品油总库存(预测/供应需求3年季节性)_提前15天'),
  90. ('美国成品车用汽油炼厂与调和装置净产量/4WMA(预测/上年季节性)超季节性/5年',
  91. 30,
  92. '美国成品车用汽油炼厂与调和装置净产量/4WMA(预测/上年季节性)超季节性/5年_提前30天'),
  93. ('美国炼厂CDU装置检修量(新)', 30, '美国炼厂CDU装置检修量(新)_提前30天'),
  94. ('美国炼厂可用产能(路透)(预测)', 100,
  95. '美国炼厂可用产能(路透)(预测)_提前100天')
  96. ]
  97. REVERSE_CONFIG = [
  98. ('美国首次申领失业金人数/4WMA',
  99. '美国首次申领失业金人数/4WMA_逆序'),
  100. ('美国首次申领失业金人数/4WMA_提前100天',
  101. '美国首次申领失业金人数/4WMA_提前100天_逆序'),
  102. ('美国首次申领失业金人数/4WMA_提前112天',
  103. '美国首次申领失业金人数/4WMA_提前112天_逆序'),
  104. ('美国EIA成品油总库存(预测/供应需求3年季节性)',
  105. '美国EIA成品油总库存(预测/供应需求3年季节性)_逆序'),
  106. ('美国EIA成品油总库存(预测/供应需求3年季节性)_提前15天',
  107. '美国EIA成品油总库存(预测/供应需求3年季节性)_提前15天_逆序'),
  108. ('美国炼厂可用产能(路透)(预测)_提前100天',
  109. '美国炼厂可用产能(路透)(预测)_逆序'),
  110. ('美国汽油调和组分RBOB库存(预测/线性外推)超季节性/3年',
  111. '美国汽油调和组分RBOB库存(预测/线性外推)超季节性/3年_逆序')
  112. ]
  113. SPECIAL_REVERSE = {
  114. '美国汽油调和组分RBOB库存(预测/线性外推)超季节性/3年_逆序_2022-01-01': {
  115. 'base_column': '美国汽油调和组分RBOB库存(预测/线性外推)超季节性/3年_逆序',
  116. 'condition_date': pd.Timestamp('2022-01-01')
  117. }
  118. }
  119. # ------------ 数据加载与预处理 ------------
  120. def load_and_preprocess_data():
  121. # 直接从API获取数据
  122. df = fetch_data_by_indicators(INDICATOR_IDS)
  123. # print("Initial DataFrame columns:", df.columns)
  124. df.index = pd.to_datetime(df.index)
  125. df_daily = df.copy()
  126. df_daily['Date'] = df_daily.index
  127. df_daily = df_daily.reset_index(drop=True)
  128. #预处理流程
  129. df_daily = fill_missing_values(df_daily, FILL_METHODS, return_only_filled=False)
  130. for col, days, new_col in SHIFT_CONFIG:
  131. df_daily[new_col] = df_daily[col].shift(days)
  132. last_idx = df_daily[TARGET_COL].last_valid_index()
  133. last_day = df_daily.loc[last_idx, 'Date']
  134. df_daily = df_daily[(df_daily['Date'] >= '2009-08-01') & (df_daily['Date'] <= last_day + pd.Timedelta(days=30))]
  135. df_daily = df_daily[df_daily['Date'].dt.weekday < 5]
  136. for base, new in REVERSE_CONFIG:
  137. df_daily[new] = reverse_column(df_daily, base)
  138. for col, cfg in SPECIAL_REVERSE.items():
  139. df_daily[col] = np.where(df_daily['Date'] >= cfg['condition_date'],
  140. df_daily[cfg['base_column']],
  141. np.nan)
  142. df_daily = df_daily[(df_daily['Date'] > last_day)|df_daily[TARGET_COL].notna()]
  143. return df_daily, last_day
  144. # ------------ 划分与特征构建 ------------
  145. def split_and_build_features(df_daily, last_day):
  146. train = df_daily[df_daily['Date'] <= last_day].copy()
  147. test = train.tail(TEST_PERIOD).copy()
  148. train = train.iloc[:-TEST_PERIOD].copy()
  149. future = df_daily[df_daily['Date'] > last_day].copy()
  150. feature_columns = [
  151. '美湾单位辛烷值价格(预测/季节性)',
  152. '美国炼厂CDU装置检修量(新)_提前30天',
  153. '美国EIA成品油总库存(预测/供应需求3年季节性)_提前15天_逆序',
  154. '美国首次申领失业金人数/4WMA_提前100天_逆序',
  155. '美国成品车用汽油倒推产量(预测/汽油库存维持上年季节性)/8WMA',
  156. '美国成品车用汽油炼厂与调和装置净产量/4WMA(预测/上年季节性)超季节性/5年_提前30天',
  157. '美国汽油调和组分RBOB库存(预测/线性外推)超季节性/3年_逆序_2022-01-01'
  158. ]
  159. X_train = train[feature_columns]
  160. y_train = train[TARGET_COL]
  161. X_test = test[feature_columns]
  162. y_test = test[TARGET_COL]
  163. X_future = future[feature_columns]
  164. return X_train, y_train, X_test, y_test, X_future, train, test, future
  165. # ------------ 特征缩放与异常值权重 ------------
  166. def scale_and_weight_features(X_train, X_test, X_future):
  167. scaler = StandardScaler()
  168. X_tr = scaler.fit_transform(X_train)
  169. X_te = scaler.transform(X_test)
  170. X_fu = scaler.transform(X_future)
  171. return scaler, X_tr, X_te, X_fu
  172. def detect_outliers_weights(X,weight_normal=1.0,weight_outlier=0.05,threshold=3):
  173. z = np.abs((X - X.mean()) / X.std())
  174. mask = (z > threshold).any(axis=1)
  175. return np.where(mask, weight_outlier, weight_normal)
  176. # ------------ 模型训练 ------------
  177. def train_model_with_tuning(X_tr, y_tr, X_te, y_te, weights, use_tuning):
  178. if use_tuning:
  179. param_dist = {
  180. 'learning_rate': list(np.arange(0.01, 0.11, 0.01)),
  181. 'max_depth': list(range(4, 11)),
  182. 'min_child_weight': list(range(1, 6)),
  183. 'gamma': list(np.arange(0, 0.6, 0.1)),
  184. 'subsample': list(np.arange(0.5, 1.01, 0.05)),
  185. 'colsample_bytree': list(np.arange(0.5, 1.01, 0.05)),
  186. 'reg_alpha': [0, 0.1, 0.2, 0.3, 0.4, 0.45, 0.5],
  187. 'reg_lambda': list(np.arange(1.0, 1.6, 0.1))
  188. }
  189. # 将数据转换为DMatrix格式
  190. dtrain = xgb.DMatrix(X_tr, label=y_tr, weight=weights)
  191. dtest = xgb.DMatrix(X_te, label=y_te)
  192. # 基础参数设置
  193. base_params = {
  194. 'objective': 'reg:squarederror',
  195. 'eval_metric': 'rmse',
  196. 'seed': RANDOM_STATE
  197. }
  198. best_score = float('inf')
  199. best_params = None
  200. # 网格搜索
  201. if SEARCH_MODE == 'grid':
  202. param_combinations = [dict(zip(param_dist.keys(), v))
  203. for v in itertools.product(*param_dist.values())]
  204. for params in param_combinations:
  205. curr_params = {**base_params, **params}
  206. cv_results = xgb.cv(curr_params, dtrain,
  207. num_boost_round=NUM_BOOST_ROUND,
  208. nfold=3,
  209. early_stopping_rounds=20,
  210. verbose_eval=False)
  211. score = cv_results['test-rmse-mean'].min()
  212. if score < best_score:
  213. best_score = score
  214. best_params = curr_params
  215. # 贝叶斯搜索
  216. elif SEARCH_MODE == 'bayesian':
  217. search_spaces = {
  218. 'learning_rate': Real(0.01, 0.11, prior='uniform'),
  219. 'max_depth': Integer(4, 11),
  220. 'min_child_weight': Integer(1, 6),
  221. 'gamma': Real(0.0, 0.6, prior='uniform'),
  222. 'subsample': Real(0.5, 1.01, prior='uniform'),
  223. 'colsample_bytree': Real(0.5, 1.01, prior='uniform'),
  224. 'reg_alpha': Real(0.0, 0.5, prior='uniform'),
  225. 'reg_lambda': Real(1.0, 1.6, prior='uniform')
  226. }
  227. def objective(params):
  228. curr_params = {**base_params, **params}
  229. cv_results = xgb.cv(curr_params, dtrain,
  230. num_boost_round=NUM_BOOST_ROUND,
  231. nfold=3,
  232. early_stopping_rounds=20,
  233. verbose_eval=False)
  234. return cv_results['test-rmse-mean'].min()
  235. # 执行贝叶斯优化
  236. from skopt import gp_minimize
  237. result = gp_minimize(
  238. objective,
  239. dimensions=[space for space in search_spaces.values()],
  240. n_calls=50,
  241. random_state=RANDOM_STATE
  242. )
  243. best_params = dict(zip(search_spaces.keys(), result.x))
  244. best_params = {**base_params, **best_params}
  245. best_score = result.fun
  246. # 随机搜索
  247. else:
  248. for _ in range(50):
  249. params = {k: random.choice(v) for k, v in param_dist.items()}
  250. curr_params = {**base_params, **params}
  251. cv_results = xgb.cv(curr_params, dtrain,
  252. num_boost_round=NUM_BOOST_ROUND,
  253. nfold=3,
  254. early_stopping_rounds=20,
  255. verbose_eval=False)
  256. score = cv_results['test-rmse-mean'].min()
  257. if score < best_score:
  258. best_score = score
  259. best_params = curr_params
  260. print("调优后的最佳参数:", best_params)
  261. print("最佳得分:", best_score)
  262. # 使用最佳参数训练最终模型
  263. best_model = xgb.train(best_params,
  264. dtrain,
  265. num_boost_round=NUM_BOOST_ROUND,
  266. evals=[(dtrain, 'Train'), (dtest, 'Test')],
  267. early_stopping_rounds=20,
  268. verbose_eval=False)
  269. else:
  270. # 直接使用默认参数训练
  271. dtrain = xgb.DMatrix(X_tr, label=y_tr, weight=weights)
  272. dtest = xgb.DMatrix(X_te, label=y_te)
  273. best_model = xgb.train(DEFAULT_PARAMS,
  274. dtrain,
  275. num_boost_round=NUM_BOOST_ROUND,
  276. evals=[(dtrain, 'Train'),
  277. (dtest, 'Test')],
  278. verbose_eval=False)
  279. return best_model
  280. # ------------ 评估与预测 ------------
  281. def evaluate_and_predict(model, scaler, X_tr, y_tr, X_te, y_te, X_fu, use_tuning):
  282. X_tr_s = scaler.transform(X_tr)
  283. X_te_s = scaler.transform(X_te)
  284. X_fu_s = scaler.transform(X_fu)
  285. if isinstance(model, xgb.Booster):
  286. y_tr_pred = model.predict(xgb.DMatrix(X_tr_s))
  287. y_te_pred = model.predict(xgb.DMatrix(X_te_s))
  288. y_fu_pred = model.predict(xgb.DMatrix(X_fu_s))
  289. else:
  290. y_tr_pred = model.predict(X_tr_s)
  291. y_te_pred = model.predict(X_te_s)
  292. y_fu_pred = model.predict(X_fu_s)
  293. # 计算评估指标并保留4位有效数字
  294. train_mse = float(f"{mean_squared_error(y_tr, y_tr_pred):.4g}")
  295. test_mse = float(f"{mean_squared_error(y_te, y_te_pred):.4g}")
  296. train_r2 = float(f"{r2_score(y_tr, y_tr_pred):.4g}")
  297. test_r2 = float(f"{r2_score(y_te, y_te_pred):.4g}") if len(y_te) >= 2 else None
  298. print("Train MSE:", train_mse, "Test MSE:", test_mse)
  299. if len(y_te) >= 2:
  300. print("Train R2:", train_r2, "Test R2:", test_r2)
  301. else:
  302. print("Test 样本不足,跳过 R² 计算")
  303. metrics = {
  304. 'train_mse': train_mse,
  305. 'test_mse': test_mse,
  306. 'train_r2': train_r2,
  307. 'test_r2': test_r2
  308. }
  309. # 保存为JSON
  310. json_path = 'model_metrics.json'
  311. try:
  312. with open(json_path, 'r', encoding='utf-8') as f:
  313. existing_metrics = json.load(f)
  314. except FileNotFoundError:
  315. existing_metrics = []
  316. existing_metrics.append(metrics)
  317. with open(json_path, 'w', encoding='utf-8') as f:
  318. json.dump(existing_metrics, f, ensure_ascii=False, indent=4)
  319. print(f"评估指标已保存至 {json_path}")
  320. return y_tr_pred, y_te_pred, y_fu_pred
  321. # ------------ 结果后处理(生成日度 & 月度 DataFrame) ------------
  322. def merge_and_prepare_df(train, test, future, y_te_pred, y_fu_pred):
  323. # 合并历史与未来预测
  324. test = test.copy()
  325. future = future.copy()
  326. test['预测值'] = y_te_pred
  327. future['预测值'] = y_fu_pred
  328. hist_actual = pd.concat([
  329. train[train['Date'].dt.year >= 2023][['Date', TARGET_COL]],
  330. test[['Date', TARGET_COL]]
  331. ])
  332. hist_actual.columns = ['Date', '实际值']
  333. future_pred = future[future['Date'] >= '2022-08-01'][['Date', '预测值']].rename(columns={'预测值': TARGET_COL}).copy()
  334. last_val = hist_actual.iloc[-1]['实际值']
  335. future_pred[TARGET_COL] = future_pred[TARGET_COL].astype(last_val.dtype)
  336. future_pred.iloc[0, 1] = last_val
  337. # 日度重采样
  338. merged = pd.merge(hist_actual, future_pred,on='Date', how='outer').sort_values('Date', ascending=False)
  339. daily_df = merged.copy()
  340. # 月度重采样
  341. monthly_df = daily_df.copy()
  342. monthly_df['Date'] = pd.to_datetime(monthly_df['Date'])
  343. monthly_df.set_index('Date', inplace=True)
  344. monthly_df = monthly_df.resample('ME').mean().reset_index()
  345. # 方向准确率
  346. pred_dir = np.sign(monthly_df[TARGET_COL].diff())
  347. true_dir = np.sign(monthly_df['实际值'].diff())
  348. valid = monthly_df[TARGET_COL].notna() & monthly_df['实际值'].notna()
  349. monthly_df['方向准确率'] = np.where(valid & (pred_dir == true_dir), '正确',
  350. np.where(valid & (pred_dir != true_dir), '错误', np.nan))
  351. # 绝对偏差
  352. monthly_df['绝对偏差'] = (monthly_df[TARGET_COL] - monthly_df['实际值']).abs()
  353. monthly_df = monthly_df.sort_values('Date', ascending=False).reset_index(drop=True)
  354. return daily_df, monthly_df
  355. def generate_and_fill_excel(
  356. daily_df,
  357. monthly_df,
  358. target_name, # 写入的"预测标的"显示名
  359. classification, # 列表页-分类
  360. model_framework, # 列表页-模型框架
  361. creator, # 列表页-创建人
  362. pred_date, # 列表页-预测日期
  363. frequency, # 列表页-预测频度
  364. output_path='update.xlsx'
  365. ):
  366. with pd.ExcelWriter(output_path, engine='xlsxwriter') as writer:
  367. workbook = writer.book
  368. # —— 计算三个汇总值 ——
  369. # 1) 测试值:最新月度的预测值
  370. test_value = monthly_df[TARGET_COL].iloc[0]
  371. # 2) 方向准确率:正确数 / 有效数
  372. total = monthly_df['方向准确率'].notna().sum()
  373. correct = (monthly_df['方向准确率'] == '正确').sum()
  374. direction_accuracy = f"{correct/total:.2%}" if total > 0 else ""
  375. # 3) 平均绝对偏差
  376. absolute_deviation = monthly_df['绝对偏差'].mean()
  377. # ========= 列表页 =========
  378. ws_list = workbook.add_worksheet('列表页')
  379. writer.sheets['列表页'] = ws_list
  380. headers = ['预测标的','分类','模型框架','创建人','预测日期','测试值','预测频度','方向准确率','绝对偏差']
  381. ws_list.write_row(0, 0, headers)
  382. ws_list.write_row(1, 0, [
  383. target_name,
  384. classification,
  385. model_framework,
  386. creator,
  387. pred_date,
  388. test_value,
  389. frequency,
  390. direction_accuracy,
  391. absolute_deviation
  392. ])
  393. # ========= 详情页 =========
  394. detail_df = monthly_df[['Date', '实际值', TARGET_COL, '方向准确率', '绝对偏差']].copy()
  395. detail_df.columns = ['指标日期','实际值','预测值','方向','偏差率']
  396. detail_df.to_excel(writer,sheet_name='详情页',index=False,header=False,startrow=2)
  397. ws_detail = writer.sheets['详情页']
  398. ws_detail.write(0, 0, target_name)
  399. ws_detail.write_row(1, 0, ['指标日期','实际值','预测值','方向','偏差率'])
  400. # ========= 日度数据表 =========
  401. daily_out = daily_df[['Date', '实际值', TARGET_COL]].copy()
  402. daily_out.columns = ['指标日期','实际值','预测值']
  403. daily_out.to_excel(writer,sheet_name='日度数据表',index=False,header=False,startrow=2)
  404. ws_daily = writer.sheets['日度数据表']
  405. ws_daily.write(0, 0, target_name)
  406. ws_daily.write_row(1, 0, ['指标日期','实际值','预测值'])
  407. print(f"已生成并填充 {output_path}")
  408. # ------------ 全量训练与预测 ------------
  409. def train_full_model_and_predict(X_train, y_train, X_test, y_test, X_future):
  410. X_all = pd.concat([X_train, X_test])
  411. y_all = pd.concat([y_train, y_test])
  412. scaler_all = StandardScaler().fit(X_all)
  413. X_all_s = scaler_all.transform(X_all)
  414. X_fu_s = scaler_all.transform(X_future)
  415. model = XGBRegressor(**DEFAULT_PARAMS, n_estimators=NUM_BOOST_ROUND)
  416. model.fit(X_all_s, y_all)
  417. y_fu_full = model.predict(X_fu_s)
  418. return model, y_fu_full, scaler_all
  419. # ------------ 可视化 ------------
  420. def plot_final_predictions(train, y_tr, y_tr_pred, test, y_te, y_te_pred,
  421. future, last_day):
  422. plt.figure(figsize=(15, 6))
  423. plt.plot(train['Date'], y_tr, label='Train True')
  424. plt.plot(train['Date'], y_tr_pred, label='Train Pred')
  425. plt.plot(test['Date'], y_te, label='Test True', alpha=0.7)
  426. plt.plot(test['Date'], y_te_pred, label='Test Pred')
  427. plt.plot(future['Date'], future['预测值'], label='Future Pred')
  428. plt.axvline(test['Date'].iloc[0], color='gray', linestyle='--')
  429. plt.axvline(last_day, color='black', linestyle='--')
  430. plt.legend()
  431. plt.xlabel('Date')
  432. plt.ylabel(TARGET_COL)
  433. plt.title('Prediction Visualization')
  434. plt.grid(True)
  435. plt.show()
  436. # ------------ 主函数 ------------
  437. def main():
  438. # 解析命令行参数
  439. args = parse_arguments()
  440. # 更新全局变量
  441. global NUM_BOOST_ROUND, USE_HYPERPARAM_TUNING, OUTPUT_PATH, DEFAULT_PARAMS
  442. NUM_BOOST_ROUND = args.num_boost_round
  443. USE_HYPERPARAM_TUNING = args.use_hyperparam_tuning.lower() == 'true'
  444. # 根据前缀生成输出路径
  445. if args.output_prefix:
  446. OUTPUT_PATH = f"{args.output_prefix}_update.xlsx"
  447. # 更新XGBoost参数
  448. DEFAULT_PARAMS = {
  449. 'objective': args.objective,
  450. 'learning_rate': args.learning_rate,
  451. 'max_depth': args.max_depth,
  452. 'min_child_weight': args.min_child_weight,
  453. 'gamma': args.gamma,
  454. 'subsample': args.subsample,
  455. 'colsample_bytree': args.colsample_bytree,
  456. 'eval_metric': args.eval_metric,
  457. 'seed': args.seed,
  458. 'reg_alpha': args.reg_alpha,
  459. 'reg_lambda': args.reg_lambda,
  460. }
  461. # print("使用参数:")
  462. # print(f"NUM_BOOST_ROUND: {NUM_BOOST_ROUND}")
  463. # print(f"USE_HYPERPARAM_TUNING: {USE_HYPERPARAM_TUNING}")
  464. # print(f"OUTPUT_PATH: {OUTPUT_PATH}")
  465. # print("DEFAULT_PARAMS:", DEFAULT_PARAMS)
  466. df_daily, last_day = load_and_preprocess_data()
  467. X_tr, y_tr, X_te, y_te, X_fu, train, test, future = split_and_build_features(df_daily, last_day)
  468. scaler, X_tr_s, X_te_s, X_fu_s = scale_and_weight_features(X_tr, X_te, X_fu)
  469. weights = detect_outliers_weights(X_tr_s)
  470. model = train_model_with_tuning(X_tr_s, y_tr, X_te_s, y_te, weights, USE_HYPERPARAM_TUNING)
  471. y_tr_pred, y_te_pred, y_fu_pred = evaluate_and_predict(model, scaler, X_tr, y_tr, X_te, y_te, X_fu, USE_HYPERPARAM_TUNING)
  472. daily_df, monthly_df = merge_and_prepare_df(train, test, future, y_te_pred, y_fu_pred)
  473. # print(monthly_df)
  474. # print(daily_df)
  475. generate_and_fill_excel(
  476. daily_df,
  477. monthly_df,
  478. target_name= TARGET_NAME,
  479. classification= CLASSIFICATION,
  480. model_framework= MODEL_FRAMEWORK,
  481. creator= CREATOR,
  482. pred_date= PRED_DATE,
  483. frequency= FREQUENCY,
  484. output_path= OUTPUT_PATH
  485. )
  486. full_model, y_fu_full, scaler_full = train_full_model_and_predict(X_tr, y_tr, X_te, y_te, X_fu)
  487. if ADJUST_FULL_PREDICTIONS:
  488. offset = y_te.iloc[-1] - y_fu_full[0]
  489. y_fu_full += offset
  490. if SHOW_PLOTS:
  491. plot_final_predictions(
  492. train, y_tr, y_tr_pred, test, y_te, y_te_pred,
  493. future.assign(预测值=y_fu_full), last_day)
  494. return daily_df, monthly_df
  495. if __name__ == '__main__':
  496. daily_df, monthly_df = main()