123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856 |
- import pandas as pd
- import numpy as np
- import matplotlib.pyplot as plt
- import random
- import warnings
- from statsmodels.tsa.stattools import acf, q_stat
- from statsmodels.stats.diagnostic import acorr_ljungbox
- from statsmodels.graphics.tsaplots import plot_acf
- from statsmodels.tsa.stattools import adfuller, kpss
- from statsmodels.tsa.seasonal import STL
- from statsmodels.nonparametric.smoothers_lowess import lowess
- import numpy as np
- import matplotlib.pyplot as plt
- from statsmodels.tsa.seasonal import STL
- from scipy.stats import linregress
- from xgboost import XGBRegressor
- from sklearn.metrics import mean_squared_error, r2_score
- from xgboost import plot_importance
- import math
- import seaborn as sns
- plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置字体
- plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
- def perform_kpss_test(residuals):
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always")
- kpss_test = kpss(residuals, regression='c')
- kpss_statistic = kpss_test[0]
- kpss_p_value = kpss_test[1]
- kpss_critical_values = kpss_test[3]
- print(f'KPSS Statistic: {kpss_statistic:.4f}')
- print(f'p-value: {kpss_p_value:.4f}')
- if kpss_statistic < kpss_critical_values['5%']:
- print("The residuals are stationary (fail to reject the null hypothesis).")
- else:
- print("The residuals are not stationary (reject the null hypothesis).")
- # Check for any warning and print a message if the KPSS statistic is outside the range
- if len(w) > 0 and issubclass(w[0].category, Warning):
- print("Warning: The test statistic is outside the range of p-values available in the look-up table. "
- "The actual p-value is smaller than the returned value.")
-
-
- # ADF 检验函数
- def perform_adf_test(residuals):
- adf_test = adfuller(residuals)
- adf_statistic = adf_test[0]
- p_value = adf_test[1]
- critical_values = adf_test[4]
- print(f'ADF Statistic: {adf_statistic:.4f}')
- print(f'p-value: {p_value:.4f}')
- if p_value < 0.05:
- print("The residuals are stationary (reject the null hypothesis).")
- else:
- print("The residuals are not stationary (fail to reject the null hypothesis).")
- # Ljung-Box 检验函数
- def perform_ljungbox_test(residuals, lags=10):
- ljung_box_results = acorr_ljungbox(residuals, lags=[lags], return_df=True)
- lb_stat = ljung_box_results['lb_stat'].values[0]
- p_value = ljung_box_results['lb_pvalue'].values[0]
- print(f"Ljung-Box Statistic: {lb_stat:.4f}")
- print(f"Ljung-Box p-value: {p_value:.4f}")
- if p_value > 0.05:
- print("The residuals are random (fail to reject the null hypothesis of no autocorrelation).")
- else:
- print("The residuals are not random (reject the null hypothesis of no autocorrelation).")
- # 对齐数据
- def align_and_adjust(df_resampled, base_year=2024, file_prefix='aligned_weekly'):
- """
- 对齐并调整日期,适用于任何周度数据。保留缺失值 (NaN) 并调整日期,使其与指定的基准年份对齐。
- 参数:
- - df_resampled: 周度数据的DataFrame,包含要处理的列。
- - base_year: 用于对齐的基准年份(默认2024年)。
- - file_prefix: 输出文件的前缀(默认'aligned_weekly')。
- """
- # 确保索引为 DatetimeIndex
- if not isinstance(df_resampled.index, pd.DatetimeIndex):
- df_resampled.index = pd.to_datetime(df_resampled.index)
- # 创建完整的基准年份的周五时间序列
- fridays_base_year = pd.date_range(start=f'{base_year}-01-05', end=f'{base_year}-12-27', freq='W-FRI')
- # 提取年份和周数
- df_resampled['年份'] = df_resampled.index.year
- df_resampled['周数'] = df_resampled.index.isocalendar().week
- # 自动获取数据中的年份范围
- years_range = range(df_resampled['年份'].min(), df_resampled['年份'].max() + 1)
- weeks_range = range(1, 53)
- # 创建全年的周数组合
- index = pd.MultiIndex.from_product([years_range, weeks_range], names=['年份', '周数'])
- # 重建索引,使得数据对齐到完整的年份和周数组合
- df_resampled_aligned = df_resampled.set_index(['年份', '周数']).reindex(index).reset_index()
- # 保留缺失值为 NaN
- df_resampled_aligned['trend'] = df_resampled_aligned['trend'].round(3)
- df_resampled_aligned['seasonal'] = df_resampled_aligned['seasonal'].round(3)
- df_resampled_aligned['residual'] = df_resampled_aligned['residual'].round(3)
- # 使用基准年份的周五时间序列创建周数到日期的映射
- week_to_date_map_base_year = {i + 1: date for i, date in enumerate(fridays_base_year)}
- # 定义日期调整函数
- def adjust_dates(row):
- week = row['周数']
- year = row['年份']
- if week in week_to_date_map_base_year:
- base_date = week_to_date_map_base_year[week]
- adjusted_date = base_date.replace(year=int(year))
- return adjusted_date
- return pd.NaT
- # 应用日期调整
- df_resampled_aligned['日期'] = df_resampled_aligned.apply(adjust_dates, axis=1)
- # 移除未来日期(当前日期之后的行)
- current_date = pd.Timestamp.today()
- df_resampled_aligned = df_resampled_aligned[df_resampled_aligned['日期'] <= current_date]
- # 设置调整后的日期为索引
- df_resampled_aligned.set_index('日期', inplace=True)
- # 检查并提示缺失值
- missing_values = df_resampled_aligned[df_resampled_aligned.isna().any(axis=1)]
- if not missing_values.empty:
- print(f"警告:存在缺失值,缺失的周数为:\n{missing_values[['年份', '周数']]}")
- # 保存对齐后的数据到不同的CSV文件
- df_resampled_aligned[['trend']].to_csv(f'{file_prefix}_trend.csv', date_format='%Y-%m-%d')
- df_resampled_aligned[['seasonal']].to_csv(f'{file_prefix}_seasonal.csv', date_format='%Y-%m-%d')
- df_resampled_aligned[['residual']].to_csv(f'{file_prefix}_residual.csv', date_format='%Y-%m-%d')
- # 返回处理后的DataFrame
- return df_resampled_aligned
- # stl 拆分存储
- def test_stl_parameters(data, value_col, seasonal, trend, period, seasonal_deg, trend_deg, low_pass_deg, robust):
- """
- 参数:
- - data: 输入的 DataFrame,包含待分解的时间序列数据。
- - value_col: 数据中包含时间序列值的列名。
- - seasonal: 季节性成分窗口大小。
- - trend: 趋势成分窗口大小。
- - period: 数据的周期。
- - seasonal_deg: STL 分解中的季节性多项式次数。
- - trend_deg: STL 分解中的趋势多项式次数。
- - low_pass_deg: STL 分解中的低通滤波多项式次数。
- - robust: 是否使用稳健方法。
- """
- stl = STL(
- data[value_col],
- seasonal=seasonal,
- trend=trend,
- period=period,
- low_pass=None,
- seasonal_deg=seasonal_deg,
- trend_deg=trend_deg,
- low_pass_deg=low_pass_deg,
- seasonal_jump=1,
- trend_jump=1,
- low_pass_jump=1,
- robust=robust
- )
-
- result = stl.fit()
- # Generate new column names based on the original column name
- trend_col = f'{value_col}_trend'
- seasonal_col = f'{value_col}_seasonal'
- residual_col = f'{value_col}_residual'
- # Add the decomposition results to the DataFrame with new column names
- data[trend_col] = result.trend
- data[seasonal_col] = result.seasonal
- data[residual_col] = result.resid
- # 计算残差标准差、ADF 检验、KPSS 检验、Ljung-Box 检验
- residual_std = np.std(data[residual_col])
- print(f"Residual Std Dev: {residual_std:.4f}")
- print("\nADF Test Results:")
- perform_adf_test(data[residual_col])
- print("\nKPSS Test Results:")
- perform_kpss_test(data[residual_col])
- print("\nLjung-Box Test Results:")
- perform_ljungbox_test(data[residual_col])
- plt.figure(figsize=(14, 8))
- plt.subplot(4, 1, 1)
- plt.plot(data.index, data[value_col], label='Original Data', color='blue')
- plt.title(f'Original Data (Robust={robust}, seasonal={seasonal}, trend={trend}, period={period})')
- plt.grid(True)
- plt.subplot(4, 1, 2)
- plt.plot(data.index, data[trend_col], label='Trend', color='green')
- plt.title(f'Trend Component ({trend_col})')
- plt.grid(True)
- plt.subplot(4, 1, 3)
- plt.plot(data.index, data[seasonal_col], label='Seasonal', color='orange')
- plt.title(f'Seasonal Component ({seasonal_col})')
- plt.grid(True)
- plt.subplot(4, 1, 4)
- plt.plot(data.index, data[residual_col], label='Residual', color='red')
- plt.title(f'Residual Component ({residual_col})')
- plt.grid(True)
- plt.tight_layout()
- plt.show()
- '''
- params_to_test = [
- {
- 'data': df1,
- 'value_col': '值',
- 'seasonal': 53,
- 'trend': 55,
- 'period': 53,
- 'seasonal_deg': 1,
- 'trend_deg': 2,
- 'low_pass_deg': 1,
- 'robust': False
- }
- # 可以继续添加更多参数组合
- ]
- for params in params_to_test:
- test_stl_parameters(**params)
- '''
-
- # 画出所有图
- def plot_factors(df, df_name):
- """
- Plot each column in the DataFrame with the same x-axis (index).
-
- Parameters:
- - df: Pandas DataFrame containing the data to be plotted.
- - df_name: A string representing the name of the DataFrame (for title purposes).
- """
- for column in df.columns:
- plt.figure(figsize=(10, 5))
- plt.plot(df.index, df[column], label=column)
- plt.title(f'{df_name}: {column}', fontsize=16)
- plt.xlabel('Date')
- plt.ylabel('Value')
- plt.legend()
- plt.grid(True)
- plt.xticks(rotation=45)
- plt.tight_layout()
- plt.show()
- '''
- plot_factors(aligned_weekly, 'aligned_daily')
- '''
- # 选取特别col 画图
- def plot_factors_by_pattern(df, df_name, pattern=None):
- """
- 根据指定的列名模式绘制 DataFrame 中的列。
-
- Parameters:
- - df: Pandas DataFrame containing the data to be plotted.
- - df_name: A string representing the name of the DataFrame (for title purposes).
- - pattern: A string representing the pattern for selecting columns to plot (e.g., "trend", "residual").
- If None, all columns will be plotted.
- """
- # 如果给定了 pattern,选择列名中包含该 pattern 的列
- if pattern:
- columns_to_plot = [col for col in df.columns if pattern in col]
- else:
- columns_to_plot = df.columns
- # 绘制符合条件的列
- for column in columns_to_plot:
- plt.figure(figsize=(10, 5))
- plt.plot(df.index, df[column], label=column)
- plt.title(f'{df_name}: {column}', fontsize=16)
- plt.xlabel('Date')
- plt.ylabel('Value')
- plt.legend()
- plt.grid(True)
- plt.xticks(rotation=45)
- plt.tight_layout()
- plt.show()
- '''
- plot_factors_by_pattern(df, 'My DataFrame', pattern='trend')
- plot_factors_by_pattern(df, 'My DataFrame')
- '''
- #空缺值填写
- def fill_missing_values(df, fill_methods, return_only_filled=True):
- """
- 根据每个因子的特性选择不同的填充方式
-
- 参数:
- df: 需要处理的 DataFrame
- fill_methods: 一个字典,其中键是列名,值是填充方法,如 'mean', 'median', 'ffill', 'bfill', 'interpolate', 'none', 'mean_of_5', 'rolling_mean_5'
- return_only_filled: 布尔值, 是否只返回填充过的列, 默认为 True
-
- 返回:
- 返回一个新的 DataFrame,只包含指定列并按相应方法填充完毕
- """
- filled_df = pd.DataFrame() # 创建一个空的 DataFrame 用于存储填充过的因子
- for factor, method in fill_methods.items():
- if factor in df.columns:
- df.loc[:, factor] = pd.to_numeric(df[factor], errors='coerce')
- if method == 'mean':
- filled_df[factor] = df[factor].fillna(df[factor].mean()).infer_objects(copy=False)
- elif method == 'median':
- filled_df[factor] = df[factor].fillna(df[factor].median()).infer_objects(copy=False)
- elif method == 'ffill':
- filled_df[factor] = df[factor].fillna(method='ffill').infer_objects(copy=False)
- elif method == 'bfill':
- filled_df[factor] = df[factor].fillna(method='bfill').infer_objects(copy=False)
- elif method == 'interpolate':
- filled_df[factor] = df[factor].infer_objects(copy=False).interpolate(method='linear')
- elif method == 'mean_of_5':
- filled_df[factor] = df[factor].copy() # 先复制原始数据
- for i in range(len(filled_df[factor])):
- if pd.isnull(filled_df[factor].iloc[i]): # 检查是否为空
- # 获取前后五个非空值,使用 pd.concat 替代 append
- surrounding_values = pd.concat([
- df[factor].iloc[max(0, i - 5):i].dropna(),
- df[factor].iloc[i + 1:min(len(df[factor]), i + 6)].dropna()
- ])
- if len(surrounding_values) > 0:
- # 使用周围非空值的平均值填充
- filled_df.loc[filled_df.index[i], factor] = surrounding_values.mean()
- elif method == 'rolling_mean_5': # 更平滑一点
- # 用滚动窗口的平均值填充
- filled_df[factor] = df[factor].fillna(df[factor].rolling(window=5, min_periods=1).mean()).infer_objects(copy=False)
- elif method == 'none':
- filled_df[factor] = df[factor] # 不做填充,返回原始数据
- else:
- print(f"未知的填充方法: {method}")
- else:
- print(f"因子 {factor} 不存在于 DataFrame 中")
- # 如果设置了 return_only_filled=False, 则返回所有原始数据+处理过的列
- if not return_only_filled:
- remaining_df = df.drop(columns=filled_df.columns) # 删除已处理列
- return pd.concat([remaining_df, filled_df], axis=1)
-
- return filled_df # 只返回处理过的列
- '''
- fill_methods = {
- '螺纹高炉成本': 'mean', # 使用均值填充
- '螺纹表需': 'median', # 使用中位数填充
- '30大中城市商品房成交面积/30DMA': 'none' # 不进行填充,保留原始数据
- }
- # 调用函数进行填充,并只返回被填充的列
- filled_data = fill_missing_values(aligned_daily, fill_methods)
- # 如果想返回整个DataFrame,包括没填充的列,可以使用:
- filled_data_with_all = fill_missing_values(aligned_daily, fill_methods, return_only_filled=False)
- :'interpolate'
- :'rolling_mean_5'
- :'rolling_mean_5'
- '''
- # daily数据变成weekly
- def daily_to_weekly(df_daily, cols_to_process, date_column='日期', method='mean'):
- """
- 将日度数据转换为周度数据,按周五对齐,并计算每周的平均值。
-
- 参数:
- df_daily: 包含日度数据的 DataFrame,索引为日期。
- cols_to_process: 需要处理的列的列表。
- date_column: 用于对齐的日期列名,默认为 '日期'。
- method: 填充每周的计算方式,默认使用 'mean'(平均值),可以根据需要修改。
- 返回:
- 返回一个新的 DataFrame,转换为周度数据,日期对齐到每周五。
- """
- # 生成周五为频率的日期范围
- weekly_date_range = pd.date_range(start='2016-09-02', end='2024-10-04', freq='W-FRI')
-
- # 创建一个空的 DataFrame,索引为周五的日期范围
- df_weekly = pd.DataFrame(index=weekly_date_range)
- # 对每个需要处理的列进行周度转换
- for column in cols_to_process:
- if column in df_daily.columns:
- # 按周进行重采样,并计算每周的平均值,忽略缺失值
- df_weekly[column] = df_daily[column].resample('W-FRI').apply(lambda x: x.mean() if len(x.dropna()) > 0 else np.nan)
- else:
- print(f"列 {column} 不存在于 DataFrame 中")
- return df_weekly
- '''
- cols_to_process = ['螺纹表需', '螺纹高炉成本']
- # 调用函数,将日度数据转换为周度数据
- weekly_data = daily_to_weekly(df_daily, cols_to_process)
- '''
- def plot_comparison_multiple(df, main_col, compare_cols, start_date=None, end_date=None):
- """
- 将一个主要指标与多个其他指标进行比较,并且允许选择指定时间范围。
-
- Parameters:
- - df: 包含多个指标的 Pandas DataFrame,索引为日期。
- - main_col: 主要指标的列名(字符串),该列将与多个其他指标进行对比。
- - compare_cols: 需要与主要指标进行比较的其他列的列表。
- - start_date: 可选,开始日期,指定要绘制的时间范围。
- - end_date: 可选,结束日期,指定要绘制的时间范围。
- """
-
- # 如果指定了时间范围,限制数据范围
- if start_date and end_date:
- df = df.loc[start_date:end_date]
- # 归一化数据函数
- def normalize(series):
- return (series - series.min()) / (series.max() - series.min())
-
- # 归一化主要指标
- main_data_normalized = normalize(df[main_col])
- # 绘制主要指标与多个其他指标的对比图
- for col in compare_cols:
- if col in df.columns:
- compare_data_normalized = normalize(df[col])
-
- # 绘制图表
- plt.figure(figsize=(10, 6))
- plt.plot(main_data_normalized.index, main_data_normalized, label=main_col, color='b')
- plt.plot(compare_data_normalized.index, compare_data_normalized, label=col, linestyle='--', color='r')
-
- # 添加标题和标签
- plt.title(f'Comparison: {main_col} vs {col}', fontsize=16)
- plt.xlabel('Date', fontsize=12)
- plt.ylabel('Normalized Value', fontsize=12)
-
- # 图例和网格
- plt.legend()
- plt.grid(True)
- plt.xticks(rotation=45)
- plt.tight_layout()
- plt.show()
- else:
- print(f"列 '{col}' 不存在于 DataFrame 中。")
- '''
- plot_comparison_multiple(
- filled_data,
- main_col='螺纹总库存_trend',
- compare_cols=['螺纹总库存_seasonal', '螺纹总库存_residual'],
- start_date='2021-01-01',
- end_date='2023-01-01'
- )
- ---
- # 预先定义多个比较组合
- comparisons = {
- '螺纹总库存_trend_vs_日均铁水产量': {
- 'main_col': '螺纹总库存_trend',
- 'compare_cols': ['日均铁水产量']
- },
- '螺纹总库存_trend_vs_多个指标': {
- 'main_col': '螺纹总库存_trend',
- 'compare_cols': ['日均铁水产量', '螺纹总库存']
- }
- }
- # 选择一个组合来进行比较
- selected_comparison_key = '螺纹总库存_trend_vs_多个指标' # 你可以从上面打印的组合中选择一个
- selected_comparison = comparisons[selected_comparison_key]
- # 调用 plot_comparison_multiple 函数,传入选择的组合和时间范围
- plot_comparison_multiple(
- filled_data,
- selected_comparison['main_col'],
- selected_comparison['compare_cols'],
- start_date='2021-01-01', # 可选的开始时间
- end_date='2023-01-01' # 可选的结束时间
- )
- ---
- # 假设你已经提前定义了主要指标和比较指标
- main_col = '螺纹总库存_trend'
- compare_cols = ['日均铁水产量', '螺纹总库存']
- # 现在调用 plot_comparison_multiple 函数时,直接使用这些变量
- plot_comparison_multiple(filled_data, main_col, compare_cols,start_date='2021-01-01', end_date='2023-01-01')
- '''
- def process_outliers(data, column, window=29, std_multiplier=2):
- """
- 处理数据中的异常波动,使用滚动窗口计算标准差,超出指定标准差倍数的异常值进行平滑处理。
-
- 参数:
- - data (pd.DataFrame): 输入数据,必须包含一个日期索引和处理的列。
- - column (str): 需要处理的列名。
- - window (int): 滑动窗口大小,用于计算标准差,默认为20天。
- - std_multiplier (float): 标准差倍数,用于判断异常值,默认为2倍标准差。
-
- 返回:
- - pd.DataFrame: 返回处理后的DataFrame,包含异常处理的列。
- """
- processed_data = data.copy()
-
- # 计算滑动均值和标准差
- rolling_mean = processed_data[column].rolling(window=window, min_periods=1).mean()
- rolling_std = processed_data[column].rolling(window=window, min_periods=1).std()
-
- # 定义上限和下限
- upper_limit = rolling_mean + std_multiplier * rolling_std
- lower_limit = rolling_mean - std_multiplier * rolling_std
-
- # 平滑处理超过阈值的异常值
- processed_data[column] = np.where(
- processed_data[column] > upper_limit, upper_limit,
- np.where(processed_data[column] < lower_limit, lower_limit, processed_data[column])
- )
-
- return processed_data
- '''
- processed_data = process_outliers(df, column="WTI连1-连4月差", window=20, std_multiplier=2)
- '''
- def align_data(xls, sheet_name, date_freq, start_date='2016-09-02', end_date='2024-10-04'):
- """
- 读取并对齐数据,根据指定频率对齐每日、每周、每月或每年数据。
-
- 参数:
- - xls: Excel 文件路径或文件对象
- - sheet_name: 表名,如 '日度数据'、'周度数据' 等
- - date_freq: 日期频率('D' 表示每日, 'W-FRI' 表示每周五, 'M' 表示每月最后一天, 'A' 表示每年最后一天)
- - start_date: 对齐的开始日期
- - end_date: 对齐的结束日期
-
- 返回:
- - 对齐后的 DataFrame
- """
- data = pd.read_excel(xls, sheet_name=sheet_name, header=0)
- date_range = pd.date_range(start=start_date, end=end_date, freq=date_freq)
- aligned_data = pd.DataFrame(index=date_range)
- for i in range(0, len(data.columns), 2):
- if i + 1 < len(data.columns): # 确保成对的列存在
- factor_name = data.columns[i] # 因子名称
- df = data[[data.columns[i], data.columns[i + 1]]].copy() # 提取因子的两列
- df.columns = ['日期', factor_name] # 重新命名列
- df['日期'] = pd.to_datetime(df['日期'], format='%Y-%m-%d', errors='coerce') # 转换日期为 datetime
- df.drop_duplicates(subset=['日期'], inplace=True) # 去掉重复日期
- df.set_index('日期', inplace=True) # 将日期设置为索引
- aligned_data[factor_name] = df.reindex(aligned_data.index)[factor_name] # 对齐到指定频率的日期
- return aligned_data
- '''
- aligned_daily = align_data(xls, sheet_name='日度数据', date_freq='D', start_date='2016-09-02', end_date='2024-10-04')
- '''
- def reverse_column(df, column_name):
- """
- 将指定列的数值进行逆序,使得最大值变为最小值,最小值变为最大值。
-
- 参数:
- df (pd.DataFrame): 包含要逆序列的 DataFrame
- column_name (str): 要逆序的列名
- 返回:
- pd.Series: 逆序后的列
- """
- max_value = df[column_name].max()
- min_value = df[column_name].min()
- return max_value + min_value - df[column_name]
- '''
- sheet_daily['美国首次申领失业金人数/4WMA_逆序'] = reverse_column(sheet_daily, '美国首次申领失业金人数/4WMA')
- '''
- '''
- def plot_scatter_with_fit(df, main_col, compare_cols, start_date=None, end_date=None):
- """
- 绘制主列与多个列的散点图及线性拟合线,评估线性关系。
-
- 参数:
- df (DataFrame): 输入数据。
- main_col (str): 主列名。
- compare_cols (list): 要对比的列名列表。
- start_date (str): 开始日期(可选),格式 'YYYY-MM-DD'。
- end_date (str): 结束日期(可选),格式 'YYYY-MM-DD'。
- """
- # 过滤日期范围
- if start_date:
- df = df[df['Date'] >= start_date]
- if end_date:
- df = df[df['Date'] <= end_date]
-
- # 检查主列是否存在
- if main_col not in df.columns:
- print(f"主列 '{main_col}' 不存在于 DataFrame 中。")
- return
-
- # 绘制主列与多个对比列的散点图和拟合直线
- for col in compare_cols:
- if col in df.columns:
- # 提取主列和对比列数据
- x = df[main_col]
- y = df[col]
-
- # 计算线性回归
- slope, intercept, r_value, p_value, std_err = linregress(x, y)
- line = slope * x + intercept # 拟合直线公式
-
- # 绘制散点图和拟合直线
- plt.figure(figsize=(10, 6))
- plt.scatter(x, y, alpha=0.7, edgecolors='k', label='Data Points')
- plt.plot(x, line, color='r', linestyle='--', label=f'Fit Line (R^2={r_value**2:.2f})')
-
- # 添加标题和标签
- plt.title(f'{main_col} vs {col}', fontsize=16)
- plt.xlabel(main_col, fontsize=12)
- plt.ylabel(col, fontsize=12)
-
- # 图例和网格
- plt.legend()
- plt.grid(True)
- plt.tight_layout()
- plt.show()
- else:
- print(f"列 '{col}' 不存在于 DataFrame 中。")
- '''
- def plot_scatter_with_fit(df, main_col, compare_cols, start_date=None, end_date=None):
- """
- 绘制主列与多个列的散点图及线性拟合线,剔除超级异常值(基于 3 倍 IQR),评估线性关系。
-
- 参数:
- df (DataFrame): 输入数据。
- main_col (str): 主列名。
- compare_cols (list): 要对比的列名列表。
- start_date (str): 开始日期(可选),格式 'YYYY-MM-DD'。
- end_date (str): 结束日期(可选),格式 'YYYY-MM-DD'。
- """
- # 过滤日期范围
- if start_date:
- df = df[df['Date'] >= start_date]
- if end_date:
- df = df[df['Date'] <= end_date]
-
- # 检查主列是否存在
- if main_col not in df.columns:
- print(f"主列 '{main_col}' 不存在于 DataFrame 中。")
- return
-
- # 内部函数:剔除异常值
- def remove_outliers(series, threshold=3.0):
- """
- 基于 IQR 剔除异常值。
-
- 参数:
- series (Series): 输入数据。
- threshold (float): IQR 倍数阈值(默认 3 倍)。
-
- 返回:
- Series: 剔除异常值后的数据。
- """
- q1 = series.quantile(0.25)
- q3 = series.quantile(0.75)
- iqr = q3 - q1
- lower_bound = q1 - threshold * iqr
- upper_bound = q3 + threshold * iqr
- return series[(series >= lower_bound) & (series <= upper_bound)]
-
- # 绘制主列与多个对比列的散点图和拟合直线
- for col in compare_cols:
- if col in df.columns:
- # 提取主列和对比列数据,并剔除缺失值
- valid_data = df[[main_col, col]].dropna()
- x = valid_data[main_col]
- y = valid_data[col]
-
- # 剔除异常值
- x_clean = remove_outliers(x, threshold=3.0)
- y_clean = remove_outliers(y, threshold=3.0)
- clean_data = valid_data[x.index.isin(x_clean.index) & y.index.isin(y_clean.index)]
-
- # 检查数据点是否足够
- if len(clean_data) < 2:
- print(f"列 '{col}' 数据不足,无法绘制拟合线。")
- continue
-
- x = clean_data[main_col]
- y = clean_data[col]
-
- # 计算线性回归
- slope, intercept, r_value, _, _ = linregress(x, y)
- line = slope * x + intercept # 拟合直线公式
-
- # 绘制散点图和拟合直线
- plt.figure(figsize=(10, 6))
- plt.scatter(x, y, alpha=0.7, edgecolors='k', label='Data Points')
- plt.plot(x, line, color='r', linestyle='--', label=f'Fit Line (R^2={r_value**2:.2f})')
-
- # 添加标题和标签
- plt.title(f'{main_col} vs {col}', fontsize=16)
- plt.xlabel(main_col, fontsize=12)
- plt.ylabel(col, fontsize=12)
-
- # 图例和网格
- plt.legend()
- plt.grid(True)
- plt.tight_layout()
- plt.show()
- else:
- print(f"列 '{col}' 不存在于 DataFrame 中。")
- '''
- plot_scatter_with_fit(df,
- main_col='主指标',
- compare_cols=['对比指标1', '对比指标2'],
- start_date='2022-01-10',
- end_date='2022-03-10')
- '''
- def plot_feature_importance(booster, X_train, importance_type='weight', title='特征重要性排序', xlabel='特征重要性'):
- """
- 绘制特征重要性的排序图
- :param booster: xgboost 模型的 booster
- :param X_train: 训练数据,用于映射特征名称
- :param importance_type: 'weight', 'gain' 或 'cover' 用于获取特征重要性
- :param title: 图表的标题 (可选)
- :param xlabel: X 轴标签 (可选)
- """
- # 获取特征重要性
- feature_importance = booster.get_score(importance_type=importance_type)
- # 创建 DataFrame 用于排序
- importance_df = pd.DataFrame({
- 'feature': list(feature_importance.keys()),
- 'importance': list(feature_importance.values())
- })
- # 将特征名称从 f0, f1 等映射到实际的列名
- feature_names = dict(zip([f'f{i}' for i in range(len(X_train.columns))], X_train.columns))
- importance_df['feature_name'] = importance_df['feature'].map(feature_names)
- # 按重要性降序排序
- importance_df_sorted = importance_df.sort_values('importance', ascending=True)
- # 绘制水平条形图
- plt.figure(figsize=(8, 6)) # 设置适中的图形大小
- plt.barh(range(len(importance_df_sorted)), importance_df_sorted['importance'])
- plt.yticks(range(len(importance_df_sorted)), importance_df_sorted['feature_name'], fontsize=8) # 调整字体大小
- plt.xlabel(xlabel)
- plt.title(title)
- plt.tight_layout() # 调整布局,避免标签重叠
- plt.show()
- # 使用示例:
- '''
- plot_feature_importance(xgb_model.get_booster(), X_train, importance_type='gain', title='特征重要性 (Gain)', xlabel='增益')
- plot_feature_importance(xgb_model.get_booster(), X_train, importance_type='weight', title='特征重要性 (Weight)', xlabel='权重')
- plot_feature_importance(xgb_model.get_booster(), X_train, importance_type='cover', title='特征重要性 (Cover)', xlabel='覆盖率')
- '''
-
- def plot_feature_distribution(df, feature_columns, bins=30, figsize=(18, 12)):
- """
- 根据传入的 DataFrame 和指定的特征列列表,绘制各个特征的分布图(直方图+核密度估计曲线),
- 以便观察数据的分布情况。
- 参数:
- df: pandas.DataFrame
- 包含数据的 DataFrame。
- feature_columns: list
- 要展示分布的特征列名称列表。
- bins: int
- 直方图的柱子数量,默认采用 30 个柱子。
- figsize: tuple
- 整体图形的尺寸,默认为 (18, 12)。
- 返回:
- None,函数会直接展示绘制的图形。
- """
- n_features = len(feature_columns)
- n_cols = 3 # 每行放置3个图
- n_rows = math.ceil(n_features / n_cols)
-
- plt.figure(figsize=figsize)
- for idx, col in enumerate(feature_columns):
- plt.subplot(n_rows, n_cols, idx + 1)
- # 去除空值后绘制直方图,并绘制核密度估计曲线
- data = df[col].dropna()
- sns.histplot(data, bins=bins, kde=True)
- plt.title(f"{col} 分布")
- plt.xlabel(col)
- plt.tight_layout()
- plt.show()
|