Dtool.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856
  1. import pandas as pd
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import random
  5. import warnings
  6. from statsmodels.tsa.stattools import acf, q_stat
  7. from statsmodels.stats.diagnostic import acorr_ljungbox
  8. from statsmodels.graphics.tsaplots import plot_acf
  9. from statsmodels.tsa.stattools import adfuller, kpss
  10. from statsmodels.tsa.seasonal import STL
  11. from statsmodels.nonparametric.smoothers_lowess import lowess
  12. import numpy as np
  13. import matplotlib.pyplot as plt
  14. from statsmodels.tsa.seasonal import STL
  15. from scipy.stats import linregress
  16. from xgboost import XGBRegressor
  17. from sklearn.metrics import mean_squared_error, r2_score
  18. from xgboost import plot_importance
  19. import math
  20. import seaborn as sns
  21. plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置字体
  22. plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
  23. def perform_kpss_test(residuals):
  24. with warnings.catch_warnings(record=True) as w:
  25. warnings.simplefilter("always")
  26. kpss_test = kpss(residuals, regression='c')
  27. kpss_statistic = kpss_test[0]
  28. kpss_p_value = kpss_test[1]
  29. kpss_critical_values = kpss_test[3]
  30. print(f'KPSS Statistic: {kpss_statistic:.4f}')
  31. print(f'p-value: {kpss_p_value:.4f}')
  32. if kpss_statistic < kpss_critical_values['5%']:
  33. print("The residuals are stationary (fail to reject the null hypothesis).")
  34. else:
  35. print("The residuals are not stationary (reject the null hypothesis).")
  36. # Check for any warning and print a message if the KPSS statistic is outside the range
  37. if len(w) > 0 and issubclass(w[0].category, Warning):
  38. print("Warning: The test statistic is outside the range of p-values available in the look-up table. "
  39. "The actual p-value is smaller than the returned value.")
  40. # ADF 检验函数
  41. def perform_adf_test(residuals):
  42. adf_test = adfuller(residuals)
  43. adf_statistic = adf_test[0]
  44. p_value = adf_test[1]
  45. critical_values = adf_test[4]
  46. print(f'ADF Statistic: {adf_statistic:.4f}')
  47. print(f'p-value: {p_value:.4f}')
  48. if p_value < 0.05:
  49. print("The residuals are stationary (reject the null hypothesis).")
  50. else:
  51. print("The residuals are not stationary (fail to reject the null hypothesis).")
  52. # Ljung-Box 检验函数
  53. def perform_ljungbox_test(residuals, lags=10):
  54. ljung_box_results = acorr_ljungbox(residuals, lags=[lags], return_df=True)
  55. lb_stat = ljung_box_results['lb_stat'].values[0]
  56. p_value = ljung_box_results['lb_pvalue'].values[0]
  57. print(f"Ljung-Box Statistic: {lb_stat:.4f}")
  58. print(f"Ljung-Box p-value: {p_value:.4f}")
  59. if p_value > 0.05:
  60. print("The residuals are random (fail to reject the null hypothesis of no autocorrelation).")
  61. else:
  62. print("The residuals are not random (reject the null hypothesis of no autocorrelation).")
  63. # 对齐数据
  64. def align_and_adjust(df_resampled, base_year=2024, file_prefix='aligned_weekly'):
  65. """
  66. 对齐并调整日期,适用于任何周度数据。保留缺失值 (NaN) 并调整日期,使其与指定的基准年份对齐。
  67. 参数:
  68. - df_resampled: 周度数据的DataFrame,包含要处理的列。
  69. - base_year: 用于对齐的基准年份(默认2024年)。
  70. - file_prefix: 输出文件的前缀(默认'aligned_weekly')。
  71. """
  72. # 确保索引为 DatetimeIndex
  73. if not isinstance(df_resampled.index, pd.DatetimeIndex):
  74. df_resampled.index = pd.to_datetime(df_resampled.index)
  75. # 创建完整的基准年份的周五时间序列
  76. fridays_base_year = pd.date_range(start=f'{base_year}-01-05', end=f'{base_year}-12-27', freq='W-FRI')
  77. # 提取年份和周数
  78. df_resampled['年份'] = df_resampled.index.year
  79. df_resampled['周数'] = df_resampled.index.isocalendar().week
  80. # 自动获取数据中的年份范围
  81. years_range = range(df_resampled['年份'].min(), df_resampled['年份'].max() + 1)
  82. weeks_range = range(1, 53)
  83. # 创建全年的周数组合
  84. index = pd.MultiIndex.from_product([years_range, weeks_range], names=['年份', '周数'])
  85. # 重建索引,使得数据对齐到完整的年份和周数组合
  86. df_resampled_aligned = df_resampled.set_index(['年份', '周数']).reindex(index).reset_index()
  87. # 保留缺失值为 NaN
  88. df_resampled_aligned['trend'] = df_resampled_aligned['trend'].round(3)
  89. df_resampled_aligned['seasonal'] = df_resampled_aligned['seasonal'].round(3)
  90. df_resampled_aligned['residual'] = df_resampled_aligned['residual'].round(3)
  91. # 使用基准年份的周五时间序列创建周数到日期的映射
  92. week_to_date_map_base_year = {i + 1: date for i, date in enumerate(fridays_base_year)}
  93. # 定义日期调整函数
  94. def adjust_dates(row):
  95. week = row['周数']
  96. year = row['年份']
  97. if week in week_to_date_map_base_year:
  98. base_date = week_to_date_map_base_year[week]
  99. adjusted_date = base_date.replace(year=int(year))
  100. return adjusted_date
  101. return pd.NaT
  102. # 应用日期调整
  103. df_resampled_aligned['日期'] = df_resampled_aligned.apply(adjust_dates, axis=1)
  104. # 移除未来日期(当前日期之后的行)
  105. current_date = pd.Timestamp.today()
  106. df_resampled_aligned = df_resampled_aligned[df_resampled_aligned['日期'] <= current_date]
  107. # 设置调整后的日期为索引
  108. df_resampled_aligned.set_index('日期', inplace=True)
  109. # 检查并提示缺失值
  110. missing_values = df_resampled_aligned[df_resampled_aligned.isna().any(axis=1)]
  111. if not missing_values.empty:
  112. print(f"警告:存在缺失值,缺失的周数为:\n{missing_values[['年份', '周数']]}")
  113. # 保存对齐后的数据到不同的CSV文件
  114. df_resampled_aligned[['trend']].to_csv(f'{file_prefix}_trend.csv', date_format='%Y-%m-%d')
  115. df_resampled_aligned[['seasonal']].to_csv(f'{file_prefix}_seasonal.csv', date_format='%Y-%m-%d')
  116. df_resampled_aligned[['residual']].to_csv(f'{file_prefix}_residual.csv', date_format='%Y-%m-%d')
  117. # 返回处理后的DataFrame
  118. return df_resampled_aligned
  119. # stl 拆分存储
  120. def test_stl_parameters(data, value_col, seasonal, trend, period, seasonal_deg, trend_deg, low_pass_deg, robust):
  121. """
  122. 参数:
  123. - data: 输入的 DataFrame,包含待分解的时间序列数据。
  124. - value_col: 数据中包含时间序列值的列名。
  125. - seasonal: 季节性成分窗口大小。
  126. - trend: 趋势成分窗口大小。
  127. - period: 数据的周期。
  128. - seasonal_deg: STL 分解中的季节性多项式次数。
  129. - trend_deg: STL 分解中的趋势多项式次数。
  130. - low_pass_deg: STL 分解中的低通滤波多项式次数。
  131. - robust: 是否使用稳健方法。
  132. """
  133. stl = STL(
  134. data[value_col],
  135. seasonal=seasonal,
  136. trend=trend,
  137. period=period,
  138. low_pass=None,
  139. seasonal_deg=seasonal_deg,
  140. trend_deg=trend_deg,
  141. low_pass_deg=low_pass_deg,
  142. seasonal_jump=1,
  143. trend_jump=1,
  144. low_pass_jump=1,
  145. robust=robust
  146. )
  147. result = stl.fit()
  148. # Generate new column names based on the original column name
  149. trend_col = f'{value_col}_trend'
  150. seasonal_col = f'{value_col}_seasonal'
  151. residual_col = f'{value_col}_residual'
  152. # Add the decomposition results to the DataFrame with new column names
  153. data[trend_col] = result.trend
  154. data[seasonal_col] = result.seasonal
  155. data[residual_col] = result.resid
  156. # 计算残差标准差、ADF 检验、KPSS 检验、Ljung-Box 检验
  157. residual_std = np.std(data[residual_col])
  158. print(f"Residual Std Dev: {residual_std:.4f}")
  159. print("\nADF Test Results:")
  160. perform_adf_test(data[residual_col])
  161. print("\nKPSS Test Results:")
  162. perform_kpss_test(data[residual_col])
  163. print("\nLjung-Box Test Results:")
  164. perform_ljungbox_test(data[residual_col])
  165. plt.figure(figsize=(14, 8))
  166. plt.subplot(4, 1, 1)
  167. plt.plot(data.index, data[value_col], label='Original Data', color='blue')
  168. plt.title(f'Original Data (Robust={robust}, seasonal={seasonal}, trend={trend}, period={period})')
  169. plt.grid(True)
  170. plt.subplot(4, 1, 2)
  171. plt.plot(data.index, data[trend_col], label='Trend', color='green')
  172. plt.title(f'Trend Component ({trend_col})')
  173. plt.grid(True)
  174. plt.subplot(4, 1, 3)
  175. plt.plot(data.index, data[seasonal_col], label='Seasonal', color='orange')
  176. plt.title(f'Seasonal Component ({seasonal_col})')
  177. plt.grid(True)
  178. plt.subplot(4, 1, 4)
  179. plt.plot(data.index, data[residual_col], label='Residual', color='red')
  180. plt.title(f'Residual Component ({residual_col})')
  181. plt.grid(True)
  182. plt.tight_layout()
  183. plt.show()
  184. '''
  185. params_to_test = [
  186. {
  187. 'data': df1,
  188. 'value_col': '值',
  189. 'seasonal': 53,
  190. 'trend': 55,
  191. 'period': 53,
  192. 'seasonal_deg': 1,
  193. 'trend_deg': 2,
  194. 'low_pass_deg': 1,
  195. 'robust': False
  196. }
  197. # 可以继续添加更多参数组合
  198. ]
  199. for params in params_to_test:
  200. test_stl_parameters(**params)
  201. '''
  202. # 画出所有图
  203. def plot_factors(df, df_name):
  204. """
  205. Plot each column in the DataFrame with the same x-axis (index).
  206. Parameters:
  207. - df: Pandas DataFrame containing the data to be plotted.
  208. - df_name: A string representing the name of the DataFrame (for title purposes).
  209. """
  210. for column in df.columns:
  211. plt.figure(figsize=(10, 5))
  212. plt.plot(df.index, df[column], label=column)
  213. plt.title(f'{df_name}: {column}', fontsize=16)
  214. plt.xlabel('Date')
  215. plt.ylabel('Value')
  216. plt.legend()
  217. plt.grid(True)
  218. plt.xticks(rotation=45)
  219. plt.tight_layout()
  220. plt.show()
  221. '''
  222. plot_factors(aligned_weekly, 'aligned_daily')
  223. '''
  224. # 选取特别col 画图
  225. def plot_factors_by_pattern(df, df_name, pattern=None):
  226. """
  227. 根据指定的列名模式绘制 DataFrame 中的列。
  228. Parameters:
  229. - df: Pandas DataFrame containing the data to be plotted.
  230. - df_name: A string representing the name of the DataFrame (for title purposes).
  231. - pattern: A string representing the pattern for selecting columns to plot (e.g., "trend", "residual").
  232. If None, all columns will be plotted.
  233. """
  234. # 如果给定了 pattern,选择列名中包含该 pattern 的列
  235. if pattern:
  236. columns_to_plot = [col for col in df.columns if pattern in col]
  237. else:
  238. columns_to_plot = df.columns
  239. # 绘制符合条件的列
  240. for column in columns_to_plot:
  241. plt.figure(figsize=(10, 5))
  242. plt.plot(df.index, df[column], label=column)
  243. plt.title(f'{df_name}: {column}', fontsize=16)
  244. plt.xlabel('Date')
  245. plt.ylabel('Value')
  246. plt.legend()
  247. plt.grid(True)
  248. plt.xticks(rotation=45)
  249. plt.tight_layout()
  250. plt.show()
  251. '''
  252. plot_factors_by_pattern(df, 'My DataFrame', pattern='trend')
  253. plot_factors_by_pattern(df, 'My DataFrame')
  254. '''
  255. #空缺值填写
  256. def fill_missing_values(df, fill_methods, return_only_filled=True):
  257. """
  258. 根据每个因子的特性选择不同的填充方式
  259. 参数:
  260. df: 需要处理的 DataFrame
  261. fill_methods: 一个字典,其中键是列名,值是填充方法,如 'mean', 'median', 'ffill', 'bfill', 'interpolate', 'none', 'mean_of_5', 'rolling_mean_5'
  262. return_only_filled: 布尔值, 是否只返回填充过的列, 默认为 True
  263. 返回:
  264. 返回一个新的 DataFrame,只包含指定列并按相应方法填充完毕
  265. """
  266. filled_df = pd.DataFrame() # 创建一个空的 DataFrame 用于存储填充过的因子
  267. for factor, method in fill_methods.items():
  268. if factor in df.columns:
  269. df.loc[:, factor] = pd.to_numeric(df[factor], errors='coerce')
  270. if method == 'mean':
  271. filled_df[factor] = df[factor].fillna(df[factor].mean()).infer_objects(copy=False)
  272. elif method == 'median':
  273. filled_df[factor] = df[factor].fillna(df[factor].median()).infer_objects(copy=False)
  274. elif method == 'ffill':
  275. filled_df[factor] = df[factor].fillna(method='ffill').infer_objects(copy=False)
  276. elif method == 'bfill':
  277. filled_df[factor] = df[factor].fillna(method='bfill').infer_objects(copy=False)
  278. elif method == 'interpolate':
  279. filled_df[factor] = df[factor].infer_objects(copy=False).interpolate(method='linear')
  280. elif method == 'mean_of_5':
  281. filled_df[factor] = df[factor].copy() # 先复制原始数据
  282. for i in range(len(filled_df[factor])):
  283. if pd.isnull(filled_df[factor].iloc[i]): # 检查是否为空
  284. # 获取前后五个非空值,使用 pd.concat 替代 append
  285. surrounding_values = pd.concat([
  286. df[factor].iloc[max(0, i - 5):i].dropna(),
  287. df[factor].iloc[i + 1:min(len(df[factor]), i + 6)].dropna()
  288. ])
  289. if len(surrounding_values) > 0:
  290. # 使用周围非空值的平均值填充
  291. filled_df.loc[filled_df.index[i], factor] = surrounding_values.mean()
  292. elif method == 'rolling_mean_5': # 更平滑一点
  293. # 用滚动窗口的平均值填充
  294. filled_df[factor] = df[factor].fillna(df[factor].rolling(window=5, min_periods=1).mean()).infer_objects(copy=False)
  295. elif method == 'none':
  296. filled_df[factor] = df[factor] # 不做填充,返回原始数据
  297. else:
  298. print(f"未知的填充方法: {method}")
  299. else:
  300. print(f"因子 {factor} 不存在于 DataFrame 中")
  301. # 如果设置了 return_only_filled=False, 则返回所有原始数据+处理过的列
  302. if not return_only_filled:
  303. remaining_df = df.drop(columns=filled_df.columns) # 删除已处理列
  304. return pd.concat([remaining_df, filled_df], axis=1)
  305. return filled_df # 只返回处理过的列
  306. '''
  307. fill_methods = {
  308. '螺纹高炉成本': 'mean', # 使用均值填充
  309. '螺纹表需': 'median', # 使用中位数填充
  310. '30大中城市商品房成交面积/30DMA': 'none' # 不进行填充,保留原始数据
  311. }
  312. # 调用函数进行填充,并只返回被填充的列
  313. filled_data = fill_missing_values(aligned_daily, fill_methods)
  314. # 如果想返回整个DataFrame,包括没填充的列,可以使用:
  315. filled_data_with_all = fill_missing_values(aligned_daily, fill_methods, return_only_filled=False)
  316. :'interpolate'
  317. :'rolling_mean_5'
  318. :'rolling_mean_5'
  319. '''
  320. # daily数据变成weekly
  321. def daily_to_weekly(df_daily, cols_to_process, date_column='日期', method='mean'):
  322. """
  323. 将日度数据转换为周度数据,按周五对齐,并计算每周的平均值。
  324. 参数:
  325. df_daily: 包含日度数据的 DataFrame,索引为日期。
  326. cols_to_process: 需要处理的列的列表。
  327. date_column: 用于对齐的日期列名,默认为 '日期'。
  328. method: 填充每周的计算方式,默认使用 'mean'(平均值),可以根据需要修改。
  329. 返回:
  330. 返回一个新的 DataFrame,转换为周度数据,日期对齐到每周五。
  331. """
  332. # 生成周五为频率的日期范围
  333. weekly_date_range = pd.date_range(start='2016-09-02', end='2024-10-04', freq='W-FRI')
  334. # 创建一个空的 DataFrame,索引为周五的日期范围
  335. df_weekly = pd.DataFrame(index=weekly_date_range)
  336. # 对每个需要处理的列进行周度转换
  337. for column in cols_to_process:
  338. if column in df_daily.columns:
  339. # 按周进行重采样,并计算每周的平均值,忽略缺失值
  340. df_weekly[column] = df_daily[column].resample('W-FRI').apply(lambda x: x.mean() if len(x.dropna()) > 0 else np.nan)
  341. else:
  342. print(f"列 {column} 不存在于 DataFrame 中")
  343. return df_weekly
  344. '''
  345. cols_to_process = ['螺纹表需', '螺纹高炉成本']
  346. # 调用函数,将日度数据转换为周度数据
  347. weekly_data = daily_to_weekly(df_daily, cols_to_process)
  348. '''
  349. def plot_comparison_multiple(df, main_col, compare_cols, start_date=None, end_date=None):
  350. """
  351. 将一个主要指标与多个其他指标进行比较,并且允许选择指定时间范围。
  352. Parameters:
  353. - df: 包含多个指标的 Pandas DataFrame,索引为日期。
  354. - main_col: 主要指标的列名(字符串),该列将与多个其他指标进行对比。
  355. - compare_cols: 需要与主要指标进行比较的其他列的列表。
  356. - start_date: 可选,开始日期,指定要绘制的时间范围。
  357. - end_date: 可选,结束日期,指定要绘制的时间范围。
  358. """
  359. # 如果指定了时间范围,限制数据范围
  360. if start_date and end_date:
  361. df = df.loc[start_date:end_date]
  362. # 归一化数据函数
  363. def normalize(series):
  364. return (series - series.min()) / (series.max() - series.min())
  365. # 归一化主要指标
  366. main_data_normalized = normalize(df[main_col])
  367. # 绘制主要指标与多个其他指标的对比图
  368. for col in compare_cols:
  369. if col in df.columns:
  370. compare_data_normalized = normalize(df[col])
  371. # 绘制图表
  372. plt.figure(figsize=(10, 6))
  373. plt.plot(main_data_normalized.index, main_data_normalized, label=main_col, color='b')
  374. plt.plot(compare_data_normalized.index, compare_data_normalized, label=col, linestyle='--', color='r')
  375. # 添加标题和标签
  376. plt.title(f'Comparison: {main_col} vs {col}', fontsize=16)
  377. plt.xlabel('Date', fontsize=12)
  378. plt.ylabel('Normalized Value', fontsize=12)
  379. # 图例和网格
  380. plt.legend()
  381. plt.grid(True)
  382. plt.xticks(rotation=45)
  383. plt.tight_layout()
  384. plt.show()
  385. else:
  386. print(f"列 '{col}' 不存在于 DataFrame 中。")
  387. '''
  388. plot_comparison_multiple(
  389. filled_data,
  390. main_col='螺纹总库存_trend',
  391. compare_cols=['螺纹总库存_seasonal', '螺纹总库存_residual'],
  392. start_date='2021-01-01',
  393. end_date='2023-01-01'
  394. )
  395. ---
  396. # 预先定义多个比较组合
  397. comparisons = {
  398. '螺纹总库存_trend_vs_日均铁水产量': {
  399. 'main_col': '螺纹总库存_trend',
  400. 'compare_cols': ['日均铁水产量']
  401. },
  402. '螺纹总库存_trend_vs_多个指标': {
  403. 'main_col': '螺纹总库存_trend',
  404. 'compare_cols': ['日均铁水产量', '螺纹总库存']
  405. }
  406. }
  407. # 选择一个组合来进行比较
  408. selected_comparison_key = '螺纹总库存_trend_vs_多个指标' # 你可以从上面打印的组合中选择一个
  409. selected_comparison = comparisons[selected_comparison_key]
  410. # 调用 plot_comparison_multiple 函数,传入选择的组合和时间范围
  411. plot_comparison_multiple(
  412. filled_data,
  413. selected_comparison['main_col'],
  414. selected_comparison['compare_cols'],
  415. start_date='2021-01-01', # 可选的开始时间
  416. end_date='2023-01-01' # 可选的结束时间
  417. )
  418. ---
  419. # 假设你已经提前定义了主要指标和比较指标
  420. main_col = '螺纹总库存_trend'
  421. compare_cols = ['日均铁水产量', '螺纹总库存']
  422. # 现在调用 plot_comparison_multiple 函数时,直接使用这些变量
  423. plot_comparison_multiple(filled_data, main_col, compare_cols,start_date='2021-01-01', end_date='2023-01-01')
  424. '''
  425. def process_outliers(data, column, window=29, std_multiplier=2):
  426. """
  427. 处理数据中的异常波动,使用滚动窗口计算标准差,超出指定标准差倍数的异常值进行平滑处理。
  428. 参数:
  429. - data (pd.DataFrame): 输入数据,必须包含一个日期索引和处理的列。
  430. - column (str): 需要处理的列名。
  431. - window (int): 滑动窗口大小,用于计算标准差,默认为20天。
  432. - std_multiplier (float): 标准差倍数,用于判断异常值,默认为2倍标准差。
  433. 返回:
  434. - pd.DataFrame: 返回处理后的DataFrame,包含异常处理的列。
  435. """
  436. processed_data = data.copy()
  437. # 计算滑动均值和标准差
  438. rolling_mean = processed_data[column].rolling(window=window, min_periods=1).mean()
  439. rolling_std = processed_data[column].rolling(window=window, min_periods=1).std()
  440. # 定义上限和下限
  441. upper_limit = rolling_mean + std_multiplier * rolling_std
  442. lower_limit = rolling_mean - std_multiplier * rolling_std
  443. # 平滑处理超过阈值的异常值
  444. processed_data[column] = np.where(
  445. processed_data[column] > upper_limit, upper_limit,
  446. np.where(processed_data[column] < lower_limit, lower_limit, processed_data[column])
  447. )
  448. return processed_data
  449. '''
  450. processed_data = process_outliers(df, column="WTI连1-连4月差", window=20, std_multiplier=2)
  451. '''
  452. def align_data(xls, sheet_name, date_freq, start_date='2016-09-02', end_date='2024-10-04'):
  453. """
  454. 读取并对齐数据,根据指定频率对齐每日、每周、每月或每年数据。
  455. 参数:
  456. - xls: Excel 文件路径或文件对象
  457. - sheet_name: 表名,如 '日度数据'、'周度数据' 等
  458. - date_freq: 日期频率('D' 表示每日, 'W-FRI' 表示每周五, 'M' 表示每月最后一天, 'A' 表示每年最后一天)
  459. - start_date: 对齐的开始日期
  460. - end_date: 对齐的结束日期
  461. 返回:
  462. - 对齐后的 DataFrame
  463. """
  464. data = pd.read_excel(xls, sheet_name=sheet_name, header=0)
  465. date_range = pd.date_range(start=start_date, end=end_date, freq=date_freq)
  466. aligned_data = pd.DataFrame(index=date_range)
  467. for i in range(0, len(data.columns), 2):
  468. if i + 1 < len(data.columns): # 确保成对的列存在
  469. factor_name = data.columns[i] # 因子名称
  470. df = data[[data.columns[i], data.columns[i + 1]]].copy() # 提取因子的两列
  471. df.columns = ['日期', factor_name] # 重新命名列
  472. df['日期'] = pd.to_datetime(df['日期'], format='%Y-%m-%d', errors='coerce') # 转换日期为 datetime
  473. df.drop_duplicates(subset=['日期'], inplace=True) # 去掉重复日期
  474. df.set_index('日期', inplace=True) # 将日期设置为索引
  475. aligned_data[factor_name] = df.reindex(aligned_data.index)[factor_name] # 对齐到指定频率的日期
  476. return aligned_data
  477. '''
  478. aligned_daily = align_data(xls, sheet_name='日度数据', date_freq='D', start_date='2016-09-02', end_date='2024-10-04')
  479. '''
  480. def reverse_column(df, column_name):
  481. """
  482. 将指定列的数值进行逆序,使得最大值变为最小值,最小值变为最大值。
  483. 参数:
  484. df (pd.DataFrame): 包含要逆序列的 DataFrame
  485. column_name (str): 要逆序的列名
  486. 返回:
  487. pd.Series: 逆序后的列
  488. """
  489. max_value = df[column_name].max()
  490. min_value = df[column_name].min()
  491. return max_value + min_value - df[column_name]
  492. '''
  493. sheet_daily['美国首次申领失业金人数/4WMA_逆序'] = reverse_column(sheet_daily, '美国首次申领失业金人数/4WMA')
  494. '''
  495. '''
  496. def plot_scatter_with_fit(df, main_col, compare_cols, start_date=None, end_date=None):
  497. """
  498. 绘制主列与多个列的散点图及线性拟合线,评估线性关系。
  499. 参数:
  500. df (DataFrame): 输入数据。
  501. main_col (str): 主列名。
  502. compare_cols (list): 要对比的列名列表。
  503. start_date (str): 开始日期(可选),格式 'YYYY-MM-DD'。
  504. end_date (str): 结束日期(可选),格式 'YYYY-MM-DD'。
  505. """
  506. # 过滤日期范围
  507. if start_date:
  508. df = df[df['Date'] >= start_date]
  509. if end_date:
  510. df = df[df['Date'] <= end_date]
  511. # 检查主列是否存在
  512. if main_col not in df.columns:
  513. print(f"主列 '{main_col}' 不存在于 DataFrame 中。")
  514. return
  515. # 绘制主列与多个对比列的散点图和拟合直线
  516. for col in compare_cols:
  517. if col in df.columns:
  518. # 提取主列和对比列数据
  519. x = df[main_col]
  520. y = df[col]
  521. # 计算线性回归
  522. slope, intercept, r_value, p_value, std_err = linregress(x, y)
  523. line = slope * x + intercept # 拟合直线公式
  524. # 绘制散点图和拟合直线
  525. plt.figure(figsize=(10, 6))
  526. plt.scatter(x, y, alpha=0.7, edgecolors='k', label='Data Points')
  527. plt.plot(x, line, color='r', linestyle='--', label=f'Fit Line (R^2={r_value**2:.2f})')
  528. # 添加标题和标签
  529. plt.title(f'{main_col} vs {col}', fontsize=16)
  530. plt.xlabel(main_col, fontsize=12)
  531. plt.ylabel(col, fontsize=12)
  532. # 图例和网格
  533. plt.legend()
  534. plt.grid(True)
  535. plt.tight_layout()
  536. plt.show()
  537. else:
  538. print(f"列 '{col}' 不存在于 DataFrame 中。")
  539. '''
  540. def plot_scatter_with_fit(df, main_col, compare_cols, start_date=None, end_date=None):
  541. """
  542. 绘制主列与多个列的散点图及线性拟合线,剔除超级异常值(基于 3 倍 IQR),评估线性关系。
  543. 参数:
  544. df (DataFrame): 输入数据。
  545. main_col (str): 主列名。
  546. compare_cols (list): 要对比的列名列表。
  547. start_date (str): 开始日期(可选),格式 'YYYY-MM-DD'。
  548. end_date (str): 结束日期(可选),格式 'YYYY-MM-DD'。
  549. """
  550. # 过滤日期范围
  551. if start_date:
  552. df = df[df['Date'] >= start_date]
  553. if end_date:
  554. df = df[df['Date'] <= end_date]
  555. # 检查主列是否存在
  556. if main_col not in df.columns:
  557. print(f"主列 '{main_col}' 不存在于 DataFrame 中。")
  558. return
  559. # 内部函数:剔除异常值
  560. def remove_outliers(series, threshold=3.0):
  561. """
  562. 基于 IQR 剔除异常值。
  563. 参数:
  564. series (Series): 输入数据。
  565. threshold (float): IQR 倍数阈值(默认 3 倍)。
  566. 返回:
  567. Series: 剔除异常值后的数据。
  568. """
  569. q1 = series.quantile(0.25)
  570. q3 = series.quantile(0.75)
  571. iqr = q3 - q1
  572. lower_bound = q1 - threshold * iqr
  573. upper_bound = q3 + threshold * iqr
  574. return series[(series >= lower_bound) & (series <= upper_bound)]
  575. # 绘制主列与多个对比列的散点图和拟合直线
  576. for col in compare_cols:
  577. if col in df.columns:
  578. # 提取主列和对比列数据,并剔除缺失值
  579. valid_data = df[[main_col, col]].dropna()
  580. x = valid_data[main_col]
  581. y = valid_data[col]
  582. # 剔除异常值
  583. x_clean = remove_outliers(x, threshold=3.0)
  584. y_clean = remove_outliers(y, threshold=3.0)
  585. clean_data = valid_data[x.index.isin(x_clean.index) & y.index.isin(y_clean.index)]
  586. # 检查数据点是否足够
  587. if len(clean_data) < 2:
  588. print(f"列 '{col}' 数据不足,无法绘制拟合线。")
  589. continue
  590. x = clean_data[main_col]
  591. y = clean_data[col]
  592. # 计算线性回归
  593. slope, intercept, r_value, _, _ = linregress(x, y)
  594. line = slope * x + intercept # 拟合直线公式
  595. # 绘制散点图和拟合直线
  596. plt.figure(figsize=(10, 6))
  597. plt.scatter(x, y, alpha=0.7, edgecolors='k', label='Data Points')
  598. plt.plot(x, line, color='r', linestyle='--', label=f'Fit Line (R^2={r_value**2:.2f})')
  599. # 添加标题和标签
  600. plt.title(f'{main_col} vs {col}', fontsize=16)
  601. plt.xlabel(main_col, fontsize=12)
  602. plt.ylabel(col, fontsize=12)
  603. # 图例和网格
  604. plt.legend()
  605. plt.grid(True)
  606. plt.tight_layout()
  607. plt.show()
  608. else:
  609. print(f"列 '{col}' 不存在于 DataFrame 中。")
  610. '''
  611. plot_scatter_with_fit(df,
  612. main_col='主指标',
  613. compare_cols=['对比指标1', '对比指标2'],
  614. start_date='2022-01-10',
  615. end_date='2022-03-10')
  616. '''
  617. def plot_feature_importance(booster, X_train, importance_type='weight', title='特征重要性排序', xlabel='特征重要性'):
  618. """
  619. 绘制特征重要性的排序图
  620. :param booster: xgboost 模型的 booster
  621. :param X_train: 训练数据,用于映射特征名称
  622. :param importance_type: 'weight', 'gain' 或 'cover' 用于获取特征重要性
  623. :param title: 图表的标题 (可选)
  624. :param xlabel: X 轴标签 (可选)
  625. """
  626. # 获取特征重要性
  627. feature_importance = booster.get_score(importance_type=importance_type)
  628. # 创建 DataFrame 用于排序
  629. importance_df = pd.DataFrame({
  630. 'feature': list(feature_importance.keys()),
  631. 'importance': list(feature_importance.values())
  632. })
  633. # 将特征名称从 f0, f1 等映射到实际的列名
  634. feature_names = dict(zip([f'f{i}' for i in range(len(X_train.columns))], X_train.columns))
  635. importance_df['feature_name'] = importance_df['feature'].map(feature_names)
  636. # 按重要性降序排序
  637. importance_df_sorted = importance_df.sort_values('importance', ascending=True)
  638. # 绘制水平条形图
  639. plt.figure(figsize=(8, 6)) # 设置适中的图形大小
  640. plt.barh(range(len(importance_df_sorted)), importance_df_sorted['importance'])
  641. plt.yticks(range(len(importance_df_sorted)), importance_df_sorted['feature_name'], fontsize=8) # 调整字体大小
  642. plt.xlabel(xlabel)
  643. plt.title(title)
  644. plt.tight_layout() # 调整布局,避免标签重叠
  645. plt.show()
  646. # 使用示例:
  647. '''
  648. plot_feature_importance(xgb_model.get_booster(), X_train, importance_type='gain', title='特征重要性 (Gain)', xlabel='增益')
  649. plot_feature_importance(xgb_model.get_booster(), X_train, importance_type='weight', title='特征重要性 (Weight)', xlabel='权重')
  650. plot_feature_importance(xgb_model.get_booster(), X_train, importance_type='cover', title='特征重要性 (Cover)', xlabel='覆盖率')
  651. '''
  652. def plot_feature_distribution(df, feature_columns, bins=30, figsize=(18, 12)):
  653. """
  654. 根据传入的 DataFrame 和指定的特征列列表,绘制各个特征的分布图(直方图+核密度估计曲线),
  655. 以便观察数据的分布情况。
  656. 参数:
  657. df: pandas.DataFrame
  658. 包含数据的 DataFrame。
  659. feature_columns: list
  660. 要展示分布的特征列名称列表。
  661. bins: int
  662. 直方图的柱子数量,默认采用 30 个柱子。
  663. figsize: tuple
  664. 整体图形的尺寸,默认为 (18, 12)。
  665. 返回:
  666. None,函数会直接展示绘制的图形。
  667. """
  668. n_features = len(feature_columns)
  669. n_cols = 3 # 每行放置3个图
  670. n_rows = math.ceil(n_features / n_cols)
  671. plt.figure(figsize=figsize)
  672. for idx, col in enumerate(feature_columns):
  673. plt.subplot(n_rows, n_cols, idx + 1)
  674. # 去除空值后绘制直方图,并绘制核密度估计曲线
  675. data = df[col].dropna()
  676. sns.histplot(data, bins=bins, kde=True)
  677. plt.title(f"{col} 分布")
  678. plt.xlabel(col)
  679. plt.tight_layout()
  680. plt.show()