base
diff_aes_scores = (aes_scores - aes_scores_plain).reshape(-1, self.num_images_per_prompt).mean(1)
diff_aes_scores = diff_aes_scores * 100
original_clip_scores = clip_scores.reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores < 0.2, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
ratio = -0.5
R = torch.clamp(clip_scores + diff_aes_scores, min = 0.001)
R_withKL = ratio * kls + torch.log(total_score).to(accelerator.device)
diff_aes_scores = (aes_scores - aes_scores_plain).reshape(-1, self.num_images_per_prompt).mean(1)
diff_aes_scores = diff_aes_scores * 100
original_clip_scores = clip_scores.reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
ratio = -0.5
R = torch.clamp(clip_scores + diff_aes_scores, min = 0.001)
R_withKL = ratio * kls + torch.log(total_score).to(accelerator.device)
效果很不错
由于该模型reward model是lcm, PPO训练的模型是之前的CompVis/stable-diffusion-v1-4
对于GFN训练的模型,我存了中间几个checkpoint, 挑选好的checkpoint,将GFN和PPO训练的语言模型分别adapt上这两个reward model上做对比
这里数据我根据paper中的评测,采样了coco\diffusionDB\mc 三个设置下,每个数据集采样了200+条数据

这个PPO用的它放出来的模型(比自己复现的是好一些的)从图中可以看出,GFN效果上要好
case study
原句:A plate holds dinner including green beans and potatoes.
GFN改写:a plate with a green beans and a potato, a detailed painting by greg rutkowski and thomas kinkade, featured on pix
PPO改写:a plate holds dinner including green beans and potatoes, by greg rutkowski, artstation, digital art, 4 k, concept art, sharp
PPO image
