基于GAN提高非平衡COVID-19死亡率预测模型准确性( 二 )


该研究考虑了11个分类输入特征和2个数字输入特征 。 目标变量是死亡/恢复 。 已填充新列“ diff_sym_hos” , 以提供当天在医院发现和接受的症状之间的差异 。
研究的重点是改善少数类别数据 , 即死亡== 1 , 从训练数据中提取了一个子集 。 子集按类别和数字分开 , 并传递给GAN模型 。
df_minority_data=http://kandian.youth.cn/index/df.loc[df['death'] == 1]#Subsetting input features without target variabledf_minority_data_withouttv=df_minority_data.loc[:, df_minority_data.columns != 'death']numerical_df = df_minority_data_withouttv.select_dtypes("number")categorical_df = df_minority_data_withouttv.select_dtypes("object")scaling = MinMaxScaler()numerical_df_rescaled = scaling.fit_transform(numerical_df)get_dummy_df = pd.get_dummies(categorical_df)#Seperating Each Categorylocation_dummy_col = [col for col in get_dummy_df.columns if 'location' in col]location_dummy = get_dummy_df[location_dummy_col]country_dummy_col = [col for col in get_dummy_df.columns if 'country' in col]country_dummy = get_dummy_df[country_dummy_col]gender_dummy_col = [col for col in get_dummy_df.columns if 'gender' in col]gender_dummy = get_dummy_df[gender_dummy_col]vis_wuhan_dummy_col = [col for col in get_dummy_df.columns if 'vis_wuhan' in col]vis_wuhan_dummy = get_dummy_df[vis_wuhan_dummy_col]from_wuhan_dummy_col = [col for col in get_dummy_df.columns if 'from_wuhan' in col]from_wuhan_dummy = get_dummy_df[from_wuhan_dummy_col]symptom1_dummy_col = [col for col in get_dummy_df.columns if 'symptom1' in col]symptom1_dummy = get_dummy_df[symptom1_dummy_col]symptom2_dummy_col = [col for col in get_dummy_df.columns if 'symptom2' in col]symptom2_dummy = get_dummy_df[symptom2_dummy_col]symptom3_dummy_col = [col for col in get_dummy_df.columns if 'symptom3' in col]symptom3_dummy = get_dummy_df[symptom3_dummy_col]symptom4_dummy_col = [col for col in get_dummy_df.columns if 'symptom4' in col]symptom4_dummy = get_dummy_df[symptom4_dummy_col]symptom5_dummy_col = [col for col in get_dummy_df.columns if 'symptom5' in col]symptom5_dummy = get_dummy_df[symptom5_dummy_col]symptom6_dummy_col = [col for col in get_dummy_df.columns if 'symptom6' in col]symptom6_dummy = get_dummy_df[symptom6_dummy_col]
基于GAN提高非平衡COVID-19死亡率预测模型准确性文章插图
定义生成器生成器从潜在空间获取输入并生成新的合成样本 。 泄露修正线性单元(LeakyReLU)是在发生器和鉴别器模型中用于处理某些负值的函数 。
它使用默认建议值0.2和适当的权重初始化程序“ he_uniform”使用 。 此外 , 在不同的层之间使用批处理归一化来标准化来自先前层的激活(零均值和单位方差)并稳定训练过程 。
在输出层中 , softmax激活函数用于分类变量 , 而sigmoid 函数用于连续变量 。
def define_generator (catsh1,catsh2,catsh3,catsh4,catsh5,catsh6,catsh7,catsh8,catsh9,catsh10,catsh11,numerical):#Inputting noisefrom latent spacenoise = Input(shape = (70,))hidden_1 = Dense(8, kernel_initializer = "he_uniform")(noise)hidden_1 = LeakyReLU(0.2)(hidden_1)hidden_1 = BatchNormalization(momentum = 0.8)(hidden_1)hidden_2 = Dense(16, kernel_initializer = "he_uniform")(hidden_1)hidden_2 = LeakyReLU(0.2)(hidden_2)hidden_2 = BatchNormalization(momentum = 0.8)(hidden_2)#Branch 1 for generating location databranch_1 = Dense(32, kernel_initializer = "he_uniform")(hidden_2)branch_1 = LeakyReLU(0.2)(branch_1)branch_1 = BatchNormalization(momentum = 0.8)(branch_1)branch_1 = Dense(64, kernel_initializer = "he_uniform")(branch_1)branch_1 = LeakyReLU(0.2)(branch_1)branch_1 = BatchNormalization(momentum=0.8)(branch_1)#Output Layer1branch_1_output = Dense(catsh1, activation = "softmax")(branch_1)#Likewise, for all remaining 10 categories branches will be defined#Branch 12 for generating numerical databranch_12 = Dense(64, kernel_initializer = "he_uniform")(hidden_2)branch_12 = LeakyReLU(0.2)(branch_3)branch_12 = BatchNormalization(momentum=0.8)(branch_12)branch_12 = Dense(128, kernel_initializer = "he_uniform")(branch_12)branch_12 = LeakyReLU(0.2)(branch_12)branch_12 = BatchNormalization(momentum=0.8)(branch_12)#Output Layer12branch_12_output = Dense(numerical, activation = "sigmoid")(branch_12)#Combined outputcombined_output = concatenate([branch_1_output, branch_2_output, branch_3_output,branch_4_output,branch_5_output,branch_6_output,branch_7_output,branch_8_output,branch_9_output,branch_10_output,branch_11_output,branch_12_output])#Return modelreturn Model(inputs = noise, outputs = combined_output)generator = define_generator(location_dummy.shape[1],country_dummy.shape[1],gender_dummy.shape[1],vis_wuhan_dummy.shape[1],from_wuhan_dummy.shape[1],symptom1_dummy.shape[1],symptom2_dummy.shape[1],symptom3_dummy.shape[1],symptom4_dummy.shape[1],symptom5_dummy.shape[1],symptom6_dummy.shape[1],numerical_df_rescaled.shape[1])generator.summary()定义鉴别器鉴别器模型将从我们的数据(例如矢量)中获取样本 , 并输出关于样本是真实还是假的分类预测 。 这是一个二进制分类问题 , 因此在输出层中使用sigmoid 激活函数 , 在模型编译中使用二进制交叉熵损失函数 。 使用学习率LR为0.0002且建议的beta1动量值为0.5的Adam优化算法 。